fa5bcca155
- glirel_load_model: cache por (model_name, device); device='auto' resuelve via torch - extract_relations_glirel: tokeniza por whitespace, mapea spans char->token, llama predict_relations y devuelve RelationCandidate; fallback text.find si la entidad llega sin offsets; max_pairs=N -> top-N por score - pyproject.toml: glirel en extra nlp Closes #0039 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
228 lines
8.0 KiB
Python
228 lines
8.0 KiB
Python
"""Extrae relaciones entre entidades usando GLiREL (zero-shot relation extraction)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import warnings
|
|
from typing import Any
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
|
|
from python.types.datascience.entity_candidate import EntityCandidate
|
|
from python.types.datascience.relation_candidate import RelationCandidate
|
|
|
|
|
|
_TOKEN_RE = re.compile(r"\S+")
|
|
|
|
|
|
def _tokenize_with_offsets(text: str) -> list[tuple[str, int, int]]:
|
|
"""Tokeniza por whitespace y devuelve [(token, char_start, char_end)]."""
|
|
return [(m.group(), m.start(), m.end()) for m in _TOKEN_RE.finditer(text)]
|
|
|
|
|
|
def _char_span_to_token_span(
|
|
char_start: int,
|
|
char_end: int,
|
|
tokens_with_offsets: list[tuple[str, int, int]],
|
|
) -> tuple[int, int] | None:
|
|
"""Mapea un span de caracteres a indices de tokens [start_token, end_token] inclusivos.
|
|
|
|
Retorna None si no hay tokens que solapen con el span.
|
|
"""
|
|
start_idx: int | None = None
|
|
end_idx: int | None = None
|
|
for i, (_tok, ts, te) in enumerate(tokens_with_offsets):
|
|
# Token solapa con [char_start, char_end) si su rango interseca.
|
|
if ts < char_end and te > char_start:
|
|
if start_idx is None:
|
|
start_idx = i
|
|
end_idx = i
|
|
if start_idx is None or end_idx is None:
|
|
return None
|
|
return (start_idx, end_idx)
|
|
|
|
|
|
def _resolve_entity_char_span(
|
|
entity: EntityCandidate,
|
|
text: str,
|
|
) -> tuple[int, int] | None:
|
|
"""Devuelve (start, end) para una entidad, usando attributes o fallback text.find."""
|
|
start = entity.attributes.get("start") if entity.attributes else None
|
|
end = entity.attributes.get("end") if entity.attributes else None
|
|
if isinstance(start, int) and isinstance(end, int) and 0 <= start < end <= len(text):
|
|
return (start, end)
|
|
|
|
# Fallback: buscar el primer match del nombre en el texto.
|
|
if not entity.name:
|
|
return None
|
|
found = text.find(entity.name)
|
|
if found < 0:
|
|
warnings.warn(
|
|
f"extract_relations_glirel: entidad '{entity.name}' sin offsets y no se "
|
|
f"encuentra en text.find — descartando.",
|
|
stacklevel=3,
|
|
)
|
|
return None
|
|
warnings.warn(
|
|
f"extract_relations_glirel: entidad '{entity.name}' sin offsets en attributes; "
|
|
f"usando text.find como fallback.",
|
|
stacklevel=3,
|
|
)
|
|
return (found, found + len(entity.name))
|
|
|
|
|
|
def extract_relations_glirel(
|
|
text: str,
|
|
entities: list[EntityCandidate],
|
|
relation_types: list[str],
|
|
model: Any,
|
|
threshold: float = 0.5,
|
|
max_pairs: int = 0,
|
|
) -> list[RelationCandidate]:
|
|
"""Extrae relaciones zero-shot con GLiREL, contrato drop-in con `extract_relations_llm`.
|
|
|
|
GLiREL recibe tokens + spans de entidades en indices de tokens. Esta funcion
|
|
se encarga de tokenizar el texto (whitespace), mapear los spans en caracteres
|
|
de cada `EntityCandidate` (de `attributes['start'/'end']` o fallback con
|
|
`text.find(name)`) y traducir el output a `RelationCandidate`.
|
|
|
|
Args:
|
|
text: Mismo chunk que se uso para extraer las entidades.
|
|
entities: Entidades ya extraidas (de GLiNER, LLM o regex). Si tienen
|
|
`attributes['start']` y `['end']` se usan; si no, fallback a
|
|
`text.find(name)` con warning.
|
|
relation_types: Tipos de relacion permitidos, ej: `["works_for", "owns"]`.
|
|
model: Instancia GLiREL cargada con `glirel_load_model`. Inyectada por
|
|
el caller para evitar penalty de carga en batch.
|
|
threshold: Score minimo para aceptar una relacion (0.0-1.0).
|
|
max_pairs: 0 = todas las relaciones encontradas; >0 = top N por score.
|
|
|
|
Returns:
|
|
Lista de RelationCandidate validados (from_name/to_name coinciden con
|
|
entidades del input). Vacia si hay menos de 2 entidades, si el modelo
|
|
no detecta nada, o si los relation_types o entidades quedan invalidos.
|
|
|
|
Raises:
|
|
ValueError: Si `relation_types` esta vacio.
|
|
"""
|
|
if not relation_types:
|
|
raise ValueError("relation_types no puede estar vacio")
|
|
if len(entities) < 2:
|
|
return []
|
|
|
|
tokens_with_offsets = _tokenize_with_offsets(text)
|
|
if not tokens_with_offsets:
|
|
return []
|
|
tokens = [tok for tok, _s, _e in tokens_with_offsets]
|
|
|
|
# Mapa token_start_idx -> EntityCandidate (para resolver outputs por posicion).
|
|
token_start_to_entity: dict[int, EntityCandidate] = {}
|
|
ner_spans: list[list] = []
|
|
entity_names_set = {e.name for e in entities if e.name}
|
|
|
|
for ent in entities:
|
|
char_span = _resolve_entity_char_span(ent, text)
|
|
if char_span is None:
|
|
continue
|
|
token_span = _char_span_to_token_span(char_span[0], char_span[1], tokens_with_offsets)
|
|
if token_span is None:
|
|
continue
|
|
start_tok, end_tok = token_span
|
|
# GLiREL espera ner como [start_idx, end_idx, type_label] (token-level).
|
|
ner_spans.append([start_tok, end_tok, ent.type_label or ent.type_ref or "Entity"])
|
|
# last-wins si dos entidades comparten token_start (poco probable).
|
|
token_start_to_entity[start_tok] = ent
|
|
|
|
if len(ner_spans) < 2:
|
|
return []
|
|
|
|
try:
|
|
raw = model.predict_relations(
|
|
tokens,
|
|
labels=list(relation_types),
|
|
threshold=threshold,
|
|
ner=ner_spans,
|
|
top_k=1,
|
|
)
|
|
except Exception as exc:
|
|
warnings.warn(
|
|
f"extract_relations_glirel: error invocando model.predict_relations: {exc}",
|
|
stacklevel=2,
|
|
)
|
|
return []
|
|
|
|
if not isinstance(raw, list):
|
|
warnings.warn(
|
|
"extract_relations_glirel: predict_relations no retorno una lista; "
|
|
"retornando vacio.",
|
|
stacklevel=2,
|
|
)
|
|
return []
|
|
|
|
relation_types_set = set(relation_types)
|
|
candidates: list[RelationCandidate] = []
|
|
for item in raw:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
|
|
relation_type = item.get("label", "")
|
|
if relation_type not in relation_types_set:
|
|
continue
|
|
|
|
score = item.get("score", 0.0)
|
|
if not isinstance(score, (int, float)):
|
|
score = 0.0
|
|
confidence = float(max(0.0, min(1.0, score)))
|
|
|
|
head_pos = item.get("head_pos")
|
|
tail_pos = item.get("tail_pos")
|
|
head_entity: EntityCandidate | None = None
|
|
tail_entity: EntityCandidate | None = None
|
|
|
|
if isinstance(head_pos, (list, tuple)) and head_pos:
|
|
head_entity = token_start_to_entity.get(int(head_pos[0]))
|
|
if isinstance(tail_pos, (list, tuple)) and tail_pos:
|
|
tail_entity = token_start_to_entity.get(int(tail_pos[0]))
|
|
|
|
# Fallback: matcheo por texto si el modelo no expone head_pos/tail_pos.
|
|
if head_entity is None:
|
|
head_text = _stringify_span(item.get("head_text"))
|
|
if head_text in entity_names_set:
|
|
head_entity = next((e for e in entities if e.name == head_text), None)
|
|
if tail_entity is None:
|
|
tail_text = _stringify_span(item.get("tail_text"))
|
|
if tail_text in entity_names_set:
|
|
tail_entity = next((e for e in entities if e.name == tail_text), None)
|
|
|
|
if head_entity is None or tail_entity is None:
|
|
continue
|
|
if head_entity.name == tail_entity.name:
|
|
continue
|
|
|
|
candidates.append(
|
|
RelationCandidate(
|
|
from_name=head_entity.name,
|
|
to_name=tail_entity.name,
|
|
relation_type=relation_type,
|
|
description="",
|
|
confidence=confidence,
|
|
)
|
|
)
|
|
|
|
if max_pairs > 0 and len(candidates) > max_pairs:
|
|
candidates.sort(key=lambda r: r.confidence, reverse=True)
|
|
candidates = candidates[:max_pairs]
|
|
|
|
return candidates
|
|
|
|
|
|
def _stringify_span(value: Any) -> str:
|
|
"""Convierte el head_text/tail_text de GLiREL (str o list[str]) a un string plano."""
|
|
if isinstance(value, str):
|
|
return value
|
|
if isinstance(value, (list, tuple)):
|
|
return " ".join(str(v) for v in value)
|
|
return ""
|