"""Extrae relaciones entre entidades usando mREBEL (seq2seq multilingue).""" from __future__ import annotations import os import re import sys from typing import Any sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) from python.types.datascience.entity_candidate import EntityCandidate from python.types.datascience.relation_candidate import RelationCandidate from python.functions.datascience.parse_rebel_output import parse_rebel_output from python.functions.datascience.align_relations_to_entities import align_relations_to_entities def extract_relations_mrebel( text: str, entities: list[EntityCandidate], tokenizer: Any, model: Any, src_lang: str = "es_XX", sentence_split_re: str = r"(?<=[.!?])\s+", min_sentence_chars: int = 20, num_beams: int = 4, max_length: int = 256, ) -> list[RelationCandidate]: """Extract relations from text using mREBEL, sentence by sentence. Orchestrates the full pipeline: 1. Split ``text`` into sentences using ``sentence_split_re``. 2. Filter out sentences shorter than ``min_sentence_chars``. 3. For each sentence: tokenize → generate → decode (with special tokens) → ``parse_rebel_output`` → accumulate raw triplets. 4. Collect all entity names from ``entities``, sorted DESC by length (so longer names win in substring matching). 5. Call ``align_relations_to_entities`` to resolve head/tail spans to canonical entity names and drop unresolved / self-loop triplets. 6. Wrap each aligned triplet in a ``RelationCandidate``. mREBEL does not produce a continuous confidence score — ``confidence`` is set to ``1.0`` as a marker meaning "the model emitted this triplet". Args: text: Source text (same language as ``src_lang``). entities: Entities already extracted from this text (e.g. via ``extract_entities_gliner``). Used to filter triplets to known entities only. tokenizer: mREBEL tokenizer loaded with ``mrebel_load_model``. model: mREBEL model loaded with ``mrebel_load_model``. src_lang: Informational — the language the tokenizer was loaded with. Not used at inference time (mBART lang tokens are set at load time). sentence_split_re: Regex pattern for sentence splitting. Default splits on whitespace that follows ``.``, ``!`` or ``?``. min_sentence_chars: Minimum character length for a sentence to be processed. Shorter fragments are skipped. num_beams: Beam search width for ``model.generate``. Default 4. max_length: Max token length for both tokenization and generation. Returns: List of ``RelationCandidate`` where ``from_name`` and ``to_name`` always correspond to names in ``entities``. Empty list if no aligned triplets are found or ``entities`` has fewer than 2 items. """ if len(entities) < 2: return [] if not text or not text.strip(): return [] split_re = re.compile(sentence_split_re) sentences = split_re.split(text.strip()) sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) >= min_sentence_chars] if not sentences: return [] # Step 1-3: gather raw triplets from all sentences. raw_triplets: list[dict] = [] for idx, sentence in enumerate(sentences): try: inputs = tokenizer( sentence, return_tensors="pt", max_length=max_length, truncation=True, ) generated = model.generate( **inputs, num_beams=num_beams, length_penalty=1.0, max_length=max_length, ) decoded = tokenizer.decode(generated[0], skip_special_tokens=False) except Exception: # Skip sentences that fail (e.g. tokenizer errors on special chars). continue sentence_triplets = parse_rebel_output(decoded) # Tag each triplet with the sentence index for source_chunk_index. for t in sentence_triplets: t["_sentence_idx"] = idx raw_triplets.extend(sentence_triplets) if not raw_triplets: return [] # Step 4-5: align to entity names (sorted DESC by length for substring match). entity_names = sorted([e.name for e in entities if e.name], key=len, reverse=True) aligned = align_relations_to_entities(raw_triplets, entity_names) # Step 6: wrap in RelationCandidate. candidates: list[RelationCandidate] = [] for item in aligned: # Recover sentence_idx from raw triplet — find matching raw by head/tail/type. sentence_idx = -1 for raw in raw_triplets: if ( raw.get("head", "").strip() and raw.get("type", "").strip() == item["kind"] ): sentence_idx = raw.get("_sentence_idx", -1) break candidates.append( RelationCandidate( from_name=item["from"], to_name=item["to"], relation_type=item["kind"], description="", confidence=1.0, source_chunk_index=sentence_idx, ) ) return candidates