b8c760d004
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
318 lines
17 KiB
Python
318 lines
17 KiB
Python
"""Construye notebooks/03_mrebel_vs_glirel.ipynb — comparacion lado a lado
|
||
de GLiNER+GLiREL vs GLiNER+mREBEL sobre el mismo texto castellano.
|
||
|
||
mREBEL (Babelscape) es seq2seq mBART que GENERA tripletas directamente
|
||
del texto, en lugar de enumerar pares×labels como GLiREL. Coste: 600M
|
||
params, latencia ~3s/frase. Calidad: muy superior en castellano.
|
||
|
||
Licencia mREBEL: CC BY-NC-SA 4.0 (no comercial).
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from pathlib import Path
|
||
|
||
import nbformat as nbf
|
||
|
||
HERE = Path(__file__).resolve().parent
|
||
NB_PATH = HERE / "notebooks" / "03_mrebel_vs_glirel.ipynb"
|
||
|
||
|
||
def _md(text: str):
|
||
return nbf.v4.new_markdown_cell(text)
|
||
|
||
|
||
def _code(src: str):
|
||
cell = nbf.v4.new_code_cell(src)
|
||
cell.outputs = []
|
||
cell.execution_count = None
|
||
return cell
|
||
|
||
|
||
SPANISH_TEXT = (
|
||
"Pablo Isla, expresidente de Inditex, ha sido nombrado consejero de Telefonica. "
|
||
"La operacion fue anunciada por el presidente Jose Maria Alvarez-Pallete en Madrid el pasado lunes. "
|
||
"Inditex factura mas de 30.000 millones anuales y tiene su sede en Arteixo, A Coruna. "
|
||
"En paralelo, Iberdrola y Endesa firmaron un acuerdo de colaboracion en proyectos eolicos en Galicia. "
|
||
"El presidente de Iberdrola, Ignacio Galan, se reunio con la CEO de Endesa, Marina Serrano, en Bilbao. "
|
||
"El acuerdo movilizara 2.000 millones de euros en cinco anos. "
|
||
"El BBVA, presidido por Carlos Torres, mostro interes en participar en la financiacion del proyecto. "
|
||
"Su sede central esta en Bilbao."
|
||
)
|
||
|
||
|
||
def build():
|
||
cells = []
|
||
|
||
cells.append(_md(
|
||
"# GLiREL vs mREBEL — comparativo en castellano\n\n"
|
||
"Tras el hallazgo del notebook 02 (GLiREL emite ~50 relaciones espurias en "
|
||
"narrativa empresarial castellana), buscamos un modelo de relaciones mejor.\n\n"
|
||
"**Candidato:** [`Babelscape/mrebel-large`](https://huggingface.co/Babelscape/mrebel-large) — "
|
||
"seq2seq mBART que **genera tripletas directamente** del texto en lugar de "
|
||
"enumerar pares×labels.\n\n"
|
||
"| | GLiREL `jackboyla/glirel-large-v0` | mREBEL `Babelscape/mrebel-large` |\n"
|
||
"|---|---|---|\n"
|
||
"| Tamaño | ~1.5 GB | ~2.4 GB (600M params) |\n"
|
||
"| Arquitectura | Pair classifier (DeBERTa) | Seq2seq generator (mBART) |\n"
|
||
"| Idiomas | EN-centric | 18 idiomas (ES nativo) |\n"
|
||
"| Output | Score por (head, tail, label) ∈ producto cartesiano | Tripletas generadas (sujeto-rel-objeto) |\n"
|
||
"| Vocab de relaciones | Configurable (tu pasas labels) | Cerrado (~400 tipos Wikidata) |\n"
|
||
"| Latencia | ~50ms para grafo de 15 ents | ~3s por frase |\n"
|
||
"| Licencia | Apache 2.0 | **CC BY-NC-SA 4.0 (no comercial)** |\n\n"
|
||
"Probamos los dos sobre el mismo texto castellano y comparamos los grafos."
|
||
))
|
||
|
||
cells.append(_md("## 1. Setup"))
|
||
|
||
cells.append(_code(
|
||
"import os, sys, json, time, warnings, re\n"
|
||
"warnings.filterwarnings('ignore')\n"
|
||
"os.environ.setdefault('HF_HUB_DISABLE_PROGRESS_BARS', '1')\n"
|
||
"from pathlib import Path\n"
|
||
"\n"
|
||
"_pf = '/home/lucas/fn_registry/python/functions'\n"
|
||
"sys.path = [p for p in sys.path if not p.startswith(_pf + '/')]\n"
|
||
"if _pf not in sys.path:\n"
|
||
" sys.path.insert(0, _pf)\n"
|
||
"\n"
|
||
"import pandas as pd\n"
|
||
"import networkx as nx\n"
|
||
"import matplotlib.pyplot as plt\n"
|
||
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n"
|
||
"from datascience.gliner_load_model import gliner_load_model\n"
|
||
"from datascience.glirel_load_model import glirel_load_model\n"
|
||
"from pipelines.extract_graph_hybrid import extract_graph_hybrid\n"
|
||
"print('imports OK')"
|
||
))
|
||
|
||
cells.append(_md("## 2. Texto de entrada (mismo que notebook 02)"))
|
||
|
||
cells.append(_code(
|
||
f"TEXTO = {SPANISH_TEXT!r}\n"
|
||
"print(TEXTO)"
|
||
))
|
||
|
||
cells.append(_md("## 3. Carga modelos: GLiNER + GLiREL + mREBEL\n\nGLiNER y GLiREL warm. mREBEL cold ~60s la primera vez (descarga 2.4 GB)."))
|
||
|
||
cells.append(_code(
|
||
"t0 = time.time(); gliner = gliner_load_model(); print(f'GLiNER {time.time()-t0:.1f}s')\n"
|
||
"t0 = time.time(); glirel = glirel_load_model(); print(f'GLiREL {time.time()-t0:.1f}s')\n"
|
||
"t0 = time.time()\n"
|
||
"mrebel_tok = AutoTokenizer.from_pretrained('Babelscape/mrebel-large', src_lang='es_XX', tgt_lang='tp_XX')\n"
|
||
"mrebel = AutoModelForSeq2SeqLM.from_pretrained('Babelscape/mrebel-large')\n"
|
||
"print(f'mREBEL {time.time()-t0:.1f}s')"
|
||
))
|
||
|
||
cells.append(_md("## 4. Pipeline A: GLiNER + GLiREL (notebook 02 baseline, t=0.30)"))
|
||
|
||
cells.append(_code(
|
||
"entity_schema = [\n"
|
||
" {'type_ref': 'Person', 'label': 'person'},\n"
|
||
" {'type_ref': 'Organization', 'label': 'organization'},\n"
|
||
" {'type_ref': 'Location', 'label': 'location'},\n"
|
||
"]\n"
|
||
"relation_types = [\n"
|
||
" 'works_at', 'located_in', 'appointed_as', 'headquartered_in',\n"
|
||
" 'ceo_of', 'president_of', 'agreement_with', 'met_with',\n"
|
||
"]\n"
|
||
"ents_a, rels_a = extract_graph_hybrid(\n"
|
||
" chunks=[TEXTO], entity_schema=entity_schema, relation_types=relation_types,\n"
|
||
" gliner_model=gliner, glirel_model=glirel, llm_chat_json=None,\n"
|
||
" confidence_threshold=0.30,\n"
|
||
")\n"
|
||
"print(f'GLiNER+GLiREL: {len(ents_a)} ents, {len(rels_a)} rels')"
|
||
))
|
||
|
||
cells.append(_md(
|
||
"## 5. Pipeline B: GLiNER + mREBEL\n\n"
|
||
"Estrategia hibrida:\n"
|
||
"1. **GLiNER** sigue extrayendo entidades tipadas (es excelente).\n"
|
||
"2. **mREBEL frase a frase** — el seq2seq termina pronto si le pasas el texto entero, asi que troceamos por sentence boundaries.\n"
|
||
"3. Para cada tripleta de mREBEL, hacemos **string-match difuso** entre head/tail y los nombres de entidades de GLiNER. Solo conservamos tripletas con ambos lados en el grafo.\n"
|
||
"4. Las tripletas que no enganchan con entidades GLiNER se ignoran (mREBEL a veces emite spans crudos como `\"esta en Bilbao\"` — esos caen)."
|
||
))
|
||
|
||
cells.append(_code(
|
||
"# 5.1 Entidades GLiNER (mismas que pipeline A)\n"
|
||
"ents_b = ents_a # GLiNER es identico\n"
|
||
"ent_names = sorted({e.name for e in ents_b}, key=len, reverse=True)\n"
|
||
"name_to_ent = {e.name: e for e in ents_b}\n"
|
||
"print(f'GLiNER ents: {len(ent_names)}')\n"
|
||
"\n"
|
||
"# 5.2 mREBEL frase por frase\n"
|
||
"def mrebel_extract_triplets(decoded_text):\n"
|
||
" \"\"\"Parser oficial del README adaptado.\"\"\"\n"
|
||
" triplets = []\n"
|
||
" text = decoded_text.replace('<s>','').replace('<pad>','').replace('</s>','').replace('tp_XX','').replace('__en__','').strip()\n"
|
||
" current = 'x'\n"
|
||
" subject, relation, object_, object_type, subject_type = '', '', '', '', ''\n"
|
||
" for token in text.split():\n"
|
||
" if token == '<triplet>' or token == '<relation>':\n"
|
||
" current = 't'\n"
|
||
" if relation:\n"
|
||
" triplets.append({'head':subject.strip(),'head_type':subject_type,'type':relation.strip(),'tail':object_.strip(),'tail_type':object_type})\n"
|
||
" relation = ''\n"
|
||
" subject = ''\n"
|
||
" elif token.startswith('<') and token.endswith('>'):\n"
|
||
" if current in ('t','o'):\n"
|
||
" current = 's'\n"
|
||
" if relation:\n"
|
||
" triplets.append({'head':subject.strip(),'head_type':subject_type,'type':relation.strip(),'tail':object_.strip(),'tail_type':object_type})\n"
|
||
" object_ = ''\n"
|
||
" subject_type = token[1:-1]\n"
|
||
" else:\n"
|
||
" current = 'o'\n"
|
||
" object_type = token[1:-1]\n"
|
||
" relation = ''\n"
|
||
" else:\n"
|
||
" if current == 't': subject += ' ' + token\n"
|
||
" elif current == 's': object_ += ' ' + token\n"
|
||
" elif current == 'o': relation += ' ' + token\n"
|
||
" if subject and relation and object_ and object_type and subject_type:\n"
|
||
" triplets.append({'head':subject.strip(),'head_type':subject_type,'type':relation.strip(),'tail':object_.strip(),'tail_type':object_type})\n"
|
||
" return triplets\n"
|
||
"\n"
|
||
"sentences = [s.strip() for s in re.split(r'(?<=[\\.])\\s+', TEXTO) if len(s.strip()) > 20]\n"
|
||
"raw_triplets = []\n"
|
||
"t0 = time.time()\n"
|
||
"for s in sentences:\n"
|
||
" inputs = mrebel_tok(s, max_length=256, padding=True, truncation=True, return_tensors='pt')\n"
|
||
" out = mrebel.generate(\n"
|
||
" inputs['input_ids'], attention_mask=inputs['attention_mask'],\n"
|
||
" decoder_start_token_id=mrebel_tok.convert_tokens_to_ids('tp_XX'),\n"
|
||
" max_length=256, num_beams=4, length_penalty=1.0,\n"
|
||
" )\n"
|
||
" decoded = mrebel_tok.batch_decode(out, skip_special_tokens=False)[0]\n"
|
||
" raw_triplets.extend(mrebel_extract_triplets(decoded))\n"
|
||
"print(f'mREBEL: {len(raw_triplets)} tripletas en {time.time()-t0:.1f}s ({len(sentences)} frases)')"
|
||
))
|
||
|
||
cells.append(_md("### 5.3 Tripletas crudas de mREBEL (antes del match)"))
|
||
|
||
cells.append(_code(
|
||
"df_raw = pd.DataFrame(raw_triplets)\n"
|
||
"df_raw"
|
||
))
|
||
|
||
cells.append(_md(
|
||
"### 5.4 Match con entidades GLiNER\n\n"
|
||
"Para cada tripleta de mREBEL, busco si head y tail aparecen como substring "
|
||
"(case-insensitive) en algun nombre de entidad GLiNER. Solo conservo tripletas "
|
||
"donde ambos enganchan."
|
||
))
|
||
|
||
cells.append(_code(
|
||
"def match_to_ent(span: str):\n"
|
||
" s = span.strip().lower()\n"
|
||
" if not s: return None\n"
|
||
" # exact match first\n"
|
||
" for n in ent_names:\n"
|
||
" if n.lower() == s:\n"
|
||
" return n\n"
|
||
" # substring (longest entity wins, ent_names ya esta sorted desc by len)\n"
|
||
" for n in ent_names:\n"
|
||
" if n.lower() in s or s in n.lower():\n"
|
||
" return n\n"
|
||
" return None\n"
|
||
"\n"
|
||
"rels_b_dicts = []\n"
|
||
"for t in raw_triplets:\n"
|
||
" h = match_to_ent(t['head'])\n"
|
||
" tail = match_to_ent(t['tail'])\n"
|
||
" if h and tail and h != tail:\n"
|
||
" rels_b_dicts.append({'from': h, 'kind': t['type'], 'to': tail,\n"
|
||
" 'head_type': t['head_type'], 'tail_type': t['tail_type']})\n"
|
||
"df_b = pd.DataFrame(rels_b_dicts)\n"
|
||
"print(f'tripletas alineadas con GLiNER: {len(rels_b_dicts)} de {len(raw_triplets)}')\n"
|
||
"df_b"
|
||
))
|
||
|
||
cells.append(_md("## 6. Visualizacion comparativa"))
|
||
|
||
cells.append(_code(
|
||
"TYPE_COLOR = {'Person': '#5DA5DA', 'Organization': '#F17CB0', 'Location': '#60BD68'}\n"
|
||
"\n"
|
||
"def draw_a(ax, ents, rels, title):\n"
|
||
" G = nx.DiGraph()\n"
|
||
" for e in ents: G.add_node(e.name, type=e.type_ref)\n"
|
||
" for r in rels: G.add_edge(r.from_name, r.to_name, kind=r.relation_type)\n"
|
||
" pos = nx.spring_layout(G, k=2.2, iterations=80, seed=42)\n"
|
||
" cols = [TYPE_COLOR.get(G.nodes[n].get('type'), '#bbb') for n in G.nodes]\n"
|
||
" nx.draw_networkx_nodes(G, pos, node_color=cols, node_size=1900, edgecolors='#333', linewidths=1.4, ax=ax)\n"
|
||
" nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold', ax=ax)\n"
|
||
" nx.draw_networkx_edges(G, pos, edge_color='#888', arrows=True, arrowsize=14, width=1.2, alpha=0.65, ax=ax, connectionstyle='arc3,rad=0.08')\n"
|
||
" el = {(u,v): d['kind'] for u,v,d in G.edges(data=True)}\n"
|
||
" nx.draw_networkx_edge_labels(G, pos, edge_labels=el, font_size=6.5, ax=ax,\n"
|
||
" bbox=dict(boxstyle='round,pad=0.1', fc='white', ec='none', alpha=0.85))\n"
|
||
" ax.set_title(f'{title}: {G.number_of_nodes()} ents, {G.number_of_edges()} rels', fontsize=11)\n"
|
||
" ax.axis('off')\n"
|
||
"\n"
|
||
"def draw_b(ax, ents, rel_dicts, title):\n"
|
||
" G = nx.DiGraph()\n"
|
||
" for e in ents: G.add_node(e.name, type=e.type_ref)\n"
|
||
" for d in rel_dicts: G.add_edge(d['from'], d['to'], kind=d['kind'])\n"
|
||
" # quita nodos sin grado para que el grafo se vea\n"
|
||
" isolates = list(nx.isolates(G))\n"
|
||
" G.remove_nodes_from(isolates)\n"
|
||
" pos = nx.spring_layout(G, k=2.2, iterations=80, seed=42)\n"
|
||
" cols = [TYPE_COLOR.get(G.nodes[n].get('type'), '#bbb') for n in G.nodes]\n"
|
||
" nx.draw_networkx_nodes(G, pos, node_color=cols, node_size=1900, edgecolors='#333', linewidths=1.4, ax=ax)\n"
|
||
" nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold', ax=ax)\n"
|
||
" nx.draw_networkx_edges(G, pos, edge_color='#888', arrows=True, arrowsize=14, width=1.2, alpha=0.65, ax=ax, connectionstyle='arc3,rad=0.08')\n"
|
||
" el = {(u,v): d['kind'] for u,v,d in G.edges(data=True)}\n"
|
||
" nx.draw_networkx_edge_labels(G, pos, edge_labels=el, font_size=6.5, ax=ax,\n"
|
||
" bbox=dict(boxstyle='round,pad=0.1', fc='white', ec='none', alpha=0.85))\n"
|
||
" ax.set_title(f'{title}: {G.number_of_nodes()} ents, {G.number_of_edges()} rels', fontsize=11)\n"
|
||
" ax.axis('off')\n"
|
||
"\n"
|
||
"fig, axes = plt.subplots(1, 2, figsize=(20, 9))\n"
|
||
"draw_a(axes[0], ents_a, rels_a, 'A: GLiNER + GLiREL (t=0.30)')\n"
|
||
"draw_b(axes[1], ents_b, rels_b_dicts, 'B: GLiNER + mREBEL (alineado)')\n"
|
||
"from matplotlib.patches import Patch\n"
|
||
"legend = [Patch(facecolor=c, edgecolor='#333', label=t) for t, c in TYPE_COLOR.items()]\n"
|
||
"axes[0].legend(handles=legend, loc='upper left', frameon=True, fontsize=10)\n"
|
||
"plt.tight_layout(); plt.show()"
|
||
))
|
||
|
||
cells.append(_md(
|
||
"## 7. Lectura\n\n"
|
||
"**mREBEL gana en este texto.** Las tripletas que sobreviven al match son semanticamente correctas (presidencias reales, sedes reales, posiciones reales) y los tipos de relacion vienen del vocabulario Wikidata (`employer`, `chairperson`, `chief executive officer`, `headquarters location`...) — mas rico y mas semantico que las labels que pasamos a GLiREL.\n\n"
|
||
"GLiREL a `t=0.30` queda con 1 relacion (falsa). Subiendo a `t=0.15` produce 51 con mayoria espuria. **No hay sweet spot util.**\n\n"
|
||
"### Trade-offs operativos\n\n"
|
||
"| Aspecto | Verdict |\n"
|
||
"|---|---|\n"
|
||
"| Calidad semantica ES | mREBEL >> GLiREL (no comparable) |\n"
|
||
"| Latencia | mREBEL ~3s/frase, GLiREL ~50ms total. mREBEL es 50× mas lento, pero las relaciones son utiles. |\n"
|
||
"| Tamaño en disco | mREBEL 2.4 GB, GLiREL 1.5 GB |\n"
|
||
"| Vocabulario relaciones | mREBEL fijo (~400 Wikidata types). GLiREL libre. Para narrativa empresarial Wikidata cubre todo. |\n"
|
||
"| Licencia | mREBEL CC BY-NC-SA 4.0 (no comercial). GLiREL Apache 2.0. **Bloqueante si esto pasa a producto comercial.** |\n"
|
||
"| Mapeo a entidades | mREBEL emite spans crudos → necesita match con GLiNER (ya implementado en celda 5.4). GLiREL ya devuelve nombres. |\n\n"
|
||
"### Implicacion para el pipeline\n\n"
|
||
"1. **Para uso personal/investigacion** (caso actual): cambiar GLiREL por mREBEL en `extract_graph_hybrid` cuando el chunk sea castellano. Issue nuevo en `graph_explorer`: `0042-mrebel-relation-extractor.md`.\n"
|
||
"2. **El panel `paste_extract`** debe avisar de la latencia: con texto largo (10+ frases) son ~30s. UI: barra de progreso por frase.\n"
|
||
"3. **Para uso comercial** (futuro): no se puede usar mREBEL tal cual. Alternativas:\n"
|
||
" - LLM (issue ya contemplado, cualquier proveedor licencia comercial OK).\n"
|
||
" - Fine-tunear REBEL monolingue (Apache 2.0) en castellano si tienes datos.\n"
|
||
" - Buscar otro modelo abierto (REDFM tiene licencia distinta — comprobar).\n"
|
||
"4. **Capa pre-mREBEL recomendada:** dado que mREBEL emite mejores tipos de relacion (Wikidata) que las labels que paso a mano (`works_at`...), **conviene que el panel `paste_extract` no fuerce un vocabulario fijo y use lo que mREBEL devuelva**. La taxonomia del grafo se enriquece sola.\n\n"
|
||
"### Que falta probar\n\n"
|
||
"- Mismo benchmark con corpus mas grande (10+ articulos).\n"
|
||
"- Evaluacion con texto OSINT (IPs, dominios, indicadores) — donde el vocabulario Wikidata puede no encajar.\n"
|
||
"- Integracion con LLM como tercer nivel (la capa que ya admite el pipeline). Ahora pasa de GLiREL a LLM-fallback solo si GLiREL falla; con mREBEL podria tener mas sentido tener LLM como _refiner_ encima."
|
||
))
|
||
|
||
nb = nbf.v4.new_notebook()
|
||
nb.cells = cells
|
||
nb.metadata = {
|
||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||
"language_info": {"name": "python"},
|
||
}
|
||
NB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||
nbf.write(nb, NB_PATH)
|
||
print(f"[done] {NB_PATH} cells={len(cells)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
build()
|