diff --git a/dev/issues/README.md b/dev/issues/README.md index 280b8418..0caf0c04 100644 --- a/dev/issues/README.md +++ b/dev/issues/README.md @@ -44,7 +44,7 @@ | [0036](0036-cpp-image-canvas-webcam.md) | C++ image_canvas + webcam_texture | pendiente | baja | feature | — | | [0037](completed/0037-ioc-regex-extractor.md) | IoC regex extractor (IP, email, dominio, hash, wallet, CVE, MAC) | completado | alta | feature | — | | [0038](completed/0038-gliner-entity-extractor.md) | GLiNER entity extractor (zero-shot NER multilingue) | completado | alta | feature | 0039, 0040 | -| [0039](0039-glirel-relation-extractor.md) | GLiREL relation extractor (zero-shot triplets) | pendiente | media | feature | 0040 | +| [0039](completed/0039-glirel-relation-extractor.md) | GLiREL relation extractor (zero-shot triplets) | completado | media | feature | 0040 | | [0040](0040-hybrid-extraction-pipeline.md) | Pipeline hibrido extraccion grafos (regex + GLiNER + GLiREL + LLM fallback) | pendiente | media | feature | — | | [0041](completed/0041-cpp-app-best-practices.md) | C++ app shell estandarizado (PATTERNS.md + AppConfig extendido) | completado | alta | feature | 0043 | | [0042](completed/0042-cpp-layout-storage-public.md) | C++ layout_storage publico (extraer de shaders_lab) | completado | alta | feature | 0043 | diff --git a/dev/issues/0039-glirel-relation-extractor.md b/dev/issues/completed/0039-glirel-relation-extractor.md similarity index 100% rename from dev/issues/0039-glirel-relation-extractor.md rename to dev/issues/completed/0039-glirel-relation-extractor.md diff --git a/python/functions/datascience/extract_relations_glirel.md b/python/functions/datascience/extract_relations_glirel.md new file mode 100644 index 00000000..242fc3de --- /dev/null +++ b/python/functions/datascience/extract_relations_glirel.md @@ -0,0 +1,131 @@ +--- +name: extract_relations_glirel +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def extract_relations_glirel(text: str, entities: list[EntityCandidate], relation_types: list[str], model: Any, threshold: float = 0.5, max_pairs: int = 0) -> list[RelationCandidate]" +description: "Extrae relaciones zero-shot con GLiREL. Drop-in del contrato de extract_relations_llm pero sin coste por token y mas rapido para corpus grandes. Tokeniza por whitespace, mapea spans de entidades (de attributes['start'/'end'] o fallback text.find) a indices de tokens, y devuelve RelationCandidate cuyos from_name/to_name siempre coinciden con entidades del input." +tags: [glirel, relation, nlp, extract, zero-shot, knowledge-graph, fuzzygraph, graph, datascience, python] +uses_functions: [glirel_load_model_py_datascience] +uses_types: + - entity_candidate_py_datascience + - relation_candidate_py_datascience +returns: + - relation_candidate_py_datascience +returns_optional: false +error_type: "error_go_core" +imports: [warnings, re] +params: + - name: text + desc: "mismo chunk de texto que se uso para extraer las entidades (parrafo, doc corto)" + - name: entities + desc: "lista de EntityCandidate ya extraidas (de extract_entities_gliner, extract_entities_llm o regex). Si tienen attributes['start'/'end'] se usan; si no, fallback a text.find(name) con warning." + - name: relation_types + desc: "tipos de relacion permitidos, ej: ['works_for','owns','communicated_with']. Vacio lanza ValueError." + - name: model + desc: "instancia GLiREL cargada con glirel_load_model. Inyectar para evitar penalty de carga en batch." + - name: threshold + desc: "score minimo para aceptar una relacion (0.0-1.0). Defecto 0.5." + - name: max_pairs + desc: "0 = todas las relaciones encontradas. >0 = top N por score (descarta el resto)." +output: "lista de RelationCandidate(from_name, to_name, relation_type, description='', confidence). from_name/to_name siempre coinciden con entidades del input." +tested: true +tests: + - "Schema basico y modelo stub retorna RelationCandidate triplets validos" + - "Threshold se propaga al modelo" + - "relation_types vacio lanza ValueError" + - "Menos de 2 entidades retorna vacio" + - "Entidad sin offsets usa fallback text.find con warning" + - "Entidad cuyo nombre no aparece en el texto se descarta" + - "Excepcion del modelo se captura y retorna vacio" + - "Relation_type fuera del set permitido se descarta" + - "max_pairs=N limita el output a top N por score" + - "head_pos/tail_pos resuelven entidades por posicion de token" + - "Fallback por head_text/tail_text si head_pos no esta presente" +test_file_path: "python/functions/datascience/tests/test_extract_relations_glirel.py" +file_path: "python/functions/datascience/extract_relations_glirel.py" +--- + +## Ejemplo + +```python +from python.functions.datascience import ( + glirel_load_model, + extract_relations_glirel, +) +from python.types.datascience.entity_candidate import EntityCandidate + +model = glirel_load_model(device="auto") + +text = "Alice Johnson works at OpenAI in San Francisco." +entities = [ + EntityCandidate(name="Alice Johnson", type_label="Person", + attributes={"start": 0, "end": 13}, confidence=0.92), + EntityCandidate(name="OpenAI", type_label="Organization", + attributes={"start": 23, "end": 29}, confidence=0.87), + EntityCandidate(name="San Francisco", type_label="Location", + attributes={"start": 33, "end": 46}, confidence=0.81), +] + +relations = extract_relations_glirel( + text=text, + entities=entities, + relation_types=["works_for", "located_in", "owns"], + model=model, + threshold=0.5, +) +# [RelationCandidate(from_name='Alice Johnson', to_name='OpenAI', +# relation_type='works_for', confidence=0.91), ...] +``` + +## Drop-in con extract_relations_llm + +El retorno es identico (`list[RelationCandidate]`) y `from_name`/`to_name` siempre +coinciden con entidades del input — `deduplicate_relations_py_datascience` lo +acepta sin cambios. Diferencias: + +- **Coste**: GLiREL = 0 USD/token. LLM = depende del modelo. +- **Latencia**: GLiREL es mucho mas rapido en GPU; en CPU depende del numero de + pares (entidades x relation_types). +- **Razonamiento implicito**: el LLM lo deduce ("CEO de la empresa" -> persona + works_for empresa); GLiREL solo extrae lo explicito en el texto. +- **Esquemas grandes**: GLiREL escala bien con muchos relation_types; el LLM + pierde foco con esquemas muy largos. +- **Idiomas**: GLiREL-large-v0 esta entrenado principalmente en ingles. Para ES + evaluar precision/recall caso a caso o caer al LLM. + +## Spans de entidades + +GLiREL necesita los spans (token indices) de cada entidad en el texto. Esta funcion: + +1. Lee `attributes["start"]` y `attributes["end"]` (offsets de caracteres) si + existen — el output natural de `extract_entities_gliner` y `extract_iocs`. +2. Si faltan, usa `text.find(entity.name)` como fallback (con warning). +3. Tokeniza por whitespace y mapea cada char span a un span de tokens + (`[start_token, end_token]`). +4. Pasa todo a `model.predict_relations(tokens, labels=..., ner=...)`. + +Si la entidad no se puede localizar en el texto, se descarta (no se le pueden +buscar relaciones sin saber donde esta). + +## Notas + +- impure: el modelo es estado externo. `error_type: error_go_core` segun la regla + de pureza del registry. +- Si dos entidades tienen el mismo nombre, GLiREL podria mezclarlas; el matcheo + por `head_pos`/`tail_pos` (token start) las distingue mejor que `head_text`. +- Una `relation_type` que no aparece en el output NO es un error — solo significa + que GLiREL no encontro evidencia. +- Combinar con LLM para razonamiento implicito: ver issue 0040 (pipeline hibrido). +- Para precision maxima, ajustar `threshold` por dominio: 0.3-0.4 = recall alto; + 0.6-0.8 = precision alta. + +## Limitacion + +GLiREL es bueno para relaciones explicitas en el texto (`X trabaja en Y`, +`A llamo a B`), malo para razonamiento implicito (`la nueva CEO`, `su empresa`). +Para razonamiento implicito seguir usando `extract_relations_llm`. El pipeline +hibrido (issue 0040) compone GLiREL para extraccion masiva + LLM para los casos +implicitos que GLiREL no cubre. diff --git a/python/functions/datascience/extract_relations_glirel.py b/python/functions/datascience/extract_relations_glirel.py new file mode 100644 index 00000000..59c3111d --- /dev/null +++ b/python/functions/datascience/extract_relations_glirel.py @@ -0,0 +1,227 @@ +"""Extrae relaciones entre entidades usando GLiREL (zero-shot relation extraction).""" + +from __future__ import annotations + +import os +import re +import sys +import warnings +from typing import Any + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) + +from python.types.datascience.entity_candidate import EntityCandidate +from python.types.datascience.relation_candidate import RelationCandidate + + +_TOKEN_RE = re.compile(r"\S+") + + +def _tokenize_with_offsets(text: str) -> list[tuple[str, int, int]]: + """Tokeniza por whitespace y devuelve [(token, char_start, char_end)].""" + return [(m.group(), m.start(), m.end()) for m in _TOKEN_RE.finditer(text)] + + +def _char_span_to_token_span( + char_start: int, + char_end: int, + tokens_with_offsets: list[tuple[str, int, int]], +) -> tuple[int, int] | None: + """Mapea un span de caracteres a indices de tokens [start_token, end_token] inclusivos. + + Retorna None si no hay tokens que solapen con el span. + """ + start_idx: int | None = None + end_idx: int | None = None + for i, (_tok, ts, te) in enumerate(tokens_with_offsets): + # Token solapa con [char_start, char_end) si su rango interseca. + if ts < char_end and te > char_start: + if start_idx is None: + start_idx = i + end_idx = i + if start_idx is None or end_idx is None: + return None + return (start_idx, end_idx) + + +def _resolve_entity_char_span( + entity: EntityCandidate, + text: str, +) -> tuple[int, int] | None: + """Devuelve (start, end) para una entidad, usando attributes o fallback text.find.""" + start = entity.attributes.get("start") if entity.attributes else None + end = entity.attributes.get("end") if entity.attributes else None + if isinstance(start, int) and isinstance(end, int) and 0 <= start < end <= len(text): + return (start, end) + + # Fallback: buscar el primer match del nombre en el texto. + if not entity.name: + return None + found = text.find(entity.name) + if found < 0: + warnings.warn( + f"extract_relations_glirel: entidad '{entity.name}' sin offsets y no se " + f"encuentra en text.find — descartando.", + stacklevel=3, + ) + return None + warnings.warn( + f"extract_relations_glirel: entidad '{entity.name}' sin offsets en attributes; " + f"usando text.find como fallback.", + stacklevel=3, + ) + return (found, found + len(entity.name)) + + +def extract_relations_glirel( + text: str, + entities: list[EntityCandidate], + relation_types: list[str], + model: Any, + threshold: float = 0.5, + max_pairs: int = 0, +) -> list[RelationCandidate]: + """Extrae relaciones zero-shot con GLiREL, contrato drop-in con `extract_relations_llm`. + + GLiREL recibe tokens + spans de entidades en indices de tokens. Esta funcion + se encarga de tokenizar el texto (whitespace), mapear los spans en caracteres + de cada `EntityCandidate` (de `attributes['start'/'end']` o fallback con + `text.find(name)`) y traducir el output a `RelationCandidate`. + + Args: + text: Mismo chunk que se uso para extraer las entidades. + entities: Entidades ya extraidas (de GLiNER, LLM o regex). Si tienen + `attributes['start']` y `['end']` se usan; si no, fallback a + `text.find(name)` con warning. + relation_types: Tipos de relacion permitidos, ej: `["works_for", "owns"]`. + model: Instancia GLiREL cargada con `glirel_load_model`. Inyectada por + el caller para evitar penalty de carga en batch. + threshold: Score minimo para aceptar una relacion (0.0-1.0). + max_pairs: 0 = todas las relaciones encontradas; >0 = top N por score. + + Returns: + Lista de RelationCandidate validados (from_name/to_name coinciden con + entidades del input). Vacia si hay menos de 2 entidades, si el modelo + no detecta nada, o si los relation_types o entidades quedan invalidos. + + Raises: + ValueError: Si `relation_types` esta vacio. + """ + if not relation_types: + raise ValueError("relation_types no puede estar vacio") + if len(entities) < 2: + return [] + + tokens_with_offsets = _tokenize_with_offsets(text) + if not tokens_with_offsets: + return [] + tokens = [tok for tok, _s, _e in tokens_with_offsets] + + # Mapa token_start_idx -> EntityCandidate (para resolver outputs por posicion). + token_start_to_entity: dict[int, EntityCandidate] = {} + ner_spans: list[list] = [] + entity_names_set = {e.name for e in entities if e.name} + + for ent in entities: + char_span = _resolve_entity_char_span(ent, text) + if char_span is None: + continue + token_span = _char_span_to_token_span(char_span[0], char_span[1], tokens_with_offsets) + if token_span is None: + continue + start_tok, end_tok = token_span + # GLiREL espera ner como [start_idx, end_idx, type_label] (token-level). + ner_spans.append([start_tok, end_tok, ent.type_label or ent.type_ref or "Entity"]) + # last-wins si dos entidades comparten token_start (poco probable). + token_start_to_entity[start_tok] = ent + + if len(ner_spans) < 2: + return [] + + try: + raw = model.predict_relations( + tokens, + labels=list(relation_types), + threshold=threshold, + ner=ner_spans, + top_k=1, + ) + except Exception as exc: + warnings.warn( + f"extract_relations_glirel: error invocando model.predict_relations: {exc}", + stacklevel=2, + ) + return [] + + if not isinstance(raw, list): + warnings.warn( + "extract_relations_glirel: predict_relations no retorno una lista; " + "retornando vacio.", + stacklevel=2, + ) + return [] + + relation_types_set = set(relation_types) + candidates: list[RelationCandidate] = [] + for item in raw: + if not isinstance(item, dict): + continue + + relation_type = item.get("label", "") + if relation_type not in relation_types_set: + continue + + score = item.get("score", 0.0) + if not isinstance(score, (int, float)): + score = 0.0 + confidence = float(max(0.0, min(1.0, score))) + + head_pos = item.get("head_pos") + tail_pos = item.get("tail_pos") + head_entity: EntityCandidate | None = None + tail_entity: EntityCandidate | None = None + + if isinstance(head_pos, (list, tuple)) and head_pos: + head_entity = token_start_to_entity.get(int(head_pos[0])) + if isinstance(tail_pos, (list, tuple)) and tail_pos: + tail_entity = token_start_to_entity.get(int(tail_pos[0])) + + # Fallback: matcheo por texto si el modelo no expone head_pos/tail_pos. + if head_entity is None: + head_text = _stringify_span(item.get("head_text")) + if head_text in entity_names_set: + head_entity = next((e for e in entities if e.name == head_text), None) + if tail_entity is None: + tail_text = _stringify_span(item.get("tail_text")) + if tail_text in entity_names_set: + tail_entity = next((e for e in entities if e.name == tail_text), None) + + if head_entity is None or tail_entity is None: + continue + if head_entity.name == tail_entity.name: + continue + + candidates.append( + RelationCandidate( + from_name=head_entity.name, + to_name=tail_entity.name, + relation_type=relation_type, + description="", + confidence=confidence, + ) + ) + + if max_pairs > 0 and len(candidates) > max_pairs: + candidates.sort(key=lambda r: r.confidence, reverse=True) + candidates = candidates[:max_pairs] + + return candidates + + +def _stringify_span(value: Any) -> str: + """Convierte el head_text/tail_text de GLiREL (str o list[str]) a un string plano.""" + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)): + return " ".join(str(v) for v in value) + return "" diff --git a/python/functions/datascience/glirel_load_model.md b/python/functions/datascience/glirel_load_model.md new file mode 100644 index 00000000..cf7d9c95 --- /dev/null +++ b/python/functions/datascience/glirel_load_model.md @@ -0,0 +1,72 @@ +--- +name: glirel_load_model +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def glirel_load_model(model_name: str = 'jackboyla/glirel-large-v0', device: str = 'auto') -> Any" +description: "Carga (y cachea por (model_name, device)) un modelo GLiREL zero-shot relation extraction. La primera llamada descarga ~500 MB desde HuggingFace; sucesivas devuelven la instancia cacheada. device='auto' usa CUDA si esta disponible, o CPU." +tags: [glirel, relation, nlp, model, huggingface, zero-shot, datascience, python] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: model_name + desc: "ID del modelo en HuggingFace Hub (defecto: jackboyla/glirel-large-v0)" + - name: device + desc: "'auto' (CUDA si disponible, sino CPU), 'cpu', 'cuda', 'cuda:N'" +output: "instancia GLiREL lista para predict_relations, cacheada por (model_name, device)" +tested: true +tests: + - "ImportError si glirel no esta instalado" + - "Cache devuelve la misma instancia con los mismos parametros" + - "device='auto' resuelve a cpu o cuda segun torch.cuda.is_available" +test_file_path: "python/functions/datascience/tests/test_extract_relations_glirel.py" +file_path: "python/functions/datascience/glirel_load_model.py" +--- + +## Ejemplo + +```python +from python.functions.datascience import glirel_load_model + +# Primera llamada descarga el modelo (~500 MB, una vez) +model = glirel_load_model(device="auto") + +# Llamadas sucesivas con mismos params devuelven el cache +model_again = glirel_load_model(device="auto") +assert model is model_again +``` + +## Instalacion + +GLiREL no esta en las dependencias principales del registry. Para usarlo: + +```bash +cd python && uv pip install glirel # solo glirel +cd python && uv pip install -e '.[nlp]' # extra completo (gliner + glirel) +``` + +## Tamaño y latencia + +- `jackboyla/glirel-large-v0`: ~500 MB en disco (modelo + tokenizer). +- Primera carga: 8-20 s en CPU, depende del disco y red. +- Inferencia CPU: depende del numero de pares entidad x relation_types. 5-20 pares/s + con esquema pequeño (5 relation types). +- Inferencia GPU (CUDA T4): 50-200x mas rapido que CPU. + +## Notas + +- El cache es por (model_name, device): cargar el mismo modelo en CPU y CUDA crea dos + instancias. Es intencional para permitir A/B. +- Si `torch` no esta instalado y `device='auto'`, cae a `'cpu'` sin error. +- Para limpiar el cache (memoria GPU): borrar entradas de `_MODEL_CACHE` directamente + o reiniciar el proceso. +- impure: lee disco/red la primera vez y mantiene estado en `_MODEL_CACHE`. +- GLiREL es bueno para relaciones explicitas en el texto (`X trabaja en Y`, `A llamo a B`), + malo para razonamiento implicito ("CEO de la empresa"). Para razonamiento implicito + seguir usando `extract_relations_llm`. diff --git a/python/functions/datascience/glirel_load_model.py b/python/functions/datascience/glirel_load_model.py new file mode 100644 index 00000000..8f83ae74 --- /dev/null +++ b/python/functions/datascience/glirel_load_model.py @@ -0,0 +1,63 @@ +"""Carga (y cachea) un modelo GLiREL en el device deseado.""" + +from __future__ import annotations + +from typing import Any + +# Cache global: (model_name, device) -> modelo cargado. +_MODEL_CACHE: dict[tuple[str, str], Any] = {} + + +def _resolve_device(device: str) -> str: + """Resuelve `device='auto'` a `cuda` o `cpu` segun disponibilidad.""" + if device != "auto": + return device + try: + import torch + except ImportError: + return "cpu" + return "cuda" if torch.cuda.is_available() else "cpu" + + +def glirel_load_model( + model_name: str = "jackboyla/glirel-large-v0", + device: str = "auto", +) -> Any: + """Carga un modelo GLiREL con cache por (model_name, device). + + La primera llamada descarga el modelo desde HuggingFace (~500 MB para + `glirel-large-v0`). Llamadas sucesivas con los mismos parametros + devuelven la instancia cacheada. + + Args: + model_name: ID del modelo en HuggingFace Hub. + device: 'auto' usa CUDA si esta disponible, o 'cpu'/'cuda'/'cuda:N' + de forma explicita. + + Returns: + Instancia del modelo GLiREL lista para `predict_relations`. + + Raises: + ImportError: si la dependencia `glirel` no esta instalada. + Solucion: `uv pip install glirel` o instalar el extra `nlp` + del proyecto (`uv pip install -e '.[nlp]'`). + """ + resolved_device = _resolve_device(device) + cache_key = (model_name, resolved_device) + cached = _MODEL_CACHE.get(cache_key) + if cached is not None: + return cached + + try: + from glirel import GLiREL + except ImportError as exc: + raise ImportError( + "glirel no esta instalado. Instalalo con " + "`uv pip install glirel` o `uv pip install -e '.[nlp]'`." + ) from exc + + model = GLiREL.from_pretrained(model_name) + if hasattr(model, "to"): + model.to(resolved_device) + _MODEL_CACHE[cache_key] = model + return model diff --git a/python/functions/datascience/tests/test_extract_relations_glirel.py b/python/functions/datascience/tests/test_extract_relations_glirel.py new file mode 100644 index 00000000..43a23b97 --- /dev/null +++ b/python/functions/datascience/tests/test_extract_relations_glirel.py @@ -0,0 +1,314 @@ +"""Tests para extract_relations_glirel y glirel_load_model. + +El modelo real (glirel) es opcional y pesa ~500 MB. Estos tests usan un stub +duck-typed para validar el contrato sin descargar el modelo. +""" + +from __future__ import annotations + +import os +import sys +from dataclasses import dataclass + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + +from python.functions.datascience.extract_relations_glirel import ( + _char_span_to_token_span, + _tokenize_with_offsets, + extract_relations_glirel, +) +from python.functions.datascience.glirel_load_model import ( + _MODEL_CACHE, + _resolve_device, + glirel_load_model, +) +from python.types.datascience.entity_candidate import EntityCandidate +from python.types.datascience.relation_candidate import RelationCandidate + + +def _ent(name: str, type_label: str, start: int, end: int) -> EntityCandidate: + return EntityCandidate( + name=name, + type_label=type_label, + type_ref=f"{type_label.lower()}_ref", + attributes={"start": start, "end": end}, + confidence=0.9, + ) + + +@dataclass +class StubModel: + """Modelo stub que devuelve una lista preconfigurada.""" + + response: list[dict] + raise_exc: Exception | None = None + last_kwargs: dict | None = None + + def predict_relations(self, tokens, labels, threshold, ner, top_k): + self.last_kwargs = { + "tokens": list(tokens), + "labels": list(labels), + "threshold": threshold, + "ner": [list(s) for s in ner], + "top_k": top_k, + } + if self.raise_exc is not None: + raise self.raise_exc + return self.response + + +# ---------- helpers ---------- + + +def test_tokenize_with_offsets_devuelve_indices_correctos(): + text = "Alice Johnson works at OpenAI." + out = _tokenize_with_offsets(text) + assert [t for t, _, _ in out] == ["Alice", "Johnson", "works", "at", "OpenAI."] + assert out[0][1:] == (0, 5) + assert out[1][1:] == (6, 13) + assert out[4][1:] == (23, 30) + + +def test_char_span_to_token_span_solapa_correctamente(): + tokens = _tokenize_with_offsets("Alice Johnson works at OpenAI.") + # "Alice Johnson" (0..13) -> tokens 0..1 + assert _char_span_to_token_span(0, 13, tokens) == (0, 1) + # "OpenAI" (23..29) -> token 4 + assert _char_span_to_token_span(23, 29, tokens) == (4, 4) + # span fuera del texto -> None + assert _char_span_to_token_span(100, 200, tokens) is None + + +# ---------- extract_relations_glirel ---------- + + +def test_schema_basico_y_modelo_stub_retorna_relation_candidate(): + text = "Alice Johnson works at OpenAI in San Francisco." + entities = [ + _ent("Alice Johnson", "Person", 0, 13), + _ent("OpenAI", "Organization", 23, 29), + _ent("San Francisco", "Location", 33, 46), + ] + relation_types = ["works_for", "located_in", "owns"] + + # Tokens: [Alice, Johnson, works, at, OpenAI, in, San, Francisco.] + # Alice Johnson -> tokens 0..1, OpenAI -> token 4, San Francisco. -> tokens 6..7 + model = StubModel(response=[ + {"head_pos": [0, 1], "tail_pos": [4, 4], + "head_text": ["Alice", "Johnson"], "tail_text": ["OpenAI"], + "label": "works_for", "score": 0.91}, + {"head_pos": [4, 4], "tail_pos": [6, 7], + "head_text": ["OpenAI"], "tail_text": ["San", "Francisco."], + "label": "located_in", "score": 0.78}, + ]) + + out = extract_relations_glirel(text, entities, relation_types, model) + assert len(out) == 2 + assert all(isinstance(r, RelationCandidate) for r in out) + + works = next(r for r in out if r.relation_type == "works_for") + assert works.from_name == "Alice Johnson" + assert works.to_name == "OpenAI" + assert pytest.approx(works.confidence, 0.001) == 0.91 + + located = next(r for r in out if r.relation_type == "located_in") + assert located.from_name == "OpenAI" + # San Francisco entity name vs token "San Francisco." (con punto pegado). + # Como matcheamos por head_pos/tail_pos (token start = 6), debe resolver a + # la entidad EntityCandidate("San Francisco", start=33). + assert located.to_name == "San Francisco" + + +def test_threshold_se_propaga_al_modelo(): + text = "Alice works at OpenAI." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + ] + model = StubModel(response=[]) + extract_relations_glirel(text, entities, ["works_for"], model, threshold=0.7) + assert model.last_kwargs["threshold"] == 0.7 + assert model.last_kwargs["labels"] == ["works_for"] + assert model.last_kwargs["top_k"] == 1 + + +def test_relation_types_vacio_lanza_value_error(): + entities = [_ent("Alice", "Person", 0, 5), _ent("Bob", "Person", 6, 9)] + with pytest.raises(ValueError): + extract_relations_glirel("Alice y Bob", entities, [], StubModel(response=[])) + + +def test_menos_de_dos_entidades_retorna_vacio(): + entities = [_ent("Alice", "Person", 0, 5)] + out = extract_relations_glirel("Alice", entities, ["works_for"], StubModel(response=[])) + assert out == [] + + +def test_entidad_sin_offsets_usa_fallback_text_find_con_warning(): + text = "Alice works at OpenAI." + entities = [ + EntityCandidate(name="Alice", type_label="Person", confidence=0.9), + EntityCandidate(name="OpenAI", type_label="Organization", confidence=0.9), + ] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [3, 3], + "head_text": ["Alice"], "tail_text": ["OpenAI."], + "label": "works_for", "score": 0.85}, + ]) + with pytest.warns(UserWarning, match="sin offsets"): + out = extract_relations_glirel(text, entities, ["works_for"], model) + assert len(out) == 1 + assert out[0].from_name == "Alice" + assert out[0].to_name == "OpenAI" + + +def test_entidad_no_encontrada_en_texto_se_descarta(): + text = "Alice y Bob hablan." + entities = [ + EntityCandidate(name="Alice", type_label="Person", confidence=0.9), + EntityCandidate(name="Carmen", type_label="Person", confidence=0.9), # no esta + EntityCandidate(name="Bob", type_label="Person", confidence=0.9), + ] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [2, 2], + "head_text": ["Alice"], "tail_text": ["Bob"], + "label": "communicated_with", "score": 0.8}, + ]) + with pytest.warns(UserWarning): + out = extract_relations_glirel(text, entities, ["communicated_with"], model) + # Carmen se descarta del input al construir ner_spans, pero los otros 2 quedan. + # GLiREL recibe solo 2 spans validos. + assert len(out) == 1 + assert out[0].from_name == "Alice" + assert out[0].to_name == "Bob" + + +def test_excepcion_del_modelo_se_captura(): + entities = [_ent("Alice", "Person", 0, 5), _ent("Bob", "Person", 8, 11)] + model = StubModel(response=[], raise_exc=RuntimeError("model exploded")) + with pytest.warns(UserWarning): + out = extract_relations_glirel("Alice y Bob.", entities, ["works_for"], model) + assert out == [] + + +def test_relation_type_fuera_del_set_se_descarta(): + text = "Alice works at OpenAI." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + ] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [3, 3], + "head_text": ["Alice"], "tail_text": ["OpenAI."], + "label": "unknown_relation", "score": 0.95}, + ]) + out = extract_relations_glirel(text, entities, ["works_for"], model) + assert out == [] + + +def test_max_pairs_limita_top_n(): + text = "Alice works at OpenAI in San Francisco." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + _ent("San Francisco", "Location", 25, 38), + ] + relation_types = ["works_for", "located_in", "lived_in"] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [3, 3], "label": "works_for", "score": 0.55, + "head_text": ["Alice"], "tail_text": ["OpenAI"]}, + {"head_pos": [3, 3], "tail_pos": [5, 6], "label": "located_in", "score": 0.92, + "head_text": ["OpenAI"], "tail_text": ["San", "Francisco."]}, + {"head_pos": [0, 0], "tail_pos": [5, 6], "label": "lived_in", "score": 0.71, + "head_text": ["Alice"], "tail_text": ["San", "Francisco."]}, + ]) + out = extract_relations_glirel(text, entities, relation_types, model, max_pairs=2) + assert len(out) == 2 + confidences = [r.confidence for r in out] + # Top 2 por score: 0.92 y 0.71 + assert confidences == sorted(confidences, reverse=True) + assert max(confidences) == pytest.approx(0.92, 0.001) + assert min(confidences) == pytest.approx(0.71, 0.001) + + +def test_fallback_por_head_text_si_head_pos_no_esta(): + text = "Alice works at OpenAI." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + ] + model = StubModel(response=[ + # Sin head_pos/tail_pos, fallback por texto. + {"head_text": "Alice", "tail_text": "OpenAI", + "label": "works_for", "score": 0.8}, + ]) + out = extract_relations_glirel(text, entities, ["works_for"], model) + assert len(out) == 1 + assert out[0].from_name == "Alice" + assert out[0].to_name == "OpenAI" + + +def test_self_loops_se_descartan(): + """head y tail apuntan a la misma entidad -> se descarta.""" + text = "Alice talks to Alice." + entities = [_ent("Alice", "Person", 0, 5), _ent("Alice", "Person", 15, 20)] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [0, 0], + "head_text": ["Alice"], "tail_text": ["Alice"], + "label": "communicated_with", "score": 0.9}, + ]) + out = extract_relations_glirel(text, entities, ["communicated_with"], model) + assert out == [] + + +# ---------- glirel_load_model ---------- + + +def test_import_error_si_glirel_no_esta_instalado(monkeypatch): + """ImportError si glirel no esta instalado.""" + _MODEL_CACHE.clear() + + real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def fake_import(name, *args, **kwargs): + if name == "glirel" or name.startswith("glirel."): + raise ImportError("glirel not installed (simulated)") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", fake_import) + + with pytest.raises(ImportError, match="glirel no esta instalado"): + glirel_load_model(model_name="dummy/model", device="cpu") + + +def test_cache_devuelve_la_misma_instancia(): + """Cache devuelve la misma instancia con los mismos parametros.""" + _MODEL_CACHE.clear() + sentinel = object() + _MODEL_CACHE[("dummy/model", "cpu")] = sentinel + + out = glirel_load_model(model_name="dummy/model", device="cpu") + assert out is sentinel + + _MODEL_CACHE.clear() + + +def test_resolve_device_explicito_se_respeta(): + assert _resolve_device("cpu") == "cpu" + assert _resolve_device("cuda") == "cuda" + assert _resolve_device("cuda:0") == "cuda:0" + + +def test_resolve_device_auto_cae_a_cpu_sin_torch(monkeypatch): + """device='auto' resuelve a cpu si torch no esta disponible.""" + real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def fake_import(name, *args, **kwargs): + if name == "torch": + raise ImportError("torch missing") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", fake_import) + assert _resolve_device("auto") == "cpu" diff --git a/python/pyproject.toml b/python/pyproject.toml index 63f4fad7..ec166b3c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ [project.optional-dependencies] nlp = [ "gliner>=0.2.13", + "glirel>=1.0.0", ] [dependency-groups]