"""Carga (y cachea) el modelo mREBEL para extraccion de relaciones multilingue.""" from __future__ import annotations from typing import Any # Cache global: (model_name, src_lang) -> (tokenizer, model) _MODEL_CACHE: dict[tuple[str, str], tuple[Any, Any]] = {} def mrebel_load_model( model_name: str = "Babelscape/mrebel-large", src_lang: str = "es_XX", tgt_lang: str = "tp_XX", ) -> tuple[Any, Any]: """Loads (and caches) the mREBEL tokenizer and model. mREBEL is a multilingual seq2seq model (mBART-based, ~600M params, ~2.4 GB) for relation extraction. It supports 30+ languages via language codes (``src_lang``). LICENSE NOTICE: Babelscape/mrebel-large is licensed under CC BY-NC-SA 4.0 (Creative Commons Non-Commercial Share-Alike). Do NOT use in commercial products without replacing this model with a commercially-licensed alternative (e.g. Babelscape/rebel-large which is Apache 2.0 but English-only). The first call downloads the model from HuggingFace Hub (~2.4 GB). Subsequent calls with the same ``(model_name, src_lang)`` return the cached instance without re-loading. Args: model_name: HuggingFace Hub model ID. Default is the large variant. src_lang: Source language code for the mBART tokenizer, e.g. ``"es_XX"`` (Spanish), ``"en_XX"`` (English), ``"fr_XX"`` (French). tgt_lang: Target language token for the decoder (always ``"tp_XX"`` for the triplet format — only change if using a custom checkpoint). Returns: Tuple ``(tokenizer, model)`` both ready for inference with ``model.generate(...)`` and ``tokenizer.decode(...)``. Raises: ImportError: if ``transformers`` is not installed. OSError: if the model cannot be downloaded or loaded from disk. """ cache_key = (model_name, src_lang) cached = _MODEL_CACHE.get(cache_key) if cached is not None: return cached try: from transformers import AutoModelForSeq2SeqLM, AutoTokenizer except ImportError as exc: raise ImportError( "transformers no esta instalado. Instalalo con " "`uv pip install transformers` o `uv pip install -e '.[nlp]'`." ) from exc tokenizer = AutoTokenizer.from_pretrained( model_name, src_lang=src_lang, tgt_lang=tgt_lang, ) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) model.eval() _MODEL_CACHE[cache_key] = (tokenizer, model) return tokenizer, model