diff --git a/python/functions/datascience/tests/test_extract_relations_glirel.py b/python/functions/datascience/tests/test_extract_relations_glirel.py new file mode 100644 index 00000000..43a23b97 --- /dev/null +++ b/python/functions/datascience/tests/test_extract_relations_glirel.py @@ -0,0 +1,314 @@ +"""Tests para extract_relations_glirel y glirel_load_model. + +El modelo real (glirel) es opcional y pesa ~500 MB. Estos tests usan un stub +duck-typed para validar el contrato sin descargar el modelo. +""" + +from __future__ import annotations + +import os +import sys +from dataclasses import dataclass + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + +from python.functions.datascience.extract_relations_glirel import ( + _char_span_to_token_span, + _tokenize_with_offsets, + extract_relations_glirel, +) +from python.functions.datascience.glirel_load_model import ( + _MODEL_CACHE, + _resolve_device, + glirel_load_model, +) +from python.types.datascience.entity_candidate import EntityCandidate +from python.types.datascience.relation_candidate import RelationCandidate + + +def _ent(name: str, type_label: str, start: int, end: int) -> EntityCandidate: + return EntityCandidate( + name=name, + type_label=type_label, + type_ref=f"{type_label.lower()}_ref", + attributes={"start": start, "end": end}, + confidence=0.9, + ) + + +@dataclass +class StubModel: + """Modelo stub que devuelve una lista preconfigurada.""" + + response: list[dict] + raise_exc: Exception | None = None + last_kwargs: dict | None = None + + def predict_relations(self, tokens, labels, threshold, ner, top_k): + self.last_kwargs = { + "tokens": list(tokens), + "labels": list(labels), + "threshold": threshold, + "ner": [list(s) for s in ner], + "top_k": top_k, + } + if self.raise_exc is not None: + raise self.raise_exc + return self.response + + +# ---------- helpers ---------- + + +def test_tokenize_with_offsets_devuelve_indices_correctos(): + text = "Alice Johnson works at OpenAI." + out = _tokenize_with_offsets(text) + assert [t for t, _, _ in out] == ["Alice", "Johnson", "works", "at", "OpenAI."] + assert out[0][1:] == (0, 5) + assert out[1][1:] == (6, 13) + assert out[4][1:] == (23, 30) + + +def test_char_span_to_token_span_solapa_correctamente(): + tokens = _tokenize_with_offsets("Alice Johnson works at OpenAI.") + # "Alice Johnson" (0..13) -> tokens 0..1 + assert _char_span_to_token_span(0, 13, tokens) == (0, 1) + # "OpenAI" (23..29) -> token 4 + assert _char_span_to_token_span(23, 29, tokens) == (4, 4) + # span fuera del texto -> None + assert _char_span_to_token_span(100, 200, tokens) is None + + +# ---------- extract_relations_glirel ---------- + + +def test_schema_basico_y_modelo_stub_retorna_relation_candidate(): + text = "Alice Johnson works at OpenAI in San Francisco." + entities = [ + _ent("Alice Johnson", "Person", 0, 13), + _ent("OpenAI", "Organization", 23, 29), + _ent("San Francisco", "Location", 33, 46), + ] + relation_types = ["works_for", "located_in", "owns"] + + # Tokens: [Alice, Johnson, works, at, OpenAI, in, San, Francisco.] + # Alice Johnson -> tokens 0..1, OpenAI -> token 4, San Francisco. -> tokens 6..7 + model = StubModel(response=[ + {"head_pos": [0, 1], "tail_pos": [4, 4], + "head_text": ["Alice", "Johnson"], "tail_text": ["OpenAI"], + "label": "works_for", "score": 0.91}, + {"head_pos": [4, 4], "tail_pos": [6, 7], + "head_text": ["OpenAI"], "tail_text": ["San", "Francisco."], + "label": "located_in", "score": 0.78}, + ]) + + out = extract_relations_glirel(text, entities, relation_types, model) + assert len(out) == 2 + assert all(isinstance(r, RelationCandidate) for r in out) + + works = next(r for r in out if r.relation_type == "works_for") + assert works.from_name == "Alice Johnson" + assert works.to_name == "OpenAI" + assert pytest.approx(works.confidence, 0.001) == 0.91 + + located = next(r for r in out if r.relation_type == "located_in") + assert located.from_name == "OpenAI" + # San Francisco entity name vs token "San Francisco." (con punto pegado). + # Como matcheamos por head_pos/tail_pos (token start = 6), debe resolver a + # la entidad EntityCandidate("San Francisco", start=33). + assert located.to_name == "San Francisco" + + +def test_threshold_se_propaga_al_modelo(): + text = "Alice works at OpenAI." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + ] + model = StubModel(response=[]) + extract_relations_glirel(text, entities, ["works_for"], model, threshold=0.7) + assert model.last_kwargs["threshold"] == 0.7 + assert model.last_kwargs["labels"] == ["works_for"] + assert model.last_kwargs["top_k"] == 1 + + +def test_relation_types_vacio_lanza_value_error(): + entities = [_ent("Alice", "Person", 0, 5), _ent("Bob", "Person", 6, 9)] + with pytest.raises(ValueError): + extract_relations_glirel("Alice y Bob", entities, [], StubModel(response=[])) + + +def test_menos_de_dos_entidades_retorna_vacio(): + entities = [_ent("Alice", "Person", 0, 5)] + out = extract_relations_glirel("Alice", entities, ["works_for"], StubModel(response=[])) + assert out == [] + + +def test_entidad_sin_offsets_usa_fallback_text_find_con_warning(): + text = "Alice works at OpenAI." + entities = [ + EntityCandidate(name="Alice", type_label="Person", confidence=0.9), + EntityCandidate(name="OpenAI", type_label="Organization", confidence=0.9), + ] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [3, 3], + "head_text": ["Alice"], "tail_text": ["OpenAI."], + "label": "works_for", "score": 0.85}, + ]) + with pytest.warns(UserWarning, match="sin offsets"): + out = extract_relations_glirel(text, entities, ["works_for"], model) + assert len(out) == 1 + assert out[0].from_name == "Alice" + assert out[0].to_name == "OpenAI" + + +def test_entidad_no_encontrada_en_texto_se_descarta(): + text = "Alice y Bob hablan." + entities = [ + EntityCandidate(name="Alice", type_label="Person", confidence=0.9), + EntityCandidate(name="Carmen", type_label="Person", confidence=0.9), # no esta + EntityCandidate(name="Bob", type_label="Person", confidence=0.9), + ] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [2, 2], + "head_text": ["Alice"], "tail_text": ["Bob"], + "label": "communicated_with", "score": 0.8}, + ]) + with pytest.warns(UserWarning): + out = extract_relations_glirel(text, entities, ["communicated_with"], model) + # Carmen se descarta del input al construir ner_spans, pero los otros 2 quedan. + # GLiREL recibe solo 2 spans validos. + assert len(out) == 1 + assert out[0].from_name == "Alice" + assert out[0].to_name == "Bob" + + +def test_excepcion_del_modelo_se_captura(): + entities = [_ent("Alice", "Person", 0, 5), _ent("Bob", "Person", 8, 11)] + model = StubModel(response=[], raise_exc=RuntimeError("model exploded")) + with pytest.warns(UserWarning): + out = extract_relations_glirel("Alice y Bob.", entities, ["works_for"], model) + assert out == [] + + +def test_relation_type_fuera_del_set_se_descarta(): + text = "Alice works at OpenAI." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + ] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [3, 3], + "head_text": ["Alice"], "tail_text": ["OpenAI."], + "label": "unknown_relation", "score": 0.95}, + ]) + out = extract_relations_glirel(text, entities, ["works_for"], model) + assert out == [] + + +def test_max_pairs_limita_top_n(): + text = "Alice works at OpenAI in San Francisco." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + _ent("San Francisco", "Location", 25, 38), + ] + relation_types = ["works_for", "located_in", "lived_in"] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [3, 3], "label": "works_for", "score": 0.55, + "head_text": ["Alice"], "tail_text": ["OpenAI"]}, + {"head_pos": [3, 3], "tail_pos": [5, 6], "label": "located_in", "score": 0.92, + "head_text": ["OpenAI"], "tail_text": ["San", "Francisco."]}, + {"head_pos": [0, 0], "tail_pos": [5, 6], "label": "lived_in", "score": 0.71, + "head_text": ["Alice"], "tail_text": ["San", "Francisco."]}, + ]) + out = extract_relations_glirel(text, entities, relation_types, model, max_pairs=2) + assert len(out) == 2 + confidences = [r.confidence for r in out] + # Top 2 por score: 0.92 y 0.71 + assert confidences == sorted(confidences, reverse=True) + assert max(confidences) == pytest.approx(0.92, 0.001) + assert min(confidences) == pytest.approx(0.71, 0.001) + + +def test_fallback_por_head_text_si_head_pos_no_esta(): + text = "Alice works at OpenAI." + entities = [ + _ent("Alice", "Person", 0, 5), + _ent("OpenAI", "Organization", 15, 21), + ] + model = StubModel(response=[ + # Sin head_pos/tail_pos, fallback por texto. + {"head_text": "Alice", "tail_text": "OpenAI", + "label": "works_for", "score": 0.8}, + ]) + out = extract_relations_glirel(text, entities, ["works_for"], model) + assert len(out) == 1 + assert out[0].from_name == "Alice" + assert out[0].to_name == "OpenAI" + + +def test_self_loops_se_descartan(): + """head y tail apuntan a la misma entidad -> se descarta.""" + text = "Alice talks to Alice." + entities = [_ent("Alice", "Person", 0, 5), _ent("Alice", "Person", 15, 20)] + model = StubModel(response=[ + {"head_pos": [0, 0], "tail_pos": [0, 0], + "head_text": ["Alice"], "tail_text": ["Alice"], + "label": "communicated_with", "score": 0.9}, + ]) + out = extract_relations_glirel(text, entities, ["communicated_with"], model) + assert out == [] + + +# ---------- glirel_load_model ---------- + + +def test_import_error_si_glirel_no_esta_instalado(monkeypatch): + """ImportError si glirel no esta instalado.""" + _MODEL_CACHE.clear() + + real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def fake_import(name, *args, **kwargs): + if name == "glirel" or name.startswith("glirel."): + raise ImportError("glirel not installed (simulated)") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", fake_import) + + with pytest.raises(ImportError, match="glirel no esta instalado"): + glirel_load_model(model_name="dummy/model", device="cpu") + + +def test_cache_devuelve_la_misma_instancia(): + """Cache devuelve la misma instancia con los mismos parametros.""" + _MODEL_CACHE.clear() + sentinel = object() + _MODEL_CACHE[("dummy/model", "cpu")] = sentinel + + out = glirel_load_model(model_name="dummy/model", device="cpu") + assert out is sentinel + + _MODEL_CACHE.clear() + + +def test_resolve_device_explicito_se_respeta(): + assert _resolve_device("cpu") == "cpu" + assert _resolve_device("cuda") == "cuda" + assert _resolve_device("cuda:0") == "cuda:0" + + +def test_resolve_device_auto_cae_a_cpu_sin_torch(monkeypatch): + """device='auto' resuelve a cpu si torch no esta disponible.""" + real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def fake_import(name, *args, **kwargs): + if name == "torch": + raise ImportError("torch missing") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", fake_import) + assert _resolve_device("auto") == "cpu"