Files
fn_registry/python/functions/datascience/extract_relations_glirel.py
egutierrez fa5bcca155 feat(datascience): GLiREL relation extractor (zero-shot triplets) drop-in con LLM
- glirel_load_model: cache por (model_name, device); device='auto' resuelve via torch
- extract_relations_glirel: tokeniza por whitespace, mapea spans char->token,
  llama predict_relations y devuelve RelationCandidate; fallback text.find si la
  entidad llega sin offsets; max_pairs=N -> top-N por score
- pyproject.toml: glirel en extra nlp

Closes #0039

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 16:41:09 +02:00

228 lines
8.0 KiB
Python

"""Extrae relaciones entre entidades usando GLiREL (zero-shot relation extraction)."""
from __future__ import annotations
import os
import re
import sys
import warnings
from typing import Any
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
from python.types.datascience.entity_candidate import EntityCandidate
from python.types.datascience.relation_candidate import RelationCandidate
_TOKEN_RE = re.compile(r"\S+")
def _tokenize_with_offsets(text: str) -> list[tuple[str, int, int]]:
"""Tokeniza por whitespace y devuelve [(token, char_start, char_end)]."""
return [(m.group(), m.start(), m.end()) for m in _TOKEN_RE.finditer(text)]
def _char_span_to_token_span(
char_start: int,
char_end: int,
tokens_with_offsets: list[tuple[str, int, int]],
) -> tuple[int, int] | None:
"""Mapea un span de caracteres a indices de tokens [start_token, end_token] inclusivos.
Retorna None si no hay tokens que solapen con el span.
"""
start_idx: int | None = None
end_idx: int | None = None
for i, (_tok, ts, te) in enumerate(tokens_with_offsets):
# Token solapa con [char_start, char_end) si su rango interseca.
if ts < char_end and te > char_start:
if start_idx is None:
start_idx = i
end_idx = i
if start_idx is None or end_idx is None:
return None
return (start_idx, end_idx)
def _resolve_entity_char_span(
entity: EntityCandidate,
text: str,
) -> tuple[int, int] | None:
"""Devuelve (start, end) para una entidad, usando attributes o fallback text.find."""
start = entity.attributes.get("start") if entity.attributes else None
end = entity.attributes.get("end") if entity.attributes else None
if isinstance(start, int) and isinstance(end, int) and 0 <= start < end <= len(text):
return (start, end)
# Fallback: buscar el primer match del nombre en el texto.
if not entity.name:
return None
found = text.find(entity.name)
if found < 0:
warnings.warn(
f"extract_relations_glirel: entidad '{entity.name}' sin offsets y no se "
f"encuentra en text.find — descartando.",
stacklevel=3,
)
return None
warnings.warn(
f"extract_relations_glirel: entidad '{entity.name}' sin offsets en attributes; "
f"usando text.find como fallback.",
stacklevel=3,
)
return (found, found + len(entity.name))
def extract_relations_glirel(
text: str,
entities: list[EntityCandidate],
relation_types: list[str],
model: Any,
threshold: float = 0.5,
max_pairs: int = 0,
) -> list[RelationCandidate]:
"""Extrae relaciones zero-shot con GLiREL, contrato drop-in con `extract_relations_llm`.
GLiREL recibe tokens + spans de entidades en indices de tokens. Esta funcion
se encarga de tokenizar el texto (whitespace), mapear los spans en caracteres
de cada `EntityCandidate` (de `attributes['start'/'end']` o fallback con
`text.find(name)`) y traducir el output a `RelationCandidate`.
Args:
text: Mismo chunk que se uso para extraer las entidades.
entities: Entidades ya extraidas (de GLiNER, LLM o regex). Si tienen
`attributes['start']` y `['end']` se usan; si no, fallback a
`text.find(name)` con warning.
relation_types: Tipos de relacion permitidos, ej: `["works_for", "owns"]`.
model: Instancia GLiREL cargada con `glirel_load_model`. Inyectada por
el caller para evitar penalty de carga en batch.
threshold: Score minimo para aceptar una relacion (0.0-1.0).
max_pairs: 0 = todas las relaciones encontradas; >0 = top N por score.
Returns:
Lista de RelationCandidate validados (from_name/to_name coinciden con
entidades del input). Vacia si hay menos de 2 entidades, si el modelo
no detecta nada, o si los relation_types o entidades quedan invalidos.
Raises:
ValueError: Si `relation_types` esta vacio.
"""
if not relation_types:
raise ValueError("relation_types no puede estar vacio")
if len(entities) < 2:
return []
tokens_with_offsets = _tokenize_with_offsets(text)
if not tokens_with_offsets:
return []
tokens = [tok for tok, _s, _e in tokens_with_offsets]
# Mapa token_start_idx -> EntityCandidate (para resolver outputs por posicion).
token_start_to_entity: dict[int, EntityCandidate] = {}
ner_spans: list[list] = []
entity_names_set = {e.name for e in entities if e.name}
for ent in entities:
char_span = _resolve_entity_char_span(ent, text)
if char_span is None:
continue
token_span = _char_span_to_token_span(char_span[0], char_span[1], tokens_with_offsets)
if token_span is None:
continue
start_tok, end_tok = token_span
# GLiREL espera ner como [start_idx, end_idx, type_label] (token-level).
ner_spans.append([start_tok, end_tok, ent.type_label or ent.type_ref or "Entity"])
# last-wins si dos entidades comparten token_start (poco probable).
token_start_to_entity[start_tok] = ent
if len(ner_spans) < 2:
return []
try:
raw = model.predict_relations(
tokens,
labels=list(relation_types),
threshold=threshold,
ner=ner_spans,
top_k=1,
)
except Exception as exc:
warnings.warn(
f"extract_relations_glirel: error invocando model.predict_relations: {exc}",
stacklevel=2,
)
return []
if not isinstance(raw, list):
warnings.warn(
"extract_relations_glirel: predict_relations no retorno una lista; "
"retornando vacio.",
stacklevel=2,
)
return []
relation_types_set = set(relation_types)
candidates: list[RelationCandidate] = []
for item in raw:
if not isinstance(item, dict):
continue
relation_type = item.get("label", "")
if relation_type not in relation_types_set:
continue
score = item.get("score", 0.0)
if not isinstance(score, (int, float)):
score = 0.0
confidence = float(max(0.0, min(1.0, score)))
head_pos = item.get("head_pos")
tail_pos = item.get("tail_pos")
head_entity: EntityCandidate | None = None
tail_entity: EntityCandidate | None = None
if isinstance(head_pos, (list, tuple)) and head_pos:
head_entity = token_start_to_entity.get(int(head_pos[0]))
if isinstance(tail_pos, (list, tuple)) and tail_pos:
tail_entity = token_start_to_entity.get(int(tail_pos[0]))
# Fallback: matcheo por texto si el modelo no expone head_pos/tail_pos.
if head_entity is None:
head_text = _stringify_span(item.get("head_text"))
if head_text in entity_names_set:
head_entity = next((e for e in entities if e.name == head_text), None)
if tail_entity is None:
tail_text = _stringify_span(item.get("tail_text"))
if tail_text in entity_names_set:
tail_entity = next((e for e in entities if e.name == tail_text), None)
if head_entity is None or tail_entity is None:
continue
if head_entity.name == tail_entity.name:
continue
candidates.append(
RelationCandidate(
from_name=head_entity.name,
to_name=tail_entity.name,
relation_type=relation_type,
description="",
confidence=confidence,
)
)
if max_pairs > 0 and len(candidates) > max_pairs:
candidates.sort(key=lambda r: r.confidence, reverse=True)
candidates = candidates[:max_pairs]
return candidates
def _stringify_span(value: Any) -> str:
"""Convierte el head_text/tail_text de GLiREL (str o list[str]) a un string plano."""
if isinstance(value, str):
return value
if isinstance(value, (list, tuple)):
return " ".join(str(v) for v in value)
return ""