b8c760d004
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
214 lines
8.5 KiB
Python
214 lines
8.5 KiB
Python
"""Experimentos GLiNER + GLiREL — corpus EN/ES, barridos de threshold/labels/top_k.
|
|
|
|
Ejecutar con el venv del analysis: ./.venv/bin/python3 run_experiments.py
|
|
|
|
Genera:
|
|
- results.json (todos los experimentos, listos para tablas/plots)
|
|
- notebooks/01_gliner_glirel_tuning.ipynb (rebuild con outputs)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from pathlib import Path
|
|
|
|
warnings.filterwarnings("ignore")
|
|
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
|
|
|
|
HERE = Path(__file__).resolve().parent
|
|
REGISTRY_ROOT = Path(os.environ.get("FN_REGISTRY_ROOT", "/home/lucas/fn_registry"))
|
|
sys.path.insert(0, str(REGISTRY_ROOT / "python" / "functions"))
|
|
|
|
from datascience.gliner_load_model import gliner_load_model
|
|
from datascience.glirel_load_model import glirel_load_model
|
|
|
|
CORPUS = {
|
|
"es_corporate": (
|
|
"Pablo Isla, expresidente de Inditex, ha sido nombrado consejero de Telefonica. "
|
|
"La operacion fue anunciada por el presidente Jose Maria Alvarez-Pallete en Madrid el pasado lunes. "
|
|
"Inditex factura mas de 30.000 millones anuales y tiene su sede en Arteixo, A Coruna."
|
|
),
|
|
"en_corporate": (
|
|
"Pablo Isla, the former chairman of Inditex, has been appointed as a director of Telefonica. "
|
|
"The announcement was made by Jose Maria Alvarez-Pallete, the chairman of Telefonica, in Madrid last Monday. "
|
|
"Inditex has its headquarters in Arteixo, A Coruna."
|
|
),
|
|
"en_osint": (
|
|
"On 2024-08-15, attacker IP 185.220.101.45 connected to victim host 10.0.5.22 over TLS. "
|
|
"Reverse DNS pointed to tor-exit-relay-3.onionrouter.net. Operator handle @phantomzero claimed responsibility on a forum. "
|
|
"The C2 panel was hosted on hxxps://malwareops[.]biz/control behind Cloudflare."
|
|
),
|
|
"es_journalism": (
|
|
"Iberdrola y Endesa firmaron un acuerdo de colaboracion en proyectos eolicos en Galicia. "
|
|
"El presidente de Iberdrola, Ignacio Galan, se reunio con la CEO de Endesa, Marina Serrano, en Bilbao. "
|
|
"El acuerdo movilizara 2.000 millones de euros en cinco anos."
|
|
),
|
|
}
|
|
|
|
ENTITY_LABELS = {
|
|
"generic_en": ["person", "organization", "location"],
|
|
"generic_es": ["persona", "organizacion", "lugar"],
|
|
"specific_en": ["executive", "company", "city", "country"],
|
|
"osint_en": ["ip_address", "domain", "url", "username", "date", "person", "organization"],
|
|
}
|
|
|
|
RELATION_LABELS = {
|
|
"snake_short": ["works_at", "located_in", "appointed_as", "headquartered_in", "ceo_of"],
|
|
"natural_long": [
|
|
"person works at organization",
|
|
"organization is located in location",
|
|
"person appointed as role at organization",
|
|
"organization headquartered in location",
|
|
"person is ceo of organization",
|
|
],
|
|
}
|
|
|
|
|
|
def _ensure_models():
|
|
"""Loads (or returns cached) GLiNER + GLiREL."""
|
|
t0 = time.time()
|
|
print(f"[load] GLiNER...")
|
|
gliner = gliner_load_model()
|
|
print(f"[load] GLiNER ready in {time.time()-t0:.1f}s")
|
|
t0 = time.time()
|
|
print(f"[load] GLiREL...")
|
|
glirel = glirel_load_model()
|
|
print(f"[load] GLiREL ready in {time.time()-t0:.1f}s")
|
|
return gliner, glirel
|
|
|
|
|
|
def gliner_threshold_sweep(gliner) -> dict:
|
|
"""Para cada (corpus, label_set, threshold) → (n_entidades, ents_list)."""
|
|
out = {}
|
|
thresholds = [0.1, 0.3, 0.5, 0.7, 0.9]
|
|
for corpus_key, text in CORPUS.items():
|
|
out[corpus_key] = {}
|
|
# pick label set per corpus
|
|
if corpus_key.startswith("es_"):
|
|
label_set_keys = ["generic_en", "generic_es"]
|
|
elif corpus_key == "en_osint":
|
|
label_set_keys = ["generic_en", "osint_en"]
|
|
else:
|
|
label_set_keys = ["generic_en", "specific_en"]
|
|
for ls_key in label_set_keys:
|
|
labels = ENTITY_LABELS[ls_key]
|
|
out[corpus_key][ls_key] = {}
|
|
# one base call at threshold 0.0 to get raw scores
|
|
base = gliner.predict_entities(text, labels, threshold=0.0)
|
|
# (text, label, score, start, end)
|
|
scored = [(e["text"], e["label"], float(e["score"]), e["start"], e["end"]) for e in base]
|
|
out[corpus_key][ls_key]["scored_at_t0"] = scored
|
|
for t in thresholds:
|
|
kept = [e for e in scored if e[2] >= t]
|
|
out[corpus_key][ls_key][f"t={t}"] = kept
|
|
return out
|
|
|
|
|
|
def glirel_score_distribution(gliner, glirel) -> dict:
|
|
"""Para cada (corpus, relation_label_style) → relations a threshold=0, top_k=5."""
|
|
out = {}
|
|
for corpus_key, text in CORPUS.items():
|
|
out[corpus_key] = {}
|
|
# entities baseline at threshold 0.5
|
|
labels_for_ents = ENTITY_LABELS["generic_es"] if corpus_key.startswith("es_") else ENTITY_LABELS["generic_en"]
|
|
ents = gliner.predict_entities(text, labels_for_ents, threshold=0.5)
|
|
if len(ents) < 2:
|
|
out[corpus_key]["entities"] = []
|
|
out[corpus_key]["note"] = "too few entities"
|
|
continue
|
|
out[corpus_key]["entities"] = [(e["text"], e["label"], round(e["score"], 3)) for e in ents]
|
|
# tokenize text
|
|
tokens = text.split()
|
|
# build ner spans (rough token alignment by char position → token)
|
|
ner = []
|
|
for e in ents:
|
|
pre = text[: e["start"]]
|
|
start_tok = len(pre.split())
|
|
end_tok = start_tok + len(e["text"].split())
|
|
if start_tok < end_tok:
|
|
ner.append([start_tok, end_tok, e["label"]])
|
|
out[corpus_key]["ner"] = ner
|
|
# ── For each relation label style, predict
|
|
out[corpus_key]["styles"] = {}
|
|
for style_key, rel_labels in RELATION_LABELS.items():
|
|
try:
|
|
raw = glirel.predict_relations(
|
|
tokens, labels=list(rel_labels), threshold=0.0, ner=ner, top_k=5
|
|
)
|
|
rels = [
|
|
{
|
|
"label": r.get("label", ""),
|
|
"score": round(float(r.get("score", 0.0)), 4),
|
|
"head_text": " ".join(r.get("head_text", [])),
|
|
"tail_text": " ".join(r.get("tail_text", [])),
|
|
}
|
|
for r in raw
|
|
]
|
|
# sort by score desc
|
|
rels.sort(key=lambda x: x["score"], reverse=True)
|
|
out[corpus_key]["styles"][style_key] = rels
|
|
except Exception as exc:
|
|
out[corpus_key]["styles"][style_key] = {"error": str(exc)}
|
|
return out
|
|
|
|
|
|
def glirel_topk_sweep(gliner, glirel) -> dict:
|
|
"""Sobre 1 corpus EN, varia top_k ∈ {1, 3, 5, 10}, threshold=0.0."""
|
|
text = CORPUS["en_corporate"]
|
|
ents = gliner.predict_entities(text, ENTITY_LABELS["generic_en"], threshold=0.5)
|
|
tokens = text.split()
|
|
ner = []
|
|
for e in ents:
|
|
pre = text[: e["start"]]
|
|
start_tok = len(pre.split())
|
|
end_tok = start_tok + len(e["text"].split())
|
|
if start_tok < end_tok:
|
|
ner.append([start_tok, end_tok, e["label"]])
|
|
out = {"entities": [(e["text"], e["label"]) for e in ents], "ner": ner, "by_topk": {}}
|
|
for topk in [1, 3, 5, 10]:
|
|
raw = glirel.predict_relations(
|
|
tokens, labels=RELATION_LABELS["snake_short"], threshold=0.0, ner=ner, top_k=topk
|
|
)
|
|
rels = [
|
|
{
|
|
"label": r.get("label", ""),
|
|
"score": round(float(r.get("score", 0.0)), 4),
|
|
"head": " ".join(r.get("head_text", [])),
|
|
"tail": " ".join(r.get("tail_text", [])),
|
|
}
|
|
for r in raw
|
|
]
|
|
rels.sort(key=lambda x: x["score"], reverse=True)
|
|
out["by_topk"][f"top_k={topk}"] = rels
|
|
return out
|
|
|
|
|
|
def main():
|
|
gliner, glirel = _ensure_models()
|
|
print("\n=== GLINER threshold sweep ===")
|
|
gliner_results = gliner_threshold_sweep(gliner)
|
|
print("\n=== GLIREL score distribution ===")
|
|
glirel_results = glirel_score_distribution(gliner, glirel)
|
|
print("\n=== GLIREL top_k sweep ===")
|
|
topk_results = glirel_topk_sweep(gliner, glirel)
|
|
results = {
|
|
"gliner_threshold_sweep": gliner_results,
|
|
"glirel_score_distribution": glirel_results,
|
|
"glirel_topk_sweep": topk_results,
|
|
"corpus": CORPUS,
|
|
"entity_labels": ENTITY_LABELS,
|
|
"relation_labels": RELATION_LABELS,
|
|
}
|
|
out_path = HERE / "results.json"
|
|
out_path.write_text(json.dumps(results, indent=2, ensure_ascii=False))
|
|
print(f"\n[done] {out_path}")
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|