Files
gliner_glirel_tuning/build_notebook_mrebel.py
2026-05-04 23:44:11 +02:00

318 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Construye notebooks/03_mrebel_vs_glirel.ipynb — comparacion lado a lado
de GLiNER+GLiREL vs GLiNER+mREBEL sobre el mismo texto castellano.
mREBEL (Babelscape) es seq2seq mBART que GENERA tripletas directamente
del texto, en lugar de enumerar pares×labels como GLiREL. Coste: 600M
params, latencia ~3s/frase. Calidad: muy superior en castellano.
Licencia mREBEL: CC BY-NC-SA 4.0 (no comercial).
"""
from __future__ import annotations
import json
from pathlib import Path
import nbformat as nbf
HERE = Path(__file__).resolve().parent
NB_PATH = HERE / "notebooks" / "03_mrebel_vs_glirel.ipynb"
def _md(text: str):
return nbf.v4.new_markdown_cell(text)
def _code(src: str):
cell = nbf.v4.new_code_cell(src)
cell.outputs = []
cell.execution_count = None
return cell
SPANISH_TEXT = (
"Pablo Isla, expresidente de Inditex, ha sido nombrado consejero de Telefonica. "
"La operacion fue anunciada por el presidente Jose Maria Alvarez-Pallete en Madrid el pasado lunes. "
"Inditex factura mas de 30.000 millones anuales y tiene su sede en Arteixo, A Coruna. "
"En paralelo, Iberdrola y Endesa firmaron un acuerdo de colaboracion en proyectos eolicos en Galicia. "
"El presidente de Iberdrola, Ignacio Galan, se reunio con la CEO de Endesa, Marina Serrano, en Bilbao. "
"El acuerdo movilizara 2.000 millones de euros en cinco anos. "
"El BBVA, presidido por Carlos Torres, mostro interes en participar en la financiacion del proyecto. "
"Su sede central esta en Bilbao."
)
def build():
cells = []
cells.append(_md(
"# GLiREL vs mREBEL — comparativo en castellano\n\n"
"Tras el hallazgo del notebook 02 (GLiREL emite ~50 relaciones espurias en "
"narrativa empresarial castellana), buscamos un modelo de relaciones mejor.\n\n"
"**Candidato:** [`Babelscape/mrebel-large`](https://huggingface.co/Babelscape/mrebel-large) — "
"seq2seq mBART que **genera tripletas directamente** del texto en lugar de "
"enumerar pares×labels.\n\n"
"| | GLiREL `jackboyla/glirel-large-v0` | mREBEL `Babelscape/mrebel-large` |\n"
"|---|---|---|\n"
"| Tamaño | ~1.5 GB | ~2.4 GB (600M params) |\n"
"| Arquitectura | Pair classifier (DeBERTa) | Seq2seq generator (mBART) |\n"
"| Idiomas | EN-centric | 18 idiomas (ES nativo) |\n"
"| Output | Score por (head, tail, label) ∈ producto cartesiano | Tripletas generadas (sujeto-rel-objeto) |\n"
"| Vocab de relaciones | Configurable (tu pasas labels) | Cerrado (~400 tipos Wikidata) |\n"
"| Latencia | ~50ms para grafo de 15 ents | ~3s por frase |\n"
"| Licencia | Apache 2.0 | **CC BY-NC-SA 4.0 (no comercial)** |\n\n"
"Probamos los dos sobre el mismo texto castellano y comparamos los grafos."
))
cells.append(_md("## 1. Setup"))
cells.append(_code(
"import os, sys, json, time, warnings, re\n"
"warnings.filterwarnings('ignore')\n"
"os.environ.setdefault('HF_HUB_DISABLE_PROGRESS_BARS', '1')\n"
"from pathlib import Path\n"
"\n"
"_pf = '/home/lucas/fn_registry/python/functions'\n"
"sys.path = [p for p in sys.path if not p.startswith(_pf + '/')]\n"
"if _pf not in sys.path:\n"
" sys.path.insert(0, _pf)\n"
"\n"
"import pandas as pd\n"
"import networkx as nx\n"
"import matplotlib.pyplot as plt\n"
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n"
"from datascience.gliner_load_model import gliner_load_model\n"
"from datascience.glirel_load_model import glirel_load_model\n"
"from pipelines.extract_graph_hybrid import extract_graph_hybrid\n"
"print('imports OK')"
))
cells.append(_md("## 2. Texto de entrada (mismo que notebook 02)"))
cells.append(_code(
f"TEXTO = {SPANISH_TEXT!r}\n"
"print(TEXTO)"
))
cells.append(_md("## 3. Carga modelos: GLiNER + GLiREL + mREBEL\n\nGLiNER y GLiREL warm. mREBEL cold ~60s la primera vez (descarga 2.4 GB)."))
cells.append(_code(
"t0 = time.time(); gliner = gliner_load_model(); print(f'GLiNER {time.time()-t0:.1f}s')\n"
"t0 = time.time(); glirel = glirel_load_model(); print(f'GLiREL {time.time()-t0:.1f}s')\n"
"t0 = time.time()\n"
"mrebel_tok = AutoTokenizer.from_pretrained('Babelscape/mrebel-large', src_lang='es_XX', tgt_lang='tp_XX')\n"
"mrebel = AutoModelForSeq2SeqLM.from_pretrained('Babelscape/mrebel-large')\n"
"print(f'mREBEL {time.time()-t0:.1f}s')"
))
cells.append(_md("## 4. Pipeline A: GLiNER + GLiREL (notebook 02 baseline, t=0.30)"))
cells.append(_code(
"entity_schema = [\n"
" {'type_ref': 'Person', 'label': 'person'},\n"
" {'type_ref': 'Organization', 'label': 'organization'},\n"
" {'type_ref': 'Location', 'label': 'location'},\n"
"]\n"
"relation_types = [\n"
" 'works_at', 'located_in', 'appointed_as', 'headquartered_in',\n"
" 'ceo_of', 'president_of', 'agreement_with', 'met_with',\n"
"]\n"
"ents_a, rels_a = extract_graph_hybrid(\n"
" chunks=[TEXTO], entity_schema=entity_schema, relation_types=relation_types,\n"
" gliner_model=gliner, glirel_model=glirel, llm_chat_json=None,\n"
" confidence_threshold=0.30,\n"
")\n"
"print(f'GLiNER+GLiREL: {len(ents_a)} ents, {len(rels_a)} rels')"
))
cells.append(_md(
"## 5. Pipeline B: GLiNER + mREBEL\n\n"
"Estrategia hibrida:\n"
"1. **GLiNER** sigue extrayendo entidades tipadas (es excelente).\n"
"2. **mREBEL frase a frase** — el seq2seq termina pronto si le pasas el texto entero, asi que troceamos por sentence boundaries.\n"
"3. Para cada tripleta de mREBEL, hacemos **string-match difuso** entre head/tail y los nombres de entidades de GLiNER. Solo conservamos tripletas con ambos lados en el grafo.\n"
"4. Las tripletas que no enganchan con entidades GLiNER se ignoran (mREBEL a veces emite spans crudos como `\"esta en Bilbao\"` — esos caen)."
))
cells.append(_code(
"# 5.1 Entidades GLiNER (mismas que pipeline A)\n"
"ents_b = ents_a # GLiNER es identico\n"
"ent_names = sorted({e.name for e in ents_b}, key=len, reverse=True)\n"
"name_to_ent = {e.name: e for e in ents_b}\n"
"print(f'GLiNER ents: {len(ent_names)}')\n"
"\n"
"# 5.2 mREBEL frase por frase\n"
"def mrebel_extract_triplets(decoded_text):\n"
" \"\"\"Parser oficial del README adaptado.\"\"\"\n"
" triplets = []\n"
" text = decoded_text.replace('<s>','').replace('<pad>','').replace('</s>','').replace('tp_XX','').replace('__en__','').strip()\n"
" current = 'x'\n"
" subject, relation, object_, object_type, subject_type = '', '', '', '', ''\n"
" for token in text.split():\n"
" if token == '<triplet>' or token == '<relation>':\n"
" current = 't'\n"
" if relation:\n"
" triplets.append({'head':subject.strip(),'head_type':subject_type,'type':relation.strip(),'tail':object_.strip(),'tail_type':object_type})\n"
" relation = ''\n"
" subject = ''\n"
" elif token.startswith('<') and token.endswith('>'):\n"
" if current in ('t','o'):\n"
" current = 's'\n"
" if relation:\n"
" triplets.append({'head':subject.strip(),'head_type':subject_type,'type':relation.strip(),'tail':object_.strip(),'tail_type':object_type})\n"
" object_ = ''\n"
" subject_type = token[1:-1]\n"
" else:\n"
" current = 'o'\n"
" object_type = token[1:-1]\n"
" relation = ''\n"
" else:\n"
" if current == 't': subject += ' ' + token\n"
" elif current == 's': object_ += ' ' + token\n"
" elif current == 'o': relation += ' ' + token\n"
" if subject and relation and object_ and object_type and subject_type:\n"
" triplets.append({'head':subject.strip(),'head_type':subject_type,'type':relation.strip(),'tail':object_.strip(),'tail_type':object_type})\n"
" return triplets\n"
"\n"
"sentences = [s.strip() for s in re.split(r'(?<=[\\.])\\s+', TEXTO) if len(s.strip()) > 20]\n"
"raw_triplets = []\n"
"t0 = time.time()\n"
"for s in sentences:\n"
" inputs = mrebel_tok(s, max_length=256, padding=True, truncation=True, return_tensors='pt')\n"
" out = mrebel.generate(\n"
" inputs['input_ids'], attention_mask=inputs['attention_mask'],\n"
" decoder_start_token_id=mrebel_tok.convert_tokens_to_ids('tp_XX'),\n"
" max_length=256, num_beams=4, length_penalty=1.0,\n"
" )\n"
" decoded = mrebel_tok.batch_decode(out, skip_special_tokens=False)[0]\n"
" raw_triplets.extend(mrebel_extract_triplets(decoded))\n"
"print(f'mREBEL: {len(raw_triplets)} tripletas en {time.time()-t0:.1f}s ({len(sentences)} frases)')"
))
cells.append(_md("### 5.3 Tripletas crudas de mREBEL (antes del match)"))
cells.append(_code(
"df_raw = pd.DataFrame(raw_triplets)\n"
"df_raw"
))
cells.append(_md(
"### 5.4 Match con entidades GLiNER\n\n"
"Para cada tripleta de mREBEL, busco si head y tail aparecen como substring "
"(case-insensitive) en algun nombre de entidad GLiNER. Solo conservo tripletas "
"donde ambos enganchan."
))
cells.append(_code(
"def match_to_ent(span: str):\n"
" s = span.strip().lower()\n"
" if not s: return None\n"
" # exact match first\n"
" for n in ent_names:\n"
" if n.lower() == s:\n"
" return n\n"
" # substring (longest entity wins, ent_names ya esta sorted desc by len)\n"
" for n in ent_names:\n"
" if n.lower() in s or s in n.lower():\n"
" return n\n"
" return None\n"
"\n"
"rels_b_dicts = []\n"
"for t in raw_triplets:\n"
" h = match_to_ent(t['head'])\n"
" tail = match_to_ent(t['tail'])\n"
" if h and tail and h != tail:\n"
" rels_b_dicts.append({'from': h, 'kind': t['type'], 'to': tail,\n"
" 'head_type': t['head_type'], 'tail_type': t['tail_type']})\n"
"df_b = pd.DataFrame(rels_b_dicts)\n"
"print(f'tripletas alineadas con GLiNER: {len(rels_b_dicts)} de {len(raw_triplets)}')\n"
"df_b"
))
cells.append(_md("## 6. Visualizacion comparativa"))
cells.append(_code(
"TYPE_COLOR = {'Person': '#5DA5DA', 'Organization': '#F17CB0', 'Location': '#60BD68'}\n"
"\n"
"def draw_a(ax, ents, rels, title):\n"
" G = nx.DiGraph()\n"
" for e in ents: G.add_node(e.name, type=e.type_ref)\n"
" for r in rels: G.add_edge(r.from_name, r.to_name, kind=r.relation_type)\n"
" pos = nx.spring_layout(G, k=2.2, iterations=80, seed=42)\n"
" cols = [TYPE_COLOR.get(G.nodes[n].get('type'), '#bbb') for n in G.nodes]\n"
" nx.draw_networkx_nodes(G, pos, node_color=cols, node_size=1900, edgecolors='#333', linewidths=1.4, ax=ax)\n"
" nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold', ax=ax)\n"
" nx.draw_networkx_edges(G, pos, edge_color='#888', arrows=True, arrowsize=14, width=1.2, alpha=0.65, ax=ax, connectionstyle='arc3,rad=0.08')\n"
" el = {(u,v): d['kind'] for u,v,d in G.edges(data=True)}\n"
" nx.draw_networkx_edge_labels(G, pos, edge_labels=el, font_size=6.5, ax=ax,\n"
" bbox=dict(boxstyle='round,pad=0.1', fc='white', ec='none', alpha=0.85))\n"
" ax.set_title(f'{title}: {G.number_of_nodes()} ents, {G.number_of_edges()} rels', fontsize=11)\n"
" ax.axis('off')\n"
"\n"
"def draw_b(ax, ents, rel_dicts, title):\n"
" G = nx.DiGraph()\n"
" for e in ents: G.add_node(e.name, type=e.type_ref)\n"
" for d in rel_dicts: G.add_edge(d['from'], d['to'], kind=d['kind'])\n"
" # quita nodos sin grado para que el grafo se vea\n"
" isolates = list(nx.isolates(G))\n"
" G.remove_nodes_from(isolates)\n"
" pos = nx.spring_layout(G, k=2.2, iterations=80, seed=42)\n"
" cols = [TYPE_COLOR.get(G.nodes[n].get('type'), '#bbb') for n in G.nodes]\n"
" nx.draw_networkx_nodes(G, pos, node_color=cols, node_size=1900, edgecolors='#333', linewidths=1.4, ax=ax)\n"
" nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold', ax=ax)\n"
" nx.draw_networkx_edges(G, pos, edge_color='#888', arrows=True, arrowsize=14, width=1.2, alpha=0.65, ax=ax, connectionstyle='arc3,rad=0.08')\n"
" el = {(u,v): d['kind'] for u,v,d in G.edges(data=True)}\n"
" nx.draw_networkx_edge_labels(G, pos, edge_labels=el, font_size=6.5, ax=ax,\n"
" bbox=dict(boxstyle='round,pad=0.1', fc='white', ec='none', alpha=0.85))\n"
" ax.set_title(f'{title}: {G.number_of_nodes()} ents, {G.number_of_edges()} rels', fontsize=11)\n"
" ax.axis('off')\n"
"\n"
"fig, axes = plt.subplots(1, 2, figsize=(20, 9))\n"
"draw_a(axes[0], ents_a, rels_a, 'A: GLiNER + GLiREL (t=0.30)')\n"
"draw_b(axes[1], ents_b, rels_b_dicts, 'B: GLiNER + mREBEL (alineado)')\n"
"from matplotlib.patches import Patch\n"
"legend = [Patch(facecolor=c, edgecolor='#333', label=t) for t, c in TYPE_COLOR.items()]\n"
"axes[0].legend(handles=legend, loc='upper left', frameon=True, fontsize=10)\n"
"plt.tight_layout(); plt.show()"
))
cells.append(_md(
"## 7. Lectura\n\n"
"**mREBEL gana en este texto.** Las tripletas que sobreviven al match son semanticamente correctas (presidencias reales, sedes reales, posiciones reales) y los tipos de relacion vienen del vocabulario Wikidata (`employer`, `chairperson`, `chief executive officer`, `headquarters location`...) — mas rico y mas semantico que las labels que pasamos a GLiREL.\n\n"
"GLiREL a `t=0.30` queda con 1 relacion (falsa). Subiendo a `t=0.15` produce 51 con mayoria espuria. **No hay sweet spot util.**\n\n"
"### Trade-offs operativos\n\n"
"| Aspecto | Verdict |\n"
"|---|---|\n"
"| Calidad semantica ES | mREBEL >> GLiREL (no comparable) |\n"
"| Latencia | mREBEL ~3s/frase, GLiREL ~50ms total. mREBEL es 50× mas lento, pero las relaciones son utiles. |\n"
"| Tamaño en disco | mREBEL 2.4 GB, GLiREL 1.5 GB |\n"
"| Vocabulario relaciones | mREBEL fijo (~400 Wikidata types). GLiREL libre. Para narrativa empresarial Wikidata cubre todo. |\n"
"| Licencia | mREBEL CC BY-NC-SA 4.0 (no comercial). GLiREL Apache 2.0. **Bloqueante si esto pasa a producto comercial.** |\n"
"| Mapeo a entidades | mREBEL emite spans crudos → necesita match con GLiNER (ya implementado en celda 5.4). GLiREL ya devuelve nombres. |\n\n"
"### Implicacion para el pipeline\n\n"
"1. **Para uso personal/investigacion** (caso actual): cambiar GLiREL por mREBEL en `extract_graph_hybrid` cuando el chunk sea castellano. Issue nuevo en `graph_explorer`: `0042-mrebel-relation-extractor.md`.\n"
"2. **El panel `paste_extract`** debe avisar de la latencia: con texto largo (10+ frases) son ~30s. UI: barra de progreso por frase.\n"
"3. **Para uso comercial** (futuro): no se puede usar mREBEL tal cual. Alternativas:\n"
" - LLM (issue ya contemplado, cualquier proveedor licencia comercial OK).\n"
" - Fine-tunear REBEL monolingue (Apache 2.0) en castellano si tienes datos.\n"
" - Buscar otro modelo abierto (REDFM tiene licencia distinta — comprobar).\n"
"4. **Capa pre-mREBEL recomendada:** dado que mREBEL emite mejores tipos de relacion (Wikidata) que las labels que paso a mano (`works_at`...), **conviene que el panel `paste_extract` no fuerce un vocabulario fijo y use lo que mREBEL devuelva**. La taxonomia del grafo se enriquece sola.\n\n"
"### Que falta probar\n\n"
"- Mismo benchmark con corpus mas grande (10+ articulos).\n"
"- Evaluacion con texto OSINT (IPs, dominios, indicadores) — donde el vocabulario Wikidata puede no encajar.\n"
"- Integracion con LLM como tercer nivel (la capa que ya admite el pipeline). Ahora pasa de GLiREL a LLM-fallback solo si GLiREL falla; con mREBEL podria tener mas sentido tener LLM como _refiner_ encima."
))
nb = nbf.v4.new_notebook()
nb.cells = cells
nb.metadata = {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python"},
}
NB_PATH.parent.mkdir(parents=True, exist_ok=True)
nbf.write(nb, NB_PATH)
print(f"[done] {NB_PATH} cells={len(cells)}")
if __name__ == "__main__":
build()