b8c760d004
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
265 lines
9.5 KiB
Python
265 lines
9.5 KiB
Python
"""Playground server — GLiNER2 + post-filter typed sobre cualquier texto.
|
|
|
|
Aplica las recetas del notebook 08:
|
|
- snake_case verbal labels
|
|
- threshold=0.3
|
|
- post-filter por (head_type, tail_type)
|
|
- coreference simple normalize+substring
|
|
|
|
Run:
|
|
cd playground && ../.venv/bin/python3 server.py
|
|
Luego: http://localhost:7878
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
warnings.filterwarnings("ignore")
|
|
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
|
|
# sys.path cleanup (mismo workaround documentado en notebook 08)
|
|
_pf = "/home/lucas/fn_registry/python/functions"
|
|
sys.path = [p for p in sys.path if not p.startswith(_pf + "/")]
|
|
if _pf not in sys.path:
|
|
sys.path.insert(0, _pf)
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import FileResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
from gliner2 import GLiNER2
|
|
|
|
HERE = Path(__file__).resolve().parent
|
|
|
|
# ── carga modelo una sola vez ──
|
|
print("[load] GLiNER2-large-v1 (CPU)...", flush=True)
|
|
t0 = time.time()
|
|
MODEL = GLiNER2.from_pretrained("fastino/gliner2-large-v1")
|
|
print(f"[load] done in {time.time()-t0:.1f}s", flush=True)
|
|
|
|
# ── recetas del notebook 08 ──
|
|
ENTITY_LABELS = ["person", "organization", "location"]
|
|
RELATION_LABELS = [
|
|
"works_at", "located_in", "ceo_of", "president_of",
|
|
"headquartered_in", "agreement_with", "subsidiary_of", "founded_by",
|
|
]
|
|
ALLOWED = {
|
|
"works_at": (["person"], ["organization"]),
|
|
"ceo_of": (["person"], ["organization"]),
|
|
"president_of": (["person"], ["organization"]),
|
|
"headquartered_in": (["organization"], ["location"]),
|
|
"located_in": (["organization", "person", "location"], ["location"]),
|
|
"agreement_with": (["organization"], ["organization"]),
|
|
"subsidiary_of": (["organization"], ["organization"]),
|
|
"founded_by": (["organization"], ["person"]),
|
|
}
|
|
|
|
|
|
def normalize_name(s: str) -> str:
|
|
s = re.sub(r"[\.,;:\"'`()\[\]]", "", s.strip())
|
|
s = re.sub(r"\s+", " ", s)
|
|
return s.strip().lower()
|
|
|
|
|
|
def merge_aliases(names: list[str]) -> dict[str, str]:
|
|
norm_groups: dict = defaultdict(list)
|
|
for n in names:
|
|
norm_groups[normalize_name(n)].append(n)
|
|
canonical: dict = {}
|
|
for nrm, group in norm_groups.items():
|
|
winner = max(group, key=lambda x: (len(x), x))
|
|
for n in group:
|
|
canonical[n] = winner
|
|
canon_set = sorted(set(canonical.values()), key=len, reverse=True)
|
|
absorbed: dict = {}
|
|
for long_n in canon_set:
|
|
long_norm = normalize_name(long_n)
|
|
for short_n in canon_set:
|
|
if short_n == long_n or short_n in absorbed:
|
|
continue
|
|
short_norm = normalize_name(short_n)
|
|
if len(short_norm) < 4:
|
|
continue
|
|
if re.search(r"\b" + re.escape(short_norm) + r"\b", long_norm):
|
|
absorbed[short_n] = long_n
|
|
final: dict = {}
|
|
for orig, canon in canonical.items():
|
|
final[orig] = absorbed.get(canon, canon)
|
|
return final
|
|
|
|
|
|
def filter_typed(rels: dict, name_to_type: dict, allowed: dict) -> tuple[list, list]:
|
|
keep: list = []
|
|
drop: list = []
|
|
for rt, pairs in rels.items():
|
|
head_ok, tail_ok = allowed.get(rt, (None, None))
|
|
for h, t in pairs:
|
|
ht = name_to_type.get(h.lower().strip())
|
|
tt = name_to_type.get(t.lower().strip())
|
|
if head_ok is None or (ht in head_ok and tt in tail_ok):
|
|
keep.append({"from": h, "kind": rt, "to": t, "head_type": ht, "tail_type": tt})
|
|
else:
|
|
drop.append({"from": h, "kind": rt, "to": t, "head_type": ht, "tail_type": tt})
|
|
return keep, drop
|
|
|
|
|
|
def chunk_text(text: str, max_chars: int = 1500, overlap_sentences: int = 2):
|
|
"""Split largo en chunks con sliding window. Same pattern as notebook 06."""
|
|
sentences = re.split(r"(?<=[\.!?])\s+", text)
|
|
sentences = [s.strip() for s in sentences if s.strip()]
|
|
chunks = []
|
|
i = 0
|
|
while i < len(sentences):
|
|
current_sents: list[str] = []
|
|
current_len = 0
|
|
if chunks and overlap_sentences > 0:
|
|
prev_sents = chunks[-1][-overlap_sentences:]
|
|
overlap_len = sum(len(s) + 1 for s in prev_sents)
|
|
next_sentence_len = len(sentences[i]) + 1
|
|
if overlap_len + next_sentence_len <= max_chars:
|
|
current_sents = list(prev_sents)
|
|
current_len = overlap_len
|
|
if i < len(sentences):
|
|
current_sents.append(sentences[i])
|
|
current_len += len(sentences[i]) + 1
|
|
i += 1
|
|
while i < len(sentences) and current_len + len(sentences[i]) + 1 <= max_chars:
|
|
current_sents.append(sentences[i])
|
|
current_len += len(sentences[i]) + 1
|
|
i += 1
|
|
chunks.append(current_sents)
|
|
return [" ".join(c) for c in chunks]
|
|
|
|
|
|
def extract_graph(text: str, threshold: float = 0.3, max_chars_per_chunk: int = 1500) -> dict:
|
|
schema = MODEL.create_schema().entities(ENTITY_LABELS).relations(RELATION_LABELS)
|
|
|
|
# Chunking automatico si el texto es largo
|
|
if len(text) <= max_chars_per_chunk:
|
|
chunks = [text]
|
|
else:
|
|
chunks = chunk_text(text, max_chars=max_chars_per_chunk, overlap_sentences=2)
|
|
print(f"[extract] {len(text)}c → {len(chunks)} chunks", flush=True)
|
|
|
|
t0 = time.time()
|
|
|
|
# Acumuladores deduplicados
|
|
name_to_type: dict = {} # name_lower → type (last seen wins)
|
|
name_canonical: dict = {} # name_lower → original casing
|
|
raw_relations: dict = {} # rel_type → list of (h, t)
|
|
|
|
for idx, chunk in enumerate(chunks):
|
|
r = MODEL.extract(chunk, schema=schema, threshold=threshold)
|
|
for typ, names in r["entities"].items():
|
|
for n in names:
|
|
key = n.lower().strip()
|
|
if not key: continue
|
|
if key not in name_to_type:
|
|
name_to_type[key] = typ
|
|
name_canonical[key] = n.strip()
|
|
# if seen with different name_canonical, keep the longer
|
|
elif len(n.strip()) > len(name_canonical[key]):
|
|
name_canonical[key] = n.strip()
|
|
for rt, pairs in r["relation_extraction"].items():
|
|
if rt not in raw_relations: raw_relations[rt] = []
|
|
for h, t in pairs:
|
|
raw_relations[rt].append((h.strip(), t.strip()))
|
|
if (idx + 1) % 10 == 0:
|
|
print(f"[extract] chunk {idx+1}/{len(chunks)} ents acum={len(name_to_type)} rels acum={sum(len(v) for v in raw_relations.values())}", flush=True)
|
|
|
|
# Post-filter typed
|
|
keep, drop = filter_typed(raw_relations, name_to_type, ALLOWED)
|
|
|
|
# Coreferencia: alias map sobre los canonical names
|
|
original_names = list(name_canonical.values())
|
|
alias = merge_aliases(original_names)
|
|
|
|
# Construir nodos con alias aplicado
|
|
nodes_dict: dict = {}
|
|
for key, typ in name_to_type.items():
|
|
canon_orig = name_canonical[key]
|
|
canon_resolved = alias.get(canon_orig, canon_orig)
|
|
if canon_resolved not in nodes_dict:
|
|
nodes_dict[canon_resolved] = typ
|
|
|
|
# Construir aristas dedupeadas tras alias
|
|
edges_set: set = set()
|
|
for e in keep:
|
|
h_canon = alias.get(e["from"], e["from"])
|
|
t_canon = alias.get(e["to"], e["to"])
|
|
if h_canon == t_canon:
|
|
continue
|
|
if h_canon not in nodes_dict:
|
|
nodes_dict[h_canon] = e.get("head_type") or "?"
|
|
if t_canon not in nodes_dict:
|
|
nodes_dict[t_canon] = e.get("tail_type") or "?"
|
|
edges_set.add((h_canon, e["kind"], t_canon))
|
|
|
|
# Layout server-side (sigma solo renderiza)
|
|
import networkx as nx
|
|
G = nx.DiGraph()
|
|
for n, t in nodes_dict.items():
|
|
G.add_node(n)
|
|
for h, k, t in edges_set:
|
|
G.add_edge(h, t, kind=k)
|
|
if G.number_of_nodes() > 0:
|
|
try:
|
|
pos = nx.spring_layout(G, k=2.0, iterations=80, seed=42)
|
|
except Exception:
|
|
pos = {n: (0.0, 0.0) for n in G.nodes}
|
|
else:
|
|
pos = {}
|
|
|
|
elapsed = time.time() - t0
|
|
print(f"[extract] done {elapsed:.2f}s nodos={len(nodes_dict)} aristas={len(edges_set)}", flush=True)
|
|
|
|
return {
|
|
"elapsed_s": round(elapsed, 2),
|
|
"n_chunks": len(chunks),
|
|
"n_nodes": len(nodes_dict),
|
|
"n_edges": len(edges_set),
|
|
"n_dropped_typed": len(drop),
|
|
"nodes": [
|
|
{"id": n, "label": n, "type": t,
|
|
"x": float(pos.get(n, (0.0, 0.0))[0]),
|
|
"y": float(pos.get(n, (0.0, 0.0))[1])}
|
|
for n, t in nodes_dict.items()
|
|
],
|
|
"edges": [{"from": h, "to": t, "label": k} for h, k, t in edges_set],
|
|
"dropped": drop[:10],
|
|
}
|
|
|
|
|
|
# ── API ──
|
|
app = FastAPI(title="GLiNER2 Playground")
|
|
app.mount("/static", StaticFiles(directory=HERE / "static"), name="static")
|
|
|
|
|
|
class ExtractReq(BaseModel):
|
|
text: str
|
|
threshold: float = 0.3
|
|
|
|
|
|
@app.get("/")
|
|
def index():
|
|
return FileResponse(HERE / "index.html")
|
|
|
|
|
|
@app.post("/extract")
|
|
def extract(req: ExtractReq):
|
|
if not req.text.strip():
|
|
return JSONResponse({"error": "empty text"}, status_code=400)
|
|
return extract_graph(req.text, threshold=req.threshold)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
print("\nServing at http://localhost:7878\n", flush=True)
|
|
uvicorn.run(app, host="0.0.0.0", port=7878, log_level="warning")
|