feat(pipelines): extract_graph_hybrid (regex + GLiNER + GLiREL + LLM fallback)
Pipeline en cascada que combina extract_iocs (regex, coste 0), GLiNER (zero-shot NER), GLiREL (zero-shot RE) y un fallback LLM opcional para chunks con baja confianza o pocas entidades. Devuelve listas concatenadas listas para deduplicate_entities/deduplicate_relations. Cierra 0040.
This commit is contained in:
@@ -0,0 +1,293 @@
|
||||
"""Tests de integracion para extract_graph_hybrid.
|
||||
|
||||
Stubs duck-typed para gliner/glirel/LLM permiten ejercitar la cascada
|
||||
sin descargar modelos pesados.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
from python.functions.pipelines.extract_graph_hybrid import extract_graph_hybrid
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
from python.types.datascience.relation_candidate import RelationCandidate
|
||||
|
||||
|
||||
# ── Stubs ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubGliner:
|
||||
"""Stub de GLiNER. `responses` se va consumiendo por chunk en orden."""
|
||||
|
||||
responses: list[list[dict]] = field(default_factory=list)
|
||||
calls: int = 0
|
||||
|
||||
def predict_entities(self, text, labels, threshold, flat_ner):
|
||||
idx = self.calls
|
||||
self.calls += 1
|
||||
if idx < len(self.responses):
|
||||
return self.responses[idx]
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubGlirel:
|
||||
"""Stub de GLiREL. Mismo patron que StubGliner."""
|
||||
|
||||
responses: list[list[dict]] = field(default_factory=list)
|
||||
calls: int = 0
|
||||
|
||||
def predict_relations(self, tokens, labels, threshold, ner, top_k=1):
|
||||
idx = self.calls
|
||||
self.calls += 1
|
||||
if idx < len(self.responses):
|
||||
return self.responses[idx]
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubLLM:
|
||||
"""LLM stub: enruta por contenido del system prompt."""
|
||||
|
||||
entity_responses: list[dict] = field(default_factory=list)
|
||||
relation_responses: list[dict] = field(default_factory=list)
|
||||
entity_calls: int = 0
|
||||
relation_calls: int = 0
|
||||
|
||||
def __call__(self, messages: list[dict]) -> dict:
|
||||
system = messages[0]["content"] if messages else ""
|
||||
if "relation extraction expert" in system.lower():
|
||||
idx = self.relation_calls
|
||||
self.relation_calls += 1
|
||||
if idx < len(self.relation_responses):
|
||||
return self.relation_responses[idx]
|
||||
return {"relations": []}
|
||||
idx = self.entity_calls
|
||||
self.entity_calls += 1
|
||||
if idx < len(self.entity_responses):
|
||||
return self.entity_responses[idx]
|
||||
return {"entities": []}
|
||||
|
||||
|
||||
SCHEMA = [
|
||||
{"type_ref": "osint_person_go_cybersecurity", "label": "Person"},
|
||||
{"type_ref": "osint_organization_go_cybersecurity", "label": "Organization"},
|
||||
{"type_ref": "osint_location_go_cybersecurity", "label": "Location"},
|
||||
]
|
||||
RELATION_TYPES = ["operates", "owns", "communicates_with", "employed_by", "related_to"]
|
||||
|
||||
|
||||
# ── Tests ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_corpus_osint_devuelve_mezcla_regex_gliner():
|
||||
"""Corpus OSINT con IoCs y entidades semanticas devuelve mezcla regex+GLiNER."""
|
||||
chunks = [
|
||||
"Alice Johnson works at OpenAI. Contact: alice@openai.com",
|
||||
]
|
||||
gliner = StubGliner(responses=[
|
||||
[
|
||||
{"start": 0, "end": 13, "text": "Alice Johnson", "label": "Person", "score": 0.92},
|
||||
{"start": 23, "end": 29, "text": "OpenAI", "label": "Organization", "score": 0.88},
|
||||
],
|
||||
])
|
||||
glirel = StubGlirel(responses=[[]])
|
||||
|
||||
entities, relations = extract_graph_hybrid(
|
||||
chunks=chunks,
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=gliner,
|
||||
glirel_model=glirel,
|
||||
llm_chat_json=None,
|
||||
)
|
||||
|
||||
types = {e.type_ref for e in entities}
|
||||
# Regex IoC: email
|
||||
assert any(e.type_ref == "ioc_email" and e.name == "alice@openai.com" for e in entities)
|
||||
# GLiNER: persona y organizacion
|
||||
assert "osint_person_go_cybersecurity" in types
|
||||
assert "osint_organization_go_cybersecurity" in types
|
||||
# source_chunk_indices marcado
|
||||
assert all(0 in e.source_chunk_indices for e in entities)
|
||||
assert relations == []
|
||||
|
||||
|
||||
def test_chunks_vacios_se_saltan():
|
||||
"""Chunks vacios o solo whitespace se saltan sin invocar modelos."""
|
||||
gliner = StubGliner(responses=[])
|
||||
glirel = StubGlirel(responses=[])
|
||||
entities, relations = extract_graph_hybrid(
|
||||
chunks=["", " ", "\n\t"],
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=gliner,
|
||||
glirel_model=glirel,
|
||||
)
|
||||
assert entities == []
|
||||
assert relations == []
|
||||
assert gliner.calls == 0
|
||||
assert glirel.calls == 0
|
||||
|
||||
|
||||
def test_entity_schema_vacio_lanza_value_error():
|
||||
"""entity_schema vacio lanza ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
extract_graph_hybrid(
|
||||
chunks=["text"],
|
||||
entity_schema=[],
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=StubGliner(),
|
||||
glirel_model=StubGlirel(),
|
||||
)
|
||||
|
||||
|
||||
def test_chunks_no_lista_lanza_value_error():
|
||||
"""chunks no-lista lanza ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
extract_graph_hybrid(
|
||||
chunks="no soy lista", # type: ignore[arg-type]
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=StubGliner(),
|
||||
glirel_model=StubGlirel(),
|
||||
)
|
||||
|
||||
|
||||
def test_gliner_pocas_entidades_dispara_fallback_llm():
|
||||
"""GLiNER produciendo pocas entidades dispara fallback LLM."""
|
||||
chunks = ["Texto complejo sin patrones obvios."]
|
||||
gliner = StubGliner(responses=[[]]) # GLiNER no encuentra nada
|
||||
glirel = StubGlirel(responses=[[]])
|
||||
llm = StubLLM(entity_responses=[
|
||||
{"entities": [
|
||||
{"name": "Acme Corp", "type_ref": "osint_organization_go_cybersecurity",
|
||||
"attributes": {}, "confidence": 0.95},
|
||||
{"name": "Bob", "type_ref": "osint_person_go_cybersecurity",
|
||||
"attributes": {}, "confidence": 0.9},
|
||||
]},
|
||||
])
|
||||
|
||||
entities, _ = extract_graph_hybrid(
|
||||
chunks=chunks,
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=gliner,
|
||||
glirel_model=glirel,
|
||||
llm_chat_json=llm,
|
||||
min_entities_per_chunk=2,
|
||||
)
|
||||
|
||||
names = {e.name for e in entities}
|
||||
assert "Acme Corp" in names
|
||||
assert "Bob" in names
|
||||
assert llm.entity_calls == 1
|
||||
|
||||
|
||||
def test_sin_llm_no_se_invoca_fallback():
|
||||
"""Sin llm_chat_json no se invoca ningun fallback LLM aunque GLiNER no encuentre nada."""
|
||||
gliner = StubGliner(responses=[[]])
|
||||
glirel = StubGlirel(responses=[[]])
|
||||
entities, relations = extract_graph_hybrid(
|
||||
chunks=["chunk dificil"],
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=gliner,
|
||||
glirel_model=glirel,
|
||||
llm_chat_json=None,
|
||||
)
|
||||
# Nada de LLM, solo lo que diera regex (en este chunk: nada)
|
||||
assert entities == []
|
||||
assert relations == []
|
||||
|
||||
|
||||
def test_glirel_sin_relaciones_dispara_fallback_llm_relations():
|
||||
"""GLiREL sin relaciones dispara fallback LLM relations."""
|
||||
chunks = ["Alice Johnson trabaja para OpenAI."]
|
||||
gliner = StubGliner(responses=[
|
||||
[
|
||||
{"start": 0, "end": 13, "text": "Alice Johnson", "label": "Person", "score": 0.95},
|
||||
{"start": 26, "end": 32, "text": "OpenAI", "label": "Organization", "score": 0.9},
|
||||
],
|
||||
])
|
||||
glirel = StubGlirel(responses=[[]]) # GLiREL no encuentra relaciones
|
||||
llm = StubLLM(relation_responses=[
|
||||
{"relations": [
|
||||
{"from_name": "Alice Johnson", "to_name": "OpenAI",
|
||||
"relation_type": "employed_by", "description": "...", "confidence": 0.9},
|
||||
]},
|
||||
])
|
||||
|
||||
_, relations = extract_graph_hybrid(
|
||||
chunks=chunks,
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=gliner,
|
||||
glirel_model=glirel,
|
||||
llm_chat_json=llm,
|
||||
confidence_threshold=0.5,
|
||||
min_entities_per_chunk=2,
|
||||
)
|
||||
|
||||
assert len(relations) == 1
|
||||
assert relations[0].from_name == "Alice Johnson"
|
||||
assert relations[0].to_name == "OpenAI"
|
||||
assert relations[0].relation_type == "employed_by"
|
||||
assert relations[0].source_chunk_index == 0
|
||||
assert llm.relation_calls == 1
|
||||
|
||||
|
||||
def test_ioc_types_acota_extractores():
|
||||
"""ioc_types acota el set de extractores regex."""
|
||||
chunks = ["Email: x@y.com, IP: 192.168.0.1, MD5: 5d41402abc4b2a76b9719d911017c592."]
|
||||
gliner = StubGliner(responses=[[]])
|
||||
glirel = StubGlirel(responses=[[]])
|
||||
entities, _ = extract_graph_hybrid(
|
||||
chunks=chunks,
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=gliner,
|
||||
glirel_model=glirel,
|
||||
llm_chat_json=None,
|
||||
ioc_types=["email"], # solo emails
|
||||
)
|
||||
types = {e.type_ref for e in entities}
|
||||
assert "ioc_email" in types
|
||||
assert "ioc_ip_address" not in types
|
||||
assert "ioc_file_hash" not in types
|
||||
|
||||
|
||||
def test_errores_se_capturan_con_warning():
|
||||
"""Errores de extractores se capturan con warnings y no abortan el pipeline."""
|
||||
|
||||
class BoomGliner:
|
||||
def predict_entities(self, *a, **k):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class BoomGlirel:
|
||||
def predict_relations(self, *a, **k):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
chunks = ["Email: contact@example.com"]
|
||||
with pytest.warns(UserWarning):
|
||||
entities, relations = extract_graph_hybrid(
|
||||
chunks=chunks,
|
||||
entity_schema=SCHEMA,
|
||||
relation_types=RELATION_TYPES,
|
||||
gliner_model=BoomGliner(),
|
||||
glirel_model=BoomGlirel(),
|
||||
llm_chat_json=None,
|
||||
)
|
||||
# Aun asi extract_iocs deberia haber sacado el email
|
||||
assert any(e.type_ref == "ioc_email" for e in entities)
|
||||
assert relations == []
|
||||
Reference in New Issue
Block a user