b8c760d004
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
328 lines
12 KiB
Python
328 lines
12 KiB
Python
"""Bateria de experimentos comparando configuraciones de GLiNER2 sobre el PDF.
|
|
|
|
Vuelca a improvements.json para que build_notebook_improvements.py construya
|
|
el notebook con outputs estaticos (sin volver a cargar el modelo).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import gc
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from collections import Counter, defaultdict
|
|
from pathlib import Path
|
|
|
|
warnings.filterwarnings("ignore")
|
|
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
|
|
HERE = Path(__file__).resolve().parent
|
|
_pf = "/home/lucas/fn_registry/python/functions"
|
|
sys.path = [p for p in sys.path if not p.startswith(_pf + "/")]
|
|
if _pf not in sys.path:
|
|
sys.path.insert(0, _pf)
|
|
|
|
from gliner2 import GLiNER2
|
|
from core.extract_pdf_text import extract_pdf_text
|
|
|
|
|
|
VAULT = Path("/home/lucas/vaults/osint_nlp_models")
|
|
PDF_PATH = VAULT / "test_documents" / "politica_proteccion_datos.pdf"
|
|
|
|
|
|
def clean_pdf_text(text: str) -> str:
|
|
text = re.sub(r"\b\d{1,2}/\d{1,2}\b", " ", text)
|
|
text = text.replace("\t", " ")
|
|
text = re.sub(r"-\s*\n\s*", "", text)
|
|
text = re.sub(r"(?<![\.!?])\n+", " ", text)
|
|
text = re.sub(r" {2,}", " ", text)
|
|
text = "\n".join(line.strip() for line in text.split("\n") if line.strip())
|
|
return text.strip()
|
|
|
|
|
|
def chunk_with_overlap(text: str, max_chars: int = 1500, overlap_sentences: int = 2):
|
|
sentences = re.split(r"(?<=[\.!?])\s+", text)
|
|
sentences = [s.strip() for s in sentences if s.strip()]
|
|
chunks = []
|
|
i = 0
|
|
while i < len(sentences):
|
|
current_sents: list[str] = []
|
|
current_len = 0
|
|
# Tenta cargar overlap del chunk anterior, pero solo si dejamos espacio
|
|
# para al menos UNA frase nueva (evita bucle infinito con frases largas).
|
|
if chunks and overlap_sentences > 0:
|
|
prev_sents = chunks[-1]["sentences"][-overlap_sentences:]
|
|
overlap_len = sum(len(s) + 1 for s in prev_sents)
|
|
next_sentence_len = len(sentences[i]) + 1
|
|
if overlap_len + next_sentence_len <= max_chars:
|
|
current_sents = list(prev_sents)
|
|
current_len = overlap_len
|
|
# Avance forzado: meter al menos una frase aunque exceda max_chars.
|
|
if i < len(sentences):
|
|
current_sents.append(sentences[i])
|
|
current_len += len(sentences[i]) + 1
|
|
i += 1
|
|
# Anadir mas frases hasta llenar
|
|
while i < len(sentences) and current_len + len(sentences[i]) + 1 <= max_chars:
|
|
current_sents.append(sentences[i])
|
|
current_len += len(sentences[i]) + 1
|
|
i += 1
|
|
chunks.append({"text": " ".join(current_sents), "sentences": current_sents})
|
|
return chunks
|
|
|
|
|
|
def aggregate(extract_results):
|
|
all_ents: dict = {}
|
|
all_rels: Counter = Counter()
|
|
for r in extract_results:
|
|
for typ, names in r.get("entities", {}).items():
|
|
for n in names:
|
|
n_clean = n.strip()
|
|
if not n_clean:
|
|
continue
|
|
key = (typ, n_clean.lower())
|
|
if key not in all_ents:
|
|
all_ents[key] = {"type": typ, "name": n_clean, "count": 0}
|
|
all_ents[key]["count"] += 1
|
|
for rt, pairs in r.get("relation_extraction", {}).items():
|
|
for h, t in pairs:
|
|
all_rels[(h.strip(), rt, t.strip())] += 1
|
|
return all_ents, all_rels
|
|
|
|
|
|
def graph_stats(ents_dict, rels_counter):
|
|
nodes = set()
|
|
for v in ents_dict.values():
|
|
nodes.add(v["name"])
|
|
edges = set()
|
|
for (h, rt, t), c in rels_counter.items():
|
|
nodes.add(h); nodes.add(t)
|
|
edges.add((h, t, rt))
|
|
has_edge = set()
|
|
for h, t, rt in edges:
|
|
has_edge.add(h); has_edge.add(t)
|
|
isolates = nodes - has_edge
|
|
return {
|
|
"n_ents": len(ents_dict),
|
|
"n_rels": len(rels_counter),
|
|
"n_nodes": len(nodes),
|
|
"n_edges": len(edges),
|
|
"n_isolates": len(isolates),
|
|
"connected": len(nodes) - len(isolates),
|
|
"connect_pct": round((len(nodes) - len(isolates)) / max(1, len(nodes)) * 100, 1),
|
|
}
|
|
|
|
|
|
def normalize_name(s: str) -> str:
|
|
s = s.strip()
|
|
s = re.sub(r"[\.,;:\"'`()\[\]]", "", s)
|
|
s = re.sub(r"\s+", " ", s)
|
|
return s.strip().lower()
|
|
|
|
|
|
def merge_aliases(ents_dict, rels_counter):
|
|
norm_groups: dict = defaultdict(list)
|
|
for v in ents_dict.values():
|
|
norm_groups[normalize_name(v["name"])].append(v)
|
|
canonical: dict = {}
|
|
canonical_data: dict = {}
|
|
for nrm, group in norm_groups.items():
|
|
winner = max(group, key=lambda v: v["count"])
|
|
canonical[nrm] = winner["name"]
|
|
canonical_data[winner["name"]] = {
|
|
"type": winner["type"],
|
|
"name": winner["name"],
|
|
"count": sum(v["count"] for v in group),
|
|
"aliases": [v["name"] for v in group if v["name"] != winner["name"]],
|
|
}
|
|
canon_names = sorted(canonical_data.keys(), key=len, reverse=True)
|
|
absorbed: dict = {}
|
|
for long_n in canon_names:
|
|
long_norm = normalize_name(long_n)
|
|
long_type = canonical_data[long_n]["type"]
|
|
for short_n in canon_names:
|
|
if short_n == long_n or short_n in absorbed:
|
|
continue
|
|
short_norm = normalize_name(short_n)
|
|
if len(short_norm) < 4:
|
|
continue
|
|
short_type = canonical_data[short_n]["type"]
|
|
if short_type != long_type:
|
|
continue
|
|
if re.search(r"\b" + re.escape(short_norm) + r"\b", long_norm):
|
|
absorbed[short_n] = long_n
|
|
canonical_data[long_n]["count"] += canonical_data[short_n]["count"]
|
|
canonical_data[long_n]["aliases"].append(short_n)
|
|
canonical_data[long_n]["aliases"].extend(canonical_data[short_n].get("aliases", []))
|
|
for short_n in list(absorbed):
|
|
canonical_data.pop(short_n, None)
|
|
|
|
def resolve(name):
|
|
nrm = normalize_name(name)
|
|
c = canonical.get(nrm, name)
|
|
return absorbed.get(c, c)
|
|
|
|
new_rels: Counter = Counter()
|
|
for (h, rt, t), c in rels_counter.items():
|
|
h_canon = resolve(h)
|
|
t_canon = resolve(t)
|
|
if h_canon == t_canon:
|
|
continue
|
|
new_rels[(h_canon, rt, t_canon)] += c
|
|
return canonical_data, new_rels, absorbed
|
|
|
|
|
|
ENTITY_LABELS = ["person", "organization", "location", "email", "right", "data_category", "authority", "law"]
|
|
|
|
RELATION_LABELS_FLAT = [
|
|
"located_in", "governed_by", "subject_to", "protected_by",
|
|
"contact_for", "rights_against", "subsidiary_of", "controlled_by",
|
|
]
|
|
RELATION_LABELS_DESC = {
|
|
"located_in": "organization or person is located in a place or address",
|
|
"governed_by": "entity is governed or supervised by an authority or law",
|
|
"subject_to": "data category or process is subject to a law or regulation",
|
|
"protected_by": "right or data is protected by a law or authority",
|
|
"contact_for": "email or address is the contact channel for an authority or right",
|
|
"rights_against": "person has rights to exercise against an organization",
|
|
"subsidiary_of": "organization is a subsidiary of a parent organization",
|
|
"controlled_by": "organization or data is controlled by another organization",
|
|
}
|
|
|
|
|
|
def main():
|
|
out: dict = {}
|
|
|
|
# --- prepare text + chunks (CPU only)
|
|
print("[prep] extract + clean + chunk...")
|
|
raw = extract_pdf_text(str(PDF_PATH))
|
|
clean = clean_pdf_text(raw)
|
|
chunks = chunk_with_overlap(clean, max_chars=1500, overlap_sentences=2)
|
|
chunks_no_overlap = chunk_with_overlap(clean, max_chars=1500, overlap_sentences=0)
|
|
out["meta"] = {
|
|
"raw_chars": len(raw),
|
|
"clean_chars": len(clean),
|
|
"n_chunks_overlap": len(chunks),
|
|
"n_chunks_no_overlap": len(chunks_no_overlap),
|
|
"first_clean_600": clean[:600],
|
|
}
|
|
print(f" raw {len(raw):,} → clean {len(clean):,} → {len(chunks)} chunks (overlap=2)")
|
|
|
|
print("[load] GLiNER2...")
|
|
t0 = time.time()
|
|
model = GLiNER2.from_pretrained("fastino/gliner2-large-v1")
|
|
print(f" load: {time.time()-t0:.1f}s")
|
|
|
|
schema_flat = model.create_schema().entities(ENTITY_LABELS).relations(RELATION_LABELS_FLAT)
|
|
schema_desc = model.create_schema().entities(ENTITY_LABELS).relations(RELATION_LABELS_DESC)
|
|
|
|
configs: list = []
|
|
|
|
# A: t=0.5 flat loop
|
|
print("[A] t=0.5 flat loop...")
|
|
t0 = time.time()
|
|
res_a = [model.extract(c["text"], schema=schema_flat, threshold=0.5) for c in chunks]
|
|
elapsed_a = time.time() - t0
|
|
ents_a, rels_a = aggregate(res_a)
|
|
configs.append({"name": "A: t=0.5 flat loop", "elapsed": round(elapsed_a, 1),
|
|
"stats": graph_stats(ents_a, rels_a)})
|
|
del res_a; gc.collect()
|
|
print(f" {elapsed_a:.1f}s stats={configs[-1]['stats']}")
|
|
|
|
# B: t=0.3 flat loop
|
|
print("[B] t=0.3 flat loop...")
|
|
t0 = time.time()
|
|
res_b = [model.extract(c["text"], schema=schema_flat, threshold=0.3) for c in chunks]
|
|
elapsed_b = time.time() - t0
|
|
ents_b, rels_b = aggregate(res_b)
|
|
configs.append({"name": "B: t=0.3 flat loop", "elapsed": round(elapsed_b, 1),
|
|
"stats": graph_stats(ents_b, rels_b)})
|
|
del res_b; gc.collect()
|
|
print(f" {elapsed_b:.1f}s stats={configs[-1]['stats']}")
|
|
|
|
# C: t=0.2 flat loop
|
|
print("[C] t=0.2 flat loop...")
|
|
t0 = time.time()
|
|
res_c = [model.extract(c["text"], schema=schema_flat, threshold=0.2) for c in chunks]
|
|
elapsed_c = time.time() - t0
|
|
ents_c, rels_c = aggregate(res_c)
|
|
configs.append({"name": "C: t=0.2 flat loop", "elapsed": round(elapsed_c, 1),
|
|
"stats": graph_stats(ents_c, rels_c)})
|
|
del res_c; gc.collect()
|
|
print(f" {elapsed_c:.1f}s stats={configs[-1]['stats']}")
|
|
|
|
# D: t=0.3 desc loop
|
|
print("[D] t=0.3 desc loop...")
|
|
t0 = time.time()
|
|
res_d = [model.extract(c["text"], schema=schema_desc, threshold=0.3) for c in chunks]
|
|
elapsed_d = time.time() - t0
|
|
ents_d, rels_d = aggregate(res_d)
|
|
configs.append({"name": "D: t=0.3 desc loop", "elapsed": round(elapsed_d, 1),
|
|
"stats": graph_stats(ents_d, rels_d)})
|
|
del res_d; gc.collect()
|
|
print(f" {elapsed_d:.1f}s stats={configs[-1]['stats']}")
|
|
|
|
# E: t=0.3 desc batch_extract
|
|
print("[E] t=0.3 desc batch_extract...")
|
|
t0 = time.time()
|
|
texts = [c["text"] for c in chunks]
|
|
res_e = model.batch_extract(texts, schemas=schema_desc, batch_size=8, threshold=0.3)
|
|
elapsed_e = time.time() - t0
|
|
ents_e, rels_e = aggregate(res_e)
|
|
configs.append({"name": "E: t=0.3 desc batch", "elapsed": round(elapsed_e, 1),
|
|
"stats": graph_stats(ents_e, rels_e)})
|
|
print(f" {elapsed_e:.1f}s stats={configs[-1]['stats']}")
|
|
out["configs"] = configs
|
|
|
|
# --- coreference sobre la mejor config (E) ---
|
|
print("[coref] applying alias merge to config E...")
|
|
t0 = time.time()
|
|
ents_merged, rels_merged, absorbed = merge_aliases(ents_e, rels_e)
|
|
ents_merged_dict = {(v["type"], v["name"].lower()): v for v in ents_merged.values()}
|
|
stats_post = graph_stats(ents_merged_dict, rels_merged)
|
|
elapsed_coref = time.time() - t0
|
|
out["coref"] = {
|
|
"elapsed": round(elapsed_coref, 2),
|
|
"pre_stats": graph_stats(ents_e, rels_e),
|
|
"post_stats": stats_post,
|
|
"n_absorbed": len(absorbed),
|
|
"absorbed_sample": list(absorbed.items())[:8],
|
|
}
|
|
print(f" pre: {out['coref']['pre_stats']}")
|
|
print(f" post: {out['coref']['post_stats']}")
|
|
print(f" absorbed: {len(absorbed)} e.g. {list(absorbed.items())[:3]}")
|
|
|
|
# --- top entities post-coref ---
|
|
top_rows = []
|
|
for v in sorted(ents_merged.values(), key=lambda x: -x["count"])[:25]:
|
|
top_rows.append({
|
|
"type": v["type"],
|
|
"canonical": v["name"],
|
|
"mentions": v["count"],
|
|
"n_aliases": len(v.get("aliases", [])),
|
|
"aliases_sample": v.get("aliases", [])[:3],
|
|
})
|
|
out["top_entities_post_coref"] = top_rows
|
|
|
|
# --- relations top ---
|
|
top_rels = []
|
|
for (h, rt, t), c in sorted(rels_merged.items(), key=lambda x: -x[1])[:25]:
|
|
top_rels.append({"from": h, "kind": rt, "to": t, "count": c})
|
|
out["top_relations_post_coref"] = top_rels
|
|
|
|
# --- save ents_merged + rels_merged for graph rendering ---
|
|
out["ents_merged"] = [{"name": v["name"], "type": v["type"], "count": v["count"]}
|
|
for v in ents_merged.values()]
|
|
out["rels_merged"] = [{"from": h, "kind": rt, "to": t, "count": c}
|
|
for (h, rt, t), c in rels_merged.items()]
|
|
|
|
out_path = HERE / "improvements.json"
|
|
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False))
|
|
print(f"\n[saved] {out_path} ({out_path.stat().st_size:,} bytes)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|