Files
gliner_glirel_tuning/playground/server.py
T
2026-05-04 23:44:11 +02:00

265 lines
9.5 KiB
Python

"""Playground server — GLiNER2 + post-filter typed sobre cualquier texto.
Aplica las recetas del notebook 08:
- snake_case verbal labels
- threshold=0.3
- post-filter por (head_type, tail_type)
- coreference simple normalize+substring
Run:
cd playground && ../.venv/bin/python3 server.py
Luego: http://localhost:7878
"""
from __future__ import annotations
import os
import re
import sys
import time
import warnings
from collections import defaultdict
from pathlib import Path
warnings.filterwarnings("ignore")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
# sys.path cleanup (mismo workaround documentado en notebook 08)
_pf = "/home/lucas/fn_registry/python/functions"
sys.path = [p for p in sys.path if not p.startswith(_pf + "/")]
if _pf not in sys.path:
sys.path.insert(0, _pf)
from fastapi import FastAPI
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from gliner2 import GLiNER2
HERE = Path(__file__).resolve().parent
# ── carga modelo una sola vez ──
print("[load] GLiNER2-large-v1 (CPU)...", flush=True)
t0 = time.time()
MODEL = GLiNER2.from_pretrained("fastino/gliner2-large-v1")
print(f"[load] done in {time.time()-t0:.1f}s", flush=True)
# ── recetas del notebook 08 ──
ENTITY_LABELS = ["person", "organization", "location"]
RELATION_LABELS = [
"works_at", "located_in", "ceo_of", "president_of",
"headquartered_in", "agreement_with", "subsidiary_of", "founded_by",
]
ALLOWED = {
"works_at": (["person"], ["organization"]),
"ceo_of": (["person"], ["organization"]),
"president_of": (["person"], ["organization"]),
"headquartered_in": (["organization"], ["location"]),
"located_in": (["organization", "person", "location"], ["location"]),
"agreement_with": (["organization"], ["organization"]),
"subsidiary_of": (["organization"], ["organization"]),
"founded_by": (["organization"], ["person"]),
}
def normalize_name(s: str) -> str:
s = re.sub(r"[\.,;:\"'`()\[\]]", "", s.strip())
s = re.sub(r"\s+", " ", s)
return s.strip().lower()
def merge_aliases(names: list[str]) -> dict[str, str]:
norm_groups: dict = defaultdict(list)
for n in names:
norm_groups[normalize_name(n)].append(n)
canonical: dict = {}
for nrm, group in norm_groups.items():
winner = max(group, key=lambda x: (len(x), x))
for n in group:
canonical[n] = winner
canon_set = sorted(set(canonical.values()), key=len, reverse=True)
absorbed: dict = {}
for long_n in canon_set:
long_norm = normalize_name(long_n)
for short_n in canon_set:
if short_n == long_n or short_n in absorbed:
continue
short_norm = normalize_name(short_n)
if len(short_norm) < 4:
continue
if re.search(r"\b" + re.escape(short_norm) + r"\b", long_norm):
absorbed[short_n] = long_n
final: dict = {}
for orig, canon in canonical.items():
final[orig] = absorbed.get(canon, canon)
return final
def filter_typed(rels: dict, name_to_type: dict, allowed: dict) -> tuple[list, list]:
keep: list = []
drop: list = []
for rt, pairs in rels.items():
head_ok, tail_ok = allowed.get(rt, (None, None))
for h, t in pairs:
ht = name_to_type.get(h.lower().strip())
tt = name_to_type.get(t.lower().strip())
if head_ok is None or (ht in head_ok and tt in tail_ok):
keep.append({"from": h, "kind": rt, "to": t, "head_type": ht, "tail_type": tt})
else:
drop.append({"from": h, "kind": rt, "to": t, "head_type": ht, "tail_type": tt})
return keep, drop
def chunk_text(text: str, max_chars: int = 1500, overlap_sentences: int = 2):
"""Split largo en chunks con sliding window. Same pattern as notebook 06."""
sentences = re.split(r"(?<=[\.!?])\s+", text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
i = 0
while i < len(sentences):
current_sents: list[str] = []
current_len = 0
if chunks and overlap_sentences > 0:
prev_sents = chunks[-1][-overlap_sentences:]
overlap_len = sum(len(s) + 1 for s in prev_sents)
next_sentence_len = len(sentences[i]) + 1
if overlap_len + next_sentence_len <= max_chars:
current_sents = list(prev_sents)
current_len = overlap_len
if i < len(sentences):
current_sents.append(sentences[i])
current_len += len(sentences[i]) + 1
i += 1
while i < len(sentences) and current_len + len(sentences[i]) + 1 <= max_chars:
current_sents.append(sentences[i])
current_len += len(sentences[i]) + 1
i += 1
chunks.append(current_sents)
return [" ".join(c) for c in chunks]
def extract_graph(text: str, threshold: float = 0.3, max_chars_per_chunk: int = 1500) -> dict:
schema = MODEL.create_schema().entities(ENTITY_LABELS).relations(RELATION_LABELS)
# Chunking automatico si el texto es largo
if len(text) <= max_chars_per_chunk:
chunks = [text]
else:
chunks = chunk_text(text, max_chars=max_chars_per_chunk, overlap_sentences=2)
print(f"[extract] {len(text)}c → {len(chunks)} chunks", flush=True)
t0 = time.time()
# Acumuladores deduplicados
name_to_type: dict = {} # name_lower → type (last seen wins)
name_canonical: dict = {} # name_lower → original casing
raw_relations: dict = {} # rel_type → list of (h, t)
for idx, chunk in enumerate(chunks):
r = MODEL.extract(chunk, schema=schema, threshold=threshold)
for typ, names in r["entities"].items():
for n in names:
key = n.lower().strip()
if not key: continue
if key not in name_to_type:
name_to_type[key] = typ
name_canonical[key] = n.strip()
# if seen with different name_canonical, keep the longer
elif len(n.strip()) > len(name_canonical[key]):
name_canonical[key] = n.strip()
for rt, pairs in r["relation_extraction"].items():
if rt not in raw_relations: raw_relations[rt] = []
for h, t in pairs:
raw_relations[rt].append((h.strip(), t.strip()))
if (idx + 1) % 10 == 0:
print(f"[extract] chunk {idx+1}/{len(chunks)} ents acum={len(name_to_type)} rels acum={sum(len(v) for v in raw_relations.values())}", flush=True)
# Post-filter typed
keep, drop = filter_typed(raw_relations, name_to_type, ALLOWED)
# Coreferencia: alias map sobre los canonical names
original_names = list(name_canonical.values())
alias = merge_aliases(original_names)
# Construir nodos con alias aplicado
nodes_dict: dict = {}
for key, typ in name_to_type.items():
canon_orig = name_canonical[key]
canon_resolved = alias.get(canon_orig, canon_orig)
if canon_resolved not in nodes_dict:
nodes_dict[canon_resolved] = typ
# Construir aristas dedupeadas tras alias
edges_set: set = set()
for e in keep:
h_canon = alias.get(e["from"], e["from"])
t_canon = alias.get(e["to"], e["to"])
if h_canon == t_canon:
continue
if h_canon not in nodes_dict:
nodes_dict[h_canon] = e.get("head_type") or "?"
if t_canon not in nodes_dict:
nodes_dict[t_canon] = e.get("tail_type") or "?"
edges_set.add((h_canon, e["kind"], t_canon))
# Layout server-side (sigma solo renderiza)
import networkx as nx
G = nx.DiGraph()
for n, t in nodes_dict.items():
G.add_node(n)
for h, k, t in edges_set:
G.add_edge(h, t, kind=k)
if G.number_of_nodes() > 0:
try:
pos = nx.spring_layout(G, k=2.0, iterations=80, seed=42)
except Exception:
pos = {n: (0.0, 0.0) for n in G.nodes}
else:
pos = {}
elapsed = time.time() - t0
print(f"[extract] done {elapsed:.2f}s nodos={len(nodes_dict)} aristas={len(edges_set)}", flush=True)
return {
"elapsed_s": round(elapsed, 2),
"n_chunks": len(chunks),
"n_nodes": len(nodes_dict),
"n_edges": len(edges_set),
"n_dropped_typed": len(drop),
"nodes": [
{"id": n, "label": n, "type": t,
"x": float(pos.get(n, (0.0, 0.0))[0]),
"y": float(pos.get(n, (0.0, 0.0))[1])}
for n, t in nodes_dict.items()
],
"edges": [{"from": h, "to": t, "label": k} for h, k, t in edges_set],
"dropped": drop[:10],
}
# ── API ──
app = FastAPI(title="GLiNER2 Playground")
app.mount("/static", StaticFiles(directory=HERE / "static"), name="static")
class ExtractReq(BaseModel):
text: str
threshold: float = 0.3
@app.get("/")
def index():
return FileResponse(HERE / "index.html")
@app.post("/extract")
def extract(req: ExtractReq):
if not req.text.strip():
return JSONResponse({"error": "empty text"}, status_code=400)
return extract_graph(req.text, threshold=req.threshold)
if __name__ == "__main__":
import uvicorn
print("\nServing at http://localhost:7878\n", flush=True)
uvicorn.run(app, host="0.0.0.0", port=7878, log_level="warning")