fa5bcca155
- glirel_load_model: cache por (model_name, device); device='auto' resuelve via torch - extract_relations_glirel: tokeniza por whitespace, mapea spans char->token, llama predict_relations y devuelve RelationCandidate; fallback text.find si la entidad llega sin offsets; max_pairs=N -> top-N por score - pyproject.toml: glirel en extra nlp Closes #0039 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
"""Carga (y cachea) un modelo GLiREL en el device deseado."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
# Cache global: (model_name, device) -> modelo cargado.
|
|
_MODEL_CACHE: dict[tuple[str, str], Any] = {}
|
|
|
|
|
|
def _resolve_device(device: str) -> str:
|
|
"""Resuelve `device='auto'` a `cuda` o `cpu` segun disponibilidad."""
|
|
if device != "auto":
|
|
return device
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
return "cpu"
|
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def glirel_load_model(
|
|
model_name: str = "jackboyla/glirel-large-v0",
|
|
device: str = "auto",
|
|
) -> Any:
|
|
"""Carga un modelo GLiREL con cache por (model_name, device).
|
|
|
|
La primera llamada descarga el modelo desde HuggingFace (~500 MB para
|
|
`glirel-large-v0`). Llamadas sucesivas con los mismos parametros
|
|
devuelven la instancia cacheada.
|
|
|
|
Args:
|
|
model_name: ID del modelo en HuggingFace Hub.
|
|
device: 'auto' usa CUDA si esta disponible, o 'cpu'/'cuda'/'cuda:N'
|
|
de forma explicita.
|
|
|
|
Returns:
|
|
Instancia del modelo GLiREL lista para `predict_relations`.
|
|
|
|
Raises:
|
|
ImportError: si la dependencia `glirel` no esta instalada.
|
|
Solucion: `uv pip install glirel` o instalar el extra `nlp`
|
|
del proyecto (`uv pip install -e '.[nlp]'`).
|
|
"""
|
|
resolved_device = _resolve_device(device)
|
|
cache_key = (model_name, resolved_device)
|
|
cached = _MODEL_CACHE.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
try:
|
|
from glirel import GLiREL
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"glirel no esta instalado. Instalalo con "
|
|
"`uv pip install glirel` o `uv pip install -e '.[nlp]'`."
|
|
) from exc
|
|
|
|
model = GLiREL.from_pretrained(model_name)
|
|
if hasattr(model, "to"):
|
|
model.to(resolved_device)
|
|
_MODEL_CACHE[cache_key] = model
|
|
return model
|