Files
gliner_glirel_tuning/run_mrebel_test.py
T
2026-05-04 23:44:11 +02:00

155 lines
5.6 KiB
Python

"""Quick test of Babelscape/mREBEL on Spanish business text.
Compara directamente con GLiREL sobre el mismo texto. Si mREBEL produce
tripletas semanticamente correctas en castellano, lo proponemos como
sustituto/complemento de GLiREL en el pipeline `extract_graph_hybrid`.
Licencia mREBEL: CC BY-NC-SA 4.0 (no comercial). OK para uso personal/
investigacion; revisar si pasa a produccion comercial.
"""
from __future__ import annotations
import sys
import time
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")
# Same sys.path cleanup as the notebook (avoid bigquery/datasets.py shadow)
import os
_pf = "/home/lucas/fn_registry/python/functions"
sys.path = [p for p in sys.path if not p.startswith(_pf + "/")]
if _pf not in sys.path:
sys.path.insert(0, _pf)
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
TEXT_ES = (
"Pablo Isla, expresidente de Inditex, ha sido nombrado consejero de Telefonica. "
"La operacion fue anunciada por el presidente Jose Maria Alvarez-Pallete en Madrid el pasado lunes. "
"Inditex factura mas de 30.000 millones anuales y tiene su sede en Arteixo, A Coruna. "
"En paralelo, Iberdrola y Endesa firmaron un acuerdo de colaboracion en proyectos eolicos en Galicia. "
"El presidente de Iberdrola, Ignacio Galan, se reunio con la CEO de Endesa, Marina Serrano, en Bilbao. "
"El acuerdo movilizara 2.000 millones de euros en cinco anos. "
"El BBVA, presidido por Carlos Torres, mostro interes en participar en la financiacion del proyecto. "
"Su sede central esta en Bilbao."
)
def extract_triplets_typed(text: str) -> list[dict]:
"""Parse mREBEL output (decoded with skip_special_tokens=False) into triplets.
Format: <triplet> head <subj> head_type <rel> rel_type <obj> tail_type ...
Adapted from the README example.
"""
triplets = []
relation = ""
text = text.strip()
current = "x"
subject, relation, object_, object_type, subject_type = "", "", "", "", ""
for token in (
text.replace("<s>", "")
.replace("<pad>", "")
.replace("</s>", "")
.replace("tp_XX", "")
.replace("__en__", "")
.split()
):
if token == "<triplet>" or token == "<relation>":
current = "t"
if relation != "":
triplets.append(
{
"head": subject.strip(),
"head_type": subject_type,
"type": relation.strip(),
"tail": object_.strip(),
"tail_type": object_type,
}
)
relation = ""
subject = ""
elif token.startswith("<") and token.endswith(">"):
if current == "t" or current == "o":
current = "s"
if relation != "":
triplets.append(
{
"head": subject.strip(),
"head_type": subject_type,
"type": relation.strip(),
"tail": object_.strip(),
"tail_type": object_type,
}
)
object_ = ""
subject_type = token[1:-1]
else:
current = "o"
object_type = token[1:-1]
relation = ""
else:
if current == "t":
subject += " " + token
elif current == "s":
object_ += " " + token
elif current == "o":
relation += " " + token
if subject != "" and relation != "" and object_ != "" and object_type != "" and subject_type != "":
triplets.append(
{
"head": subject.strip(),
"head_type": subject_type,
"type": relation.strip(),
"tail": object_.strip(),
"tail_type": object_type,
}
)
return triplets
def main():
print("[load] mREBEL...", flush=True)
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(
"Babelscape/mrebel-large", src_lang="es_XX", tgt_lang="tp_XX"
)
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-large")
print(f"[load] mREBEL ready in {time.time()-t0:.1f}s")
print(f"\n[input ES] {len(TEXT_ES)} chars")
inputs = tokenizer(TEXT_ES, max_length=512, padding=True, truncation=True, return_tensors="pt")
print("[generate]")
t0 = time.time()
out = model.generate(
inputs["input_ids"].to(model.device),
attention_mask=inputs["attention_mask"].to(model.device),
decoder_start_token_id=tokenizer.convert_tokens_to_ids("tp_XX"),
max_length=512,
num_beams=4,
length_penalty=0.0,
)
print(f"[generate] {time.time()-t0:.1f}s")
decoded = tokenizer.batch_decode(out, skip_special_tokens=False)
print("\n=== RAW DECODED ===")
print(decoded[0][:2000])
print("\n=== TRIPLETS ===")
triplets = extract_triplets_typed(decoded[0])
print(f"n={len(triplets)}\n")
for t in triplets:
print(f" ({t['head']:32s} : {t['head_type']:15s}) --[{t['type']:25s}]--> ({t['tail']:32s} : {t['tail_type']:15s})")
# Save for the notebook
import json
out_path = Path(__file__).resolve().parent / "mrebel_results.json"
out_path.write_text(json.dumps({
"text": TEXT_ES,
"raw_decoded": decoded[0],
"triplets": triplets,
}, indent=2, ensure_ascii=False))
print(f"\n[saved] {out_path}")
if __name__ == "__main__":
main()