"""Extrae relaciones entre entidades usando GLiREL (zero-shot relation extraction).""" from __future__ import annotations import os import re import sys import warnings from typing import Any sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) from python.types.datascience.entity_candidate import EntityCandidate from python.types.datascience.relation_candidate import RelationCandidate _TOKEN_RE = re.compile(r"\S+") def _tokenize_with_offsets(text: str) -> list[tuple[str, int, int]]: """Tokeniza por whitespace y devuelve [(token, char_start, char_end)].""" return [(m.group(), m.start(), m.end()) for m in _TOKEN_RE.finditer(text)] def _char_span_to_token_span( char_start: int, char_end: int, tokens_with_offsets: list[tuple[str, int, int]], ) -> tuple[int, int] | None: """Mapea un span de caracteres a indices de tokens [start_token, end_token] inclusivos. Retorna None si no hay tokens que solapen con el span. """ start_idx: int | None = None end_idx: int | None = None for i, (_tok, ts, te) in enumerate(tokens_with_offsets): # Token solapa con [char_start, char_end) si su rango interseca. if ts < char_end and te > char_start: if start_idx is None: start_idx = i end_idx = i if start_idx is None or end_idx is None: return None return (start_idx, end_idx) def _resolve_entity_char_span( entity: EntityCandidate, text: str, ) -> tuple[int, int] | None: """Devuelve (start, end) para una entidad, usando attributes o fallback text.find.""" start = entity.attributes.get("start") if entity.attributes else None end = entity.attributes.get("end") if entity.attributes else None if isinstance(start, int) and isinstance(end, int) and 0 <= start < end <= len(text): return (start, end) # Fallback: buscar el primer match del nombre en el texto. if not entity.name: return None found = text.find(entity.name) if found < 0: warnings.warn( f"extract_relations_glirel: entidad '{entity.name}' sin offsets y no se " f"encuentra en text.find — descartando.", stacklevel=3, ) return None warnings.warn( f"extract_relations_glirel: entidad '{entity.name}' sin offsets en attributes; " f"usando text.find como fallback.", stacklevel=3, ) return (found, found + len(entity.name)) def extract_relations_glirel( text: str, entities: list[EntityCandidate], relation_types: list[str], model: Any, threshold: float = 0.5, max_pairs: int = 0, ) -> list[RelationCandidate]: """Extrae relaciones zero-shot con GLiREL, contrato drop-in con `extract_relations_llm`. GLiREL recibe tokens + spans de entidades en indices de tokens. Esta funcion se encarga de tokenizar el texto (whitespace), mapear los spans en caracteres de cada `EntityCandidate` (de `attributes['start'/'end']` o fallback con `text.find(name)`) y traducir el output a `RelationCandidate`. Args: text: Mismo chunk que se uso para extraer las entidades. entities: Entidades ya extraidas (de GLiNER, LLM o regex). Si tienen `attributes['start']` y `['end']` se usan; si no, fallback a `text.find(name)` con warning. relation_types: Tipos de relacion permitidos, ej: `["works_for", "owns"]`. model: Instancia GLiREL cargada con `glirel_load_model`. Inyectada por el caller para evitar penalty de carga en batch. threshold: Score minimo para aceptar una relacion (0.0-1.0). max_pairs: 0 = todas las relaciones encontradas; >0 = top N por score. Returns: Lista de RelationCandidate validados (from_name/to_name coinciden con entidades del input). Vacia si hay menos de 2 entidades, si el modelo no detecta nada, o si los relation_types o entidades quedan invalidos. Raises: ValueError: Si `relation_types` esta vacio. """ if not relation_types: raise ValueError("relation_types no puede estar vacio") if len(entities) < 2: return [] tokens_with_offsets = _tokenize_with_offsets(text) if not tokens_with_offsets: return [] tokens = [tok for tok, _s, _e in tokens_with_offsets] # Mapa token_start_idx -> EntityCandidate (para resolver outputs por posicion). token_start_to_entity: dict[int, EntityCandidate] = {} ner_spans: list[list] = [] entity_names_set = {e.name for e in entities if e.name} for ent in entities: char_span = _resolve_entity_char_span(ent, text) if char_span is None: continue token_span = _char_span_to_token_span(char_span[0], char_span[1], tokens_with_offsets) if token_span is None: continue start_tok, end_tok = token_span # GLiREL espera ner como [start_idx, end_idx, type_label] (token-level). ner_spans.append([start_tok, end_tok, ent.type_label or ent.type_ref or "Entity"]) # last-wins si dos entidades comparten token_start (poco probable). token_start_to_entity[start_tok] = ent if len(ner_spans) < 2: return [] try: raw = model.predict_relations( tokens, labels=list(relation_types), threshold=threshold, ner=ner_spans, top_k=1, ) except Exception as exc: warnings.warn( f"extract_relations_glirel: error invocando model.predict_relations: {exc}", stacklevel=2, ) return [] if not isinstance(raw, list): warnings.warn( "extract_relations_glirel: predict_relations no retorno una lista; " "retornando vacio.", stacklevel=2, ) return [] relation_types_set = set(relation_types) candidates: list[RelationCandidate] = [] for item in raw: if not isinstance(item, dict): continue relation_type = item.get("label", "") if relation_type not in relation_types_set: continue score = item.get("score", 0.0) if not isinstance(score, (int, float)): score = 0.0 confidence = float(max(0.0, min(1.0, score))) head_pos = item.get("head_pos") tail_pos = item.get("tail_pos") head_entity: EntityCandidate | None = None tail_entity: EntityCandidate | None = None if isinstance(head_pos, (list, tuple)) and head_pos: head_entity = token_start_to_entity.get(int(head_pos[0])) if isinstance(tail_pos, (list, tuple)) and tail_pos: tail_entity = token_start_to_entity.get(int(tail_pos[0])) # Fallback: matcheo por texto si el modelo no expone head_pos/tail_pos. if head_entity is None: head_text = _stringify_span(item.get("head_text")) if head_text in entity_names_set: head_entity = next((e for e in entities if e.name == head_text), None) if tail_entity is None: tail_text = _stringify_span(item.get("tail_text")) if tail_text in entity_names_set: tail_entity = next((e for e in entities if e.name == tail_text), None) if head_entity is None or tail_entity is None: continue if head_entity.name == tail_entity.name: continue candidates.append( RelationCandidate( from_name=head_entity.name, to_name=tail_entity.name, relation_type=relation_type, description="", confidence=confidence, ) ) if max_pairs > 0 and len(candidates) > max_pairs: candidates.sort(key=lambda r: r.confidence, reverse=True) candidates = candidates[:max_pairs] return candidates def _stringify_span(value: Any) -> str: """Convierte el head_text/tail_text de GLiREL (str o list[str]) a un string plano.""" if isinstance(value, str): return value if isinstance(value, (list, tuple)): return " ".join(str(v) for v in value) return ""