Files
gliner_glirel_tuning/run_improvements.py
T
2026-05-04 23:44:11 +02:00

328 lines
12 KiB
Python

"""Bateria de experimentos comparando configuraciones de GLiNER2 sobre el PDF.
Vuelca a improvements.json para que build_notebook_improvements.py construya
el notebook con outputs estaticos (sin volver a cargar el modelo).
"""
from __future__ import annotations
import gc
import json
import os
import re
import sys
import time
import warnings
from collections import Counter, defaultdict
from pathlib import Path
warnings.filterwarnings("ignore")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
HERE = Path(__file__).resolve().parent
_pf = "/home/lucas/fn_registry/python/functions"
sys.path = [p for p in sys.path if not p.startswith(_pf + "/")]
if _pf not in sys.path:
sys.path.insert(0, _pf)
from gliner2 import GLiNER2
from core.extract_pdf_text import extract_pdf_text
VAULT = Path("/home/lucas/vaults/osint_nlp_models")
PDF_PATH = VAULT / "test_documents" / "politica_proteccion_datos.pdf"
def clean_pdf_text(text: str) -> str:
text = re.sub(r"\b\d{1,2}/\d{1,2}\b", " ", text)
text = text.replace("\t", " ")
text = re.sub(r"-\s*\n\s*", "", text)
text = re.sub(r"(?<![\.!?])\n+", " ", text)
text = re.sub(r" {2,}", " ", text)
text = "\n".join(line.strip() for line in text.split("\n") if line.strip())
return text.strip()
def chunk_with_overlap(text: str, max_chars: int = 1500, overlap_sentences: int = 2):
sentences = re.split(r"(?<=[\.!?])\s+", text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
i = 0
while i < len(sentences):
current_sents: list[str] = []
current_len = 0
# Tenta cargar overlap del chunk anterior, pero solo si dejamos espacio
# para al menos UNA frase nueva (evita bucle infinito con frases largas).
if chunks and overlap_sentences > 0:
prev_sents = chunks[-1]["sentences"][-overlap_sentences:]
overlap_len = sum(len(s) + 1 for s in prev_sents)
next_sentence_len = len(sentences[i]) + 1
if overlap_len + next_sentence_len <= max_chars:
current_sents = list(prev_sents)
current_len = overlap_len
# Avance forzado: meter al menos una frase aunque exceda max_chars.
if i < len(sentences):
current_sents.append(sentences[i])
current_len += len(sentences[i]) + 1
i += 1
# Anadir mas frases hasta llenar
while i < len(sentences) and current_len + len(sentences[i]) + 1 <= max_chars:
current_sents.append(sentences[i])
current_len += len(sentences[i]) + 1
i += 1
chunks.append({"text": " ".join(current_sents), "sentences": current_sents})
return chunks
def aggregate(extract_results):
all_ents: dict = {}
all_rels: Counter = Counter()
for r in extract_results:
for typ, names in r.get("entities", {}).items():
for n in names:
n_clean = n.strip()
if not n_clean:
continue
key = (typ, n_clean.lower())
if key not in all_ents:
all_ents[key] = {"type": typ, "name": n_clean, "count": 0}
all_ents[key]["count"] += 1
for rt, pairs in r.get("relation_extraction", {}).items():
for h, t in pairs:
all_rels[(h.strip(), rt, t.strip())] += 1
return all_ents, all_rels
def graph_stats(ents_dict, rels_counter):
nodes = set()
for v in ents_dict.values():
nodes.add(v["name"])
edges = set()
for (h, rt, t), c in rels_counter.items():
nodes.add(h); nodes.add(t)
edges.add((h, t, rt))
has_edge = set()
for h, t, rt in edges:
has_edge.add(h); has_edge.add(t)
isolates = nodes - has_edge
return {
"n_ents": len(ents_dict),
"n_rels": len(rels_counter),
"n_nodes": len(nodes),
"n_edges": len(edges),
"n_isolates": len(isolates),
"connected": len(nodes) - len(isolates),
"connect_pct": round((len(nodes) - len(isolates)) / max(1, len(nodes)) * 100, 1),
}
def normalize_name(s: str) -> str:
s = s.strip()
s = re.sub(r"[\.,;:\"'`()\[\]]", "", s)
s = re.sub(r"\s+", " ", s)
return s.strip().lower()
def merge_aliases(ents_dict, rels_counter):
norm_groups: dict = defaultdict(list)
for v in ents_dict.values():
norm_groups[normalize_name(v["name"])].append(v)
canonical: dict = {}
canonical_data: dict = {}
for nrm, group in norm_groups.items():
winner = max(group, key=lambda v: v["count"])
canonical[nrm] = winner["name"]
canonical_data[winner["name"]] = {
"type": winner["type"],
"name": winner["name"],
"count": sum(v["count"] for v in group),
"aliases": [v["name"] for v in group if v["name"] != winner["name"]],
}
canon_names = sorted(canonical_data.keys(), key=len, reverse=True)
absorbed: dict = {}
for long_n in canon_names:
long_norm = normalize_name(long_n)
long_type = canonical_data[long_n]["type"]
for short_n in canon_names:
if short_n == long_n or short_n in absorbed:
continue
short_norm = normalize_name(short_n)
if len(short_norm) < 4:
continue
short_type = canonical_data[short_n]["type"]
if short_type != long_type:
continue
if re.search(r"\b" + re.escape(short_norm) + r"\b", long_norm):
absorbed[short_n] = long_n
canonical_data[long_n]["count"] += canonical_data[short_n]["count"]
canonical_data[long_n]["aliases"].append(short_n)
canonical_data[long_n]["aliases"].extend(canonical_data[short_n].get("aliases", []))
for short_n in list(absorbed):
canonical_data.pop(short_n, None)
def resolve(name):
nrm = normalize_name(name)
c = canonical.get(nrm, name)
return absorbed.get(c, c)
new_rels: Counter = Counter()
for (h, rt, t), c in rels_counter.items():
h_canon = resolve(h)
t_canon = resolve(t)
if h_canon == t_canon:
continue
new_rels[(h_canon, rt, t_canon)] += c
return canonical_data, new_rels, absorbed
ENTITY_LABELS = ["person", "organization", "location", "email", "right", "data_category", "authority", "law"]
RELATION_LABELS_FLAT = [
"located_in", "governed_by", "subject_to", "protected_by",
"contact_for", "rights_against", "subsidiary_of", "controlled_by",
]
RELATION_LABELS_DESC = {
"located_in": "organization or person is located in a place or address",
"governed_by": "entity is governed or supervised by an authority or law",
"subject_to": "data category or process is subject to a law or regulation",
"protected_by": "right or data is protected by a law or authority",
"contact_for": "email or address is the contact channel for an authority or right",
"rights_against": "person has rights to exercise against an organization",
"subsidiary_of": "organization is a subsidiary of a parent organization",
"controlled_by": "organization or data is controlled by another organization",
}
def main():
out: dict = {}
# --- prepare text + chunks (CPU only)
print("[prep] extract + clean + chunk...")
raw = extract_pdf_text(str(PDF_PATH))
clean = clean_pdf_text(raw)
chunks = chunk_with_overlap(clean, max_chars=1500, overlap_sentences=2)
chunks_no_overlap = chunk_with_overlap(clean, max_chars=1500, overlap_sentences=0)
out["meta"] = {
"raw_chars": len(raw),
"clean_chars": len(clean),
"n_chunks_overlap": len(chunks),
"n_chunks_no_overlap": len(chunks_no_overlap),
"first_clean_600": clean[:600],
}
print(f" raw {len(raw):,} → clean {len(clean):,}{len(chunks)} chunks (overlap=2)")
print("[load] GLiNER2...")
t0 = time.time()
model = GLiNER2.from_pretrained("fastino/gliner2-large-v1")
print(f" load: {time.time()-t0:.1f}s")
schema_flat = model.create_schema().entities(ENTITY_LABELS).relations(RELATION_LABELS_FLAT)
schema_desc = model.create_schema().entities(ENTITY_LABELS).relations(RELATION_LABELS_DESC)
configs: list = []
# A: t=0.5 flat loop
print("[A] t=0.5 flat loop...")
t0 = time.time()
res_a = [model.extract(c["text"], schema=schema_flat, threshold=0.5) for c in chunks]
elapsed_a = time.time() - t0
ents_a, rels_a = aggregate(res_a)
configs.append({"name": "A: t=0.5 flat loop", "elapsed": round(elapsed_a, 1),
"stats": graph_stats(ents_a, rels_a)})
del res_a; gc.collect()
print(f" {elapsed_a:.1f}s stats={configs[-1]['stats']}")
# B: t=0.3 flat loop
print("[B] t=0.3 flat loop...")
t0 = time.time()
res_b = [model.extract(c["text"], schema=schema_flat, threshold=0.3) for c in chunks]
elapsed_b = time.time() - t0
ents_b, rels_b = aggregate(res_b)
configs.append({"name": "B: t=0.3 flat loop", "elapsed": round(elapsed_b, 1),
"stats": graph_stats(ents_b, rels_b)})
del res_b; gc.collect()
print(f" {elapsed_b:.1f}s stats={configs[-1]['stats']}")
# C: t=0.2 flat loop
print("[C] t=0.2 flat loop...")
t0 = time.time()
res_c = [model.extract(c["text"], schema=schema_flat, threshold=0.2) for c in chunks]
elapsed_c = time.time() - t0
ents_c, rels_c = aggregate(res_c)
configs.append({"name": "C: t=0.2 flat loop", "elapsed": round(elapsed_c, 1),
"stats": graph_stats(ents_c, rels_c)})
del res_c; gc.collect()
print(f" {elapsed_c:.1f}s stats={configs[-1]['stats']}")
# D: t=0.3 desc loop
print("[D] t=0.3 desc loop...")
t0 = time.time()
res_d = [model.extract(c["text"], schema=schema_desc, threshold=0.3) for c in chunks]
elapsed_d = time.time() - t0
ents_d, rels_d = aggregate(res_d)
configs.append({"name": "D: t=0.3 desc loop", "elapsed": round(elapsed_d, 1),
"stats": graph_stats(ents_d, rels_d)})
del res_d; gc.collect()
print(f" {elapsed_d:.1f}s stats={configs[-1]['stats']}")
# E: t=0.3 desc batch_extract
print("[E] t=0.3 desc batch_extract...")
t0 = time.time()
texts = [c["text"] for c in chunks]
res_e = model.batch_extract(texts, schemas=schema_desc, batch_size=8, threshold=0.3)
elapsed_e = time.time() - t0
ents_e, rels_e = aggregate(res_e)
configs.append({"name": "E: t=0.3 desc batch", "elapsed": round(elapsed_e, 1),
"stats": graph_stats(ents_e, rels_e)})
print(f" {elapsed_e:.1f}s stats={configs[-1]['stats']}")
out["configs"] = configs
# --- coreference sobre la mejor config (E) ---
print("[coref] applying alias merge to config E...")
t0 = time.time()
ents_merged, rels_merged, absorbed = merge_aliases(ents_e, rels_e)
ents_merged_dict = {(v["type"], v["name"].lower()): v for v in ents_merged.values()}
stats_post = graph_stats(ents_merged_dict, rels_merged)
elapsed_coref = time.time() - t0
out["coref"] = {
"elapsed": round(elapsed_coref, 2),
"pre_stats": graph_stats(ents_e, rels_e),
"post_stats": stats_post,
"n_absorbed": len(absorbed),
"absorbed_sample": list(absorbed.items())[:8],
}
print(f" pre: {out['coref']['pre_stats']}")
print(f" post: {out['coref']['post_stats']}")
print(f" absorbed: {len(absorbed)} e.g. {list(absorbed.items())[:3]}")
# --- top entities post-coref ---
top_rows = []
for v in sorted(ents_merged.values(), key=lambda x: -x["count"])[:25]:
top_rows.append({
"type": v["type"],
"canonical": v["name"],
"mentions": v["count"],
"n_aliases": len(v.get("aliases", [])),
"aliases_sample": v.get("aliases", [])[:3],
})
out["top_entities_post_coref"] = top_rows
# --- relations top ---
top_rels = []
for (h, rt, t), c in sorted(rels_merged.items(), key=lambda x: -x[1])[:25]:
top_rels.append({"from": h, "kind": rt, "to": t, "count": c})
out["top_relations_post_coref"] = top_rels
# --- save ents_merged + rels_merged for graph rendering ---
out["ents_merged"] = [{"name": v["name"], "type": v["type"], "count": v["count"]}
for v in ents_merged.values()]
out["rels_merged"] = [{"from": h, "kind": rt, "to": t, "count": c}
for (h, rt, t), c in rels_merged.items()]
out_path = HERE / "improvements.json"
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False))
print(f"\n[saved] {out_path} ({out_path.stat().st_size:,} bytes)")
if __name__ == "__main__":
main()