test(datascience): corpus stub para gliner_load_model + extract_entities_gliner
11 tests sin necesidad de descargar el modelo (200 MB): - StubModel duck-typed que valida el contrato de predict_entities - Threshold y flat_ner se propagan al modelo - Schema vacio lanza ValueError; schema sin labels validos warning + [] - Excepcion del modelo se captura - Label desconocido se descarta - gliner_load_model: ImportError simulado, cache hit, _resolve_device auto cae a cpu si torch no esta presente Refs #0038 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
"""Tests para extract_entities_gliner y gliner_load_model.
|
||||
|
||||
El modelo real (gliner) es opcional. Estos tests usan un stub duck-typed
|
||||
para validar el contrato sin descargar 200 MB. Tests que requieran el
|
||||
modelo real se marcan con `pytest.importorskip('gliner')`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
|
||||
from python.functions.datascience.extract_entities_gliner import (
|
||||
extract_entities_gliner,
|
||||
)
|
||||
from python.functions.datascience.gliner_load_model import (
|
||||
_MODEL_CACHE,
|
||||
_resolve_device,
|
||||
gliner_load_model,
|
||||
)
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
|
||||
|
||||
SCHEMA_BASIC = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_organization_go_cybersecurity",
|
||||
"label": "Organization",
|
||||
"metadata_fields": ["name"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_location_go_cybersecurity",
|
||||
"label": "Location",
|
||||
"metadata_fields": ["name"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubModel:
|
||||
"""Modelo stub que devuelve una lista preconfigurada."""
|
||||
|
||||
response: list[dict]
|
||||
raise_exc: Exception | None = None
|
||||
last_kwargs: dict | None = None
|
||||
|
||||
def predict_entities(self, text, labels, threshold, flat_ner):
|
||||
self.last_kwargs = {
|
||||
"text": text,
|
||||
"labels": list(labels),
|
||||
"threshold": threshold,
|
||||
"flat_ner": flat_ner,
|
||||
}
|
||||
if self.raise_exc is not None:
|
||||
raise self.raise_exc
|
||||
return self.response
|
||||
|
||||
|
||||
# ---------- extract_entities_gliner ----------
|
||||
|
||||
|
||||
def test_schema_basico_y_modelo_stub_retorna_entity_candidate():
|
||||
"""Schema basico y modelo stub retorna EntityCandidate con offsets."""
|
||||
text = "Alice Johnson works at OpenAI in San Francisco."
|
||||
model = StubModel(response=[
|
||||
{"start": 0, "end": 13, "text": "Alice Johnson", "label": "Person", "score": 0.92},
|
||||
{"start": 23, "end": 29, "text": "OpenAI", "label": "Organization", "score": 0.87},
|
||||
{"start": 33, "end": 46, "text": "San Francisco", "label": "Location", "score": 0.81},
|
||||
])
|
||||
out = extract_entities_gliner(text, SCHEMA_BASIC, model, threshold=0.5)
|
||||
assert len(out) == 3
|
||||
assert all(isinstance(e, EntityCandidate) for e in out)
|
||||
|
||||
person = next(e for e in out if e.name == "Alice Johnson")
|
||||
assert person.type_ref == "osint_person_go_cybersecurity"
|
||||
assert person.type_label == "Person"
|
||||
assert person.attributes["start"] == 0
|
||||
assert person.attributes["end"] == 13
|
||||
assert pytest.approx(person.confidence, 0.001) == 0.92
|
||||
|
||||
|
||||
def test_threshold_filtra_spans_con_score_bajo():
|
||||
"""Threshold filtra spans con score bajo."""
|
||||
# El stub no aplica threshold internamente — el modelo real si. Este
|
||||
# test verifica que el threshold se PASA al modelo (kwargs).
|
||||
model = StubModel(response=[
|
||||
{"start": 0, "end": 5, "text": "Alice", "label": "Person", "score": 0.95},
|
||||
])
|
||||
extract_entities_gliner("Alice", SCHEMA_BASIC, model, threshold=0.7, flat_ner=False)
|
||||
assert model.last_kwargs["threshold"] == 0.7
|
||||
assert model.last_kwargs["flat_ner"] is False
|
||||
|
||||
|
||||
def test_schema_vacio_lanza_value_error():
|
||||
"""Schema vacio lanza ValueError."""
|
||||
model = StubModel(response=[])
|
||||
with pytest.raises(ValueError):
|
||||
extract_entities_gliner("text", [], model)
|
||||
|
||||
|
||||
def test_schema_sin_labels_validos_retorna_vacio():
|
||||
"""Schema sin label+type_ref validos retorna vacio con warning."""
|
||||
bad_schema = [{"label": "", "type_ref": ""}, {"label": "X"}]
|
||||
model = StubModel(response=[])
|
||||
with pytest.warns(UserWarning):
|
||||
out = extract_entities_gliner("text", bad_schema, model)
|
||||
assert out == []
|
||||
|
||||
|
||||
def test_excepcion_del_modelo_se_captura():
|
||||
"""Excepcion del modelo se captura y retorna vacio."""
|
||||
model = StubModel(response=[], raise_exc=RuntimeError("model exploded"))
|
||||
with pytest.warns(UserWarning):
|
||||
out = extract_entities_gliner("text", SCHEMA_BASIC, model)
|
||||
assert out == []
|
||||
|
||||
|
||||
def test_label_desconocido_se_descarta():
|
||||
"""Label desconocido se descarta."""
|
||||
model = StubModel(response=[
|
||||
{"start": 0, "end": 5, "text": "Alice", "label": "Person", "score": 0.9},
|
||||
{"start": 6, "end": 10, "text": "blob", "label": "UnknownLabel", "score": 0.9},
|
||||
])
|
||||
out = extract_entities_gliner("Alice blob", SCHEMA_BASIC, model)
|
||||
names = [e.name for e in out]
|
||||
assert "Alice" in names
|
||||
assert "blob" not in names
|
||||
|
||||
|
||||
def test_flat_ner_se_propaga_al_modelo():
|
||||
"""flat_ner se propaga al modelo."""
|
||||
model = StubModel(response=[])
|
||||
extract_entities_gliner("text", SCHEMA_BASIC, model, flat_ner=True)
|
||||
assert model.last_kwargs["flat_ner"] is True
|
||||
extract_entities_gliner("text", SCHEMA_BASIC, model, flat_ner=False)
|
||||
assert model.last_kwargs["flat_ner"] is False
|
||||
|
||||
|
||||
# ---------- gliner_load_model ----------
|
||||
|
||||
|
||||
def test_import_error_si_gliner_no_esta_instalado(monkeypatch):
|
||||
"""ImportError si gliner no esta instalado."""
|
||||
_MODEL_CACHE.clear()
|
||||
|
||||
real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "gliner" or name.startswith("gliner."):
|
||||
raise ImportError("gliner not installed (simulated)")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("builtins.__import__", fake_import)
|
||||
|
||||
with pytest.raises(ImportError, match="gliner no esta instalado"):
|
||||
gliner_load_model(model_name="dummy/model", device="cpu")
|
||||
|
||||
|
||||
def test_cache_devuelve_la_misma_instancia(monkeypatch):
|
||||
"""Cache devuelve la misma instancia con los mismos parametros."""
|
||||
_MODEL_CACHE.clear()
|
||||
sentinel = object()
|
||||
_MODEL_CACHE[("dummy/model", "cpu")] = sentinel
|
||||
|
||||
out = gliner_load_model(model_name="dummy/model", device="cpu")
|
||||
assert out is sentinel
|
||||
|
||||
# Limpiar al terminar para no contaminar otros tests.
|
||||
_MODEL_CACHE.clear()
|
||||
|
||||
|
||||
def test_resolve_device_explicito_se_respeta():
|
||||
"""device explicito se respeta tal cual."""
|
||||
assert _resolve_device("cpu") == "cpu"
|
||||
assert _resolve_device("cuda") == "cuda"
|
||||
assert _resolve_device("cuda:0") == "cuda:0"
|
||||
|
||||
|
||||
def test_resolve_device_auto_cae_a_cpu_sin_torch(monkeypatch):
|
||||
"""device='auto' resuelve a cpu o cuda segun torch.cuda.is_available."""
|
||||
real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "torch":
|
||||
raise ImportError("torch missing")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("builtins.__import__", fake_import)
|
||||
assert _resolve_device("auto") == "cpu"
|
||||
Reference in New Issue
Block a user