fix(infra): gradle_run detecta android-sdk — issue 0076 #2
@@ -9,6 +9,8 @@ from .cybersecurity import (
|
||||
levenshtein_distance,
|
||||
jaccard_similarity,
|
||||
normalize_url,
|
||||
envelope_encrypt,
|
||||
envelope_decrypt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -22,4 +24,6 @@ __all__ = [
|
||||
"levenshtein_distance",
|
||||
"jaccard_similarity",
|
||||
"normalize_url",
|
||||
"envelope_encrypt",
|
||||
"envelope_decrypt",
|
||||
]
|
||||
|
||||
@@ -4,8 +4,11 @@ import hashlib
|
||||
import math
|
||||
import re
|
||||
import base64
|
||||
import secrets
|
||||
import struct
|
||||
from collections import Counter
|
||||
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
|
||||
def hash_sha256(data: bytes) -> str:
|
||||
@@ -165,3 +168,147 @@ def normalize_url(raw_url: str) -> str:
|
||||
sorted_query = urlencode(sorted(params.items()), doseq=True)
|
||||
# Drop fragment
|
||||
return urlunparse((scheme, netloc, path, parsed.params, sorted_query, ""))
|
||||
|
||||
|
||||
# --- Envelope Encryption (AES-256-GCM) ---
|
||||
|
||||
_ENVELOPE_MAGIC = b"OVE1"
|
||||
_ENVELOPE_VERSION = 0x01
|
||||
_HEADER_SIZE = 12 # magic(4) + version(1) + reserved(1) + efk_len(2) + kiv_len(2) + div_len(2)
|
||||
|
||||
|
||||
def _build_envelope(
|
||||
encrypted_file_key: bytes,
|
||||
key_iv: bytes,
|
||||
data_iv: bytes,
|
||||
encrypted_content: bytes,
|
||||
) -> bytes:
|
||||
"""Construye el formato binario del envelope (helper puro interno).
|
||||
|
||||
Header (12 bytes):
|
||||
Magic (4B): b"OVE1"
|
||||
Version (1B): 0x01
|
||||
Reserved (1B): 0x00
|
||||
EFK_len (2B): longitud de encrypted_file_key (big-endian)
|
||||
KIV_len (2B): longitud de key_iv (big-endian)
|
||||
DIV_len (2B): longitud de data_iv (big-endian)
|
||||
Seguido de: encrypted_file_key + key_iv + data_iv + encrypted_content
|
||||
"""
|
||||
header = (
|
||||
_ENVELOPE_MAGIC
|
||||
+ struct.pack(">BBHHH", _ENVELOPE_VERSION, 0x00,
|
||||
len(encrypted_file_key), len(key_iv), len(data_iv))
|
||||
)
|
||||
return header + encrypted_file_key + key_iv + data_iv + encrypted_content
|
||||
|
||||
|
||||
def _parse_envelope(ciphertext: bytes) -> tuple:
|
||||
"""Parsea el envelope binario y retorna sus componentes (helper puro interno).
|
||||
|
||||
Returns:
|
||||
(encrypted_file_key, key_iv, data_iv, encrypted_content)
|
||||
|
||||
Raises:
|
||||
ValueError: si el envelope esta truncado o la version no es soportada.
|
||||
"""
|
||||
if len(ciphertext) < _HEADER_SIZE:
|
||||
raise ValueError(
|
||||
f"Envelope truncado: se esperaban al menos {_HEADER_SIZE} bytes, "
|
||||
f"se recibieron {len(ciphertext)}"
|
||||
)
|
||||
|
||||
magic = ciphertext[:4]
|
||||
if magic != _ENVELOPE_MAGIC:
|
||||
raise ValueError(f"Magic invalido: se esperaba {_ENVELOPE_MAGIC!r}, se obtuvo {magic!r}")
|
||||
|
||||
version, _reserved, efk_len, kiv_len, div_len = struct.unpack(">BBHHH", ciphertext[4:12])
|
||||
|
||||
if version != _ENVELOPE_VERSION:
|
||||
raise ValueError(f"Version de envelope no soportada: {version}")
|
||||
|
||||
offset = _HEADER_SIZE
|
||||
encrypted_file_key = ciphertext[offset : offset + efk_len]
|
||||
offset += efk_len
|
||||
key_iv = ciphertext[offset : offset + kiv_len]
|
||||
offset += kiv_len
|
||||
data_iv = ciphertext[offset : offset + div_len]
|
||||
offset += div_len
|
||||
encrypted_content = ciphertext[offset:]
|
||||
|
||||
if (
|
||||
len(encrypted_file_key) != efk_len
|
||||
or len(key_iv) != kiv_len
|
||||
or len(data_iv) != div_len
|
||||
):
|
||||
raise ValueError("Envelope truncado: longitudes declaradas exceden los datos disponibles")
|
||||
|
||||
return encrypted_file_key, key_iv, data_iv, encrypted_content
|
||||
|
||||
|
||||
def envelope_encrypt(plaintext: bytes, master_key: bytes) -> bytes:
|
||||
"""Cifra datos usando patron Envelope Encryption con AES-256-GCM.
|
||||
|
||||
Genera una file key aleatoria de 32 bytes, cifra los datos con ella,
|
||||
luego cifra la file key con la master_key. El resultado es un envelope
|
||||
binario que contiene todo lo necesario para descifrar con la master_key.
|
||||
|
||||
Args:
|
||||
plaintext: Datos a cifrar (puede ser vacio).
|
||||
master_key: Clave maestra de 32 bytes (AES-256).
|
||||
|
||||
Returns:
|
||||
Envelope binario cifrado.
|
||||
|
||||
Raises:
|
||||
Exception: Si ocurre un error en el cifrado (clave de longitud incorrecta, etc.).
|
||||
"""
|
||||
# 1. Generar file_key aleatoria (DEK: Data Encryption Key)
|
||||
file_key = secrets.token_bytes(32)
|
||||
|
||||
# 2. Cifrar contenido con la file_key
|
||||
data_iv = secrets.token_bytes(12)
|
||||
aesgcm_data = AESGCM(file_key)
|
||||
encrypted_content = aesgcm_data.encrypt(data_iv, plaintext, None)
|
||||
|
||||
# 3. Cifrar file_key con la master_key (KEK: Key Encryption Key)
|
||||
key_iv = secrets.token_bytes(12)
|
||||
aesgcm_key = AESGCM(master_key)
|
||||
encrypted_file_key = aesgcm_key.encrypt(key_iv, file_key, None)
|
||||
|
||||
# 4. Construir envelope
|
||||
return _build_envelope(encrypted_file_key, key_iv, data_iv, encrypted_content)
|
||||
|
||||
|
||||
def envelope_decrypt(ciphertext: bytes, master_key: bytes) -> bytes:
|
||||
"""Descifra datos cifrados con envelope_encrypt.
|
||||
|
||||
Si los datos no empiezan con el magic b"OVE1", se asume que no estan
|
||||
cifrados y se retornan tal cual (comportamiento passthrough). Esto
|
||||
permite usar la funcion en archivos que pueden o no estar cifrados.
|
||||
|
||||
Args:
|
||||
ciphertext: Envelope cifrado (o datos en plano si no tienen magic).
|
||||
master_key: Clave maestra de 32 bytes (AES-256).
|
||||
|
||||
Returns:
|
||||
Datos descifrados, o ciphertext sin modificar si no tiene magic.
|
||||
|
||||
Raises:
|
||||
ValueError: Si el envelope esta corrupto o truncado.
|
||||
cryptography.exceptions.InvalidTag: Si la master_key es incorrecta
|
||||
o los datos fueron manipulados (falla de autenticacion GCM).
|
||||
"""
|
||||
# Passthrough: si no comienza con magic, asumir que no esta cifrado
|
||||
if not ciphertext.startswith(_ENVELOPE_MAGIC):
|
||||
return ciphertext
|
||||
|
||||
# Parsear envelope
|
||||
encrypted_file_key, key_iv, data_iv, encrypted_content = _parse_envelope(ciphertext)
|
||||
|
||||
# Descifrar file_key con master_key
|
||||
aesgcm_key = AESGCM(master_key)
|
||||
file_key = aesgcm_key.decrypt(key_iv, encrypted_file_key, None)
|
||||
|
||||
# Descifrar contenido con file_key
|
||||
aesgcm_data = AESGCM(file_key)
|
||||
return aesgcm_data.decrypt(data_iv, encrypted_content, None)
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
---
|
||||
name: envelope_decrypt
|
||||
kind: function
|
||||
lang: py
|
||||
domain: cybersecurity
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def envelope_decrypt(ciphertext: bytes, master_key: bytes) -> bytes"
|
||||
description: "Descifra datos cifrados con envelope_encrypt. Si los datos no comienzan con el magic b'OVE1', los retorna sin modificar (passthrough). Soporta archivos que pueden o no estar cifrados sin necesidad de chequeo previo."
|
||||
tags: [decryption, aes, gcm, envelope-encryption, dek, kek, cryptography, cybersecurity, passthrough]
|
||||
uses_functions: [envelope_encrypt_py_cybersecurity]
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [cryptography, struct]
|
||||
tested: true
|
||||
tests:
|
||||
- "decrypt de datos cifrados"
|
||||
- "decrypt de datos no cifrados passthrough"
|
||||
- "key incorrecta"
|
||||
- "envelope truncado"
|
||||
- "magic invalido"
|
||||
test_file_path: "python/functions/cybersecurity/envelope_encrypt_test.py"
|
||||
file_path: "python/functions/cybersecurity/cybersecurity.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
import secrets
|
||||
from cybersecurity import envelope_encrypt, envelope_decrypt
|
||||
|
||||
master_key = secrets.token_bytes(32)
|
||||
|
||||
# Caso 1: descifrar datos cifrados
|
||||
ciphertext = envelope_encrypt(b"datos secretos", master_key)
|
||||
plaintext = envelope_decrypt(ciphertext, master_key)
|
||||
# plaintext == b"datos secretos"
|
||||
|
||||
# Caso 2: passthrough — datos no cifrados
|
||||
raw = b"archivo en plano"
|
||||
result = envelope_decrypt(raw, master_key)
|
||||
# result == b"archivo en plano" (sin modificar)
|
||||
|
||||
# Caso 3: key incorrecta — lanza InvalidTag
|
||||
wrong_key = secrets.token_bytes(32)
|
||||
# envelope_decrypt(ciphertext, wrong_key) → cryptography.exceptions.InvalidTag
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Implementacion original inspirada en OpenViking `openviking/crypto/encryptor.py` (AGPL-3.0). Reimplementada desde cero.
|
||||
|
||||
- **Passthrough**: si `ciphertext` no empieza con `b"OVE1"`, se retorna sin modificar. Permite usar la funcion indistintamente en archivos cifrados y no cifrados.
|
||||
- **Autenticacion GCM**: si la master_key es incorrecta o los datos fueron manipulados, `cryptography.exceptions.InvalidTag` es lanzado por la capa GCM — nunca se retorna texto corrupto.
|
||||
- **ValueError**: lanzado si el envelope tiene magic correcto pero estructura invalida (truncado o version no soportada).
|
||||
- `master_key` debe ser de exactamente 32 bytes para AES-256.
|
||||
- Requiere `cryptography` instalado: `uv add cryptography`.
|
||||
@@ -0,0 +1,68 @@
|
||||
---
|
||||
name: envelope_encrypt
|
||||
kind: function
|
||||
lang: py
|
||||
domain: cybersecurity
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def envelope_encrypt(plaintext: bytes, master_key: bytes) -> bytes"
|
||||
description: "Cifra datos usando patron Envelope Encryption con AES-256-GCM. Genera una file key aleatoria (DEK), cifra los datos con ella, luego cifra la file key con la master_key (KEK). Retorna un envelope binario con magic b'OVE1'."
|
||||
tags: [encryption, aes, gcm, envelope-encryption, dek, kek, cryptography, cybersecurity]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [cryptography, secrets, struct]
|
||||
tested: true
|
||||
tests:
|
||||
- "encrypt → decrypt roundtrip"
|
||||
- "datos vacios"
|
||||
- "datos grandes"
|
||||
- "ciphertext tiene magic correcto"
|
||||
- "ciphertext es distinto cada vez"
|
||||
test_file_path: "python/functions/cybersecurity/envelope_encrypt_test.py"
|
||||
file_path: "python/functions/cybersecurity/cybersecurity.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
import secrets
|
||||
from cybersecurity import envelope_encrypt, envelope_decrypt
|
||||
|
||||
master_key = secrets.token_bytes(32) # 256-bit KEK
|
||||
plaintext = b"datos confidenciales"
|
||||
|
||||
ciphertext = envelope_encrypt(plaintext, master_key)
|
||||
# ciphertext[:4] == b"OVE1"
|
||||
|
||||
recovered = envelope_decrypt(ciphertext, master_key)
|
||||
# recovered == plaintext
|
||||
```
|
||||
|
||||
## Formato del envelope
|
||||
|
||||
```
|
||||
Magic (4B): b"OVE1" identificador de formato
|
||||
Version (1B): 0x01 version del protocolo
|
||||
Reserved (1B): 0x00 reservado para uso futuro
|
||||
EFK_len (2B): big-endian longitud de encrypted_file_key
|
||||
KIV_len (2B): big-endian longitud de key_iv
|
||||
DIV_len (2B): big-endian longitud de data_iv
|
||||
--- header: 12 bytes total ---
|
||||
Encrypted File Key (variable, incluye GCM auth tag de 16B)
|
||||
Key IV (12B)
|
||||
Data IV (12B)
|
||||
Encrypted Content (variable, incluye GCM auth tag de 16B)
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Implementacion original inspirada en OpenViking `openviking/crypto/encryptor.py` (AGPL-3.0). Reimplementada desde cero.
|
||||
|
||||
- La file key (DEK) es de 32 bytes generados con `secrets.token_bytes` (CSPRNG).
|
||||
- Tanto el cifrado de datos como el de la file key usan AES-256-GCM con IVs de 12 bytes.
|
||||
- El GCM auth tag (16 bytes) garantiza autenticidad e integridad.
|
||||
- `master_key` debe ser de exactamente 32 bytes para AES-256.
|
||||
- Requiere `cryptography` instalado: `uv add cryptography`.
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Tests para envelope_encrypt y envelope_decrypt."""
|
||||
|
||||
import secrets
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from cybersecurity import envelope_encrypt, envelope_decrypt
|
||||
|
||||
|
||||
def test_encrypt_decrypt_roundtrip():
|
||||
master_key = secrets.token_bytes(32)
|
||||
plaintext = b"datos de prueba para envelope encryption"
|
||||
ciphertext = envelope_encrypt(plaintext, master_key)
|
||||
result = envelope_decrypt(ciphertext, master_key)
|
||||
assert result == plaintext
|
||||
|
||||
|
||||
def test_datos_vacios():
|
||||
master_key = secrets.token_bytes(32)
|
||||
ciphertext = envelope_encrypt(b"", master_key)
|
||||
result = envelope_decrypt(ciphertext, master_key)
|
||||
assert result == b""
|
||||
|
||||
|
||||
def test_datos_grandes():
|
||||
master_key = secrets.token_bytes(32)
|
||||
plaintext = secrets.token_bytes(1024 * 1024) # 1 MB
|
||||
ciphertext = envelope_encrypt(plaintext, master_key)
|
||||
result = envelope_decrypt(ciphertext, master_key)
|
||||
assert result == plaintext
|
||||
|
||||
|
||||
def test_decrypt_datos_no_cifrados_passthrough():
|
||||
master_key = secrets.token_bytes(32)
|
||||
plain = b"archivo no cifrado, sin magic bytes"
|
||||
result = envelope_decrypt(plain, master_key)
|
||||
assert result == plain
|
||||
|
||||
|
||||
def test_key_incorrecta():
|
||||
master_key = secrets.token_bytes(32)
|
||||
wrong_key = secrets.token_bytes(32)
|
||||
ciphertext = envelope_encrypt(b"secreto", master_key)
|
||||
try:
|
||||
envelope_decrypt(ciphertext, wrong_key)
|
||||
assert False, "deberia haber lanzado excepcion"
|
||||
except Exception:
|
||||
pass # esperado: InvalidTag de cryptography
|
||||
|
||||
|
||||
def test_envelope_truncado():
|
||||
master_key = secrets.token_bytes(32)
|
||||
ciphertext = envelope_encrypt(b"datos", master_key)
|
||||
truncated = ciphertext[:6] # header incompleto
|
||||
try:
|
||||
envelope_decrypt(truncated, master_key)
|
||||
assert False, "deberia haber lanzado ValueError"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_magic_invalido():
|
||||
master_key = secrets.token_bytes(32)
|
||||
# Construir datos con magic valido para pasar el check del passthrough
|
||||
# pero con header corrupto
|
||||
bad_envelope = b"OVE1" + b"\x00" * 20 # magic correcto pero header invalido
|
||||
try:
|
||||
envelope_decrypt(bad_envelope, master_key)
|
||||
assert False, "deberia haber lanzado excepcion"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def test_ciphertext_tiene_magic_correcto():
|
||||
master_key = secrets.token_bytes(32)
|
||||
ciphertext = envelope_encrypt(b"test", master_key)
|
||||
assert ciphertext[:4] == b"OVE1"
|
||||
|
||||
|
||||
def test_ciphertext_es_distinto_cada_vez():
|
||||
master_key = secrets.token_bytes(32)
|
||||
plaintext = b"mismo mensaje"
|
||||
ct1 = envelope_encrypt(plaintext, master_key)
|
||||
ct2 = envelope_encrypt(plaintext, master_key)
|
||||
# IVs aleatorios garantizan ciphertexts distintos
|
||||
assert ct1 != ct2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_encrypt_decrypt_roundtrip()
|
||||
test_datos_vacios()
|
||||
test_datos_grandes()
|
||||
test_decrypt_datos_no_cifrados_passthrough()
|
||||
test_key_incorrecta()
|
||||
test_envelope_truncado()
|
||||
test_magic_invalido()
|
||||
test_ciphertext_tiene_magic_correcto()
|
||||
test_ciphertext_es_distinto_cada_vez()
|
||||
print("Todos los tests pasaron.")
|
||||
@@ -0,0 +1,45 @@
|
||||
---
|
||||
name: aggregate_by_group
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def aggregate_by_group(rows: list[dict], group_by: list[str], aggs: dict[str, str]) -> list[dict]"
|
||||
description: "GROUP BY + agregaciones sobre datos tabulares. aggs es un dict de columna a funcion (sum, mean, count, min, max, first, last, collect). collect acumula valores en lista. None se ignora en agregaciones numericas."
|
||||
tags: [datascience, tabular, groupby, aggregate, transform, python]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: ["collections"]
|
||||
tested: true
|
||||
tests:
|
||||
- "Group by una columna con sum"
|
||||
- "Group by multiples columnas"
|
||||
- "Agregacion mean count min max"
|
||||
- "collect acumula en lista"
|
||||
- "Grupo con una sola fila"
|
||||
- "Campo con None se ignora en agregaciones numericas"
|
||||
test_file_path: "python/functions/datascience/aggregate_by_group_test.py"
|
||||
file_path: "python/functions/datascience/aggregate_by_group.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
rows = [
|
||||
{"dept": "eng", "salary": 100},
|
||||
{"dept": "eng", "salary": 120},
|
||||
{"dept": "sales", "salary": 80},
|
||||
]
|
||||
aggregate_by_group(rows, group_by=["dept"], aggs={"salary": "mean"})
|
||||
# [{"dept": "eng", "salary": 110.0}, {"dept": "sales", "salary": 80.0}]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura sin dependencias externas (solo collections.defaultdict de stdlib).
|
||||
Preserva el orden de primera aparicion de cada grupo.
|
||||
La funcion 'collect' no filtra None — acumula todos los valores incluyendo None.
|
||||
@@ -0,0 +1,71 @@
|
||||
"""GROUP BY + agregaciones sobre datos tabulares list[dict]."""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def aggregate_by_group(
|
||||
rows: list[dict],
|
||||
group_by: list[str],
|
||||
aggs: dict[str, str],
|
||||
) -> list[dict]:
|
||||
"""Agrupa filas por una o varias columnas y aplica agregaciones.
|
||||
|
||||
Equivalente a SQL GROUP BY con funciones de agregacion.
|
||||
La funcion 'collect' acumula todos los valores en una lista.
|
||||
Los valores None se ignoran en agregaciones numericas (sum, mean, min, max).
|
||||
|
||||
Args:
|
||||
rows: Lista de dicts con los datos.
|
||||
group_by: Lista de columnas por las que agrupar.
|
||||
aggs: Dict de {columna: funcion}. Funciones: sum, mean, count,
|
||||
min, max, first, last, collect.
|
||||
|
||||
Returns:
|
||||
Lista de dicts con las columnas de group_by mas los campos agregados.
|
||||
El orden de las filas sigue el orden de primera aparicion del grupo.
|
||||
"""
|
||||
# Mantener orden de grupos con lista de claves
|
||||
group_keys: list[tuple] = []
|
||||
seen_groups: set[tuple] = set()
|
||||
buckets: dict[tuple, dict[str, list]] = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for row in rows:
|
||||
gk = tuple(row.get(col) for col in group_by)
|
||||
if gk not in seen_groups:
|
||||
seen_groups.add(gk)
|
||||
group_keys.append(gk)
|
||||
for col in aggs:
|
||||
val = row.get(col)
|
||||
buckets[gk][col].append(val)
|
||||
|
||||
def _aggregate(vals: list, func: str):
|
||||
if func == "collect":
|
||||
return vals
|
||||
if func == "count":
|
||||
return len(vals)
|
||||
if func == "first":
|
||||
return vals[0] if vals else None
|
||||
if func == "last":
|
||||
return vals[-1] if vals else None
|
||||
# Para sum, mean, min, max: ignorar None
|
||||
numeric = [v for v in vals if v is not None]
|
||||
if not numeric:
|
||||
return None
|
||||
if func == "sum":
|
||||
return sum(numeric)
|
||||
if func == "mean":
|
||||
return sum(numeric) / len(numeric)
|
||||
if func == "min":
|
||||
return min(numeric)
|
||||
if func == "max":
|
||||
return max(numeric)
|
||||
raise ValueError(f"Funcion de agregacion no soportada: {func}")
|
||||
|
||||
result = []
|
||||
for gk in group_keys:
|
||||
record: dict = dict(zip(group_by, gk))
|
||||
for col, func in aggs.items():
|
||||
record[col] = _aggregate(buckets[gk][col], func)
|
||||
result.append(record)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Tests para aggregate_by_group."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from aggregate_by_group import aggregate_by_group
|
||||
|
||||
|
||||
def test_group_by_una_columna_con_sum():
|
||||
"""Group by una columna con sum."""
|
||||
rows = [
|
||||
{"dept": "eng", "salary": 100},
|
||||
{"dept": "eng", "salary": 120},
|
||||
{"dept": "sales", "salary": 80},
|
||||
]
|
||||
result = aggregate_by_group(rows, group_by=["dept"], aggs={"salary": "sum"})
|
||||
assert len(result) == 2
|
||||
eng = next(r for r in result if r["dept"] == "eng")
|
||||
sales = next(r for r in result if r["dept"] == "sales")
|
||||
assert eng["salary"] == 220
|
||||
assert sales["salary"] == 80
|
||||
|
||||
|
||||
def test_group_by_multiples_columnas():
|
||||
"""Group by multiples columnas."""
|
||||
rows = [
|
||||
{"dept": "eng", "level": "senior", "salary": 150},
|
||||
{"dept": "eng", "level": "junior", "salary": 80},
|
||||
{"dept": "eng", "level": "senior", "salary": 160},
|
||||
{"dept": "sales", "level": "senior", "salary": 120},
|
||||
]
|
||||
result = aggregate_by_group(rows, group_by=["dept", "level"], aggs={"salary": "sum"})
|
||||
assert len(result) == 3
|
||||
eng_senior = next(r for r in result if r["dept"] == "eng" and r["level"] == "senior")
|
||||
assert eng_senior["salary"] == 310
|
||||
|
||||
|
||||
def test_agregacion_mean_count_min_max():
|
||||
"""Agregacion mean count min max."""
|
||||
rows = [
|
||||
{"cat": "A", "val": 10},
|
||||
{"cat": "A", "val": 20},
|
||||
{"cat": "A", "val": 30},
|
||||
]
|
||||
result_mean = aggregate_by_group(rows, group_by=["cat"], aggs={"val": "mean"})
|
||||
assert result_mean[0]["val"] == 20.0
|
||||
|
||||
result_count = aggregate_by_group(rows, group_by=["cat"], aggs={"val": "count"})
|
||||
assert result_count[0]["val"] == 3
|
||||
|
||||
result_min = aggregate_by_group(rows, group_by=["cat"], aggs={"val": "min"})
|
||||
assert result_min[0]["val"] == 10
|
||||
|
||||
result_max = aggregate_by_group(rows, group_by=["cat"], aggs={"val": "max"})
|
||||
assert result_max[0]["val"] == 30
|
||||
|
||||
|
||||
def test_collect_acumula_en_lista():
|
||||
"""collect acumula en lista."""
|
||||
rows = [
|
||||
{"dept": "eng", "name": "Alice"},
|
||||
{"dept": "eng", "name": "Bob"},
|
||||
{"dept": "sales", "name": "Carol"},
|
||||
]
|
||||
result = aggregate_by_group(rows, group_by=["dept"], aggs={"name": "collect"})
|
||||
eng = next(r for r in result if r["dept"] == "eng")
|
||||
assert sorted(eng["name"]) == ["Alice", "Bob"]
|
||||
|
||||
|
||||
def test_grupo_con_una_sola_fila():
|
||||
"""Grupo con una sola fila."""
|
||||
rows = [{"dept": "eng", "salary": 100}]
|
||||
result = aggregate_by_group(rows, group_by=["dept"], aggs={"salary": "sum"})
|
||||
assert len(result) == 1
|
||||
assert result[0]["salary"] == 100
|
||||
|
||||
|
||||
def test_campo_con_none_se_ignora_en_agregaciones_numericas():
|
||||
"""Campo con None se ignora en agregaciones numericas."""
|
||||
rows = [
|
||||
{"dept": "eng", "salary": 100},
|
||||
{"dept": "eng", "salary": None},
|
||||
{"dept": "eng", "salary": 200},
|
||||
]
|
||||
result = aggregate_by_group(rows, group_by=["dept"], aggs={"salary": "sum"})
|
||||
assert result[0]["salary"] == 300
|
||||
|
||||
result_mean = aggregate_by_group(rows, group_by=["dept"], aggs={"salary": "mean"})
|
||||
assert result_mean[0]["salary"] == 150.0
|
||||
@@ -0,0 +1,62 @@
|
||||
---
|
||||
name: build_entity_schema_prompt
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def build_entity_schema_prompt(entity_presets: list[dict]) -> str"
|
||||
description: "Genera la seccion del system prompt que describe los entity types disponibles para extraccion. Formatea los presets del registry en texto legible para el LLM."
|
||||
tags: [prompt, llm, entity, schema, osint, graph, extraction]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "lista con varios presets"
|
||||
- "lista vacia retorna string vacio"
|
||||
- "preset sin metadata_fields"
|
||||
test_file_path: "python/functions/datascience/build_entity_schema_prompt_test.py"
|
||||
file_path: "python/functions/datascience/build_entity_schema_prompt.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from build_entity_schema_prompt import build_entity_schema_prompt
|
||||
|
||||
presets = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "alias", "nationality", "dob", "risk_score"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_organization_go_cybersecurity",
|
||||
"label": "Organization",
|
||||
"metadata_fields": ["legal_name", "country", "sector", "founded", "risk_score"],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = build_entity_schema_prompt(presets)
|
||||
# Entity types available for extraction:
|
||||
#
|
||||
# 1. Person (type_ref: osint_person_go_cybersecurity)
|
||||
# Attributes: full_name, alias, nationality, dob, risk_score
|
||||
#
|
||||
# 2. Organization (type_ref: osint_organization_go_cybersecurity)
|
||||
# Attributes: legal_name, country, sector, founded, risk_score
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura. No requiere dependencias externas.
|
||||
|
||||
El formato de salida es deliberadamente sencillo para maximizar la comprension por el LLM: numero de orden, label humano, type_ref del registry y lista de atributos en una sola linea.
|
||||
|
||||
Si un preset no tiene `metadata_fields` (o tiene lista vacia), se omite la linea de atributos.
|
||||
|
||||
Pensada para componer con `build_relation_schema_prompt` al construir el system prompt completo de extraccion de grafos OSINT.
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Genera la seccion del system prompt que describe los entity types disponibles para extraccion."""
|
||||
|
||||
|
||||
def build_entity_schema_prompt(entity_presets: list[dict]) -> str:
|
||||
"""Genera texto legible para el LLM describiendo los entity types disponibles.
|
||||
|
||||
Formatea los presets del registry en una seccion del system prompt que indica
|
||||
al LLM que tipos de entidades puede extraer y que atributos tiene cada uno.
|
||||
|
||||
Args:
|
||||
entity_presets: Lista de presets con campos 'label', 'type_ref' y
|
||||
opcionalmente 'metadata_fields'. Ejemplo:
|
||||
[{"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "alias"]}]
|
||||
|
||||
Returns:
|
||||
String formateado con la seccion del prompt. Retorna string vacio si
|
||||
la lista de presets esta vacia.
|
||||
"""
|
||||
if not entity_presets:
|
||||
return ""
|
||||
|
||||
lines = ["Entity types available for extraction:", ""]
|
||||
|
||||
for i, preset in enumerate(entity_presets, start=1):
|
||||
label = preset.get("label", "Unknown")
|
||||
type_ref = preset.get("type_ref", "")
|
||||
metadata_fields = preset.get("metadata_fields", [])
|
||||
|
||||
lines.append(f"{i}. {label} (type_ref: {type_ref})")
|
||||
|
||||
if metadata_fields:
|
||||
attrs = ", ".join(metadata_fields)
|
||||
lines.append(f" Attributes: {attrs}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Remove trailing blank line
|
||||
if lines and lines[-1] == "":
|
||||
lines.pop()
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Tests para build_entity_schema_prompt."""
|
||||
|
||||
from build_entity_schema_prompt import build_entity_schema_prompt
|
||||
|
||||
|
||||
def test_lista_con_varios_presets():
|
||||
presets = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "alias", "nationality", "dob", "risk_score"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_organization_go_cybersecurity",
|
||||
"label": "Organization",
|
||||
"metadata_fields": ["legal_name", "country", "sector", "founded", "risk_score"],
|
||||
},
|
||||
]
|
||||
result = build_entity_schema_prompt(presets)
|
||||
assert "Entity types available for extraction:" in result
|
||||
assert "1. Person (type_ref: osint_person_go_cybersecurity)" in result
|
||||
assert " Attributes: full_name, alias, nationality, dob, risk_score" in result
|
||||
assert "2. Organization (type_ref: osint_organization_go_cybersecurity)" in result
|
||||
assert " Attributes: legal_name, country, sector, founded, risk_score" in result
|
||||
|
||||
|
||||
def test_lista_vacia_retorna_string_vacio():
|
||||
result = build_entity_schema_prompt([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_preset_sin_metadata_fields():
|
||||
presets = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
}
|
||||
]
|
||||
result = build_entity_schema_prompt(presets)
|
||||
assert "1. Person (type_ref: osint_person_go_cybersecurity)" in result
|
||||
assert "Attributes:" not in result
|
||||
@@ -0,0 +1,43 @@
|
||||
---
|
||||
name: build_relation_schema_prompt
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def build_relation_schema_prompt(relation_types: list[str]) -> str"
|
||||
description: "Genera la seccion del system prompt con los tipos de relacion permitidos para extraccion. Formatea la lista de tipos en texto legible para el LLM."
|
||||
tags: [prompt, llm, relation, schema, osint, graph, extraction]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "lista con varios tipos"
|
||||
- "lista vacia retorna string vacio"
|
||||
- "un solo tipo"
|
||||
test_file_path: "python/functions/datascience/build_relation_schema_prompt_test.py"
|
||||
file_path: "python/functions/datascience/build_relation_schema_prompt.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from build_relation_schema_prompt import build_relation_schema_prompt
|
||||
|
||||
types = ["funds", "employs", "communicates_with", "owns"]
|
||||
prompt = build_relation_schema_prompt(types)
|
||||
# Allowed relation types:
|
||||
# funds, employs, communicates_with, owns
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura. No requiere dependencias externas.
|
||||
|
||||
La salida es una sola linea con todos los tipos separados por coma, precedida por el encabezado. El formato es minimal para no consumir tokens innecesarios del contexto del LLM.
|
||||
|
||||
Pensada para componer con `build_entity_schema_prompt` al construir el system prompt completo de extraccion de grafos OSINT.
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Genera la seccion del system prompt con los tipos de relacion permitidos."""
|
||||
|
||||
|
||||
def build_relation_schema_prompt(relation_types: list[str]) -> str:
|
||||
"""Genera texto legible para el LLM describiendo los tipos de relacion permitidos.
|
||||
|
||||
Formatea la lista de tipos de relacion en una seccion del system prompt que
|
||||
indica al LLM que relaciones puede extraer entre entidades.
|
||||
|
||||
Args:
|
||||
relation_types: Lista de strings con los tipos de relacion permitidos.
|
||||
Ejemplo: ["funds", "employs", "communicates_with"]
|
||||
|
||||
Returns:
|
||||
String formateado con la seccion del prompt. Retorna string vacio si
|
||||
la lista esta vacia.
|
||||
"""
|
||||
if not relation_types:
|
||||
return ""
|
||||
|
||||
joined = ", ".join(relation_types)
|
||||
return f"Allowed relation types:\n{joined}"
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Tests para build_relation_schema_prompt."""
|
||||
|
||||
from build_relation_schema_prompt import build_relation_schema_prompt
|
||||
|
||||
|
||||
def test_lista_normal():
|
||||
relation_types = ["funds", "employs", "communicates_with", "owns", "operates"]
|
||||
result = build_relation_schema_prompt(relation_types)
|
||||
assert result.startswith("Allowed relation types:")
|
||||
assert "funds" in result
|
||||
assert "employs" in result
|
||||
assert "communicates_with" in result
|
||||
assert "owns" in result
|
||||
assert "operates" in result
|
||||
|
||||
|
||||
def test_lista_vacia_retorna_string_vacio():
|
||||
result = build_relation_schema_prompt([])
|
||||
assert result == ""
|
||||
@@ -121,3 +121,72 @@ def linspace(start: float, stop: float, num: int) -> list:
|
||||
return [start]
|
||||
step = (stop - start) / (num - 1)
|
||||
return [start + i * step for i in range(num)]
|
||||
|
||||
|
||||
def estimate_hawkes(arrivals: list[int], max_lag: int = 30) -> dict:
|
||||
"""Estima parámetros de un proceso Hawkes desde autocorrelación de arrivals.
|
||||
|
||||
Ajusta exponencial a*exp(-b*lag) sobre la ACF.
|
||||
Retorna dict con alpha, beta, branching_ratio, acf.
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.optimize import curve_fit
|
||||
|
||||
arr = np.array(arrivals, dtype=float)
|
||||
mean_a = np.mean(arr)
|
||||
var_a = np.var(arr)
|
||||
if var_a == 0:
|
||||
return {'alpha': 0.0, 'beta': 1.0, 'branching_ratio': 0.0, 'acf': [1.0]}
|
||||
|
||||
acf = [1.0] + [
|
||||
float(np.mean((arr[lag:] - mean_a) * (arr[:-lag] - mean_a)) / var_a)
|
||||
for lag in range(1, max_lag)
|
||||
]
|
||||
|
||||
lags = np.arange(1, max_lag)
|
||||
acf_vals = np.array(acf[1:])
|
||||
|
||||
if acf_vals[0] <= 0.01:
|
||||
return {'alpha': 0.0, 'beta': 1.0, 'branching_ratio': 0.0, 'acf': acf}
|
||||
|
||||
exp_decay = lambda x, a, b: a * np.exp(-b * x)
|
||||
try:
|
||||
popt, _ = curve_fit(exp_decay, lags, acf_vals, p0=[0.5, 0.5], maxfev=5000)
|
||||
alpha_est, beta_est = abs(popt[0]), abs(popt[1])
|
||||
except RuntimeError:
|
||||
alpha_est, beta_est = 0.0, 1.0
|
||||
|
||||
branching = alpha_est / beta_est if beta_est > 0 else 0.0
|
||||
return {
|
||||
'alpha': round(alpha_est, 4),
|
||||
'beta': round(beta_est, 4),
|
||||
'branching_ratio': round(branching, 4),
|
||||
'acf': acf,
|
||||
}
|
||||
|
||||
|
||||
def estimate_pareto_alpha(values: list[float], x_min_percentile: float = 90.0) -> dict:
|
||||
"""Estima el exponente alpha de una distribución Pareto via MLE.
|
||||
|
||||
α = n / Σ ln(xi / x_min) donde x_min es el percentil indicado.
|
||||
Alpha bajo = cola más pesada = más valores extremos.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
arr = np.array([v for v in values if v > 0], dtype=float)
|
||||
if len(arr) < 10:
|
||||
return {'alpha': 0.0, 'x_min': 0.0, 'n_tail': 0}
|
||||
|
||||
x_min = float(np.percentile(arr, x_min_percentile))
|
||||
tail = arr[arr >= x_min]
|
||||
|
||||
if len(tail) < 2 or x_min <= 0:
|
||||
return {'alpha': 0.0, 'x_min': x_min, 'n_tail': len(tail)}
|
||||
|
||||
alpha = float(len(tail) / np.sum(np.log(tail / x_min)))
|
||||
|
||||
return {
|
||||
'alpha': round(alpha, 4),
|
||||
'x_min': round(x_min, 6),
|
||||
'n_tail': len(tail),
|
||||
}
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
---
|
||||
name: deduplicate_entities
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def deduplicate_entities(candidates: list[EntityCandidate], name_threshold: float = 0.85, same_type_only: bool = True) -> DeduplicationResult"
|
||||
description: "Agrupa entidades candidatas que refieren a la misma entidad real usando fuzzy matching de nombres (Levenshtein + Jaccard) y Union-Find para clusters transitivos. Retorna entidades mergeadas con mapas de resolucion de IDs y log de merges."
|
||||
tags: [deduplication, entity, fuzzy, levenshtein, jaccard, union-find, knowledge-graph, nlp, fuzzygraph, datascience]
|
||||
uses_functions:
|
||||
- normalize_entity_name_py_core
|
||||
- merge_entity_attributes_py_core
|
||||
uses_types:
|
||||
- entity_candidate_py_datascience
|
||||
- deduplication_result_py_datascience
|
||||
returns: [deduplication_result_py_datascience]
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports:
|
||||
- uuid
|
||||
tested: true
|
||||
tests:
|
||||
- "John Smith y Smith, John se mergean"
|
||||
- "Google y Google LLC se mergean"
|
||||
- "192.168.1.1 y 192.168.1.1 se mergean por matching exacto"
|
||||
- "John Smith (person) y John Smith (organization) NO se mergean"
|
||||
- "Clusters transitivos: A~B, B~C -> {A, B, C} en un solo cluster"
|
||||
- "Entidades sin duplicados pasan sin modificacion"
|
||||
- "Confidence toma el max del cluster; atributos se fusionan"
|
||||
- "Lista vacia retorna resultado vacio"
|
||||
- "name_to_id contiene todos los nombres originales del cluster"
|
||||
test_file_path: "python/functions/datascience/deduplicate_entities_test.py"
|
||||
file_path: "python/functions/datascience/deduplicate_entities.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
from python.functions.datascience.deduplicate_entities import deduplicate_entities
|
||||
|
||||
candidates = [
|
||||
EntityCandidate(name="John Smith", type_ref="person", confidence=0.9),
|
||||
EntityCandidate(name="Smith, John", type_ref="person", confidence=0.85),
|
||||
EntityCandidate(name="Google", type_ref="organization", confidence=0.95),
|
||||
EntityCandidate(name="Google LLC", type_ref="organization", confidence=0.88),
|
||||
]
|
||||
|
||||
result = deduplicate_entities(candidates, name_threshold=0.85, same_type_only=True)
|
||||
# result.total_before = 4
|
||||
# result.total_after = 2
|
||||
# result.merge_log = [
|
||||
# {"canonical": "John Smith", "merged": ["Smith, John"], "score": 0.91, "reason": "fuzzy_name"},
|
||||
# {"canonical": "Google", "merged": ["Google LLC"], "score": 0.89, "reason": "fuzzy_name"},
|
||||
# ]
|
||||
```
|
||||
|
||||
## Algoritmo
|
||||
|
||||
1. **Normalizar nombres** usando `normalize_entity_name()` sobre cada candidato segun su `type_ref`
|
||||
2. **Comparacion pairwise** dentro del mismo tipo (si `same_type_only=True`):
|
||||
- Para tipos tecnicos (ip, email, domain, crypto_wallet, phone): matching exacto normalizado
|
||||
- Para el resto: `score = max(levenshtein_sim, jaccard_sim)` + bonus por contencion (+0.3) y acronimos (+0.3)
|
||||
3. **Union-Find** para clusters transitivos: si A~B y B~C, entonces {A, B, C} forman un cluster
|
||||
4. **Merge por cluster:**
|
||||
- Nombre canonico: candidato con mayor `confidence`
|
||||
- Atributos: `merge_entity_attributes()` sobre todos los candidatos del cluster
|
||||
- Confidence: `max` del cluster
|
||||
- Source chunks: union de todos los candidatos
|
||||
- `merged_from`: union de todos los nombres originales
|
||||
|
||||
## Heuristicas de similitud de nombres
|
||||
|
||||
| Heuristica | Efecto |
|
||||
|---|---|
|
||||
| Levenshtein | `1 - (edit_distance / max_len)` |
|
||||
| Jaccard sobre tokens | `\|A ∩ B\| / \|A ∪ B\|` |
|
||||
| Score base | `max(lev_sim, jaccard_sim)` |
|
||||
| Contencion (a in b o b in a) | `+0.3` hasta max 1.0 |
|
||||
| Acronimo ("FBI" ~ "Federal Bureau of Investigation") | `+0.3` hasta max 1.0 |
|
||||
| Tipos exactos (ip/email/domain) | solo matching exacto, ignora umbral |
|
||||
|
||||
## Complejidad
|
||||
|
||||
- Pairwise: O(N^2) — aceptable para <1000 entidades (tipico por documento)
|
||||
- Union-Find con path compression: O(α(N)) amortizado por operacion
|
||||
- Para escalar a >1000: pre-filtrar por primera letra o n-gram index antes de comparar
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura. Implementa Levenshtein y Jaccard internamente para evitar dependencias externas a este modulo. Las funciones del registry `levenshtein_distance_py_cybersecurity` y `jaccard_similarity_py_cybersecurity` son equivalentes pero requieren imports adicionales — la implementacion inline mantiene la funcion sin dependencias de stdlib.
|
||||
|
||||
El `name_to_id` del resultado es el mapa de resolucion principal para la fase de deduplicacion de relaciones: permite resolver cualquier variante de nombre de una entidad a su ID canonico.
|
||||
@@ -0,0 +1,283 @@
|
||||
"""Deduplica entidades candidatas usando fuzzy matching de nombres."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import uuid
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
from python.types.datascience.deduplication_result import DeduplicationResult
|
||||
from python.functions.core.normalize_entity_name import normalize_entity_name
|
||||
from python.functions.core.merge_entity_attributes import merge_entity_attributes
|
||||
|
||||
|
||||
# ── Similitud helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _levenshtein(a: str, b: str) -> int:
|
||||
"""Distancia de edicion Levenshtein entre dos strings."""
|
||||
if a == b:
|
||||
return 0
|
||||
if not a:
|
||||
return len(b)
|
||||
if not b:
|
||||
return len(a)
|
||||
prev = list(range(len(b) + 1))
|
||||
for i, ca in enumerate(a, 1):
|
||||
curr = [i]
|
||||
for j, cb in enumerate(b, 1):
|
||||
cost = 0 if ca == cb else 1
|
||||
curr.append(min(prev[j] + 1, curr[j - 1] + 1, prev[j - 1] + cost))
|
||||
prev = curr
|
||||
return prev[-1]
|
||||
|
||||
|
||||
def _jaccard(tokens_a: list[str], tokens_b: list[str]) -> float:
|
||||
"""Similitud de Jaccard entre dos conjuntos de tokens."""
|
||||
set_a = set(tokens_a)
|
||||
set_b = set(tokens_b)
|
||||
if not set_a and not set_b:
|
||||
return 1.0
|
||||
inter = len(set_a & set_b)
|
||||
union = len(set_a | set_b)
|
||||
return inter / union if union else 0.0
|
||||
|
||||
|
||||
def _name_similarity(a: str, b: str) -> float:
|
||||
"""Score de similitud entre dos nombres normalizados.
|
||||
|
||||
Combina similitud de Levenshtein y Jaccard sobre tokens.
|
||||
Aplica bonus de contencion (+0.3) y deteccion de acronimos.
|
||||
"""
|
||||
if a == b:
|
||||
return 1.0
|
||||
|
||||
# Similitud Levenshtein
|
||||
max_len = max(len(a), len(b))
|
||||
lev_sim = 1.0 - (_levenshtein(a, b) / max_len) if max_len else 1.0
|
||||
|
||||
# Similitud Jaccard sobre tokens
|
||||
tokens_a = a.split()
|
||||
tokens_b = b.split()
|
||||
jac_sim = _jaccard(tokens_a, tokens_b)
|
||||
|
||||
score = max(lev_sim, jac_sim)
|
||||
|
||||
# Bonus de contencion: un nombre contiene al otro
|
||||
if a in b or b in a:
|
||||
score = min(1.0, score + 0.3)
|
||||
|
||||
# Deteccion de acronimo: "FBI" ~ "Federal Bureau of Investigation"
|
||||
if _is_acronym_of(a, tokens_b) or _is_acronym_of(b, tokens_a):
|
||||
score = min(1.0, score + 0.3)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _is_acronym_of(candidate: str, tokens: list[str]) -> bool:
|
||||
"""Comprueba si candidate es un acronimo formado por las iniciales de tokens."""
|
||||
if not candidate or not tokens:
|
||||
return False
|
||||
initials = "".join(t[0] for t in tokens if t).upper()
|
||||
return candidate.upper() == initials
|
||||
|
||||
|
||||
_EXACT_TYPES = {"ip", "email", "domain", "crypto_wallet", "phone"}
|
||||
|
||||
|
||||
def _is_exact_type(entity_type: str) -> bool:
|
||||
"""Tipos tecnicos donde solo se acepta matching exacto."""
|
||||
return entity_type.lower() in _EXACT_TYPES
|
||||
|
||||
|
||||
# ── Union-Find ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class _UnionFind:
|
||||
def __init__(self, n: int) -> None:
|
||||
self._parent = list(range(n))
|
||||
self._rank = [0] * n
|
||||
|
||||
def find(self, x: int) -> int:
|
||||
while self._parent[x] != x:
|
||||
self._parent[x] = self._parent[self._parent[x]]
|
||||
x = self._parent[x]
|
||||
return x
|
||||
|
||||
def union(self, x: int, y: int) -> None:
|
||||
rx, ry = self.find(x), self.find(y)
|
||||
if rx == ry:
|
||||
return
|
||||
if self._rank[rx] < self._rank[ry]:
|
||||
rx, ry = ry, rx
|
||||
self._parent[ry] = rx
|
||||
if self._rank[rx] == self._rank[ry]:
|
||||
self._rank[rx] += 1
|
||||
|
||||
|
||||
# ── Implementacion principal ────────────────────────────────────────────────────
|
||||
|
||||
def deduplicate_entities(
|
||||
candidates: list[EntityCandidate],
|
||||
name_threshold: float = 0.85,
|
||||
same_type_only: bool = True,
|
||||
) -> DeduplicationResult:
|
||||
"""Agrupa entidades candidatas que refieren a la misma entidad real.
|
||||
|
||||
Usa fuzzy matching de nombres (Levenshtein + Jaccard) y Union-Find para
|
||||
detectar clusters transitivos. Por cada cluster genera una entidad canonica
|
||||
mergeando atributos de todos sus miembros.
|
||||
|
||||
Para tipos tecnicos (ip, email, domain, crypto_wallet, phone) solo se
|
||||
acepta matching exacto normalizado, ignorando el umbral de nombre.
|
||||
|
||||
Args:
|
||||
candidates: lista de EntityCandidate a deduplicar.
|
||||
name_threshold: score minimo para considerar dos nombres iguales (0-1).
|
||||
same_type_only: si True, solo compara entidades del mismo type_ref.
|
||||
|
||||
Returns:
|
||||
DeduplicationResult con entidades deduplicadas, mapas de resolucion
|
||||
e historial de merges.
|
||||
"""
|
||||
if not candidates:
|
||||
return DeduplicationResult(
|
||||
entities=[],
|
||||
entity_id_map={},
|
||||
name_to_id={},
|
||||
merge_log=[],
|
||||
total_before=0,
|
||||
total_after=0,
|
||||
)
|
||||
|
||||
n = len(candidates)
|
||||
|
||||
# Paso 1: normalizar nombres
|
||||
normalized: list[str] = []
|
||||
for c in candidates:
|
||||
norm = normalize_entity_name(c.name, c.type_ref)
|
||||
normalized.append(norm)
|
||||
|
||||
# Paso 2: Union-Find sobre todos los indices
|
||||
uf = _UnionFind(n)
|
||||
|
||||
# Paso 3: comparacion pairwise (con agrupacion por tipo si same_type_only)
|
||||
merge_pairs: list[tuple[int, int, float]] = []
|
||||
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if same_type_only and candidates[i].type_ref != candidates[j].type_ref:
|
||||
continue
|
||||
|
||||
ni, nj = normalized[i], normalized[j]
|
||||
et = candidates[i].type_ref.lower()
|
||||
|
||||
if _is_exact_type(et):
|
||||
if ni == nj:
|
||||
uf.union(i, j)
|
||||
merge_pairs.append((i, j, 1.0))
|
||||
continue
|
||||
|
||||
score = _name_similarity(ni, nj)
|
||||
if score >= name_threshold:
|
||||
uf.union(i, j)
|
||||
merge_pairs.append((i, j, score))
|
||||
|
||||
# Paso 4: agrupar indices por raiz del Union-Find
|
||||
clusters: dict[int, list[int]] = {}
|
||||
for i in range(n):
|
||||
root = uf.find(i)
|
||||
clusters.setdefault(root, []).append(i)
|
||||
|
||||
# Paso 5: merge por cluster
|
||||
merged_entities: list[EntityCandidate] = []
|
||||
entity_id_map: dict[str, str] = {}
|
||||
name_to_id: dict[str, str] = {}
|
||||
merge_log: list[dict] = []
|
||||
|
||||
# Pares mergeados para construir el log
|
||||
merged_pairs_by_root: dict[int, list[tuple[int, int, float]]] = {}
|
||||
for i, j, score in merge_pairs:
|
||||
root = uf.find(i)
|
||||
merged_pairs_by_root.setdefault(root, []).append((i, j, score))
|
||||
|
||||
for root, indices in clusters.items():
|
||||
cluster_candidates = [candidates[idx] for idx in indices]
|
||||
|
||||
if len(cluster_candidates) == 1:
|
||||
c = cluster_candidates[0]
|
||||
canonical_name = c.name
|
||||
canonical_norm = normalized[indices[0]]
|
||||
merged_attrs = c.attributes
|
||||
merged_confidence = c.confidence
|
||||
merged_chunks = list(c.source_chunk_indices)
|
||||
merged_from = list(c.merged_from) if c.merged_from else [c.name]
|
||||
else:
|
||||
# Candidato con mayor confidence es el canonico
|
||||
best = max(cluster_candidates, key=lambda c: c.confidence)
|
||||
canonical_name = best.name
|
||||
canonical_norm = normalize_entity_name(best.name, best.type_ref)
|
||||
|
||||
merged_attrs = merge_entity_attributes(
|
||||
[c.attributes for c in cluster_candidates]
|
||||
)
|
||||
merged_confidence = max(c.confidence for c in cluster_candidates)
|
||||
|
||||
merged_chunks: list[int] = []
|
||||
seen_chunks: set[int] = set()
|
||||
for c in cluster_candidates:
|
||||
for idx in c.source_chunk_indices:
|
||||
if idx not in seen_chunks:
|
||||
merged_chunks.append(idx)
|
||||
seen_chunks.add(idx)
|
||||
|
||||
merged_from: list[str] = []
|
||||
seen_names: set[str] = set()
|
||||
for c in cluster_candidates:
|
||||
names_to_add = c.merged_from if c.merged_from else [c.name]
|
||||
for nm in names_to_add:
|
||||
if nm not in seen_names:
|
||||
merged_from.append(nm)
|
||||
seen_names.add(nm)
|
||||
|
||||
# Log de merge
|
||||
other_names = [c.name for c in cluster_candidates if c is not best]
|
||||
pairs = merged_pairs_by_root.get(root, [])
|
||||
max_score = max((s for _, _, s in pairs), default=1.0)
|
||||
merge_log.append(
|
||||
{
|
||||
"canonical": canonical_name,
|
||||
"merged": other_names,
|
||||
"score": round(max_score, 4),
|
||||
"reason": "fuzzy_name",
|
||||
}
|
||||
)
|
||||
|
||||
ent_id = str(uuid.uuid4())
|
||||
entity = EntityCandidate(
|
||||
name=canonical_name,
|
||||
name_normalized=canonical_norm,
|
||||
type_ref=cluster_candidates[0].type_ref,
|
||||
type_label=cluster_candidates[0].type_label,
|
||||
attributes=merged_attrs,
|
||||
confidence=merged_confidence,
|
||||
source_chunk_indices=merged_chunks,
|
||||
merged_from=merged_from,
|
||||
)
|
||||
merged_entities.append(entity)
|
||||
|
||||
# Poblar mapas de resolucion
|
||||
entity_id_map[canonical_norm] = ent_id
|
||||
for orig_name in merged_from:
|
||||
name_to_id[orig_name] = ent_id
|
||||
name_to_id[canonical_norm] = ent_id
|
||||
|
||||
return DeduplicationResult(
|
||||
entities=merged_entities,
|
||||
entity_id_map=entity_id_map,
|
||||
name_to_id=name_to_id,
|
||||
merge_log=merge_log,
|
||||
total_before=n,
|
||||
total_after=len(merged_entities),
|
||||
)
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Tests para deduplicate_entities."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
from python.functions.datascience.deduplicate_entities import deduplicate_entities
|
||||
|
||||
|
||||
def _make(name: str, type_ref: str = "person", confidence: float = 0.9, **attrs) -> EntityCandidate:
|
||||
return EntityCandidate(
|
||||
name=name,
|
||||
type_ref=type_ref,
|
||||
type_label=type_ref.capitalize(),
|
||||
attributes=attrs,
|
||||
confidence=confidence,
|
||||
source_chunk_indices=[0],
|
||||
)
|
||||
|
||||
|
||||
def test_john_smith_y_smith_john_merge():
|
||||
"""John Smith y Smith, John se mergean."""
|
||||
a = _make("John Smith", type_ref="person")
|
||||
b = _make("Smith, John", type_ref="person")
|
||||
result = deduplicate_entities([a, b])
|
||||
assert result.total_before == 2
|
||||
assert result.total_after == 1
|
||||
assert len(result.entities) == 1
|
||||
assert len(result.merge_log) == 1
|
||||
|
||||
|
||||
def test_google_y_google_llc_merge():
|
||||
"""Google y Google LLC se mergean."""
|
||||
a = _make("Google", type_ref="organization")
|
||||
b = _make("Google LLC", type_ref="organization")
|
||||
result = deduplicate_entities([a, b])
|
||||
assert result.total_after == 1
|
||||
assert len(result.entities) == 1
|
||||
|
||||
|
||||
def test_ip_matching_exacto():
|
||||
"""192.168.1.1 y 192.168.1.1 se mergean por matching exacto."""
|
||||
a = _make("192.168.1.1", type_ref="ip", confidence=0.8)
|
||||
b = _make("192.168.1.1", type_ref="ip", confidence=0.9)
|
||||
result = deduplicate_entities([a, b])
|
||||
assert result.total_after == 1
|
||||
|
||||
|
||||
def test_same_name_different_type_no_merge():
|
||||
"""John Smith (person) y John Smith (organization) NO se mergean."""
|
||||
a = _make("John Smith", type_ref="person")
|
||||
b = _make("John Smith", type_ref="organization")
|
||||
result = deduplicate_entities([a, b], same_type_only=True)
|
||||
assert result.total_after == 2
|
||||
|
||||
|
||||
def test_clusters_transitivos():
|
||||
"""Clusters transitivos: A~B, B~C -> {A, B, C} en un solo cluster."""
|
||||
a = _make("Alice Johnson", type_ref="person")
|
||||
b = _make("Alice Johnso", type_ref="person") # muy similar a A
|
||||
c = _make("Alice Johns", type_ref="person") # muy similar a B
|
||||
result = deduplicate_entities([a, b, c], name_threshold=0.80)
|
||||
assert result.total_after == 1
|
||||
|
||||
|
||||
def test_sin_duplicados_sin_cambios():
|
||||
"""Entidades sin duplicados pasan sin modificacion."""
|
||||
a = _make("Alice Smith", type_ref="person")
|
||||
b = _make("Bob Jones", type_ref="person")
|
||||
c = _make("Charlie Brown", type_ref="person")
|
||||
result = deduplicate_entities([a, b, c])
|
||||
assert result.total_before == 3
|
||||
assert result.total_after == 3
|
||||
assert len(result.merge_log) == 0
|
||||
|
||||
|
||||
def test_confidence_y_atributos_merge_correctos():
|
||||
"""Confidence toma el max del cluster; atributos se fusionan."""
|
||||
a = _make("John Smith", type_ref="person", confidence=0.7, role="CEO")
|
||||
b = _make("Smith, John", type_ref="person", confidence=0.95, company="Acme")
|
||||
result = deduplicate_entities([a, b])
|
||||
assert result.total_after == 1
|
||||
entity = result.entities[0]
|
||||
# confidence = max(0.7, 0.95)
|
||||
assert entity.confidence == 0.95
|
||||
# atributos de ambos candidatos presentes
|
||||
assert "role" in entity.attributes
|
||||
assert "company" in entity.attributes
|
||||
|
||||
|
||||
def test_lista_vacia():
|
||||
"""Lista vacia retorna resultado vacio."""
|
||||
result = deduplicate_entities([])
|
||||
assert result.total_before == 0
|
||||
assert result.total_after == 0
|
||||
assert result.entities == []
|
||||
assert result.merge_log == []
|
||||
|
||||
|
||||
def test_name_to_id_resolucion():
|
||||
"""name_to_id contiene todos los nombres originales del cluster."""
|
||||
a = _make("John Smith", type_ref="person")
|
||||
b = _make("Smith, John", type_ref="person")
|
||||
result = deduplicate_entities([a, b])
|
||||
# Ambos nombres deben apuntar al mismo ID
|
||||
ids = list(result.entity_id_map.values())
|
||||
assert len(ids) == 1
|
||||
ent_id = ids[0]
|
||||
# name_to_id debe tener entradas para los nombres originales
|
||||
assert any(v == ent_id for v in result.name_to_id.values())
|
||||
assert len(result.name_to_id) >= 2
|
||||
@@ -0,0 +1,81 @@
|
||||
---
|
||||
name: deduplicate_relations
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def deduplicate_relations(relations: list[RelationCandidate], entity_id_map: dict[str, str]) -> list[RelationCandidate]"
|
||||
description: "Deduplica relaciones candidatas resolviendo from_name/to_name a entity IDs finales via entity_id_map. Descarta self-loops y relaciones sin match. Mergea duplicados (mismo from_id, to_id, relation_type) concatenando descripciones unicas y tomando max confidence."
|
||||
tags: [datascience, extraction, knowledge-graph, nlp, deduplication, fuzzy-match, fuzzygraph]
|
||||
uses_functions:
|
||||
- levenshtein_distance_py_cybersecurity
|
||||
uses_types:
|
||||
- relation_candidate_py_datascience
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "dos relaciones identicas se colapsan en una"
|
||||
- "relacion con nombre mergeado se resuelve al id correcto"
|
||||
- "self loop se descarta"
|
||||
- "nombre no mapeado sin fuzzy match se descarta"
|
||||
- "relaciones distintas se mantienen"
|
||||
- "merge descripcion concatena unicas"
|
||||
- "lista vacia retorna lista vacia"
|
||||
- "fuzzy match resuelve nombre cercano"
|
||||
test_file_path: "python/functions/datascience/deduplicate_relations_test.py"
|
||||
file_path: "python/functions/datascience/deduplicate_relations.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from python.types.datascience.relation_candidate import RelationCandidate
|
||||
from python.functions.datascience.deduplicate_relations import deduplicate_relations
|
||||
|
||||
# entity_id_map producido por deduplicate_entities
|
||||
entity_id_map = {
|
||||
"john smith": "entity_001",
|
||||
"smith, john": "entity_001", # alias mergeado
|
||||
"acme corp": "entity_002",
|
||||
}
|
||||
|
||||
relations = [
|
||||
RelationCandidate(from_name="John Smith", to_name="Acme Corp",
|
||||
relation_type="works_at", description="John es CEO",
|
||||
confidence=0.9, source_chunk_index=0),
|
||||
RelationCandidate(from_name="Smith, John", to_name="Acme Corp",
|
||||
relation_type="works_at", description="CEO de Acme",
|
||||
confidence=0.7, source_chunk_index=2),
|
||||
]
|
||||
|
||||
result = deduplicate_relations(relations, entity_id_map)
|
||||
# → 1 RelationCandidate con from_id="entity_001", to_id="entity_002",
|
||||
# confidence=0.9, description="John es CEO; CEO de Acme"
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
La funcion es pura: no hace I/O, no tiene efectos secundarios. El logging es
|
||||
de nivel DEBUG/WARNING — en produccion configurar el logger de la aplicacion.
|
||||
|
||||
**Resolucion de nombres:**
|
||||
- Lookup exacto primero (lowercase strip del nombre contra las claves del mapa).
|
||||
- Si no hay match exacto, fuzzy match con Levenshtein (threshold=3 ediciones).
|
||||
- Si sigue sin match, la relacion se descarta con `logger.warning`.
|
||||
|
||||
**Self-loops:** relaciones donde `from_id == to_id` siempre se descartan.
|
||||
|
||||
**Merge:** cuando varias relaciones comparten `(from_id, to_id, relation_type)`:
|
||||
- `confidence`: max del grupo.
|
||||
- `description`: union de descripciones unicas (no duplicadas), separadas por `'; '`.
|
||||
- `from_name` / `to_name` / `source_chunk_index`: del primer candidato del grupo.
|
||||
|
||||
**Integracion con fuzzygraph:**
|
||||
Esta funcion es el paso 4 del pipeline de extraccion. Recibe el output de
|
||||
`extract_relations_llm` (relaciones crudas con nombres de texto) y el
|
||||
`entity_id_map` producido por `deduplicate_entities`. Produce la lista final
|
||||
de relaciones para `ExtractionResult`.
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Deduplica RelationCandidate resolviendo nombres a IDs y colapsando duplicados."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Importar levenshtein_distance desde cybersecurity ---
|
||||
# Soporta dos contextos:
|
||||
# 1. Ejecutado desde python/functions/datascience/ (pytest local)
|
||||
# 2. Ejecutado desde la raiz del registry (fn run)
|
||||
def _levenshtein_distance(a: str, b: str) -> int:
|
||||
"""Calcula la distancia de edicion de Levenshtein entre dos strings."""
|
||||
if len(a) < len(b):
|
||||
return _levenshtein_distance(b, a)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
prev_row = list(range(len(b) + 1))
|
||||
for i, ca in enumerate(a):
|
||||
curr_row = [i + 1]
|
||||
for j, cb in enumerate(b):
|
||||
cost = 0 if ca == cb else 1
|
||||
curr_row.append(
|
||||
min(curr_row[j] + 1, prev_row[j + 1] + 1, prev_row[j] + cost)
|
||||
)
|
||||
prev_row = curr_row
|
||||
return prev_row[-1]
|
||||
|
||||
|
||||
try:
|
||||
_here = os.path.dirname(os.path.abspath(__file__))
|
||||
_cyber_path = os.path.join(_here, "..", "cybersecurity")
|
||||
if _cyber_path not in sys.path:
|
||||
sys.path.insert(0, _cyber_path)
|
||||
from cybersecurity import levenshtein_distance as _lev
|
||||
except ImportError:
|
||||
_lev = None # type: ignore
|
||||
|
||||
levenshtein_distance = _lev if _lev is not None else _levenshtein_distance
|
||||
|
||||
|
||||
def _fuzzy_resolve(name: str, entity_id_map: dict[str, str], threshold: int = 3) -> str:
|
||||
"""Intenta resolver un nombre contra las claves del mapa por fuzzy match.
|
||||
|
||||
Recorre todas las claves de entity_id_map y busca la mas cercana segun
|
||||
distancia de Levenshtein. Retorna el entity_id si la distancia es <=
|
||||
threshold, o '' si no hay match aceptable.
|
||||
|
||||
Args:
|
||||
name: nombre a resolver (ya en lowercase strip).
|
||||
entity_id_map: mapa nombre_normalizado -> entity_id.
|
||||
threshold: distancia maxima de edicion para considerar match (default 3).
|
||||
|
||||
Returns:
|
||||
entity_id del mejor match o '' si no hay match.
|
||||
"""
|
||||
best_id = ""
|
||||
best_dist = threshold + 1
|
||||
for key, entity_id in entity_id_map.items():
|
||||
dist = levenshtein_distance(name, key)
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best_id = entity_id
|
||||
return best_id if best_dist <= threshold else ""
|
||||
|
||||
|
||||
def deduplicate_relations(
|
||||
relations: list,
|
||||
entity_id_map: dict[str, str],
|
||||
) -> list:
|
||||
"""Deduplica relaciones candidatas resolviendo nombres a IDs de entidad finales.
|
||||
|
||||
Algoritmo:
|
||||
1. Para cada RelationCandidate, intentar resolver from_name y to_name al
|
||||
entity_id via entity_id_map (lookup exacto primero, ignorando mayusculas).
|
||||
Si no hay match exacto, intentar fuzzy match con levenshtein_distance.
|
||||
Si sigue sin match, descartar la relacion con warning.
|
||||
2. Descartar self-loops (from_id == to_id).
|
||||
3. Deduplicar por (from_id, to_id, relation_type):
|
||||
- description: concatenar descripciones unicas separadas por '; '
|
||||
- confidence: max del grupo
|
||||
4. Retornar lista limpia de RelationCandidate con from_id y to_id resueltos.
|
||||
|
||||
Args:
|
||||
relations: lista de RelationCandidate con from_name/to_name originales.
|
||||
entity_id_map: mapa nombre_normalizado -> entity_id (output de
|
||||
deduplicate_entities). Permite resolver nombres que fueron mergeados.
|
||||
|
||||
Returns:
|
||||
Lista deduplicada de RelationCandidate con from_id y to_id resueltos.
|
||||
"""
|
||||
# Importar tipo — funciona tanto desde datascience/ como desde raiz del registry
|
||||
try:
|
||||
_types_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..", "..", "..", "python", "types", "datascience",
|
||||
)
|
||||
if _types_path not in sys.path:
|
||||
sys.path.insert(0, _types_path)
|
||||
from relation_candidate import RelationCandidate
|
||||
except ImportError:
|
||||
from python.types.datascience.relation_candidate import RelationCandidate # type: ignore
|
||||
|
||||
resolved: list = []
|
||||
|
||||
for rel in relations:
|
||||
# --- Resolver from_name ---
|
||||
from_key = rel.from_name.lower().strip()
|
||||
from_id = entity_id_map.get(from_key, "")
|
||||
if not from_id:
|
||||
from_id = _fuzzy_resolve(from_key, entity_id_map)
|
||||
if not from_id:
|
||||
logger.warning(
|
||||
"deduplicate_relations: no se pudo resolver from_name=%r — descartando",
|
||||
rel.from_name,
|
||||
)
|
||||
continue
|
||||
|
||||
# --- Resolver to_name ---
|
||||
to_key = rel.to_name.lower().strip()
|
||||
to_id = entity_id_map.get(to_key, "")
|
||||
if not to_id:
|
||||
to_id = _fuzzy_resolve(to_key, entity_id_map)
|
||||
if not to_id:
|
||||
logger.warning(
|
||||
"deduplicate_relations: no se pudo resolver to_name=%r — descartando",
|
||||
rel.to_name,
|
||||
)
|
||||
continue
|
||||
|
||||
# --- Descartar self-loops ---
|
||||
if from_id == to_id:
|
||||
logger.debug(
|
||||
"deduplicate_relations: self-loop descartado (from=%r, to=%r, type=%r)",
|
||||
rel.from_name,
|
||||
rel.to_name,
|
||||
rel.relation_type,
|
||||
)
|
||||
continue
|
||||
|
||||
resolved.append(
|
||||
RelationCandidate(
|
||||
from_name=rel.from_name,
|
||||
to_name=rel.to_name,
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
relation_type=rel.relation_type,
|
||||
description=rel.description,
|
||||
confidence=rel.confidence,
|
||||
source_chunk_index=rel.source_chunk_index,
|
||||
)
|
||||
)
|
||||
|
||||
# --- Deduplicar por (from_id, to_id, relation_type) ---
|
||||
groups: dict[tuple, list] = {}
|
||||
for rel in resolved:
|
||||
key = (rel.from_id, rel.to_id, rel.relation_type)
|
||||
groups.setdefault(key, []).append(rel)
|
||||
|
||||
result: list = []
|
||||
for (from_id, to_id, rel_type), group in groups.items():
|
||||
if len(group) == 1:
|
||||
result.append(group[0])
|
||||
continue
|
||||
|
||||
# Mergear: max confidence + union de descripciones unicas
|
||||
best_confidence = max(r.confidence for r in group)
|
||||
seen_desc: set[str] = set()
|
||||
descriptions: list[str] = []
|
||||
for r in group:
|
||||
if r.description and r.description not in seen_desc:
|
||||
descriptions.append(r.description)
|
||||
seen_desc.add(r.description)
|
||||
|
||||
result.append(
|
||||
RelationCandidate(
|
||||
from_name=group[0].from_name,
|
||||
to_name=group[0].to_name,
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
relation_type=rel_type,
|
||||
description="; ".join(descriptions),
|
||||
confidence=best_confidence,
|
||||
source_chunk_index=group[0].source_chunk_index,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Tests para deduplicate_relations."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Permitir importar RelationCandidate desde python/types/datascience/
|
||||
_here = os.path.dirname(os.path.abspath(__file__))
|
||||
_types_path = os.path.join(_here, "..", "..", "..", "python", "types", "datascience")
|
||||
if _types_path not in sys.path:
|
||||
sys.path.insert(0, _types_path)
|
||||
|
||||
from relation_candidate import RelationCandidate
|
||||
from deduplicate_relations import deduplicate_relations
|
||||
|
||||
|
||||
def _make_rel(
|
||||
from_name: str,
|
||||
to_name: str,
|
||||
relation_type: str = "works_at",
|
||||
description: str = "",
|
||||
confidence: float = 0.8,
|
||||
source_chunk_index: int = 0,
|
||||
) -> RelationCandidate:
|
||||
return RelationCandidate(
|
||||
from_name=from_name,
|
||||
to_name=to_name,
|
||||
relation_type=relation_type,
|
||||
description=description,
|
||||
confidence=confidence,
|
||||
source_chunk_index=source_chunk_index,
|
||||
)
|
||||
|
||||
|
||||
# entity_id_map tipico: claves en lowercase normalizado
|
||||
_ENTITY_MAP: dict[str, str] = {
|
||||
"john smith": "entity_001",
|
||||
"acme corp": "entity_002",
|
||||
"jane doe": "entity_003",
|
||||
"google": "entity_004",
|
||||
}
|
||||
|
||||
|
||||
def test_dos_relaciones_identicas_se_colapsan_en_una():
|
||||
"""2 relaciones identicas (from, to, type) → 1."""
|
||||
rels = [
|
||||
_make_rel("John Smith", "Acme Corp", description="John es CEO", confidence=0.9),
|
||||
_make_rel("John Smith", "Acme Corp", description="John es CEO", confidence=0.7),
|
||||
]
|
||||
result = deduplicate_relations(rels, _ENTITY_MAP)
|
||||
assert len(result) == 1
|
||||
assert result[0].from_id == "entity_001"
|
||||
assert result[0].to_id == "entity_002"
|
||||
assert result[0].confidence == 0.9 # max
|
||||
|
||||
|
||||
def test_relacion_con_nombre_mergeado_se_resuelve_al_id_correcto():
|
||||
"""Relacion con nombre mergeado → se resuelve al ID correcto."""
|
||||
# entity_id_map incluye "smith, john" como alias de entity_001
|
||||
merged_map = {**_ENTITY_MAP, "smith, john": "entity_001"}
|
||||
rels = [_make_rel("Smith, John", "Acme Corp")]
|
||||
result = deduplicate_relations(rels, merged_map)
|
||||
assert len(result) == 1
|
||||
assert result[0].from_id == "entity_001"
|
||||
assert result[0].to_id == "entity_002"
|
||||
|
||||
|
||||
def test_self_loop_se_descarta():
|
||||
"""Self-loop (from_id == to_id) → descartado."""
|
||||
rels = [_make_rel("John Smith", "John Smith", relation_type="knows")]
|
||||
result = deduplicate_relations(rels, _ENTITY_MAP)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_nombre_no_mapeado_sin_fuzzy_match_se_descarta():
|
||||
"""Relacion con nombre no mapeado y sin fuzzy match → descartada."""
|
||||
rels = [_make_rel("Unknown Entity XYZ", "Acme Corp")]
|
||||
result = deduplicate_relations(rels, _ENTITY_MAP)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_relaciones_distintas_se_mantienen():
|
||||
"""Relaciones con (from, to, type) distintos → todas se mantienen."""
|
||||
rels = [
|
||||
_make_rel("John Smith", "Acme Corp", relation_type="works_at"),
|
||||
_make_rel("Jane Doe", "Acme Corp", relation_type="works_at"),
|
||||
_make_rel("John Smith", "Google", relation_type="invested_in"),
|
||||
]
|
||||
result = deduplicate_relations(rels, _ENTITY_MAP)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_merge_descripcion_concatena_unicas():
|
||||
"""Merge de relaciones: descripciones unicas se concatenan."""
|
||||
rels = [
|
||||
_make_rel("John Smith", "Acme Corp", description="John es CEO", confidence=0.9),
|
||||
_make_rel("John Smith", "Acme Corp", description="Acme fue fundada por John", confidence=0.7),
|
||||
_make_rel("John Smith", "Acme Corp", description="John es CEO", confidence=0.6),
|
||||
]
|
||||
result = deduplicate_relations(rels, _ENTITY_MAP)
|
||||
assert len(result) == 1
|
||||
assert "John es CEO" in result[0].description
|
||||
assert "Acme fue fundada por John" in result[0].description
|
||||
# La descripcion duplicada ("John es CEO") no aparece dos veces
|
||||
assert result[0].description.count("John es CEO") == 1
|
||||
assert result[0].confidence == 0.9
|
||||
|
||||
|
||||
def test_lista_vacia_retorna_lista_vacia():
|
||||
"""Lista vacia de relaciones → lista vacia."""
|
||||
result = deduplicate_relations([], _ENTITY_MAP)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fuzzy_match_resuelve_nombre_cercano():
|
||||
"""Nombre con typo pequeño → fuzzy match lo resuelve."""
|
||||
# "john smit" tiene distancia 1 de "john smith"
|
||||
rels = [_make_rel("John Smit", "Acme Corp")]
|
||||
result = deduplicate_relations(rels, _ENTITY_MAP)
|
||||
assert len(result) == 1
|
||||
assert result[0].from_id == "entity_001"
|
||||
@@ -0,0 +1,56 @@
|
||||
---
|
||||
name: detect_drift
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def detect_drift(history: list[dict], current: dict, fields: list[str], threshold: float = 2.0) -> list[dict]"
|
||||
description: "Detecta drift estadistico comparando metricas de la ejecucion actual contra el historial usando z-score. Si |z| > threshold, el campo ha drifteado. Util para monitorizar executions en operations.db."
|
||||
tags: [drift, statistics, z-score, monitoring, executions, operations, datascience]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [math]
|
||||
tested: true
|
||||
tests:
|
||||
- "campo con drift claro (z > threshold)"
|
||||
- "campo estable (z < threshold)"
|
||||
- "historial con un solo punto → std=0, no puede calcular → drifted=False con nota"
|
||||
- "historial vacio → todos drifted=False"
|
||||
- "threshold custom"
|
||||
test_file_path: "python/functions/datascience/detect_drift_test.py"
|
||||
file_path: "python/functions/datascience/detect_drift.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
history = [
|
||||
{"records_out": 100, "duration_ms": 500},
|
||||
{"records_out": 105, "duration_ms": 480},
|
||||
{"records_out": 98, "duration_ms": 510},
|
||||
]
|
||||
current = {"records_out": 50, "duration_ms": 2000}
|
||||
|
||||
results = detect_drift(history, current, ["records_out", "duration_ms"])
|
||||
# [
|
||||
# {"field": "records_out", "current": 50, "mean": 101.0, "std": 3.6, "z_score": -14.2, "drifted": True},
|
||||
# {"field": "duration_ms", "current": 2000, "mean": 496.7, "std": 15.3, "z_score": 98.3, "drifted": True},
|
||||
# ]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura. Solo stdlib (`math`).
|
||||
|
||||
El z-score usa desviacion estandar poblacional (dividir por N, no N-1) para ser consistente con historial de cualquier tamanio.
|
||||
|
||||
Casos especiales:
|
||||
- **Historial vacio**: z_score=0.0, drifted=False para todos los campos.
|
||||
- **Un solo punto en historial**: std=0.0, z_score=0.0, drifted=False. No hay suficiente historia para calcular variabilidad.
|
||||
- **Std=0 con N>=2**: todos los valores historicos identicos. z_score=0.0, drifted=False (cualquier desviacion seria tecnicamente infinita, pero se asume que el sistema es muy estable).
|
||||
|
||||
Pensado para el paso ANALIZAR del bucle reactivo: comparar `metrics` de la ejecucion actual con executions historicas de `operations.db`.
|
||||
@@ -0,0 +1,86 @@
|
||||
"""detect_drift — detecta drift estadistico por z-score comparando metricas contra historial."""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def detect_drift(
|
||||
history: list[dict],
|
||||
current: dict,
|
||||
fields: list[str],
|
||||
threshold: float = 2.0,
|
||||
) -> list[dict]:
|
||||
"""Detecta drift estadistico comparando metricas actuales contra el historial.
|
||||
|
||||
Usa z-score: si |z| > threshold, el campo ha drifteado. Pensado para
|
||||
comparar metrics de executions sucesivas en operations.db.
|
||||
|
||||
Args:
|
||||
history: Lista de dicts con metricas historicas. Cada dict puede
|
||||
contener cualquier combinacion de los campos indicados.
|
||||
current: Dict con las metricas de la ejecucion actual.
|
||||
fields: Lista de campos numericos a analizar.
|
||||
threshold: Umbral de z-score para considerar drift. Default 2.0.
|
||||
|
||||
Returns:
|
||||
Lista de dicts con: field, current, mean, std, z_score, drifted.
|
||||
Si el historial tiene 0 o 1 punto, z_score=0.0 y drifted=False
|
||||
porque no hay suficiente informacion estadistica.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for field in fields:
|
||||
values = [
|
||||
float(h[field])
|
||||
for h in history
|
||||
if field in h and h[field] is not None
|
||||
]
|
||||
|
||||
current_val = float(current.get(field, 0))
|
||||
|
||||
if len(values) == 0:
|
||||
results.append({
|
||||
"field": field,
|
||||
"current": current_val,
|
||||
"mean": 0.0,
|
||||
"std": 0.0,
|
||||
"z_score": 0.0,
|
||||
"drifted": False,
|
||||
})
|
||||
continue
|
||||
|
||||
n = len(values)
|
||||
mean = sum(values) / n
|
||||
|
||||
if n < 2:
|
||||
# Un solo punto: no hay std, no podemos calcular z-score
|
||||
results.append({
|
||||
"field": field,
|
||||
"current": current_val,
|
||||
"mean": mean,
|
||||
"std": 0.0,
|
||||
"z_score": 0.0,
|
||||
"drifted": False,
|
||||
})
|
||||
continue
|
||||
|
||||
variance = sum((v - mean) ** 2 for v in values) / n
|
||||
std = math.sqrt(variance)
|
||||
|
||||
if std == 0.0:
|
||||
# Todos los valores identicos: z_score indeterminado, no drift
|
||||
z_score = 0.0
|
||||
drifted = False
|
||||
else:
|
||||
z_score = (current_val - mean) / std
|
||||
drifted = abs(z_score) > threshold
|
||||
|
||||
results.append({
|
||||
"field": field,
|
||||
"current": current_val,
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
"z_score": z_score,
|
||||
"drifted": drifted,
|
||||
})
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Tests para detect_drift."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from detect_drift import detect_drift
|
||||
|
||||
|
||||
def test_campo_con_drift_claro_z_mayor_threshold():
|
||||
history = [
|
||||
{"records_out": 100},
|
||||
{"records_out": 105},
|
||||
{"records_out": 98},
|
||||
]
|
||||
current = {"records_out": 50}
|
||||
results = detect_drift(history, current, ["records_out"])
|
||||
assert len(results) == 1
|
||||
r = results[0]
|
||||
assert r["field"] == "records_out"
|
||||
assert r["current"] == 50.0
|
||||
assert r["drifted"] is True
|
||||
assert r["z_score"] < -2.0 # muy lejos de la media
|
||||
|
||||
|
||||
def test_campo_estable_z_menor_threshold():
|
||||
history = [
|
||||
{"val": 100.0},
|
||||
{"val": 102.0},
|
||||
{"val": 98.0},
|
||||
{"val": 101.0},
|
||||
]
|
||||
current = {"val": 100.5} # dentro del rango normal
|
||||
results = detect_drift(history, current, ["val"])
|
||||
assert len(results) == 1
|
||||
r = results[0]
|
||||
assert r["drifted"] is False
|
||||
assert abs(r["z_score"]) < 2.0
|
||||
|
||||
|
||||
def test_historial_con_un_solo_punto_std_0_drifted_False_con_nota():
|
||||
history = [{"val": 100.0}]
|
||||
current = {"val": 999.0}
|
||||
results = detect_drift(history, current, ["val"])
|
||||
assert len(results) == 1
|
||||
r = results[0]
|
||||
assert r["std"] == 0.0
|
||||
assert r["z_score"] == 0.0
|
||||
assert r["drifted"] is False
|
||||
assert r["mean"] == 100.0
|
||||
|
||||
|
||||
def test_historial_vacio_todos_drifted_False():
|
||||
history = []
|
||||
current = {"records_out": 50, "duration_ms": 2000}
|
||||
results = detect_drift(history, current, ["records_out", "duration_ms"])
|
||||
assert len(results) == 2
|
||||
for r in results:
|
||||
assert r["drifted"] is False
|
||||
assert r["z_score"] == 0.0
|
||||
assert r["mean"] == 0.0
|
||||
|
||||
|
||||
def test_threshold_custom():
|
||||
history = [
|
||||
{"val": 100.0},
|
||||
{"val": 100.0},
|
||||
{"val": 110.0},
|
||||
{"val": 90.0},
|
||||
]
|
||||
# std ~ 7.07, mean = 100
|
||||
current = {"val": 115.0} # z ~ 2.12
|
||||
|
||||
# threshold default 2.0 -> drifted
|
||||
results = detect_drift(history, current, ["val"], threshold=2.0)
|
||||
assert results[0]["drifted"] is True
|
||||
|
||||
# threshold 3.0 -> no drifted
|
||||
results2 = detect_drift(history, current, ["val"], threshold=3.0)
|
||||
assert results2[0]["drifted"] is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_campo_con_drift_claro_z_mayor_threshold()
|
||||
test_campo_estable_z_menor_threshold()
|
||||
test_historial_con_un_solo_punto_std_0_drifted_False_con_nota()
|
||||
test_historial_vacio_todos_drifted_False()
|
||||
test_threshold_custom()
|
||||
print("All tests passed.")
|
||||
@@ -0,0 +1,58 @@
|
||||
---
|
||||
name: diff_entities
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def diff_entities(before: list[dict], after: list[dict], key: str = 'id', ignore_fields: list[str] | None = None, compare_fields: list[str] | None = None) -> dict"
|
||||
description: "Compara dos snapshots de entities y devuelve diferencias campo a campo. Detecta añadidas, eliminadas, modificadas e inalteradas. Ignora created_at y updated_at por defecto."
|
||||
tags: [diff, entities, snapshot, operations, comparison, datascience]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "entity añadida"
|
||||
- "entity eliminada"
|
||||
- "entity modificada con detalle de campos"
|
||||
- "entities identicas → unchanged"
|
||||
- "ignore_fields funciona"
|
||||
- "compare_fields filtra correctamente"
|
||||
- "lista vacia vs lista con datos"
|
||||
test_file_path: "python/functions/datascience/diff_entities_test.py"
|
||||
file_path: "python/functions/datascience/diff_entities.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
before = [
|
||||
{"id": "1", "name": "Alice", "status": "active", "updated_at": "2024-01-01"},
|
||||
{"id": "2", "name": "Bob", "status": "active", "updated_at": "2024-01-01"},
|
||||
]
|
||||
after = [
|
||||
{"id": "1", "name": "Alice", "status": "inactive", "updated_at": "2024-01-02"},
|
||||
{"id": "3", "name": "Carol", "status": "active", "updated_at": "2024-01-02"},
|
||||
]
|
||||
|
||||
result = diff_entities(before, after)
|
||||
# result["added"] -> [{"id": "3", "name": "Carol", ...}]
|
||||
# result["removed"] -> [{"id": "2", "name": "Bob", ...}]
|
||||
# result["modified"] -> [{"key": "1", "changes": {"status": {"old": "active", "new": "inactive"}}}]
|
||||
# result["unchanged"] -> 0
|
||||
# result["summary"] -> "1 added, 1 removed, 1 modified, 0 unchanged"
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura. No hace I/O — toma listas de dicts ya cargadas en memoria.
|
||||
|
||||
El campo `key` debe existir en todas las entities; las que no lo tengan se ignoran silenciosamente.
|
||||
|
||||
Si `compare_fields` se da, tiene prioridad sobre `ignore_fields`. Esto permite comparar solo un subconjunto especifico de campos sin preocuparse por los campos temporales.
|
||||
|
||||
El orden de `added` y `removed` no esta garantizado (depende del orden de iteracion de sets).
|
||||
@@ -0,0 +1,77 @@
|
||||
"""diff_entities — compara dos snapshots de entities detectando cambios campo a campo."""
|
||||
|
||||
|
||||
def diff_entities(
|
||||
before: list[dict],
|
||||
after: list[dict],
|
||||
key: str = "id",
|
||||
ignore_fields: list[str] | None = None,
|
||||
compare_fields: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Compara dos snapshots de entities y devuelve diferencias campo a campo.
|
||||
|
||||
Detecta entities añadidas, eliminadas, modificadas e inalteradas.
|
||||
Ignora campos de metadata temporal por defecto (created_at, updated_at).
|
||||
|
||||
Args:
|
||||
before: Lista de entities del snapshot anterior.
|
||||
after: Lista de entities del snapshot posterior.
|
||||
key: Campo que identifica unicamente cada entity. Default "id".
|
||||
ignore_fields: Campos a excluir de la comparacion.
|
||||
Default ["created_at", "updated_at"].
|
||||
compare_fields: Si se da, solo compara estos campos (tiene prioridad
|
||||
sobre ignore_fields).
|
||||
|
||||
Returns:
|
||||
Dict con keys: added, removed, modified, unchanged, summary.
|
||||
modified contiene lista de {"key": str, "changes": {"field": {"old": ..., "new": ...}}}.
|
||||
"""
|
||||
if ignore_fields is None:
|
||||
ignore_fields = ["created_at", "updated_at"]
|
||||
|
||||
before_map = {str(e[key]): e for e in before if key in e}
|
||||
after_map = {str(e[key]): e for e in after if key in e}
|
||||
|
||||
before_keys = set(before_map.keys())
|
||||
after_keys = set(after_map.keys())
|
||||
|
||||
added = [after_map[k] for k in after_keys - before_keys]
|
||||
removed = [before_map[k] for k in before_keys - after_keys]
|
||||
|
||||
modified = []
|
||||
unchanged = 0
|
||||
|
||||
for k in before_keys & after_keys:
|
||||
b = before_map[k]
|
||||
a = after_map[k]
|
||||
|
||||
if compare_fields is not None:
|
||||
fields_to_check = compare_fields
|
||||
else:
|
||||
all_fields = set(b.keys()) | set(a.keys())
|
||||
fields_to_check = [f for f in all_fields if f not in ignore_fields and f != key]
|
||||
|
||||
changes = {}
|
||||
for field in fields_to_check:
|
||||
old_val = b.get(field)
|
||||
new_val = a.get(field)
|
||||
if old_val != new_val:
|
||||
changes[field] = {"old": old_val, "new": new_val}
|
||||
|
||||
if changes:
|
||||
modified.append({"key": k, "changes": changes})
|
||||
else:
|
||||
unchanged += 1
|
||||
|
||||
n_added = len(added)
|
||||
n_removed = len(removed)
|
||||
n_modified = len(modified)
|
||||
summary = f"{n_added} added, {n_removed} removed, {n_modified} modified, {unchanged} unchanged"
|
||||
|
||||
return {
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"modified": modified,
|
||||
"unchanged": unchanged,
|
||||
"summary": summary,
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Tests para diff_entities."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from diff_entities import diff_entities
|
||||
|
||||
|
||||
def test_entity_anadida():
|
||||
before = [{"id": "1", "name": "Alice"}]
|
||||
after = [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}]
|
||||
result = diff_entities(before, after)
|
||||
assert len(result["added"]) == 1
|
||||
assert result["added"][0]["id"] == "2"
|
||||
assert result["removed"] == []
|
||||
assert result["modified"] == []
|
||||
assert result["unchanged"] == 1
|
||||
assert "1 added" in result["summary"]
|
||||
|
||||
|
||||
def test_entity_eliminada():
|
||||
before = [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}]
|
||||
after = [{"id": "1", "name": "Alice"}]
|
||||
result = diff_entities(before, after)
|
||||
assert result["added"] == []
|
||||
assert len(result["removed"]) == 1
|
||||
assert result["removed"][0]["id"] == "2"
|
||||
assert result["unchanged"] == 1
|
||||
assert "1 removed" in result["summary"]
|
||||
|
||||
|
||||
def test_entity_modificada_con_detalle_de_campos():
|
||||
before = [{"id": "1", "name": "Alice", "status": "active"}]
|
||||
after = [{"id": "1", "name": "Alice", "status": "inactive"}]
|
||||
result = diff_entities(before, after)
|
||||
assert result["added"] == []
|
||||
assert result["removed"] == []
|
||||
assert len(result["modified"]) == 1
|
||||
mod = result["modified"][0]
|
||||
assert mod["key"] == "1"
|
||||
assert "status" in mod["changes"]
|
||||
assert mod["changes"]["status"]["old"] == "active"
|
||||
assert mod["changes"]["status"]["new"] == "inactive"
|
||||
assert result["unchanged"] == 0
|
||||
|
||||
|
||||
def test_entities_identicas_unchanged():
|
||||
before = [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}]
|
||||
after = [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}]
|
||||
result = diff_entities(before, after)
|
||||
assert result["added"] == []
|
||||
assert result["removed"] == []
|
||||
assert result["modified"] == []
|
||||
assert result["unchanged"] == 2
|
||||
assert "2 unchanged" in result["summary"]
|
||||
|
||||
|
||||
def test_ignore_fields_funciona():
|
||||
before = [{"id": "1", "name": "Alice", "updated_at": "2024-01-01", "created_at": "2023-01-01"}]
|
||||
after = [{"id": "1", "name": "Alice", "updated_at": "2024-06-01", "created_at": "2023-01-01"}]
|
||||
result = diff_entities(before, after)
|
||||
# updated_at se ignora por defecto -> unchanged
|
||||
assert result["unchanged"] == 1
|
||||
assert result["modified"] == []
|
||||
|
||||
# Si no ignoramos updated_at, debe detectar el cambio
|
||||
result2 = diff_entities(before, after, ignore_fields=[])
|
||||
assert len(result2["modified"]) == 1
|
||||
assert "updated_at" in result2["modified"][0]["changes"]
|
||||
|
||||
|
||||
def test_compare_fields_filtra_correctamente():
|
||||
before = [{"id": "1", "name": "Alice", "status": "active", "score": 10}]
|
||||
after = [{"id": "1", "name": "Bob", "status": "inactive", "score": 10}]
|
||||
# Solo comparar score -> no hay cambio en score, unchanged
|
||||
result = diff_entities(before, after, compare_fields=["score"])
|
||||
assert result["unchanged"] == 1
|
||||
assert result["modified"] == []
|
||||
|
||||
# Solo comparar name -> detecta cambio
|
||||
result2 = diff_entities(before, after, compare_fields=["name"])
|
||||
assert len(result2["modified"]) == 1
|
||||
assert "name" in result2["modified"][0]["changes"]
|
||||
assert "status" not in result2["modified"][0]["changes"]
|
||||
|
||||
|
||||
def test_lista_vacia_vs_lista_con_datos():
|
||||
before = []
|
||||
after = [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}]
|
||||
result = diff_entities(before, after)
|
||||
assert len(result["added"]) == 2
|
||||
assert result["removed"] == []
|
||||
assert result["unchanged"] == 0
|
||||
|
||||
# Invertido
|
||||
result2 = diff_entities(after, before)
|
||||
assert result2["added"] == []
|
||||
assert len(result2["removed"]) == 2
|
||||
assert result2["unchanged"] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_entity_anadida()
|
||||
test_entity_eliminada()
|
||||
test_entity_modificada_con_detalle_de_campos()
|
||||
test_entities_identicas_unchanged()
|
||||
test_ignore_fields_funciona()
|
||||
test_compare_fields_filtra_correctamente()
|
||||
test_lista_vacia_vs_lista_con_datos()
|
||||
print("All tests passed.")
|
||||
@@ -0,0 +1,52 @@
|
||||
---
|
||||
name: diff_relations
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def diff_relations(before: list[dict], after: list[dict], key: tuple[str, str, str] = ('source_id', 'target_id', 'relation_type'), ignore_fields: list[str] | None = None, compare_fields: list[str] | None = None) -> dict"
|
||||
description: "Compara relaciones entre dos snapshots usando key compuesta (source_id, target_id, relation_type). Detecta relaciones añadidas, eliminadas y modificadas con detalle campo a campo."
|
||||
tags: [diff, relations, graph, snapshot, operations, comparison, datascience]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "relacion añadida"
|
||||
- "relacion eliminada"
|
||||
- "relacion con metadata modificada (mismo source/target/type, distinto weight)"
|
||||
- "key compuesta funciona correctamente"
|
||||
test_file_path: "python/functions/datascience/diff_relations_test.py"
|
||||
file_path: "python/functions/datascience/diff_relations.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
before = [
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0},
|
||||
{"source_id": "B", "target_id": "C", "relation_type": "owns", "weight": 0.5},
|
||||
]
|
||||
after = [
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 2.0},
|
||||
{"source_id": "C", "target_id": "D", "relation_type": "knows", "weight": 1.0},
|
||||
]
|
||||
|
||||
result = diff_relations(before, after)
|
||||
# result["added"] -> [{"source_id": "C", "target_id": "D", ...}]
|
||||
# result["removed"] -> [{"source_id": "B", "target_id": "C", ...}]
|
||||
# result["modified"] -> [{"key": "A|B|knows", "changes": {"weight": {"old": 1.0, "new": 2.0}}}]
|
||||
# result["unchanged"] -> 0
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
La key compuesta se serializa como `source_id|target_id|relation_type`. Si alguno de los campos clave no existe en la relacion, se usa string vacio.
|
||||
|
||||
Misma semantica que `diff_entities_py_datascience` pero adaptada para relaciones donde no hay un ID unico — la identidad se define por los tres campos de la key.
|
||||
|
||||
Complemento natural de `diff_entities_py_datascience` para comparar grafos completos entre ejecuciones de pipelines.
|
||||
@@ -0,0 +1,82 @@
|
||||
"""diff_relations — compara dos snapshots de relaciones con key compuesta."""
|
||||
|
||||
|
||||
def diff_relations(
|
||||
before: list[dict],
|
||||
after: list[dict],
|
||||
key: tuple[str, str, str] = ("source_id", "target_id", "relation_type"),
|
||||
ignore_fields: list[str] | None = None,
|
||||
compare_fields: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Compara relaciones entre dos snapshots usando key compuesta.
|
||||
|
||||
Las relaciones se identifican por (source_id, target_id, relation_type)
|
||||
porque no tienen un ID unico propio. Detecta relaciones añadidas,
|
||||
eliminadas y modificadas (mismo source/target/type, distinta metadata).
|
||||
|
||||
Args:
|
||||
before: Lista de relaciones del snapshot anterior.
|
||||
after: Lista de relaciones del snapshot posterior.
|
||||
key: Tupla de campos que forman la key compuesta.
|
||||
Default ("source_id", "target_id", "relation_type").
|
||||
ignore_fields: Campos a excluir de la comparacion.
|
||||
Default ["created_at", "updated_at"].
|
||||
compare_fields: Si se da, solo compara estos campos.
|
||||
|
||||
Returns:
|
||||
Dict con keys: added, removed, modified, unchanged, summary.
|
||||
modified contiene lista de {"key": str, "changes": {"field": {"old": ..., "new": ...}}}.
|
||||
"""
|
||||
if ignore_fields is None:
|
||||
ignore_fields = ["created_at", "updated_at"]
|
||||
|
||||
def make_key(rel: dict) -> str:
|
||||
return "|".join(str(rel.get(k, "")) for k in key)
|
||||
|
||||
before_map = {make_key(r): r for r in before}
|
||||
after_map = {make_key(r): r for r in after}
|
||||
|
||||
before_keys = set(before_map.keys())
|
||||
after_keys = set(after_map.keys())
|
||||
|
||||
added = [after_map[k] for k in after_keys - before_keys]
|
||||
removed = [before_map[k] for k in before_keys - after_keys]
|
||||
|
||||
modified = []
|
||||
unchanged = 0
|
||||
|
||||
for k in before_keys & after_keys:
|
||||
b = before_map[k]
|
||||
a = after_map[k]
|
||||
|
||||
if compare_fields is not None:
|
||||
fields_to_check = compare_fields
|
||||
else:
|
||||
all_fields = set(b.keys()) | set(a.keys())
|
||||
key_set = set(key)
|
||||
fields_to_check = [f for f in all_fields if f not in ignore_fields and f not in key_set]
|
||||
|
||||
changes = {}
|
||||
for field in fields_to_check:
|
||||
old_val = b.get(field)
|
||||
new_val = a.get(field)
|
||||
if old_val != new_val:
|
||||
changes[field] = {"old": old_val, "new": new_val}
|
||||
|
||||
if changes:
|
||||
modified.append({"key": k, "changes": changes})
|
||||
else:
|
||||
unchanged += 1
|
||||
|
||||
n_added = len(added)
|
||||
n_removed = len(removed)
|
||||
n_modified = len(modified)
|
||||
summary = f"{n_added} added, {n_removed} removed, {n_modified} modified, {unchanged} unchanged"
|
||||
|
||||
return {
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"modified": modified,
|
||||
"unchanged": unchanged,
|
||||
"summary": summary,
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests para diff_relations."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from diff_relations import diff_relations
|
||||
|
||||
|
||||
def test_relacion_anadida():
|
||||
before = [{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0}]
|
||||
after = [
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0},
|
||||
{"source_id": "C", "target_id": "D", "relation_type": "owns", "weight": 0.5},
|
||||
]
|
||||
result = diff_relations(before, after)
|
||||
assert len(result["added"]) == 1
|
||||
assert result["added"][0]["source_id"] == "C"
|
||||
assert result["removed"] == []
|
||||
assert result["unchanged"] == 1
|
||||
assert "1 added" in result["summary"]
|
||||
|
||||
|
||||
def test_relacion_eliminada():
|
||||
before = [
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0},
|
||||
{"source_id": "C", "target_id": "D", "relation_type": "owns", "weight": 0.5},
|
||||
]
|
||||
after = [{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0}]
|
||||
result = diff_relations(before, after)
|
||||
assert result["added"] == []
|
||||
assert len(result["removed"]) == 1
|
||||
assert result["removed"][0]["source_id"] == "C"
|
||||
assert result["unchanged"] == 1
|
||||
assert "1 removed" in result["summary"]
|
||||
|
||||
|
||||
def test_relacion_con_metadata_modificada_mismo_source_target_type_distinto_weight():
|
||||
before = [{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0}]
|
||||
after = [{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 5.0}]
|
||||
result = diff_relations(before, after)
|
||||
assert result["added"] == []
|
||||
assert result["removed"] == []
|
||||
assert len(result["modified"]) == 1
|
||||
mod = result["modified"][0]
|
||||
assert mod["key"] == "A|B|knows"
|
||||
assert "weight" in mod["changes"]
|
||||
assert mod["changes"]["weight"]["old"] == 1.0
|
||||
assert mod["changes"]["weight"]["new"] == 5.0
|
||||
assert result["unchanged"] == 0
|
||||
|
||||
|
||||
def test_key_compuesta_funciona_correctamente():
|
||||
# Misma pareja A->B pero diferente tipo de relacion -> dos relaciones distintas
|
||||
before = [
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0},
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "owns", "weight": 0.5},
|
||||
]
|
||||
after = [
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "knows", "weight": 1.0},
|
||||
{"source_id": "A", "target_id": "B", "relation_type": "trusts", "weight": 0.8},
|
||||
]
|
||||
result = diff_relations(before, after)
|
||||
# owns eliminada, trusts añadida, knows sin cambios
|
||||
assert len(result["added"]) == 1
|
||||
assert result["added"][0]["relation_type"] == "trusts"
|
||||
assert len(result["removed"]) == 1
|
||||
assert result["removed"][0]["relation_type"] == "owns"
|
||||
assert result["unchanged"] == 1
|
||||
assert result["modified"] == []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_relacion_anadida()
|
||||
test_relacion_eliminada()
|
||||
test_relacion_con_metadata_modificada_mismo_source_target_type_distinto_weight()
|
||||
test_key_compuesta_funciona_correctamente()
|
||||
print("All tests passed.")
|
||||
@@ -0,0 +1,36 @@
|
||||
---
|
||||
name: estimate_hawkes
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def estimate_hawkes(arrivals: list[int], max_lag: int = 30) -> dict"
|
||||
description: "Estima parámetros de un proceso Hawkes (alpha, beta, branching_ratio) desde la autocorrelación de arrivals ajustando una exponencial decreciente sobre la ACF."
|
||||
tags: [estimation, hawkes, stochastic-process, microstructure, timeseries]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [numpy, scipy]
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/datascience/datascience.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
arrivals = [0, 1, 3, 2, 0, 1, 4, 2, 1, 0] * 10
|
||||
result = estimate_hawkes(arrivals, max_lag=10)
|
||||
# {'alpha': 0.312, 'beta': 0.874, 'branching_ratio': 0.357, 'acf': [...]}
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Ajusta la función `a * exp(-b * lag)` sobre los lags 1..max_lag de la ACF usando `curve_fit` de scipy.
|
||||
Si el primer lag de la ACF es <= 0.01 (sin autocorrelación), retorna alpha=0, beta=1, branching_ratio=0.
|
||||
El branching_ratio = alpha/beta; si se acerca a 1, el proceso es explosivo.
|
||||
Función pura: requiere numpy y scipy instalados.
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
name: estimate_pareto_alpha
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def estimate_pareto_alpha(values: list[float], x_min_percentile: float = 90.0) -> dict"
|
||||
description: "Estima el exponente alpha de una distribución Pareto via MLE. Alpha bajo indica cola más pesada y mayor frecuencia de valores extremos."
|
||||
tags: [estimation, pareto, power-law, heavy-tail, statistics]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [numpy]
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/datascience/datascience.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
# Simular datos con cola pesada
|
||||
values = list(np.random.pareto(2.0, 1000) + 1)
|
||||
result = estimate_pareto_alpha(values, x_min_percentile=90.0)
|
||||
# {'alpha': ~2.0, 'x_min': ..., 'n_tail': 100}
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Usa el estimador MLE de Hill: α = n / Σ ln(xᵢ / x_min).
|
||||
x_min se determina como el percentil indicado de los valores positivos.
|
||||
Retorna alpha=0 si hay menos de 10 valores positivos o la cola tiene menos de 2 elementos.
|
||||
Función pura: requiere numpy instalado.
|
||||
@@ -0,0 +1,87 @@
|
||||
---
|
||||
name: extract_entities_llm
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def extract_entities_llm(text: str, entity_schema: list[dict], llm_chat_json: Callable[[list[dict]], dict], language_instruction: str = 'Respond in English.') -> list[EntityCandidate]"
|
||||
description: "Extrae entidades de un chunk de texto usando un LLM inyectado. Construye el system prompt con el schema, llama al LLM y valida la respuesta retornando EntityCandidate. JSON invalido o type_ref fuera del schema se descartan con warning."
|
||||
tags: [llm, extraction, entity, nlp, osint, graph, fuzzygraph, datascience, prompt]
|
||||
uses_functions: []
|
||||
uses_types: [entity_candidate_py_datascience]
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [warnings, typing.Callable]
|
||||
tested: true
|
||||
tests:
|
||||
- "texto con entidades claras retorna EntityCandidate"
|
||||
- "texto sin entidades retorna lista vacia"
|
||||
- "llm retorna json mal formado retorna lista vacia con warning"
|
||||
- "type_ref invalido en respuesta se descarta con warning"
|
||||
- "confidence se propaga correctamente"
|
||||
- "schema vacio lanza ValueError"
|
||||
test_file_path: "python/functions/datascience/extract_entities_llm_test.py"
|
||||
file_path: "python/functions/datascience/extract_entities_llm.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
import json
|
||||
from extract_entities_llm import extract_entities_llm
|
||||
|
||||
# LLM stub para tests — en produccion usar litellm o similar
|
||||
def mock_llm(messages: list[dict]) -> dict:
|
||||
return {
|
||||
"entities": [
|
||||
{
|
||||
"name": "John Smith",
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"attributes": {"full_name": "John Smith", "nationality": "US"},
|
||||
"confidence": 0.95,
|
||||
},
|
||||
{
|
||||
"name": "evil-corp.com",
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"attributes": {"fqdn": "evil-corp.com"},
|
||||
"confidence": 0.88,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
schema = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "alias", "nationality", "dob", "risk_score"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"label": "Domain",
|
||||
"metadata_fields": ["fqdn", "registrar", "created_date"],
|
||||
},
|
||||
]
|
||||
|
||||
text = "John Smith, a US citizen, was linked to the domain evil-corp.com."
|
||||
candidates = extract_entities_llm(text, schema, mock_llm)
|
||||
# [EntityCandidate(name='John Smith', type_ref='osint_person_go_cybersecurity', confidence=0.95),
|
||||
# EntityCandidate(name='evil-corp.com', type_ref='osint_domain_go_cybersecurity', confidence=0.88)]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
**Inyeccion de dependencia del LLM:** `llm_chat_json` recibe mensajes en formato OpenAI (`[{"role": "system", "content": "..."}, ...]`) y retorna un `dict` con la respuesta ya parseada como JSON. Esto desacopla la funcion de cualquier cliente especifico — puede usarse con OpenAI, Anthropic via litellm, o cualquier mock.
|
||||
|
||||
**Validacion de type_ref:** Solo se aceptan entidades cuyo `type_ref` aparece en el `entity_schema`. Entidades con type_ref desconocido se descartan con `warnings.warn` (no lanzan excepcion) para ser resiliente ante alucinaciones del LLM.
|
||||
|
||||
**Manejo de JSON invalido:** Si `llm_chat_json` lanza una excepcion o retorna un dict sin la clave `entities`, se retorna lista vacia y se emite un warning. El llamador puede decidir si reintentar.
|
||||
|
||||
**Confidence clamping:** El valor de confidence se clampea al rango [0.0, 1.0] automaticamente.
|
||||
|
||||
**Atributos null:** Los atributos con valor `None` se filtran del dict de atributos para mantener el output limpio.
|
||||
|
||||
**source_chunk_indices:** Esta funcion no setea `source_chunk_indices` — ese campo lo llena el pipeline exterior que conoce el indice del chunk actual.
|
||||
|
||||
Esta funcion es el bloque atomico de extraccion. El pipeline completo de grafos la llama por cada chunk del documento y luego deduplica los candidatos resultantes.
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Extrae entidades de un chunk de texto usando un LLM inyectado."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
|
||||
|
||||
def _build_system_prompt(entity_schema: list[dict], language_instruction: str) -> str:
|
||||
"""Construye el system prompt para extraccion de entidades."""
|
||||
lines = [
|
||||
"You are an entity extraction expert. Given text, extract all entities",
|
||||
"matching these types. For each entity, provide: name, type_ref,",
|
||||
"attributes (matching the metadata_fields for that type), and a",
|
||||
"confidence score (0.0-1.0).",
|
||||
"",
|
||||
"Entity types:",
|
||||
]
|
||||
|
||||
for schema_entry in entity_schema:
|
||||
label = schema_entry.get("label", "Unknown")
|
||||
type_ref = schema_entry.get("type_ref", "")
|
||||
metadata_fields = schema_entry.get("metadata_fields", [])
|
||||
lines.append(f"- {label} (type_ref: {type_ref})")
|
||||
if metadata_fields:
|
||||
lines.append(f" fields: {', '.join(metadata_fields)}")
|
||||
|
||||
lines += [
|
||||
"",
|
||||
'Output JSON: {"entities": [{"name": "...", "type_ref": "...", "attributes": {...}, "confidence": 0.9}]}',
|
||||
"",
|
||||
"Rules:",
|
||||
"- Only extract entities explicitly mentioned in the text",
|
||||
"- Use the exact type_ref from the schema",
|
||||
"- Leave unknown attributes as null",
|
||||
"- Confidence: 1.0 = explicitly named, 0.7 = strongly implied, 0.5 = weakly implied",
|
||||
f"- {language_instruction}",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def extract_entities_llm(
|
||||
text: str,
|
||||
entity_schema: list[dict],
|
||||
llm_chat_json: Callable[[list[dict]], dict],
|
||||
language_instruction: str = "Respond in English.",
|
||||
) -> list[EntityCandidate]:
|
||||
"""Extrae entidades de un chunk de texto usando un LLM inyectado.
|
||||
|
||||
Construye un system prompt con el schema de entity types, llama al LLM
|
||||
y valida la respuesta retornando una lista de EntityCandidate.
|
||||
|
||||
Args:
|
||||
text: Chunk de texto a analizar.
|
||||
entity_schema: Lista de tipos con metadata fields. Cada entrada es un
|
||||
dict con las claves 'type_ref', 'label' y opcionalmente
|
||||
'metadata_fields'. Ejemplo:
|
||||
[{"type_ref": "osint_person_go_cybersecurity", "label": "Person",
|
||||
"metadata_fields": ["full_name", "alias"]}]
|
||||
llm_chat_json: Funcion que recibe una lista de mensajes OpenAI-style
|
||||
y retorna un dict con la respuesta JSON del LLM. Interfaz:
|
||||
llm_chat_json([{"role": "system", "content": "..."}, ...]) -> dict
|
||||
language_instruction: Instruccion de idioma para el LLM. Por defecto
|
||||
"Respond in English."
|
||||
|
||||
Returns:
|
||||
Lista de EntityCandidate extraidos. Retorna lista vacia si el LLM
|
||||
no retorna JSON valido o si no se encuentran entidades.
|
||||
|
||||
Raises:
|
||||
ValueError: Si entity_schema esta vacio.
|
||||
"""
|
||||
if not entity_schema:
|
||||
raise ValueError("entity_schema no puede estar vacio")
|
||||
|
||||
valid_type_refs = {entry.get("type_ref", "") for entry in entity_schema}
|
||||
type_ref_to_label = {
|
||||
entry.get("type_ref", ""): entry.get("label", "") for entry in entity_schema
|
||||
}
|
||||
|
||||
system_prompt = _build_system_prompt(entity_schema, language_instruction)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm_chat_json(messages)
|
||||
except Exception as exc:
|
||||
warnings.warn(f"extract_entities_llm: error llamando al LLM: {exc}", stacklevel=2)
|
||||
return []
|
||||
|
||||
raw_entities = response.get("entities", [])
|
||||
if not isinstance(raw_entities, list):
|
||||
warnings.warn(
|
||||
"extract_entities_llm: la respuesta del LLM no contiene 'entities' como lista",
|
||||
stacklevel=2,
|
||||
)
|
||||
return []
|
||||
|
||||
candidates: list[EntityCandidate] = []
|
||||
for item in raw_entities:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
name = item.get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
|
||||
type_ref = item.get("type_ref", "")
|
||||
if type_ref not in valid_type_refs:
|
||||
warnings.warn(
|
||||
f"extract_entities_llm: type_ref '{type_ref}' no esta en el schema, descartando entidad '{name}'",
|
||||
stacklevel=2,
|
||||
)
|
||||
continue
|
||||
|
||||
attributes = item.get("attributes", {})
|
||||
if not isinstance(attributes, dict):
|
||||
attributes = {}
|
||||
# Normalizar null values a None
|
||||
attributes = {k: v for k, v in attributes.items() if v is not None}
|
||||
|
||||
confidence = item.get("confidence", 0.0)
|
||||
if not isinstance(confidence, (int, float)):
|
||||
confidence = 0.0
|
||||
confidence = float(max(0.0, min(1.0, confidence)))
|
||||
|
||||
candidates.append(
|
||||
EntityCandidate(
|
||||
name=name,
|
||||
type_ref=type_ref,
|
||||
type_label=type_ref_to_label.get(type_ref, ""),
|
||||
attributes=attributes,
|
||||
confidence=confidence,
|
||||
)
|
||||
)
|
||||
|
||||
return candidates
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Tests para extract_entities_llm."""
|
||||
|
||||
import warnings
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
from python.functions.datascience.extract_entities_llm import extract_entities_llm
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
|
||||
SCHEMA = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "alias", "nationality", "dob", "risk_score"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"label": "Domain",
|
||||
"metadata_fields": ["fqdn", "registrar", "created_date"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def make_llm(response: dict):
|
||||
"""Crea un stub de LLM que retorna la respuesta dada."""
|
||||
def _llm(messages: list[dict]) -> dict:
|
||||
return response
|
||||
return _llm
|
||||
|
||||
|
||||
def test_texto_con_entidades_claras_retorna_entity_candidate():
|
||||
"""texto con entidades claras retorna EntityCandidate"""
|
||||
llm = make_llm({
|
||||
"entities": [
|
||||
{
|
||||
"name": "John Smith",
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"attributes": {"full_name": "John Smith", "nationality": "US"},
|
||||
"confidence": 0.95,
|
||||
},
|
||||
{
|
||||
"name": "evil-corp.com",
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"attributes": {"fqdn": "evil-corp.com"},
|
||||
"confidence": 0.88,
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
result = extract_entities_llm(
|
||||
"John Smith, US citizen, linked to evil-corp.com.", SCHEMA, llm
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
person = next(e for e in result if e.name == "John Smith")
|
||||
assert person.type_ref == "osint_person_go_cybersecurity"
|
||||
assert person.type_label == "Person"
|
||||
assert person.attributes["full_name"] == "John Smith"
|
||||
assert person.confidence == 0.95
|
||||
|
||||
domain = next(e for e in result if e.name == "evil-corp.com")
|
||||
assert domain.type_ref == "osint_domain_go_cybersecurity"
|
||||
assert domain.type_label == "Domain"
|
||||
assert domain.attributes["fqdn"] == "evil-corp.com"
|
||||
assert domain.confidence == 0.88
|
||||
|
||||
|
||||
def test_texto_sin_entidades_retorna_lista_vacia():
|
||||
"""texto sin entidades retorna lista vacia"""
|
||||
llm = make_llm({"entities": []})
|
||||
|
||||
result = extract_entities_llm(
|
||||
"The sky is blue and the grass is green.", SCHEMA, llm
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_llm_retorna_json_mal_formado_retorna_lista_vacia_con_warning():
|
||||
"""llm retorna json mal formado retorna lista vacia con warning"""
|
||||
def bad_llm(messages: list[dict]) -> dict:
|
||||
raise ValueError("JSON decode error")
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
result = extract_entities_llm("Some text with entities.", SCHEMA, bad_llm)
|
||||
|
||||
assert result == []
|
||||
assert len(caught) == 1
|
||||
assert "error llamando al LLM" in str(caught[0].message)
|
||||
|
||||
|
||||
def test_type_ref_invalido_en_respuesta_se_descarta_con_warning():
|
||||
"""type_ref invalido en respuesta se descarta con warning"""
|
||||
llm = make_llm({
|
||||
"entities": [
|
||||
{
|
||||
"name": "Valid Person",
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"attributes": {},
|
||||
"confidence": 0.9,
|
||||
},
|
||||
{
|
||||
"name": "Unknown Thing",
|
||||
"type_ref": "nonexistent_type_ref",
|
||||
"attributes": {},
|
||||
"confidence": 0.8,
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
result = extract_entities_llm("Text with entities.", SCHEMA, llm)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "Valid Person"
|
||||
assert any("nonexistent_type_ref" in str(w.message) for w in caught)
|
||||
|
||||
|
||||
def test_confidence_se_propaga_correctamente():
|
||||
"""confidence se propaga correctamente"""
|
||||
llm = make_llm({
|
||||
"entities": [
|
||||
{
|
||||
"name": "Implied Person",
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"attributes": {},
|
||||
"confidence": 0.7,
|
||||
},
|
||||
{
|
||||
"name": "Weakly Implied Domain",
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"attributes": {},
|
||||
"confidence": 0.5,
|
||||
},
|
||||
{
|
||||
"name": "Explicit Entity",
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"attributes": {},
|
||||
"confidence": 1.0,
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
result = extract_entities_llm("Some text.", SCHEMA, llm)
|
||||
|
||||
assert len(result) == 3
|
||||
confidences = {e.name: e.confidence for e in result}
|
||||
assert confidences["Implied Person"] == 0.7
|
||||
assert confidences["Weakly Implied Domain"] == 0.5
|
||||
assert confidences["Explicit Entity"] == 1.0
|
||||
|
||||
|
||||
def test_schema_vacio_lanza_value_error():
|
||||
"""schema vacio lanza ValueError"""
|
||||
llm = make_llm({"entities": []})
|
||||
|
||||
with pytest.raises(ValueError, match="entity_schema no puede estar vacio"):
|
||||
extract_entities_llm("Some text.", [], llm)
|
||||
@@ -0,0 +1,75 @@
|
||||
---
|
||||
name: extract_relations_llm
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def extract_relations_llm(text: str, entities: list, relation_types: list[str], llm_chat_json: Callable[[list[dict]], dict], language_instruction: str = 'Respond in English.') -> list"
|
||||
description: "Extrae relaciones entre entidades de un chunk de texto usando un LLM inyectado. Valida que from_name y to_name correspondan a entidades existentes, y usa 'related_to' como fallback para tipos de relacion no permitidos."
|
||||
tags: [extraction, relation, llm, knowledge-graph, nlp, datascience, fuzzygraph, graph]
|
||||
uses_functions: []
|
||||
uses_types:
|
||||
- entity_candidate_py_datascience
|
||||
- relation_candidate_py_datascience
|
||||
returns:
|
||||
- relation_candidate_py_datascience
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [logging, sys, os, typing]
|
||||
tested: true
|
||||
tests:
|
||||
- "texto con dos entidades relacionadas"
|
||||
- "texto con entidades pero sin relacion"
|
||||
- "menos de dos entidades retorna lista vacia"
|
||||
- "llm inventa entidad que no existe se descarta"
|
||||
test_file_path: "python/functions/datascience/extract_relations_llm_test.py"
|
||||
file_path: "python/functions/datascience/extract_relations_llm.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from extract_relations_llm import extract_relations_llm
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
|
||||
# Stub de llm_chat_json (en produccion usar llm_completion_retry o similar)
|
||||
def my_llm(messages: list[dict]) -> dict:
|
||||
# Llamar al LLM real aqui
|
||||
return {"relations": [...]}
|
||||
|
||||
entities = [
|
||||
EntityCandidate(name="Acme Corp", type_label="Organization", confidence=0.95),
|
||||
EntityCandidate(name="John Smith", type_label="Person", confidence=0.9),
|
||||
]
|
||||
|
||||
relation_types = ["employs", "funds", "owns", "communicates_with", "related_to"]
|
||||
|
||||
relations = extract_relations_llm(
|
||||
text="Acme Corp employs John Smith as CEO and funds his research.",
|
||||
entities=entities,
|
||||
relation_types=relation_types,
|
||||
llm_chat_json=my_llm,
|
||||
)
|
||||
|
||||
for rel in relations:
|
||||
print(f"{rel.from_name} --[{rel.relation_type}]--> {rel.to_name} ({rel.confidence:.2f})")
|
||||
# Acme Corp --[employs]--> John Smith (0.90)
|
||||
# Acme Corp --[funds]--> John Smith (0.85)
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
**Inyeccion de dependencia del LLM:** `llm_chat_json` recibe una lista de mensajes en formato OpenAI (`[{"role": "system", "content": ...}, {"role": "user", "content": ...}]`) y retorna un dict con la clave `"relations"`. Esto desacopla la funcion de cualquier proveedor de LLM concreto.
|
||||
|
||||
**Validacion de entidades:** Solo se aceptan relaciones donde `from_name` y `to_name` aparecen exactamente en los nombres de las entidades proporcionadas. Relaciones con nombres inventados por el LLM se descartan silenciosamente (con debug log).
|
||||
|
||||
**Fallback de tipo:** Si el LLM propone un `relation_type` que no esta en la lista permitida, se reemplaza por `"related_to"`. Si `"related_to"` tampoco esta en la lista, se incluye igualmente como catch-all seguro.
|
||||
|
||||
**Menos de 2 entidades:** La funcion retorna `[]` inmediatamente sin llamar al LLM, ya que no puede haber relaciones con menos de 2 participantes.
|
||||
|
||||
**Error handling:** Si `llm_chat_json` lanza una excepcion, se captura con warning y retorna `[]`. Si la respuesta no contiene la clave `"relations"` o no es una lista, idem.
|
||||
|
||||
**Confianza:** Los valores de confianza del LLM se clampean al rango `[0.0, 1.0]`. Valores no numericos se convierten a `0.0`.
|
||||
|
||||
Disenado para fuzzygraph — se compone con `extract_entities_llm` (paso anterior) y `deduplicate_relations` (paso siguiente en el pipeline de extraccion).
|
||||
@@ -0,0 +1,141 @@
|
||||
"""extract_relations_llm — extrae relaciones entre entidades usando un LLM."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ""))
|
||||
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
from python.types.datascience.relation_candidate import RelationCandidate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_relations_llm(
|
||||
text: str,
|
||||
entities: list[EntityCandidate],
|
||||
relation_types: list[str],
|
||||
llm_chat_json: Callable[[list[dict]], dict],
|
||||
language_instruction: str = "Respond in English.",
|
||||
) -> list[RelationCandidate]:
|
||||
"""Extrae relaciones entre entidades de un chunk de texto usando un LLM.
|
||||
|
||||
Dado el texto original y las entidades ya extraidas, pide al LLM que
|
||||
identifique relaciones entre pares de entidades. Las relaciones cuyo
|
||||
from_name o to_name no coincidan con ninguna entidad existente se descartan.
|
||||
Los tipos de relacion no permitidos se reemplazan por "related_to".
|
||||
|
||||
Args:
|
||||
text: chunk de texto (el mismo que se uso para extraer las entidades).
|
||||
entities: entidades ya extraidas del chunk.
|
||||
relation_types: tipos de relacion permitidos, ej: ["funds", "employs",
|
||||
"communicates_with", "owns", "related_to"].
|
||||
llm_chat_json: funcion inyectada que recibe una lista de mensajes
|
||||
(dicts con "role" y "content") y retorna un dict con la respuesta
|
||||
JSON del LLM.
|
||||
language_instruction: instruccion de idioma para el LLM.
|
||||
|
||||
Returns:
|
||||
Lista de RelationCandidate validados. Vacia si hay menos de 2 entidades
|
||||
o si el LLM no encuentra relaciones.
|
||||
"""
|
||||
if len(entities) < 2:
|
||||
return []
|
||||
|
||||
entity_names = {e.name for e in entities}
|
||||
relation_types_set = set(relation_types)
|
||||
|
||||
# Construir lista de entidades para el prompt
|
||||
entity_lines = "\n".join(
|
||||
f'- "{e.name}" ({e.type_label or e.type_ref or "Entity"})' for e in entities
|
||||
)
|
||||
|
||||
# Construir tipos de relacion para el prompt
|
||||
relation_types_str = ", ".join(relation_types)
|
||||
|
||||
system_prompt = f"""\
|
||||
You are a relation extraction expert. Given text and a list of entities already \
|
||||
extracted, identify relationships between them.
|
||||
|
||||
Entities found in this text:
|
||||
{entity_lines}
|
||||
|
||||
Allowed relation types: {relation_types_str}
|
||||
|
||||
Output JSON: {{"relations": [
|
||||
{{"from_name": "Entity A", "to_name": "Entity B",
|
||||
"relation_type": "employs", "description": "...", "confidence": 0.8}}
|
||||
]}}
|
||||
|
||||
Rules:
|
||||
- Only extract relations explicitly stated or strongly implied in the text
|
||||
- from_name and to_name must match entity names exactly as listed above
|
||||
- relation_type must be one of the allowed types
|
||||
- Confidence: 1.0 = explicitly stated, 0.7 = strongly implied, 0.5 = weakly implied
|
||||
- Do not invent entities not in the list above
|
||||
- {language_instruction}"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm_chat_json(messages)
|
||||
except Exception as exc:
|
||||
logger.warning("extract_relations_llm: LLM call failed: %s", exc)
|
||||
return []
|
||||
|
||||
raw_relations = response.get("relations", [])
|
||||
if not isinstance(raw_relations, list):
|
||||
logger.warning("extract_relations_llm: 'relations' is not a list in LLM response")
|
||||
return []
|
||||
|
||||
results: list[RelationCandidate] = []
|
||||
for item in raw_relations:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
from_name = item.get("from_name", "")
|
||||
to_name = item.get("to_name", "")
|
||||
|
||||
# Validar que ambos nombres corresponden a entidades existentes
|
||||
if from_name not in entity_names:
|
||||
logger.debug(
|
||||
"extract_relations_llm: from_name '%s' no coincide con ninguna entidad — descartando",
|
||||
from_name,
|
||||
)
|
||||
continue
|
||||
if to_name not in entity_names:
|
||||
logger.debug(
|
||||
"extract_relations_llm: to_name '%s' no coincide con ninguna entidad — descartando",
|
||||
to_name,
|
||||
)
|
||||
continue
|
||||
|
||||
relation_type = item.get("relation_type", "")
|
||||
if relation_type not in relation_types_set:
|
||||
logger.debug(
|
||||
"extract_relations_llm: tipo '%s' no permitido — usando 'related_to'",
|
||||
relation_type,
|
||||
)
|
||||
relation_type = "related_to"
|
||||
|
||||
confidence = item.get("confidence", 0.0)
|
||||
if not isinstance(confidence, (int, float)):
|
||||
confidence = 0.0
|
||||
confidence = float(max(0.0, min(1.0, confidence)))
|
||||
|
||||
results.append(
|
||||
RelationCandidate(
|
||||
from_name=from_name,
|
||||
to_name=to_name,
|
||||
relation_type=relation_type,
|
||||
description=item.get("description", ""),
|
||||
confidence=confidence,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Tests para extract_relations_llm."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Rutas para importar desde el registry
|
||||
REGISTRY_ROOT = os.path.join(os.path.dirname(__file__), "..", "..", "..", "")
|
||||
sys.path.insert(0, REGISTRY_ROOT)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
from python.types.datascience.relation_candidate import RelationCandidate
|
||||
from extract_relations_llm import extract_relations_llm
|
||||
|
||||
|
||||
def _make_entity(name: str, type_label: str = "Entity") -> EntityCandidate:
|
||||
return EntityCandidate(name=name, type_label=type_label, confidence=0.9)
|
||||
|
||||
|
||||
def _make_llm(response: dict):
|
||||
"""Crea un stub de llm_chat_json que retorna la respuesta fija."""
|
||||
def llm_chat_json(messages: list[dict]) -> dict:
|
||||
return response
|
||||
return llm_chat_json
|
||||
|
||||
|
||||
def test_texto_con_dos_entidades_relacionadas():
|
||||
entities = [_make_entity("Acme Corp", "Organization"), _make_entity("John Smith", "Person")]
|
||||
relation_types = ["employs", "funds", "related_to"]
|
||||
|
||||
llm_response = {
|
||||
"relations": [
|
||||
{
|
||||
"from_name": "Acme Corp",
|
||||
"to_name": "John Smith",
|
||||
"relation_type": "employs",
|
||||
"description": "Acme Corp employs John Smith as CEO",
|
||||
"confidence": 0.9,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = extract_relations_llm(
|
||||
text="Acme Corp employs John Smith as CEO.",
|
||||
entities=entities,
|
||||
relation_types=relation_types,
|
||||
llm_chat_json=_make_llm(llm_response),
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
rel = result[0]
|
||||
assert rel.from_name == "Acme Corp"
|
||||
assert rel.to_name == "John Smith"
|
||||
assert rel.relation_type == "employs"
|
||||
assert rel.confidence == 0.9
|
||||
assert "CEO" in rel.description
|
||||
|
||||
|
||||
def test_texto_con_entidades_pero_sin_relacion():
|
||||
entities = [_make_entity("Alice", "Person"), _make_entity("Bob", "Person")]
|
||||
relation_types = ["funds", "employs"]
|
||||
|
||||
llm_response = {"relations": []}
|
||||
|
||||
result = extract_relations_llm(
|
||||
text="Alice and Bob both attended the conference.",
|
||||
entities=entities,
|
||||
relation_types=relation_types,
|
||||
llm_chat_json=_make_llm(llm_response),
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_menos_de_dos_entidades_retorna_lista_vacia():
|
||||
entities = [_make_entity("Solo Corp", "Organization")]
|
||||
relation_types = ["employs", "funds"]
|
||||
|
||||
# El LLM nunca deberia ser llamado, pero si lo fuera retornaria relaciones
|
||||
llm_response = {
|
||||
"relations": [
|
||||
{"from_name": "Solo Corp", "to_name": "Nobody", "relation_type": "employs", "confidence": 0.9}
|
||||
]
|
||||
}
|
||||
|
||||
result = extract_relations_llm(
|
||||
text="Solo Corp is a company.",
|
||||
entities=entities,
|
||||
relation_types=relation_types,
|
||||
llm_chat_json=_make_llm(llm_response),
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_llm_inventa_entidad_que_no_existe_se_descarta():
|
||||
entities = [_make_entity("Alice", "Person"), _make_entity("Bob", "Person")]
|
||||
relation_types = ["funds", "employs", "related_to"]
|
||||
|
||||
llm_response = {
|
||||
"relations": [
|
||||
# Valida — Alice y Bob existen
|
||||
{
|
||||
"from_name": "Alice",
|
||||
"to_name": "Bob",
|
||||
"relation_type": "funds",
|
||||
"description": "Alice funds Bob",
|
||||
"confidence": 0.8,
|
||||
},
|
||||
# Invalida — "Charlie" no esta en entities
|
||||
{
|
||||
"from_name": "Alice",
|
||||
"to_name": "Charlie",
|
||||
"relation_type": "employs",
|
||||
"description": "Alice employs Charlie",
|
||||
"confidence": 0.7,
|
||||
},
|
||||
# Invalida — "Unknown Corp" no esta en entities
|
||||
{
|
||||
"from_name": "Unknown Corp",
|
||||
"to_name": "Bob",
|
||||
"relation_type": "related_to",
|
||||
"description": "...",
|
||||
"confidence": 0.6,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = extract_relations_llm(
|
||||
text="Alice funds Bob. Alice also employs Charlie from Unknown Corp.",
|
||||
entities=entities,
|
||||
relation_types=relation_types,
|
||||
llm_chat_json=_make_llm(llm_response),
|
||||
)
|
||||
|
||||
# Solo la primera relacion es valida
|
||||
assert len(result) == 1
|
||||
assert result[0].from_name == "Alice"
|
||||
assert result[0].to_name == "Bob"
|
||||
assert result[0].relation_type == "funds"
|
||||
@@ -0,0 +1,72 @@
|
||||
---
|
||||
name: hotness_score
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def hotness_score(active_count: int, updated_at: datetime | None, now: datetime | None = None, half_life_days: float = 7.0) -> float"
|
||||
description: "Calcula un score de hotness combinando frecuencia de acceso y recencia temporal. Util para ranking de resultados, memoria hot/cold y cache eviction."
|
||||
tags: [ranking, decay, recency, frequency, scoring, cache, memory, datascience]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [math, datetime]
|
||||
tested: true
|
||||
tests:
|
||||
- "active_count=0, updated_at reciente"
|
||||
- "active_count=100, updated_at reciente (score alto)"
|
||||
- "active_count=100, updated_at hace 30 dias (score bajo)"
|
||||
- "updated_at=None (retorna 0.0)"
|
||||
- "now explicito (determinista para tests)"
|
||||
- "half_life_days custom"
|
||||
test_file_path: "python/functions/datascience/hotness_score_test.py"
|
||||
file_path: "python/functions/datascience/hotness_score.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datascience.hotness_score import hotness_score
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Item reciente con muchos accesos -> score alto
|
||||
score = hotness_score(active_count=150, updated_at=now - timedelta(hours=2), now=now)
|
||||
# score > 0.95
|
||||
|
||||
# Item antiguo aunque muy accedido -> score bajo
|
||||
score = hotness_score(active_count=150, updated_at=now - timedelta(days=30), now=now)
|
||||
# score ~ 0.05
|
||||
|
||||
# Item sin fecha -> siempre 0
|
||||
score = hotness_score(active_count=999, updated_at=None)
|
||||
# score == 0.0
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Formula: `score = sigmoid(log1p(active_count)) * exp(-ln(2)/half_life_days * age_days)`
|
||||
|
||||
**Componente de frecuencia** — `sigmoid(log1p(count))` mapea enteros no negativos al rango `(0.5, 1.0)`:
|
||||
- count=0 -> 0.5
|
||||
- count=10 -> ~0.92
|
||||
- count=100 -> ~0.99
|
||||
|
||||
**Componente de recencia** — decaimiento exponencial con vida media configurable:
|
||||
- `half_life_days=7` (default): score se reduce a la mitad cada 7 dias
|
||||
- `half_life_days=1`: decaimiento agresivo (util para feeds en tiempo real)
|
||||
- `half_life_days=365`: decaimiento lento (util para contenido evergreen)
|
||||
|
||||
**Propiedades del score:**
|
||||
- `updated_at=None` -> 0.0 siempre (item sin fecha no tiene hotness)
|
||||
- `active_count=0, reciente` -> ~0.5 (neutro pero fresco)
|
||||
- `active_count alto, reciente` -> ~1.0 (muy caliente)
|
||||
- `active_count alto, antiguo` -> ~0.0 (frio a pesar de popularidad pasada)
|
||||
|
||||
Timestamps sin timezone se interpretan como UTC. Pasar `now` explicitamente garantiza determinismo en tests y reproducibilidad en pipelines batch.
|
||||
|
||||
Fuente conceptual: openviking/retrieve/memory_lifecycle.py (AGPL-3.0). Reimplementado desde cero con formula equivalente.
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Hotness score — combining access frequency and recency decay."""
|
||||
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def hotness_score(
|
||||
active_count: int,
|
||||
updated_at: datetime | None,
|
||||
now: datetime | None = None,
|
||||
half_life_days: float = 7.0,
|
||||
) -> float:
|
||||
"""Calcula un score de hotness combinando frecuencia de acceso y recencia.
|
||||
|
||||
Formula: sigmoid(log1p(active_count)) * exp_decay(age_days, half_life_days)
|
||||
|
||||
El componente de frecuencia mapea conteos enteros al rango (0, 1) via sigmoid(log1p).
|
||||
El componente de recencia decae exponencialmente con vida media configurable.
|
||||
|
||||
Args:
|
||||
active_count: Numero de accesos o activaciones. Debe ser >= 0.
|
||||
updated_at: Timestamp de la ultima actualizacion. None retorna 0.0.
|
||||
now: Momento de referencia para calcular la edad. Si es None usa datetime.now(UTC).
|
||||
half_life_days: Dias para que la recencia se reduzca a la mitad. Default 7.
|
||||
|
||||
Returns:
|
||||
float en [0.0, 1.0]. Valores mas cercanos a 1.0 indican mayor hotness.
|
||||
"""
|
||||
if updated_at is None:
|
||||
return 0.0
|
||||
|
||||
# Componente de frecuencia: sigmoid(log1p(count)) mapea 0..inf -> (0.5, 1.0)
|
||||
freq = 1.0 / (1.0 + math.exp(-math.log1p(active_count)))
|
||||
|
||||
# Componente de recencia: decaimiento exponencial
|
||||
if now is None:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Normalizar ambos timestamps a UTC para comparacion segura
|
||||
if updated_at.tzinfo is None:
|
||||
updated_at = updated_at.replace(tzinfo=timezone.utc)
|
||||
if now.tzinfo is None:
|
||||
now = now.replace(tzinfo=timezone.utc)
|
||||
|
||||
age_days = max((now - updated_at).total_seconds() / 86400.0, 0.0)
|
||||
decay_rate = math.log(2) / half_life_days
|
||||
recency = math.exp(-decay_rate * age_days)
|
||||
|
||||
return freq * recency
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Tests para hotness_score."""
|
||||
|
||||
import math
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from hotness_score import hotness_score
|
||||
|
||||
NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_active_count_zero_updated_at_reciente():
|
||||
"""active_count=0, updated_at reciente."""
|
||||
updated_at = NOW - timedelta(hours=1)
|
||||
score = hotness_score(0, updated_at, now=NOW)
|
||||
# freq = sigmoid(log1p(0)) = sigmoid(0) = 0.5
|
||||
# recency ~ 1.0 (casi nuevo)
|
||||
assert 0.45 < score < 0.55, f"Expected ~0.5, got {score}"
|
||||
|
||||
|
||||
def test_active_count_alto_updated_at_reciente():
|
||||
"""active_count=100, updated_at reciente (score alto)."""
|
||||
updated_at = NOW - timedelta(hours=1)
|
||||
score = hotness_score(100, updated_at, now=NOW)
|
||||
# freq = sigmoid(log1p(100)) = sigmoid(4.615) ~ 0.99
|
||||
# recency ~ 1.0
|
||||
assert score > 0.95, f"Expected > 0.95, got {score}"
|
||||
|
||||
|
||||
def test_active_count_alto_updated_at_hace_30_dias():
|
||||
"""active_count=100, updated_at hace 30 dias (score bajo)."""
|
||||
updated_at = NOW - timedelta(days=30)
|
||||
score = hotness_score(100, updated_at, now=NOW)
|
||||
# recency = exp(-ln2/7 * 30) = exp(-2.97) ~ 0.051
|
||||
# score ~ 0.99 * 0.051 ~ 0.05
|
||||
assert score < 0.1, f"Expected < 0.1, got {score}"
|
||||
|
||||
|
||||
def test_updated_at_none_retorna_cero():
|
||||
"""updated_at=None (retorna 0.0)."""
|
||||
score = hotness_score(100, None, now=NOW)
|
||||
assert score == 0.0, f"Expected 0.0, got {score}"
|
||||
|
||||
|
||||
def test_now_explicito():
|
||||
"""now explicito (determinista para tests)."""
|
||||
updated_at = NOW - timedelta(days=7)
|
||||
score = hotness_score(50, updated_at, now=NOW)
|
||||
# recency = exp(-ln2/7 * 7) = 0.5
|
||||
# freq = sigmoid(log1p(50)) ~ sigmoid(3.93) ~ 0.981
|
||||
expected = (1.0 / (1.0 + math.exp(-math.log1p(50)))) * 0.5
|
||||
assert abs(score - expected) < 1e-9, f"Expected {expected}, got {score}"
|
||||
|
||||
|
||||
def test_half_life_days_custom():
|
||||
"""half_life_days custom."""
|
||||
updated_at = NOW - timedelta(days=1)
|
||||
# Con half_life=1 dia, despues de 1 dia recency = 0.5
|
||||
score = hotness_score(50, updated_at, now=NOW, half_life_days=1.0)
|
||||
freq = 1.0 / (1.0 + math.exp(-math.log1p(50)))
|
||||
expected = freq * 0.5
|
||||
assert abs(score - expected) < 1e-6, f"Expected {expected}, got {score}"
|
||||
@@ -0,0 +1,40 @@
|
||||
---
|
||||
name: melt
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def melt(rows: list[dict], id_vars: list[str], value_vars: list[str] | None = None, var_name: str = 'variable', value_name: str = 'value') -> list[dict]"
|
||||
description: "Inversa de pivot. Convierte columnas en filas (formato largo). Cada combinacion de id_vars + value_var genera una fila. Si value_vars es None, derrite todas las columnas no-id."
|
||||
tags: [datascience, tabular, melt, unpivot, transform, python]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: true
|
||||
tests:
|
||||
- "Melt basico"
|
||||
- "Multiples id_vars"
|
||||
- "value_vars None derrite todas las columnas no-id"
|
||||
- "Fila con campo faltante en value_vars"
|
||||
test_file_path: "python/functions/datascience/melt_test.py"
|
||||
file_path: "python/functions/datascience/melt.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
rows = [{"region": "US", "q1": 10, "q2": 20}]
|
||||
melt(rows, id_vars=["region"], value_vars=["q1", "q2"])
|
||||
# [{"region": "US", "variable": "q1", "value": 10},
|
||||
# {"region": "US", "variable": "q2", "value": 20}]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura sin dependencias externas.
|
||||
Si un campo de value_vars no existe en la fila, su valor sera None.
|
||||
El parametro value_vars=None es util cuando se desconoce el schema exacto.
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Melt (unpivot) para datos tabulares list[dict]."""
|
||||
|
||||
|
||||
def melt(
|
||||
rows: list[dict],
|
||||
id_vars: list[str],
|
||||
value_vars: list[str] | None = None,
|
||||
var_name: str = "variable",
|
||||
value_name: str = "value",
|
||||
) -> list[dict]:
|
||||
"""Convierte columnas en filas (formato largo). Inversa de pivot.
|
||||
|
||||
Cada combinacion de id_vars + value_var genera una fila nueva.
|
||||
Si value_vars es None, se usan todas las columnas que no esten en id_vars.
|
||||
|
||||
Args:
|
||||
rows: Lista de dicts en formato ancho.
|
||||
id_vars: Columnas que se mantienen como identificadores en cada fila.
|
||||
value_vars: Columnas a convertir en filas. None = todas las no-id.
|
||||
var_name: Nombre de la columna que contendra los nombres de variables.
|
||||
value_name: Nombre de la columna que contendra los valores.
|
||||
|
||||
Returns:
|
||||
Lista de dicts en formato largo con una fila por combinacion id+variable.
|
||||
"""
|
||||
result = []
|
||||
for row in rows:
|
||||
# Determinar que columnas derretir
|
||||
if value_vars is None:
|
||||
vars_to_melt = [k for k in row if k not in id_vars]
|
||||
else:
|
||||
vars_to_melt = value_vars
|
||||
|
||||
for var in vars_to_melt:
|
||||
new_row: dict = {k: row.get(k) for k in id_vars}
|
||||
new_row[var_name] = var
|
||||
new_row[value_name] = row.get(var)
|
||||
result.append(new_row)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Tests para melt."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from melt import melt
|
||||
|
||||
|
||||
def test_melt_basico():
|
||||
"""Melt basico."""
|
||||
rows = [{"region": "US", "q1": 10, "q2": 20}]
|
||||
result = melt(rows, id_vars=["region"], value_vars=["q1", "q2"])
|
||||
assert len(result) == 2
|
||||
assert result[0] == {"region": "US", "variable": "q1", "value": 10}
|
||||
assert result[1] == {"region": "US", "variable": "q2", "value": 20}
|
||||
|
||||
|
||||
def test_melt_multiples_id_vars():
|
||||
"""Multiples id_vars."""
|
||||
rows = [{"region": "US", "year": 2023, "q1": 10, "q2": 20}]
|
||||
result = melt(rows, id_vars=["region", "year"], value_vars=["q1", "q2"])
|
||||
assert len(result) == 2
|
||||
assert result[0]["region"] == "US"
|
||||
assert result[0]["year"] == 2023
|
||||
assert result[0]["variable"] == "q1"
|
||||
assert result[0]["value"] == 10
|
||||
assert result[1]["variable"] == "q2"
|
||||
assert result[1]["value"] == 20
|
||||
|
||||
|
||||
def test_melt_value_vars_none_derrite_todas_las_columnas_no_id():
|
||||
"""value_vars None derrite todas las columnas no-id."""
|
||||
rows = [{"id": 1, "a": 10, "b": 20, "c": 30}]
|
||||
result = melt(rows, id_vars=["id"])
|
||||
assert len(result) == 3
|
||||
vars_found = {r["variable"] for r in result}
|
||||
assert vars_found == {"a", "b", "c"}
|
||||
values_found = {r["value"] for r in result}
|
||||
assert values_found == {10, 20, 30}
|
||||
|
||||
|
||||
def test_melt_fila_con_campo_faltante_en_value_vars():
|
||||
"""Fila con campo faltante en value_vars."""
|
||||
rows = [{"region": "US", "q1": 10}] # q2 no existe
|
||||
result = melt(rows, id_vars=["region"], value_vars=["q1", "q2"])
|
||||
assert len(result) == 2
|
||||
q2_row = next(r for r in result if r["variable"] == "q2")
|
||||
assert q2_row["value"] is None
|
||||
@@ -0,0 +1,68 @@
|
||||
---
|
||||
name: merge_graphs
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def merge_graphs(graphs: list[dict], entity_key: str = 'name', similarity_threshold: float = 0.85) -> dict"
|
||||
description: "Mergea multiples grafos de conocimiento en uno deduplicando entities por similitud de nombre (Levenshtein normalizado). Relaciones se re-apuntan a las entities canonicas. Atributos se combinan por union."
|
||||
tags: [graph, merge, deduplication, knowledge-graph, levenshtein, similarity, datascience]
|
||||
uses_functions: [levenshtein_distance_py_cybersecurity]
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [sys, os]
|
||||
tested: true
|
||||
tests:
|
||||
- "dos grafos con entity duplicada → merge"
|
||||
- "entities similares pero bajo threshold → no merge"
|
||||
- "relaciones re-apuntadas correctamente"
|
||||
- "merge log registra cada merge"
|
||||
- "tres grafos → merge transitivo"
|
||||
- "grafos sin overlap → concatenacion simple"
|
||||
test_file_path: "python/functions/datascience/merge_graphs_test.py"
|
||||
file_path: "python/functions/datascience/merge_graphs.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
g1 = {
|
||||
"entities": [
|
||||
{"id": "1", "name": "Alice Corp", "type": "company"},
|
||||
{"id": "2", "name": "Bob", "type": "person"},
|
||||
],
|
||||
"relations": [
|
||||
{"source_id": "2", "target_id": "1", "relation_type": "works_at"},
|
||||
],
|
||||
}
|
||||
g2 = {
|
||||
"entities": [
|
||||
{"id": "3", "name": "Alice Corp.", "type": "company", "country": "US"},
|
||||
],
|
||||
"relations": [],
|
||||
}
|
||||
|
||||
result = merge_graphs([g1, g2], similarity_threshold=0.85)
|
||||
# result["entities"] -> 2 entities (Alice Corp mergeada, Bob)
|
||||
# result["merge_log"] -> [{"merged": ["3", "1"], "into": "1", "similarity": 0.909}]
|
||||
# "Alice Corp." mergeada en "Alice Corp" porque similitud > 0.85
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura. Reutiliza `levenshtein_distance_py_cybersecurity` para calcular similitud normalizada entre nombres.
|
||||
|
||||
**Algoritmo de merge transitivo**: si A~B y B~C, entonces A, B, C se mergean en uno solo. Se implementa via union-find (path compression simple).
|
||||
|
||||
**Eleccion de canonical**: la entity con mas campos no-null gana. En caso de empate, la primera encontrada en el par.
|
||||
|
||||
**Conflictos de atributos**: si ambas entities tienen un campo con valor, el canonical conserva el suyo (primero gana). Solo se copian campos que el canonical no tiene o tiene null.
|
||||
|
||||
**Deduplicacion de relaciones**: por (source_id, target_id, relation_type). Si dos relaciones son identicas tras re-apuntar los IDs, se conserva la primera encontrada.
|
||||
|
||||
**Complejidad**: O(n^2) en numero de entities por la comparacion de pares. Adecuado para grafos de knowledge tipicos (< 10K entities). Para grafos muy grandes, usar indexado por prefijo antes de comparar.
|
||||
|
||||
**Importacion**: intenta importar `levenshtein_distance` desde el paquete `cybersecurity` del registry. Si no esta disponible, usa una reimplementacion inline equivalente.
|
||||
@@ -0,0 +1,169 @@
|
||||
"""merge_graphs — mergea multiples grafos de conocimiento deduplicando entities por similitud."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Importar levenshtein_distance desde el registry
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "cybersecurity"))
|
||||
try:
|
||||
from cybersecurity import levenshtein_distance
|
||||
except ImportError:
|
||||
# Fallback: reimplementacion inline si el paquete no esta disponible
|
||||
def levenshtein_distance(a: str, b: str) -> int:
|
||||
"""Calcula la distancia de Levenshtein entre dos strings."""
|
||||
if len(a) < len(b):
|
||||
return levenshtein_distance(b, a)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
prev_row = list(range(len(b) + 1))
|
||||
for i, ca in enumerate(a):
|
||||
curr_row = [i + 1]
|
||||
for j, cb in enumerate(b):
|
||||
cost = 0 if ca == cb else 1
|
||||
curr_row.append(
|
||||
min(curr_row[j] + 1, prev_row[j + 1] + 1, prev_row[j] + cost)
|
||||
)
|
||||
prev_row = curr_row
|
||||
return prev_row[-1]
|
||||
|
||||
|
||||
def _name_similarity(a: str, b: str) -> float:
|
||||
"""Similitud de Levenshtein normalizada entre 0 y 1."""
|
||||
if not a and not b:
|
||||
return 1.0
|
||||
max_len = max(len(a), len(b))
|
||||
if max_len == 0:
|
||||
return 1.0
|
||||
dist = levenshtein_distance(a.lower(), b.lower())
|
||||
return 1.0 - dist / max_len
|
||||
|
||||
|
||||
def _count_non_null_fields(entity: dict) -> int:
|
||||
"""Cuenta campos con valor no-None."""
|
||||
return sum(1 for v in entity.values() if v is not None)
|
||||
|
||||
|
||||
def _merge_two_entities(canonical: dict, other: dict) -> dict:
|
||||
"""Combina dos entities: union de campos, ultimo gana en conflictos."""
|
||||
merged = dict(canonical)
|
||||
for k, v in other.items():
|
||||
if k not in merged or merged[k] is None:
|
||||
merged[k] = v
|
||||
# Si ambos tienen valor, el canonical (primero) gana — no sobreescribir
|
||||
return merged
|
||||
|
||||
|
||||
def merge_graphs(
|
||||
graphs: list[dict],
|
||||
entity_key: str = "name",
|
||||
similarity_threshold: float = 0.85,
|
||||
) -> dict:
|
||||
"""Mergea multiples grafos de conocimiento en uno, deduplicando entities por similitud.
|
||||
|
||||
Algoritmo:
|
||||
1. Juntar todas las entities de todos los grafos (con ID de origen).
|
||||
2. Para cada par con similitud de nombre >= threshold, mergear.
|
||||
3. Elegir entity canonica (la que tiene mas campos no-null).
|
||||
4. Re-apuntar relaciones al ID canonico.
|
||||
5. Deduplicar relaciones identicas (mismo source, target, type).
|
||||
6. Registrar cada merge en merge_log.
|
||||
|
||||
Args:
|
||||
graphs: Lista de grafos. Cada grafo es un dict con keys:
|
||||
"entities" (list[dict]) y "relations" (list[dict]).
|
||||
Las entities deben tener "id" y el campo entity_key.
|
||||
entity_key: Campo de texto usado para calcular similitud. Default "name".
|
||||
similarity_threshold: Umbral de similitud Levenshtein normalizada [0,1].
|
||||
Default 0.85.
|
||||
|
||||
Returns:
|
||||
Dict con keys: entities, relations, merge_log.
|
||||
"""
|
||||
# Recopilar todas las entities y relaciones
|
||||
all_entities: list[dict] = []
|
||||
all_relations: list[dict] = []
|
||||
|
||||
for graph in graphs:
|
||||
all_entities.extend(graph.get("entities", []))
|
||||
all_relations.extend(graph.get("relations", []))
|
||||
|
||||
# Construir union-find para agrupar entities similares
|
||||
# id_map: entity_id original -> entity_id canonico
|
||||
id_map: dict[str, str] = {e["id"]: e["id"] for e in all_entities if "id" in e}
|
||||
entity_by_id: dict[str, dict] = {e["id"]: e for e in all_entities if "id" in e}
|
||||
|
||||
merge_log: list[dict] = []
|
||||
|
||||
def find_canonical(eid: str) -> str:
|
||||
while id_map.get(eid, eid) != eid:
|
||||
eid = id_map[eid]
|
||||
return eid
|
||||
|
||||
entity_ids = [e["id"] for e in all_entities if "id" in e]
|
||||
|
||||
# Comparar todos los pares (O(n^2) — aceptable para grafos de knowledge tipicos)
|
||||
for i in range(len(entity_ids)):
|
||||
for j in range(i + 1, len(entity_ids)):
|
||||
id_i = find_canonical(entity_ids[i])
|
||||
id_j = find_canonical(entity_ids[j])
|
||||
|
||||
if id_i == id_j:
|
||||
continue # ya mergeados
|
||||
|
||||
e_i = entity_by_id.get(id_i)
|
||||
e_j = entity_by_id.get(id_j)
|
||||
|
||||
if e_i is None or e_j is None:
|
||||
continue
|
||||
|
||||
name_i = str(e_i.get(entity_key, ""))
|
||||
name_j = str(e_j.get(entity_key, ""))
|
||||
|
||||
sim = _name_similarity(name_i, name_j)
|
||||
if sim >= similarity_threshold:
|
||||
# Elegir canonical: el que tiene mas campos no-null
|
||||
if _count_non_null_fields(e_i) >= _count_non_null_fields(e_j):
|
||||
canonical_id, other_id = id_i, id_j
|
||||
else:
|
||||
canonical_id, other_id = id_j, id_i
|
||||
|
||||
# Mergear datos
|
||||
merged = _merge_two_entities(entity_by_id[canonical_id], entity_by_id[other_id])
|
||||
entity_by_id[canonical_id] = merged
|
||||
|
||||
# Redirigir other_id -> canonical_id
|
||||
id_map[other_id] = canonical_id
|
||||
|
||||
merge_log.append({
|
||||
"merged": [other_id, canonical_id],
|
||||
"into": canonical_id,
|
||||
"similarity": round(sim, 4),
|
||||
})
|
||||
|
||||
# Construir lista final de entities (solo canonicas)
|
||||
canonical_ids = {eid for eid in entity_ids if find_canonical(eid) == eid}
|
||||
final_entities = [entity_by_id[eid] for eid in canonical_ids if eid in entity_by_id]
|
||||
|
||||
# Re-apuntar relaciones a IDs canonicos
|
||||
final_relations_set: dict[tuple, dict] = {}
|
||||
for rel in all_relations:
|
||||
new_rel = dict(rel)
|
||||
if "source_id" in new_rel:
|
||||
new_rel["source_id"] = find_canonical(new_rel["source_id"])
|
||||
if "target_id" in new_rel:
|
||||
new_rel["target_id"] = find_canonical(new_rel["target_id"])
|
||||
|
||||
# Deduplicar por (source_id, target_id, relation_type)
|
||||
rel_key = (
|
||||
new_rel.get("source_id", ""),
|
||||
new_rel.get("target_id", ""),
|
||||
new_rel.get("relation_type", ""),
|
||||
)
|
||||
if rel_key not in final_relations_set:
|
||||
final_relations_set[rel_key] = new_rel
|
||||
|
||||
return {
|
||||
"entities": final_entities,
|
||||
"relations": list(final_relations_set.values()),
|
||||
"merge_log": merge_log,
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Tests para merge_graphs."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from merge_graphs import merge_graphs
|
||||
|
||||
|
||||
def test_dos_grafos_con_entity_duplicada_merge():
|
||||
g1 = {
|
||||
"entities": [{"id": "1", "name": "Alice Corp", "type": "company"}],
|
||||
"relations": [],
|
||||
}
|
||||
g2 = {
|
||||
"entities": [{"id": "2", "name": "Alice Corp", "type": "company", "country": "US"}],
|
||||
"relations": [],
|
||||
}
|
||||
result = merge_graphs([g1, g2], similarity_threshold=0.95)
|
||||
# Nombres identicos -> similitud 1.0 -> deben mergearse
|
||||
assert len(result["entities"]) == 1
|
||||
assert len(result["merge_log"]) == 1
|
||||
merged = result["entities"][0]
|
||||
# El merge debe preservar "country" aunque el canonical no lo tuviera
|
||||
assert merged.get("country") == "US" or merged.get("name") == "Alice Corp"
|
||||
|
||||
|
||||
def test_entities_similares_pero_bajo_threshold_no_merge():
|
||||
g1 = {
|
||||
"entities": [{"id": "1", "name": "Alice"}],
|
||||
"relations": [],
|
||||
}
|
||||
g2 = {
|
||||
"entities": [{"id": "2", "name": "Bob"}],
|
||||
"relations": [],
|
||||
}
|
||||
result = merge_graphs([g1, g2], similarity_threshold=0.85)
|
||||
# Alice y Bob son muy distintos -> no merge
|
||||
assert len(result["entities"]) == 2
|
||||
assert len(result["merge_log"]) == 0
|
||||
|
||||
|
||||
def test_relaciones_re_apuntadas_correctamente():
|
||||
g1 = {
|
||||
"entities": [
|
||||
{"id": "1", "name": "Alice Corp"},
|
||||
{"id": "2", "name": "Bob"},
|
||||
],
|
||||
"relations": [
|
||||
{"source_id": "2", "target_id": "1", "relation_type": "works_at"},
|
||||
],
|
||||
}
|
||||
g2 = {
|
||||
"entities": [
|
||||
{"id": "3", "name": "Alice Corp"}, # duplicada de id=1
|
||||
],
|
||||
"relations": [
|
||||
{"source_id": "3", "target_id": "2", "relation_type": "knows"},
|
||||
],
|
||||
}
|
||||
result = merge_graphs([g1, g2], similarity_threshold=0.95)
|
||||
# Entity 3 mergeada en 1 -> relacion source_id=3 debe apuntar al canonical de 1
|
||||
assert len(result["entities"]) == 2 # Alice Corp + Bob
|
||||
# Verificar que las relaciones tienen IDs canonicos (no "3")
|
||||
for rel in result["relations"]:
|
||||
assert rel["source_id"] != "3"
|
||||
assert rel["target_id"] != "3"
|
||||
|
||||
|
||||
def test_merge_log_registra_cada_merge():
|
||||
g1 = {
|
||||
"entities": [{"id": "1", "name": "OpenAI"}],
|
||||
"relations": [],
|
||||
}
|
||||
g2 = {
|
||||
"entities": [{"id": "2", "name": "OpenAI"}],
|
||||
"relations": [],
|
||||
}
|
||||
result = merge_graphs([g1, g2], similarity_threshold=0.9)
|
||||
assert len(result["merge_log"]) == 1
|
||||
log = result["merge_log"][0]
|
||||
assert "merged" in log
|
||||
assert "into" in log
|
||||
assert "similarity" in log
|
||||
assert log["similarity"] == 1.0
|
||||
|
||||
|
||||
def test_tres_grafos_merge_transitivo():
|
||||
# A~B y B~C -> A, B, C deben mergearse en uno
|
||||
g1 = {"entities": [{"id": "1", "name": "Acme Corp"}], "relations": []}
|
||||
g2 = {"entities": [{"id": "2", "name": "Acme Corp"}], "relations": []}
|
||||
g3 = {"entities": [{"id": "3", "name": "Acme Corp"}], "relations": []}
|
||||
result = merge_graphs([g1, g2, g3], similarity_threshold=0.9)
|
||||
assert len(result["entities"]) == 1
|
||||
|
||||
|
||||
def test_grafos_sin_overlap_concatenacion_simple():
|
||||
g1 = {
|
||||
"entities": [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}],
|
||||
"relations": [{"source_id": "1", "target_id": "2", "relation_type": "knows"}],
|
||||
}
|
||||
g2 = {
|
||||
"entities": [{"id": "3", "name": "Carol"}, {"id": "4", "name": "Dave"}],
|
||||
"relations": [{"source_id": "3", "target_id": "4", "relation_type": "knows"}],
|
||||
}
|
||||
result = merge_graphs([g1, g2], similarity_threshold=0.85)
|
||||
# Ninguna entity similar -> concatenacion directa
|
||||
assert len(result["entities"]) == 4
|
||||
assert len(result["relations"]) == 2
|
||||
assert len(result["merge_log"]) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dos_grafos_con_entity_duplicada_merge()
|
||||
test_entities_similares_pero_bajo_threshold_no_merge()
|
||||
test_relaciones_re_apuntadas_correctamente()
|
||||
test_merge_log_registra_cada_merge()
|
||||
test_tres_grafos_merge_transitivo()
|
||||
test_grafos_sin_overlap_concatenacion_simple()
|
||||
print("All tests passed.")
|
||||
@@ -0,0 +1,44 @@
|
||||
---
|
||||
name: pivot
|
||||
kind: function
|
||||
lang: py
|
||||
domain: datascience
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "def pivot(rows: list[dict], index: str, columns: str, values: str, agg: str = 'sum') -> list[dict]"
|
||||
description: "Pivot table sin pandas. Agrupa por index, expande valores unicos de columns como nuevas columnas y agrega values con la funcion indicada (sum, count, mean, min, max, first, last)."
|
||||
tags: [datascience, tabular, pivot, transform, aggregation, python]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: ["collections"]
|
||||
tested: true
|
||||
tests:
|
||||
- "Pivot basico con sum"
|
||||
- "Pivot con count y mean"
|
||||
- "Valores faltantes rellenados con 0"
|
||||
- "Una sola fila"
|
||||
- "Multiples valores por celda requieren agregacion"
|
||||
test_file_path: "python/functions/datascience/pivot_test.py"
|
||||
file_path: "python/functions/datascience/pivot.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
rows = [
|
||||
{"region": "US", "product": "A", "sales": 10},
|
||||
{"region": "US", "product": "B", "sales": 20},
|
||||
{"region": "EU", "product": "A", "sales": 15},
|
||||
]
|
||||
pivot(rows, index="region", columns="product", values="sales")
|
||||
# [{"region": "US", "A": 10, "B": 20}, {"region": "EU", "A": 15, "B": 0}]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura sin dependencias externas (solo collections.defaultdict de stdlib).
|
||||
Preserva el orden de aparicion de los valores de index y columns.
|
||||
Valores numericos faltantes se rellenan con 0; no numericos con None.
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Pivot table sin pandas para datos tabulares list[dict]."""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def pivot(
|
||||
rows: list[dict],
|
||||
index: str,
|
||||
columns: str,
|
||||
values: str,
|
||||
agg: str = "sum",
|
||||
) -> list[dict]:
|
||||
"""Transforma datos del formato largo al formato ancho (pivot table).
|
||||
|
||||
Agrupa por `index`, expande los valores unicos de `columns` como nuevas
|
||||
columnas y agrega la columna `values` con la funcion indicada.
|
||||
|
||||
Args:
|
||||
rows: Lista de dicts con los datos en formato largo.
|
||||
index: Nombre de la columna que actua como indice de filas.
|
||||
columns: Nombre de la columna cuyos valores unicos se convierten en columnas.
|
||||
values: Nombre de la columna cuyos valores se agregan.
|
||||
agg: Funcion de agregacion: sum, count, mean, min, max, first, last.
|
||||
|
||||
Returns:
|
||||
Lista de dicts con una fila por valor unico de index y una columna
|
||||
por cada valor unico de columns. Valores numericos faltantes rellenados
|
||||
con 0, valores no numericos con None.
|
||||
"""
|
||||
# Recopilar valores unicos de columns (orden de aparicion)
|
||||
col_values: list = []
|
||||
seen_cols: set = set()
|
||||
index_order: list = []
|
||||
seen_index: set = set()
|
||||
|
||||
for row in rows:
|
||||
idx = row.get(index)
|
||||
col = row.get(columns)
|
||||
if idx not in seen_index:
|
||||
seen_index.add(idx)
|
||||
index_order.append(idx)
|
||||
if col not in seen_cols:
|
||||
seen_cols.add(col)
|
||||
col_values.append(col)
|
||||
|
||||
# Acumular: groups[index_val][col_val] = lista de values
|
||||
groups: dict[any, dict[any, list]] = defaultdict(lambda: defaultdict(list))
|
||||
for row in rows:
|
||||
idx = row.get(index)
|
||||
col = row.get(columns)
|
||||
val = row.get(values)
|
||||
if val is not None:
|
||||
groups[idx][col].append(val)
|
||||
|
||||
# Determinar si los valores son numericos (para relleno de 0)
|
||||
sample_vals = [v for g in groups.values() for vs in g.values() for v in vs]
|
||||
is_numeric = all(isinstance(v, (int, float)) for v in sample_vals) if sample_vals else True
|
||||
|
||||
def _aggregate(vals: list, func: str):
|
||||
if not vals:
|
||||
return 0 if is_numeric else None
|
||||
if func == "sum":
|
||||
return sum(vals)
|
||||
if func == "count":
|
||||
return len(vals)
|
||||
if func == "mean":
|
||||
return sum(vals) / len(vals)
|
||||
if func == "min":
|
||||
return min(vals)
|
||||
if func == "max":
|
||||
return max(vals)
|
||||
if func == "first":
|
||||
return vals[0]
|
||||
if func == "last":
|
||||
return vals[-1]
|
||||
raise ValueError(f"Funcion de agregacion no soportada: {func}")
|
||||
|
||||
result = []
|
||||
for idx in index_order:
|
||||
record: dict = {index: idx}
|
||||
for col in col_values:
|
||||
vals = groups[idx][col]
|
||||
if vals:
|
||||
record[col] = _aggregate(vals, agg)
|
||||
else:
|
||||
record[col] = 0 if is_numeric else None
|
||||
result.append(record)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests para pivot."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from pivot import pivot
|
||||
|
||||
|
||||
def test_pivot_basico_con_sum():
|
||||
"""Pivot basico con sum."""
|
||||
rows = [
|
||||
{"region": "US", "product": "A", "sales": 10},
|
||||
{"region": "US", "product": "B", "sales": 20},
|
||||
{"region": "EU", "product": "A", "sales": 15},
|
||||
]
|
||||
result = pivot(rows, index="region", columns="product", values="sales")
|
||||
assert len(result) == 2
|
||||
us = next(r for r in result if r["region"] == "US")
|
||||
eu = next(r for r in result if r["region"] == "EU")
|
||||
assert us["A"] == 10
|
||||
assert us["B"] == 20
|
||||
assert eu["A"] == 15
|
||||
assert eu["B"] == 0
|
||||
|
||||
|
||||
def test_pivot_con_count_y_mean():
|
||||
"""Pivot con count y mean."""
|
||||
rows = [
|
||||
{"region": "US", "product": "A", "sales": 10},
|
||||
{"region": "US", "product": "A", "sales": 20},
|
||||
{"region": "EU", "product": "A", "sales": 15},
|
||||
]
|
||||
result_count = pivot(rows, index="region", columns="product", values="sales", agg="count")
|
||||
us_count = next(r for r in result_count if r["region"] == "US")
|
||||
assert us_count["A"] == 2
|
||||
|
||||
result_mean = pivot(rows, index="region", columns="product", values="sales", agg="mean")
|
||||
us_mean = next(r for r in result_mean if r["region"] == "US")
|
||||
assert us_mean["A"] == 15.0
|
||||
|
||||
|
||||
def test_pivot_valores_faltantes_rellenados_con_0():
|
||||
"""Valores faltantes rellenados con 0."""
|
||||
rows = [
|
||||
{"region": "US", "product": "A", "sales": 5},
|
||||
{"region": "EU", "product": "B", "sales": 8},
|
||||
]
|
||||
result = pivot(rows, index="region", columns="product", values="sales")
|
||||
us = next(r for r in result if r["region"] == "US")
|
||||
eu = next(r for r in result if r["region"] == "EU")
|
||||
assert us["B"] == 0
|
||||
assert eu["A"] == 0
|
||||
|
||||
|
||||
def test_pivot_una_sola_fila():
|
||||
"""Una sola fila."""
|
||||
rows = [{"region": "US", "product": "A", "sales": 42}]
|
||||
result = pivot(rows, index="region", columns="product", values="sales")
|
||||
assert len(result) == 1
|
||||
assert result[0]["region"] == "US"
|
||||
assert result[0]["A"] == 42
|
||||
|
||||
|
||||
def test_pivot_multiples_valores_por_celda_requieren_agregacion():
|
||||
"""Multiples valores por celda requieren agregacion."""
|
||||
rows = [
|
||||
{"region": "US", "product": "A", "sales": 10},
|
||||
{"region": "US", "product": "A", "sales": 30},
|
||||
]
|
||||
result_sum = pivot(rows, index="region", columns="product", values="sales", agg="sum")
|
||||
assert result_sum[0]["A"] == 40
|
||||
|
||||
result_min = pivot(rows, index="region", columns="product", values="sales", agg="min")
|
||||
assert result_min[0]["A"] == 10
|
||||
|
||||
result_max = pivot(rows, index="region", columns="product", values="sales", agg="max")
|
||||
assert result_max[0]["A"] == 30
|
||||
@@ -0,0 +1,48 @@
|
||||
---
|
||||
name: avellaneda_stoikov_quotes
|
||||
kind: function
|
||||
lang: py
|
||||
domain: finance
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "avellaneda_stoikov_quotes(mid_price: float, inventory: float, gamma: float, sigma: float, spread_base: float, n_levels: int, qty_base: float) -> list[dict]"
|
||||
description: "Genera ordenes de market maker usando el modelo Avellaneda-Stoikov. Calcula precio de reserva y half spread optimos segun inventario y volatilidad."
|
||||
tags: [simulation, market-making, avellaneda-stoikov, montecarlo, finance, order-book]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: []
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/finance/finance.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
orders = avellaneda_stoikov_quotes(
|
||||
mid_price=100.0,
|
||||
inventory=0.0,
|
||||
gamma=0.1,
|
||||
sigma=0.02,
|
||||
spread_base=0.5,
|
||||
n_levels=3,
|
||||
qty_base=10.0,
|
||||
)
|
||||
# [
|
||||
# {'side': 'buy', 'price': 99.75, 'qty': 10.0},
|
||||
# {'side': 'sell', 'price': 100.25, 'qty': 10.0},
|
||||
# ...
|
||||
# ]
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura — sin aleatoriedad.
|
||||
`gamma` controla la aversion al riesgo de inventario: mayor gamma = spreads mas amplios.
|
||||
`inventory` positivo sesga los quotes hacia venta (reduce inventario largo).
|
||||
Cada nivel adicional ensancha el spread en `half_spread * 0.5` y aumenta la cantidad en `qty_base * 0.5`.
|
||||
Ordenes con precio <= 0 se descartan automaticamente.
|
||||
@@ -135,3 +135,104 @@ def annualized_volatility(returns: list, periods_per_year: float) -> float:
|
||||
mean = sum(returns) / n
|
||||
variance = sum((r - mean) ** 2 for r in returns) / (n - 1)
|
||||
return math.sqrt(variance) * math.sqrt(periods_per_year)
|
||||
|
||||
|
||||
def generate_gbm_prices(
|
||||
initial_price: float,
|
||||
n_ticks: int,
|
||||
sigma: float,
|
||||
mu: float = 0.0,
|
||||
jump_intensity: float = 0.0,
|
||||
jump_size_std: float = 0.05,
|
||||
seed: int = 42,
|
||||
) -> list:
|
||||
"""Genera serie de precios fundamentales con Geometric Brownian Motion + jump-diffusion.
|
||||
|
||||
S(t+1) = S(t) * exp((mu - sigma^2/2)*dt + sigma*sqrt(dt)*Z + J*N)
|
||||
donde Z ~ N(0,1), N ~ Bernoulli(jump_intensity), J ~ N(0, jump_size_std)
|
||||
"""
|
||||
import numpy as np
|
||||
rng = np.random.default_rng(seed)
|
||||
prices = [0.0] * n_ticks
|
||||
prices[0] = initial_price
|
||||
dt = 1.0
|
||||
for t in range(1, n_ticks):
|
||||
z = rng.standard_normal()
|
||||
gbm = (mu - 0.5 * sigma**2) * dt + sigma * np.sqrt(dt) * z
|
||||
jump = 0.0
|
||||
if jump_intensity > 0 and rng.random() < jump_intensity:
|
||||
jump = rng.normal(0, jump_size_std)
|
||||
prices[t] = prices[t - 1] * np.exp(gbm + jump)
|
||||
return prices
|
||||
|
||||
|
||||
def avellaneda_stoikov_quotes(
|
||||
mid_price: float,
|
||||
inventory: float,
|
||||
gamma: float,
|
||||
sigma: float,
|
||||
spread_base: float,
|
||||
n_levels: int = 3,
|
||||
qty_base: float = 10.0,
|
||||
) -> list:
|
||||
"""Genera ordenes de market maker usando el modelo Avellaneda-Stoikov.
|
||||
|
||||
Precio de reserva: r = mid - inventory * gamma * sigma^2
|
||||
Half spread: delta = spread_base/2 + gamma * sigma^2/2
|
||||
|
||||
Retorna lista de dicts con keys: side, price, qty
|
||||
"""
|
||||
reservation = mid_price - inventory * gamma * sigma**2
|
||||
half_spread = spread_base / 2 + gamma * sigma**2 / 2
|
||||
orders = []
|
||||
for level in range(n_levels):
|
||||
offset = level * half_spread * 0.5
|
||||
qty = qty_base * (1 + level * 0.5)
|
||||
bid_price = round(reservation - half_spread - offset, 2)
|
||||
ask_price = round(reservation + half_spread + offset, 2)
|
||||
if bid_price > 0:
|
||||
orders.append({'side': 'buy', 'price': bid_price, 'qty': qty})
|
||||
if ask_price > 0:
|
||||
orders.append({'side': 'sell', 'price': ask_price, 'qty': qty})
|
||||
return orders
|
||||
|
||||
|
||||
def generate_taker_order(
|
||||
alpha: float = 2.0,
|
||||
size_min: float = 1.0,
|
||||
size_max: float = 100.0,
|
||||
buy_prob: float = 0.5,
|
||||
seed: int | None = None,
|
||||
) -> dict:
|
||||
"""Genera una market order de taker con tamano power-law (Pareto).
|
||||
|
||||
P(size > x) ~ x^(-alpha). Alpha bajo = mas ballenas.
|
||||
Retorna dict con keys: side, qty
|
||||
"""
|
||||
import numpy as np
|
||||
rng = np.random.default_rng(seed)
|
||||
side = 'buy' if rng.random() < buy_prob else 'sell'
|
||||
raw_size = (rng.pareto(alpha) + 1) * size_min
|
||||
size = min(round(raw_size, 1), size_max)
|
||||
return {'side': side, 'qty': size}
|
||||
|
||||
|
||||
def hawkes_intensity(
|
||||
base_rate: float,
|
||||
hawkes_alpha: float,
|
||||
hawkes_beta: float,
|
||||
event_times: list,
|
||||
current_time: float,
|
||||
) -> float:
|
||||
"""Calcula la intensidad lambda(t) de un proceso de Hawkes en el tiempo actual.
|
||||
|
||||
lambda(t) = base_rate + sum(alpha * exp(-beta * (t - ti)))
|
||||
donde ti son los tiempos de eventos pasados.
|
||||
"""
|
||||
import numpy as np
|
||||
excitation = sum(
|
||||
hawkes_alpha * np.exp(-hawkes_beta * (current_time - ti))
|
||||
for ti in event_times
|
||||
if ti < current_time
|
||||
)
|
||||
return max(0.0, base_rate + excitation)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
---
|
||||
name: generate_gbm_prices
|
||||
kind: function
|
||||
lang: py
|
||||
domain: finance
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "generate_gbm_prices(initial_price: float, n_ticks: int, sigma: float, mu: float, jump_intensity: float, jump_size_std: float, seed: int) -> list[float]"
|
||||
description: "Genera serie de precios fundamentales con Geometric Brownian Motion + jump-diffusion. S(t+1) = S(t) * exp((mu - sigma^2/2)*dt + sigma*sqrt(dt)*Z + J*N)."
|
||||
tags: [simulation, gbm, price, montecarlo, finance, stochastic]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [numpy]
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/finance/finance.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
prices = generate_gbm_prices(
|
||||
initial_price=100.0,
|
||||
n_ticks=1000,
|
||||
sigma=0.02,
|
||||
mu=0.0,
|
||||
jump_intensity=0.01,
|
||||
jump_size_std=0.05,
|
||||
seed=42,
|
||||
)
|
||||
# prices[0] == 100.0
|
||||
# len(prices) == 1000
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura — el seed fija el resultado deterministicamente.
|
||||
`jump_intensity=0.0` desactiva los saltos (GBM puro).
|
||||
`dt=1.0` por tick (tiempo discreto). Para tiempo continuo, ajustar sigma y mu en consecuencia.
|
||||
Requiere numpy para la generacion de numeros aleatorios y el calculo de exp.
|
||||
@@ -0,0 +1,41 @@
|
||||
---
|
||||
name: generate_taker_order
|
||||
kind: function
|
||||
lang: py
|
||||
domain: finance
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "generate_taker_order(alpha: float, size_min: float, size_max: float, buy_prob: float, seed: int | None) -> dict"
|
||||
description: "Genera una market order de taker con tamano distribuido segun power-law (Pareto). Alpha bajo produce ordenes mas grandes (ballenas)."
|
||||
tags: [simulation, taker, power-law, montecarlo, finance, order-book]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [numpy]
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/finance/finance.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
order = generate_taker_order(
|
||||
alpha=2.0,
|
||||
size_min=1.0,
|
||||
size_max=100.0,
|
||||
buy_prob=0.5,
|
||||
seed=42,
|
||||
)
|
||||
# {'side': 'buy', 'qty': 3.7}
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura cuando se fija seed. Con seed=None el resultado es no deterministico.
|
||||
La distribucion Pareto con alpha=2 modela bien la distribucion empirica de tamaños de ordenes en mercados reales.
|
||||
`size_max` actua como techo (clipping) para evitar ordenes extremas.
|
||||
Retorna dict con keys: `side` ('buy' o 'sell') y `qty` (float redondeado a 1 decimal).
|
||||
@@ -0,0 +1,43 @@
|
||||
---
|
||||
name: hawkes_intensity
|
||||
kind: function
|
||||
lang: py
|
||||
domain: finance
|
||||
version: "1.0.0"
|
||||
purity: pure
|
||||
signature: "hawkes_intensity(base_rate: float, hawkes_alpha: float, hawkes_beta: float, event_times: list[float], current_time: float) -> float"
|
||||
description: "Calcula la intensidad lambda(t) de un proceso de Hawkes en el tiempo actual. Modela la autocorrelacion temporal de eventos de mercado (rafagas de ordenes)."
|
||||
tags: [simulation, hawkes, stochastic-process, montecarlo, finance, point-process]
|
||||
uses_functions: []
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: ""
|
||||
imports: [numpy]
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/finance/finance.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
intensity = hawkes_intensity(
|
||||
base_rate=1.0,
|
||||
hawkes_alpha=0.8,
|
||||
hawkes_beta=2.0,
|
||||
event_times=[0.5, 1.2, 1.8],
|
||||
current_time=2.5,
|
||||
)
|
||||
# Intensidad > base_rate por excitacion de eventos pasados
|
||||
```
|
||||
|
||||
## Notas
|
||||
|
||||
Funcion pura — determinista dado el mismo historial de eventos.
|
||||
`hawkes_alpha` controla la magnitud del salto de intensidad por evento.
|
||||
`hawkes_beta` controla la velocidad de decaimiento (mayor beta = decaimiento mas rapido).
|
||||
La condicion de estabilidad del proceso es hawkes_alpha < hawkes_beta.
|
||||
Eventos con ti >= current_time se ignoran automaticamente.
|
||||
Retorna max(0.0, ...) para garantizar intensidad no negativa.
|
||||
@@ -0,0 +1,123 @@
|
||||
---
|
||||
name: extraction_pipeline
|
||||
kind: pipeline
|
||||
lang: py
|
||||
domain: pipelines
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def extraction_pipeline(file_path: str, entity_presets: list[dict], relation_types: list[str], llm_chat_json: Callable[[list[dict]], dict], chunk_size: int = 500, chunk_overlap: int = 50, confidence_threshold: float = 0.5, dedup_threshold: float = 0.85, on_progress: Callable[[str, float], None] | None = None) -> ExtractionResult"
|
||||
description: "Pipeline completa de extraccion de entidades y relaciones desde un documento. Orquesta extract_text_from_file -> preprocess_text -> split_text_into_chunks -> extract_entities_llm por chunk -> deduplicate_entities -> extract_relations_llm por chunk -> deduplicate_relations."
|
||||
tags: [pipeline, extraction, entities, relations, llm, nlp, fuzzygraph, datascience]
|
||||
uses_functions:
|
||||
- extract_text_from_file_py_core
|
||||
- preprocess_text_py_core
|
||||
- split_text_into_chunks_py_core
|
||||
- build_entity_schema_prompt_py_datascience
|
||||
- build_relation_schema_prompt_py_datascience
|
||||
- extract_entities_llm_py_datascience
|
||||
- extract_relations_llm_py_datascience
|
||||
- deduplicate_entities_py_datascience
|
||||
- deduplicate_relations_py_datascience
|
||||
uses_types:
|
||||
- entity_candidate_py_datascience
|
||||
- extraction_result_py_datascience
|
||||
- extraction_stats_py_datascience
|
||||
- relation_candidate_py_datascience
|
||||
returns:
|
||||
- extraction_result_py_datascience
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports:
|
||||
- time
|
||||
- warnings
|
||||
- typing.Callable
|
||||
tested: true
|
||||
tests:
|
||||
- "documento con entidades y relaciones retorna ExtractionResult completo"
|
||||
- "documento vacio retorna ExtractionResult con listas vacias"
|
||||
- "documento sin entidades detectables retorna listas vacias"
|
||||
- "archivo no encontrado lanza FileNotFoundError"
|
||||
- "entity presets vacio lanza ValueError"
|
||||
- "progress callback se invoca durante la ejecucion"
|
||||
- "stats se rellenan correctamente con conteos y tiempo"
|
||||
test_file_path: "python/functions/pipelines/extraction_pipeline_test.py"
|
||||
file_path: "python/functions/pipelines/extraction_pipeline.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```python
|
||||
from python.functions.pipelines.extraction_pipeline import extraction_pipeline
|
||||
|
||||
entity_presets = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "alias", "nationality"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"label": "Domain",
|
||||
"metadata_fields": ["fqdn", "registrar"],
|
||||
},
|
||||
]
|
||||
|
||||
relation_types = ["operates", "owns", "funds", "communicates_with", "related_to"]
|
||||
|
||||
# Inyectar un cliente LLM real
|
||||
def llm_chat_json(messages):
|
||||
# llamada al proveedor LLM elegido
|
||||
...
|
||||
|
||||
result = extraction_pipeline(
|
||||
file_path="report.pdf",
|
||||
entity_presets=entity_presets,
|
||||
relation_types=relation_types,
|
||||
llm_chat_json=llm_chat_json,
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
confidence_threshold=0.5,
|
||||
dedup_threshold=0.85,
|
||||
on_progress=lambda msg, pct: print(f"[{pct:.0%}] {msg}"),
|
||||
)
|
||||
|
||||
print(f"Entities: {len(result.entities)}, Relations: {len(result.relations)}")
|
||||
print(f"Stats: {result.stats}")
|
||||
|
||||
# Integrar con fuzzygraph / operations.db
|
||||
for entity in result.entities:
|
||||
db.add_entity(
|
||||
name=entity.name,
|
||||
type_ref=entity.type_ref,
|
||||
metadata=entity.attributes,
|
||||
)
|
||||
|
||||
for relation in result.relations:
|
||||
db.add_relation(
|
||||
name=relation.relation_type,
|
||||
from_entity=relation.from_id,
|
||||
to_entity=relation.to_id,
|
||||
)
|
||||
```
|
||||
|
||||
## Algoritmo
|
||||
|
||||
1. **Extract:** `extract_text_from_file(file_path)` — texto crudo desde PDF, TXT, Markdown
|
||||
2. **Preprocess:** `preprocess_text(text)` — normaliza espacios, caracteres especiales
|
||||
3. **Split:** `split_text_into_chunks(text, chunk_size, chunk_overlap)` — divide en ventanas solapadas
|
||||
4. **Extract entities per chunk (0-40%):** Para cada chunk llama `extract_entities_llm` con el schema de presets. Anota `source_chunk_index` en cada candidato
|
||||
5. **Filter:** filtra por `confidence >= confidence_threshold`
|
||||
6. **Deduplicate entities (40%):** `deduplicate_entities` con fuzzy matching, produce `entity_id_map`
|
||||
7. **Extract relations per chunk (40-80%):** Para cada chunk obtiene las entidades de ese chunk y llama `extract_relations_llm`
|
||||
8. **Deduplicate relations (80-100%):** `deduplicate_relations` resuelve nombres a IDs y colapsa duplicados
|
||||
9. **Return:** `ExtractionResult` con entidades, relaciones y stats del proceso
|
||||
|
||||
## Notas
|
||||
|
||||
- El parametro `llm_chat_json` inyecta el cliente LLM, sin acoplamiento a ningun proveedor (OpenAI, Anthropic, Ollama, etc.)
|
||||
- El progress callback cubre: 0-40% extraccion de entidades, 40-80% extraccion de relaciones, 80-100% deduplicacion
|
||||
- Si el archivo no existe lanza `FileNotFoundError` antes de cualquier llamada al LLM
|
||||
- Si `entity_presets` esta vacio lanza `ValueError`
|
||||
- Errores en chunks individuales se capturan con warnings y continuan (robustez)
|
||||
- Los `entity_id_map` de `deduplicate_entities` conectan nombres originales del texto con IDs UUID finales para `deduplicate_relations`
|
||||
- La retorna `ExtractionResult` esta lista para insertar en `operations.db` via `fn ops entity add` / `fn ops relation add`
|
||||
@@ -0,0 +1,211 @@
|
||||
"""Pipeline de extraccion de entidades y relaciones desde un documento."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
# Soporte para ejecucion desde la raiz del registry o desde el directorio del archivo
|
||||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
from python.functions.core.extract_text_from_file import extract_text_from_file
|
||||
from python.functions.core.core import preprocess_text
|
||||
from python.functions.core.split_text_into_chunks import split_text_into_chunks
|
||||
from python.functions.datascience.build_entity_schema_prompt import build_entity_schema_prompt
|
||||
from python.functions.datascience.build_relation_schema_prompt import build_relation_schema_prompt
|
||||
from python.functions.datascience.extract_entities_llm import extract_entities_llm
|
||||
from python.functions.datascience.extract_relations_llm import extract_relations_llm
|
||||
from python.functions.datascience.deduplicate_entities import deduplicate_entities
|
||||
from python.functions.datascience.deduplicate_relations import deduplicate_relations
|
||||
from python.types.datascience.entity_candidate import EntityCandidate
|
||||
from python.types.datascience.extraction_result import ExtractionResult
|
||||
from python.types.datascience.extraction_stats import ExtractionStats
|
||||
|
||||
|
||||
def extraction_pipeline(
|
||||
file_path: str,
|
||||
entity_presets: list[dict],
|
||||
relation_types: list[str],
|
||||
llm_chat_json: Callable[[list[dict]], dict],
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
confidence_threshold: float = 0.5,
|
||||
dedup_threshold: float = 0.85,
|
||||
on_progress: Callable[[str, float], None] | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""Pipeline completa de extraccion de entidades y relaciones desde un documento.
|
||||
|
||||
Orquesta extract_text_from_file -> preprocess_text -> split_text_into_chunks
|
||||
-> extract_entities_llm por chunk -> deduplicate_entities ->
|
||||
extract_relations_llm por chunk -> deduplicate_relations.
|
||||
|
||||
Args:
|
||||
file_path: ruta al archivo a procesar (PDF, Markdown, TXT).
|
||||
entity_presets: lista de dicts con type_ref, label y metadata_fields.
|
||||
Ejemplo: [{"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "nationality"]}]
|
||||
relation_types: tipos de relacion permitidos para extraccion.
|
||||
Ejemplo: ["funds", "employs", "communicates_with", "owns"]
|
||||
llm_chat_json: funcion inyectada que recibe messages OpenAI y retorna dict
|
||||
con la respuesta JSON ya parseada. Sin acoplamiento a ningun proveedor.
|
||||
chunk_size: numero de caracteres por chunk (default 500).
|
||||
chunk_overlap: overlap entre chunks consecutivos (default 50).
|
||||
confidence_threshold: umbral minimo de confidence para aceptar entidades
|
||||
candidatas antes de deduplicar (default 0.5).
|
||||
dedup_threshold: score minimo de similitud para mergear entidades (default 0.85).
|
||||
on_progress: callback opcional de progreso (message: str, pct: float 0-1).
|
||||
0-40%: extraccion de entidades, 40-80%: extraccion de relaciones,
|
||||
80-100%: deduplicacion.
|
||||
|
||||
Returns:
|
||||
ExtractionResult con entidades y relaciones deduplicadas y stats del proceso.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: si file_path no existe.
|
||||
ValueError: si entity_presets esta vacio.
|
||||
"""
|
||||
if not entity_presets:
|
||||
raise ValueError("entity_presets no puede estar vacio")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Archivo no encontrado: {file_path}")
|
||||
|
||||
def _progress(msg: str, pct: float) -> None:
|
||||
if on_progress is not None:
|
||||
try:
|
||||
on_progress(msg, pct)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
start_time = time.monotonic()
|
||||
stats = ExtractionStats()
|
||||
|
||||
# ── Paso 1: Extraer texto ──────────────────────────────────────────────────
|
||||
_progress("Extracting text from file...", 0.0)
|
||||
try:
|
||||
raw_text = extract_text_from_file(file_path)
|
||||
except Exception as exc:
|
||||
warnings.warn(f"extraction_pipeline: error al extraer texto: {exc}")
|
||||
raw_text = ""
|
||||
|
||||
# ── Paso 2: Preprocesar ────────────────────────────────────────────────────
|
||||
clean_text = preprocess_text(raw_text)
|
||||
stats.total_chars = len(clean_text)
|
||||
|
||||
# ── Paso 3: Dividir en chunks ──────────────────────────────────────────────
|
||||
chunks = split_text_into_chunks(clean_text, chunk_size=chunk_size, overlap=chunk_overlap)
|
||||
n = len(chunks)
|
||||
stats.total_chunks = n
|
||||
|
||||
if n == 0:
|
||||
stats.processing_time_seconds = time.monotonic() - start_time
|
||||
return ExtractionResult(entities=[], relations=[], stats=stats)
|
||||
|
||||
# ── Paso 4: Extraer entidades por chunk ────────────────────────────────────
|
||||
all_raw_entities: list[EntityCandidate] = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
_progress(f"Extracting entities from chunk {i + 1}/{n}", (i / n) * 0.4)
|
||||
try:
|
||||
candidates = extract_entities_llm(
|
||||
text=chunk,
|
||||
entity_schema=entity_presets,
|
||||
llm_chat_json=llm_chat_json,
|
||||
)
|
||||
except Exception as exc:
|
||||
warnings.warn(
|
||||
f"extraction_pipeline: error en extract_entities_llm chunk {i}: {exc}"
|
||||
)
|
||||
candidates = []
|
||||
|
||||
for candidate in candidates:
|
||||
# Anotar el chunk de origen
|
||||
if i not in candidate.source_chunk_indices:
|
||||
candidate.source_chunk_indices.append(i)
|
||||
all_raw_entities.append(candidate)
|
||||
|
||||
# ── Paso 5: Filtrar por confidence ─────────────────────────────────────────
|
||||
filtered_entities = [
|
||||
e for e in all_raw_entities if e.confidence >= confidence_threshold
|
||||
]
|
||||
stats.raw_entities_count = len(filtered_entities)
|
||||
|
||||
# Actualizar stats de tipos
|
||||
for ent in filtered_entities:
|
||||
stats.entity_types_found[ent.type_ref] = (
|
||||
stats.entity_types_found.get(ent.type_ref, 0) + 1
|
||||
)
|
||||
|
||||
# ── Paso 6: Deduplicar entidades ───────────────────────────────────────────
|
||||
_progress("Deduplicating entities...", 0.4)
|
||||
dedup_result = deduplicate_entities(filtered_entities, name_threshold=dedup_threshold)
|
||||
|
||||
stats.final_entities_count = dedup_result.total_after
|
||||
stats.entities_merged = dedup_result.total_before - dedup_result.total_after
|
||||
|
||||
final_entities = dedup_result.entities
|
||||
entity_id_map = dedup_result.name_to_id # nombre_original -> entity_id
|
||||
|
||||
# ── Paso 7: Extraer relaciones por chunk ───────────────────────────────────
|
||||
all_raw_relations = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
_progress(f"Extracting relations...", 0.4 + (i / n) * 0.4)
|
||||
|
||||
# Obtener entidades relevantes de este chunk
|
||||
chunk_entities = [
|
||||
e for e in final_entities if i in e.source_chunk_indices
|
||||
]
|
||||
# Si no hay entidades en este chunk especifico, usar todas
|
||||
if not chunk_entities:
|
||||
chunk_entities = final_entities
|
||||
|
||||
if len(chunk_entities) < 2:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_relations = extract_relations_llm(
|
||||
text=chunk,
|
||||
entities=chunk_entities,
|
||||
relation_types=relation_types,
|
||||
llm_chat_json=llm_chat_json,
|
||||
)
|
||||
except Exception as exc:
|
||||
warnings.warn(
|
||||
f"extraction_pipeline: error en extract_relations_llm chunk {i}: {exc}"
|
||||
)
|
||||
chunk_relations = []
|
||||
|
||||
for rel in chunk_relations:
|
||||
rel.source_chunk_index = i
|
||||
all_raw_relations.extend(chunk_relations)
|
||||
|
||||
stats.raw_relations_count = len(all_raw_relations)
|
||||
|
||||
# Actualizar stats de tipos de relacion
|
||||
for rel in all_raw_relations:
|
||||
stats.relation_types_found[rel.relation_type] = (
|
||||
stats.relation_types_found.get(rel.relation_type, 0) + 1
|
||||
)
|
||||
|
||||
# ── Paso 8: Deduplicar relaciones ──────────────────────────────────────────
|
||||
_progress("Deduplicating relations...", 0.8)
|
||||
final_relations = deduplicate_relations(all_raw_relations, entity_id_map)
|
||||
|
||||
stats.final_relations_count = len(final_relations)
|
||||
stats.relations_merged = stats.raw_relations_count - len(final_relations)
|
||||
stats.processing_time_seconds = time.monotonic() - start_time
|
||||
|
||||
_progress("Done", 1.0)
|
||||
|
||||
return ExtractionResult(
|
||||
entities=final_entities,
|
||||
relations=final_relations,
|
||||
stats=stats,
|
||||
)
|
||||
@@ -0,0 +1,227 @@
|
||||
"""Tests para extraction_pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
from python.functions.pipelines.extraction_pipeline import extraction_pipeline
|
||||
|
||||
|
||||
# ── LLM stubs ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def _llm_with_entities(messages: list[dict]) -> dict:
|
||||
"""LLM stub que retorna entidades fijas para el primer mensaje de extraccion."""
|
||||
system_content = messages[0]["content"] if messages else ""
|
||||
if "entity" in system_content.lower() or "entities" in system_content.lower():
|
||||
return {
|
||||
"entities": [
|
||||
{
|
||||
"name": "John Smith",
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"attributes": {"full_name": "John Smith", "nationality": "US"},
|
||||
"confidence": 0.95,
|
||||
},
|
||||
{
|
||||
"name": "evil-corp.com",
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"attributes": {"fqdn": "evil-corp.com"},
|
||||
"confidence": 0.88,
|
||||
},
|
||||
]
|
||||
}
|
||||
# Llamada de relaciones
|
||||
return {
|
||||
"relations": [
|
||||
{
|
||||
"from_name": "John Smith",
|
||||
"to_name": "evil-corp.com",
|
||||
"relation_type": "operates",
|
||||
"description": "John Smith operates evil-corp.com",
|
||||
"confidence": 0.8,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _llm_empty(messages: list[dict]) -> dict:
|
||||
"""LLM stub que retorna siempre resultado vacio."""
|
||||
system_content = messages[0]["content"] if messages else ""
|
||||
if "entit" in system_content.lower():
|
||||
return {"entities": []}
|
||||
return {"relations": []}
|
||||
|
||||
|
||||
ENTITY_PRESETS = [
|
||||
{
|
||||
"type_ref": "osint_person_go_cybersecurity",
|
||||
"label": "Person",
|
||||
"metadata_fields": ["full_name", "alias", "nationality"],
|
||||
},
|
||||
{
|
||||
"type_ref": "osint_domain_go_cybersecurity",
|
||||
"label": "Domain",
|
||||
"metadata_fields": ["fqdn", "registrar"],
|
||||
},
|
||||
]
|
||||
|
||||
RELATION_TYPES = ["operates", "owns", "funds", "communicates_with", "related_to"]
|
||||
|
||||
|
||||
# ── Tests ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_documento_con_entidades_y_relaciones():
|
||||
"""documento con entidades y relaciones retorna ExtractionResult completo"""
|
||||
text = (
|
||||
"John Smith, a US national, operates the domain evil-corp.com. "
|
||||
"He was identified as the main administrator of the infrastructure."
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
tmp_path = f.name
|
||||
|
||||
try:
|
||||
result = extraction_pipeline(
|
||||
file_path=tmp_path,
|
||||
entity_presets=ENTITY_PRESETS,
|
||||
relation_types=RELATION_TYPES,
|
||||
llm_chat_json=_llm_with_entities,
|
||||
chunk_size=500,
|
||||
chunk_overlap=50,
|
||||
confidence_threshold=0.5,
|
||||
dedup_threshold=0.85,
|
||||
)
|
||||
assert result is not None
|
||||
assert len(result.entities) >= 1
|
||||
assert result.stats.total_chunks >= 1
|
||||
assert result.stats.total_chars > 0
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def test_documento_vacio():
|
||||
"""documento vacio retorna ExtractionResult con listas vacias"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
|
||||
f.write("")
|
||||
tmp_path = f.name
|
||||
|
||||
try:
|
||||
result = extraction_pipeline(
|
||||
file_path=tmp_path,
|
||||
entity_presets=ENTITY_PRESETS,
|
||||
relation_types=RELATION_TYPES,
|
||||
llm_chat_json=_llm_empty,
|
||||
)
|
||||
assert result is not None
|
||||
assert result.entities == []
|
||||
assert result.relations == []
|
||||
assert result.stats.total_chunks == 0
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def test_documento_sin_entidades_detectables():
|
||||
"""documento sin entidades detectables retorna listas vacias"""
|
||||
text = "The weather is nice today. The sun shines brightly over the mountains."
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
tmp_path = f.name
|
||||
|
||||
try:
|
||||
result = extraction_pipeline(
|
||||
file_path=tmp_path,
|
||||
entity_presets=ENTITY_PRESETS,
|
||||
relation_types=RELATION_TYPES,
|
||||
llm_chat_json=_llm_empty,
|
||||
confidence_threshold=0.5,
|
||||
)
|
||||
assert result is not None
|
||||
assert result.entities == []
|
||||
assert result.relations == []
|
||||
assert result.stats.raw_entities_count == 0
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def test_archivo_no_encontrado_lanza_filenotfounderror():
|
||||
"""archivo no encontrado lanza FileNotFoundError"""
|
||||
import pytest
|
||||
with pytest.raises(FileNotFoundError):
|
||||
extraction_pipeline(
|
||||
file_path="/tmp/no_existe_para_test_extraccion_pipeline.txt",
|
||||
entity_presets=ENTITY_PRESETS,
|
||||
relation_types=RELATION_TYPES,
|
||||
llm_chat_json=_llm_empty,
|
||||
)
|
||||
|
||||
|
||||
def test_entity_presets_vacio_lanza_valueerror():
|
||||
"""entity presets vacio lanza ValueError"""
|
||||
import pytest
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
|
||||
f.write("some text")
|
||||
tmp_path = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError):
|
||||
extraction_pipeline(
|
||||
file_path=tmp_path,
|
||||
entity_presets=[],
|
||||
relation_types=RELATION_TYPES,
|
||||
llm_chat_json=_llm_empty,
|
||||
)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def test_progress_callback_se_invoca():
|
||||
"""progress callback se invoca durante la ejecucion"""
|
||||
calls: list[tuple[str, float]] = []
|
||||
|
||||
def _on_progress(msg: str, pct: float) -> None:
|
||||
calls.append((msg, pct))
|
||||
|
||||
text = "John Smith operates evil-corp.com."
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
tmp_path = f.name
|
||||
|
||||
try:
|
||||
extraction_pipeline(
|
||||
file_path=tmp_path,
|
||||
entity_presets=ENTITY_PRESETS,
|
||||
relation_types=RELATION_TYPES,
|
||||
llm_chat_json=_llm_with_entities,
|
||||
on_progress=_on_progress,
|
||||
)
|
||||
assert len(calls) > 0
|
||||
messages = [c[0] for c in calls]
|
||||
assert any("Extracting" in m or "Done" in m or "Dedup" in m for m in messages)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def test_stats_se_rellenan_correctamente():
|
||||
"""stats se rellenan correctamente con conteos y tiempo"""
|
||||
text = "John Smith, a US national, operates the domain evil-corp.com."
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
tmp_path = f.name
|
||||
|
||||
try:
|
||||
result = extraction_pipeline(
|
||||
file_path=tmp_path,
|
||||
entity_presets=ENTITY_PRESETS,
|
||||
relation_types=RELATION_TYPES,
|
||||
llm_chat_json=_llm_with_entities,
|
||||
)
|
||||
assert result.stats.total_chars > 0
|
||||
assert result.stats.total_chunks >= 1
|
||||
assert result.stats.processing_time_seconds >= 0.0
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
@@ -0,0 +1,74 @@
|
||||
---
|
||||
name: monte_carlo_market
|
||||
kind: pipeline
|
||||
lang: py
|
||||
domain: pipelines
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def monte_carlo_market(n_simulations: int, base_params: dict, vary_params: dict, seed_start: int) -> list[dict]"
|
||||
description: "Ejecuta N simulaciones de mercado con parámetros variados uniformemente. Cada simulación usa run_market_sim y retorna métricas resumen: spreads, trades por tick, volatilidad realizada y PnL total de makers."
|
||||
tags: [montecarlo, simulation, market, launcher, finance, microstructure]
|
||||
uses_functions:
|
||||
- run_market_sim_py_pipelines
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [numpy]
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/pipelines/monte_carlo_market.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```bash
|
||||
# 10 simulaciones con sigma y gamma variables
|
||||
python python/functions/pipelines/monte_carlo_market.py -n 10
|
||||
```
|
||||
|
||||
```python
|
||||
from monte_carlo_market import monte_carlo_market
|
||||
|
||||
results = monte_carlo_market(
|
||||
n_simulations=50,
|
||||
base_params={'n_ticks': 300, 'n_makers': 3},
|
||||
vary_params={
|
||||
'sigma': (0.005, 0.05),
|
||||
'gamma': (0.01, 1.0),
|
||||
'hawkes_alpha': (0.1, 0.9),
|
||||
},
|
||||
seed_start=42,
|
||||
)
|
||||
# Cada resultado tiene: sim_id, seed, sigma, gamma, hawkes_alpha,
|
||||
# total_trades, mean_spread, std_spread, mean_trades_per_tick,
|
||||
# price_return, maker_total_pnl, realized_vol
|
||||
```
|
||||
|
||||
## Flujo
|
||||
|
||||
1. Para cada simulación i en range(n_simulations):
|
||||
- Tomar `base_params` + `seed = seed_start + i`
|
||||
- Samplear `vary_params` uniformemente con rng derivado de `seed_start`
|
||||
- Llamar `run_market_sim(**params)`
|
||||
- Calcular métricas resumen sobre el resultado
|
||||
2. Reportar progreso cada 10% de simulaciones
|
||||
3. Retornar lista de dicts con params usados + métricas
|
||||
|
||||
## Métricas por simulación
|
||||
|
||||
| Campo | Descripción |
|
||||
|---|---|
|
||||
| `total_trades` | Número total de trades en la simulación |
|
||||
| `mean_spread` | Spread bid-ask medio |
|
||||
| `std_spread` | Desviación estándar del spread |
|
||||
| `mean_trades_per_tick` | Intensidad media del flujo de órdenes |
|
||||
| `price_return` | Retorno % del precio fundamental |
|
||||
| `maker_total_pnl` | PnL agregado de todos los makers |
|
||||
| `realized_vol` | Volatilidad realizada de los trade prices (si hay trades) |
|
||||
|
||||
## Notas
|
||||
|
||||
`vary_params` acepta cualquier parámetro válido de `run_market_sim` como clave, con valor `(min, max)`.
|
||||
Los parámetros en `base_params` tienen precedencia sobre los defaults pero son sobreescritos por `vary_params`.
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Ejecuta N simulaciones de mercado con parámetros variables para análisis Monte Carlo."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def monte_carlo_market(
|
||||
n_simulations: int = 100,
|
||||
base_params: dict | None = None,
|
||||
vary_params: dict | None = None,
|
||||
seed_start: int = 0,
|
||||
) -> list[dict]:
|
||||
"""Ejecuta N simulaciones variando parámetros.
|
||||
|
||||
base_params: parámetros fijos para run_market_sim
|
||||
vary_params: dict de param_name -> (min, max) para variar uniformemente
|
||||
|
||||
Retorna lista de dicts, cada uno con los params usados + métricas resumen.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.join(os.environ.get('FN_REGISTRY_ROOT', os.path.expanduser('~/fn_registry')), 'python', 'functions'))
|
||||
sys.path.insert(0, os.path.join(os.environ.get('FN_REGISTRY_ROOT', os.path.expanduser('~/fn_registry')), 'python', 'functions', 'pipelines'))
|
||||
from run_market_sim import run_market_sim
|
||||
|
||||
if base_params is None:
|
||||
base_params = {}
|
||||
if vary_params is None:
|
||||
vary_params = {}
|
||||
|
||||
rng = np.random.default_rng(seed_start)
|
||||
results = []
|
||||
|
||||
for i in range(n_simulations):
|
||||
params = dict(base_params)
|
||||
params['seed'] = seed_start + i
|
||||
|
||||
# Variar parámetros
|
||||
varied = {}
|
||||
for pname, (pmin, pmax) in vary_params.items():
|
||||
val = rng.uniform(pmin, pmax)
|
||||
params[pname] = round(val, 6)
|
||||
varied[pname] = params[pname]
|
||||
|
||||
sim = run_market_sim(**params)
|
||||
|
||||
# Métricas resumen
|
||||
spreads = sim['spreads']
|
||||
trade_prices = sim['trade_prices']
|
||||
n_per_tick = sim['n_trades_per_tick']
|
||||
|
||||
result = {
|
||||
'sim_id': i,
|
||||
'seed': params['seed'],
|
||||
**varied,
|
||||
'total_trades': sim['total_trades'],
|
||||
'mean_spread': round(np.mean(spreads), 6) if spreads else 0,
|
||||
'std_spread': round(np.std(spreads), 6) if spreads else 0,
|
||||
'mean_trades_per_tick': round(np.mean(n_per_tick), 2),
|
||||
'price_return': round((sim['fundamental_prices'][-1] / sim['fundamental_prices'][0] - 1) * 100, 4),
|
||||
'maker_total_pnl': round(sum(sim['maker_pnls']), 2),
|
||||
}
|
||||
|
||||
if trade_prices:
|
||||
tp = np.array(trade_prices)
|
||||
log_ret = np.diff(np.log(tp[tp > 0]))
|
||||
if len(log_ret) > 1:
|
||||
result['realized_vol'] = round(float(np.std(log_ret)), 6)
|
||||
|
||||
results.append(result)
|
||||
|
||||
if (i + 1) % max(1, n_simulations // 10) == 0:
|
||||
print(f' {i+1}/{n_simulations} simulaciones completadas')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-n', type=int, default=10)
|
||||
args = parser.parse_args()
|
||||
|
||||
results = monte_carlo_market(
|
||||
n_simulations=args.n,
|
||||
base_params={'n_ticks': 200},
|
||||
vary_params={'sigma': (0.005, 0.05), 'gamma': (0.01, 1.0)},
|
||||
)
|
||||
print(json.dumps(results[-1], indent=2))
|
||||
print(f'\n{len(results)} simulaciones completadas')
|
||||
@@ -0,0 +1,65 @@
|
||||
---
|
||||
name: run_market_sim
|
||||
kind: pipeline
|
||||
lang: py
|
||||
domain: pipelines
|
||||
version: "1.0.0"
|
||||
purity: impure
|
||||
signature: "def run_market_sim(initial_price: float, n_ticks: int, sigma: float, mu: float, jump_intensity: float, jump_size_std: float, n_makers: int, maker_spread: float, gamma: float, maker_levels: int, maker_qty: float, n_takers_lambda: float, taker_size_alpha: float, taker_size_min: float, taker_size_max: float, hawkes_alpha: float, hawkes_beta: float, seed: int) -> dict"
|
||||
description: "Simula un mercado completo con matching engine FIFO. Makers usan Avellaneda-Stoikov, takers llegan según proceso Hawkes con tamaños power-law. Retorna trades, spreads, midprices y PnL de makers."
|
||||
tags: [simulation, market, matching-engine, montecarlo, launcher, finance, microstructure]
|
||||
uses_functions:
|
||||
- generate_gbm_prices_py_finance
|
||||
- avellaneda_stoikov_quotes_py_finance
|
||||
uses_types: []
|
||||
returns: []
|
||||
returns_optional: false
|
||||
error_type: "error_go_core"
|
||||
imports: [numpy]
|
||||
tested: false
|
||||
tests: []
|
||||
test_file_path: ""
|
||||
file_path: "python/functions/pipelines/run_market_sim.py"
|
||||
---
|
||||
|
||||
## Ejemplo
|
||||
|
||||
```bash
|
||||
python python/functions/pipelines/run_market_sim.py
|
||||
# {
|
||||
# "total_trades": 1234,
|
||||
# "mean_spread": 0.4821,
|
||||
# "maker_pnls": [12.5, -3.2, 8.1, 5.6, -1.4]
|
||||
# }
|
||||
```
|
||||
|
||||
```python
|
||||
from run_market_sim import run_market_sim
|
||||
|
||||
result = run_market_sim(
|
||||
initial_price=100.0,
|
||||
n_ticks=200,
|
||||
sigma=0.01,
|
||||
n_makers=3,
|
||||
seed=0,
|
||||
)
|
||||
print(result['total_trades'])
|
||||
print(result['maker_pnls'])
|
||||
```
|
||||
|
||||
## Flujo
|
||||
|
||||
1. `generate_gbm_prices` — genera la serie de precios fundamentales con GBM + saltos
|
||||
2. Loop por ticks:
|
||||
- Cada maker coloca quotes via `avellaneda_stoikov_quotes`
|
||||
- Takers llegan según Poisson con intensidad modulada por excitación Hawkes
|
||||
- Tamaños de taker siguen distribución Pareto (power-law)
|
||||
- Matching FIFO sobre el order book simplificado
|
||||
- Excitación Hawkes decae exponencialmente entre ticks
|
||||
3. Mark-to-market final de inventarios de makers
|
||||
|
||||
## Notas
|
||||
|
||||
Los parámetros Hawkes (`hawkes_alpha`, `hawkes_beta`) controlan la autocorrelación del flujo de órdenes.
|
||||
`branching_ratio = hawkes_alpha / hawkes_beta`; si > 1, el proceso es explosivo.
|
||||
El matching es simplificado: no hay cancelaciones intra-tick, el book se reconstituye en cada tick.
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Ejecuta una simulación de mercado completa con matching engine FIFO."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def run_market_sim(
|
||||
initial_price: float = 100.0,
|
||||
n_ticks: int = 500,
|
||||
sigma: float = 0.02,
|
||||
mu: float = 0.0,
|
||||
jump_intensity: float = 0.02,
|
||||
jump_size_std: float = 0.05,
|
||||
n_makers: int = 5,
|
||||
maker_spread: float = 0.5,
|
||||
gamma: float = 0.1,
|
||||
maker_levels: int = 3,
|
||||
maker_qty: float = 10.0,
|
||||
n_takers_lambda: float = 2.0,
|
||||
taker_size_alpha: float = 2.0,
|
||||
taker_size_min: float = 1.0,
|
||||
taker_size_max: float = 100.0,
|
||||
hawkes_alpha: float = 0.5,
|
||||
hawkes_beta: float = 1.0,
|
||||
seed: int = 42,
|
||||
) -> dict:
|
||||
"""Simula un mercado con makers (Avellaneda-Stoikov) y takers (Hawkes + power-law).
|
||||
|
||||
Retorna dict con:
|
||||
- trade_prices, trade_times, trade_sizes: listas de trades
|
||||
- spreads, midprices: series por tick
|
||||
- n_trades_per_tick: arrivals por tick
|
||||
- maker_pnls: PnL final de cada maker
|
||||
- total_trades: conteo total
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Importar funciones del registry
|
||||
sys.path.insert(0, os.path.join(os.environ.get('FN_REGISTRY_ROOT', os.path.expanduser('~/fn_registry')), 'python', 'functions'))
|
||||
from finance.finance import generate_gbm_prices, avellaneda_stoikov_quotes
|
||||
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
# Generar precios fundamentales
|
||||
fund_prices = generate_gbm_prices(initial_price, n_ticks, sigma, mu, jump_intensity, jump_size_std, seed)
|
||||
|
||||
# Order book simplificado: listas de (price, qty, maker_idx)
|
||||
# Matching inline para no depender del notebook
|
||||
trade_prices, trade_times, trade_sizes = [], [], []
|
||||
spreads, midprices = [], []
|
||||
n_trades_per_tick = []
|
||||
maker_inventories = [0.0] * n_makers
|
||||
maker_pnls = [0.0] * n_makers
|
||||
hawkes_excitation = 0.0
|
||||
|
||||
for t in range(n_ticks):
|
||||
mid = fund_prices[t]
|
||||
|
||||
# Makers place orders
|
||||
all_bids = [] # (price, qty, maker_idx)
|
||||
all_asks = []
|
||||
for m in range(n_makers):
|
||||
noise = rng.uniform(-0.05, 0.05)
|
||||
quotes = avellaneda_stoikov_quotes(
|
||||
mid + noise, maker_inventories[m], gamma, sigma, maker_spread, maker_levels, maker_qty
|
||||
)
|
||||
for q in quotes:
|
||||
if q['side'] == 'buy':
|
||||
all_bids.append((q['price'], q['qty'], m))
|
||||
else:
|
||||
all_asks.append((q['price'], q['qty'], m))
|
||||
|
||||
all_bids.sort(key=lambda x: -x[0]) # best bid first
|
||||
all_asks.sort(key=lambda x: x[0]) # best ask first
|
||||
|
||||
# Record book state
|
||||
if all_bids and all_asks:
|
||||
spreads.append(all_asks[0][0] - all_bids[0][0])
|
||||
midprices.append((all_bids[0][0] + all_asks[0][0]) / 2)
|
||||
else:
|
||||
spreads.append(0.0)
|
||||
midprices.append(mid)
|
||||
|
||||
# Takers arrive (Hawkes)
|
||||
lam = max(0.1, n_takers_lambda + hawkes_excitation)
|
||||
n_takers = rng.poisson(lam)
|
||||
tick_trades = 0
|
||||
|
||||
for _ in range(n_takers):
|
||||
side = 'buy' if rng.random() < 0.5 else 'sell'
|
||||
raw_size = (rng.pareto(taker_size_alpha) + 1) * taker_size_min
|
||||
qty_remaining = min(round(raw_size, 1), taker_size_max)
|
||||
|
||||
book = list(all_asks) if side == 'buy' else list(all_bids)
|
||||
|
||||
for i, (price, available, maker_idx) in enumerate(book):
|
||||
if qty_remaining <= 0:
|
||||
break
|
||||
fill = min(qty_remaining, available)
|
||||
trade_prices.append(price)
|
||||
trade_times.append(t)
|
||||
trade_sizes.append(fill)
|
||||
tick_trades += 1
|
||||
qty_remaining -= fill
|
||||
|
||||
if side == 'buy':
|
||||
maker_inventories[maker_idx] -= fill
|
||||
maker_pnls[maker_idx] += price * fill
|
||||
else:
|
||||
maker_inventories[maker_idx] += fill
|
||||
maker_pnls[maker_idx] -= price * fill
|
||||
|
||||
book[i] = (price, available - fill, maker_idx)
|
||||
|
||||
if side == 'buy':
|
||||
all_asks = [(p, q, m) for p, q, m in book if q > 0]
|
||||
else:
|
||||
all_bids = [(p, q, m) for p, q, m in book if q > 0]
|
||||
|
||||
hawkes_excitation *= np.exp(-hawkes_beta)
|
||||
hawkes_excitation += hawkes_alpha * tick_trades
|
||||
n_trades_per_tick.append(tick_trades)
|
||||
|
||||
# Mark to market
|
||||
final_price = fund_prices[-1]
|
||||
for m in range(n_makers):
|
||||
maker_pnls[m] += maker_inventories[m] * final_price
|
||||
|
||||
return {
|
||||
'trade_prices': trade_prices,
|
||||
'trade_times': trade_times,
|
||||
'trade_sizes': trade_sizes,
|
||||
'spreads': spreads,
|
||||
'midprices': midprices,
|
||||
'n_trades_per_tick': n_trades_per_tick,
|
||||
'fundamental_prices': fund_prices,
|
||||
'maker_pnls': [round(p, 2) for p in maker_pnls],
|
||||
'total_trades': len(trade_prices),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
result = run_market_sim()
|
||||
print(json.dumps({
|
||||
'total_trades': result['total_trades'],
|
||||
'mean_spread': round(sum(result['spreads']) / len(result['spreads']), 4),
|
||||
'maker_pnls': result['maker_pnls'],
|
||||
}, indent=2))
|
||||
Reference in New Issue
Block a user