falta por mejorar modelos de vision

This commit is contained in:
2025-11-28 23:01:32 +01:00
parent 7ca6ae3dd4
commit b68a4ec43b
5 changed files with 357 additions and 17 deletions
+180 -15
View File
@@ -1,9 +1,10 @@
import sys
import time
import uuid
from typing import List, Optional
from typing import List, Optional, Union
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
@@ -12,21 +13,63 @@ import uvicorn
from llama_cpp import Llama
def str_to_bool(value: str) -> bool:
"""Parsea valores booleanos desde la CLI."""
if isinstance(value, bool):
return value
lowered = value.strip().lower()
if lowered in {"true", "1", "yes", "y", "t"}:
return True
if lowered in {"false", "0", "no", "n", "f"}:
return False
raise ValueError(f"Valor booleano invalido: {value}")
# Modelos de datos para la API compatible con OpenAI
class ChatCompletionMessage(BaseModel):
role: str = Field(..., description="El rol del mensaje: 'system', 'user', o 'assistant'")
content: str = Field(..., description="El contenido del mensaje")
content: Union[str, List["MessageContentPart"]] = Field(..., description="Contenido del mensaje (texto o partes multimodales)")
class ChatCompletionRequest(BaseModel):
model: str = Field(default="llama", description="Nombre del modelo")
messages: List[ChatCompletionMessage] = Field(..., description="Lista de mensajes")
max_tokens: Optional[int] = Field(default=2048, description="Máximo número de tokens a generar")
temperature: Optional[float] = Field(default=0.7, description="Temperatura para el muestreo")
top_p: Optional[float] = Field(default=0.9, description="Top-p para el muestreo")
max_tokens: Optional[int] = Field(default=None, description="Máximo número de tokens a generar")
temperature: Optional[float] = Field(default=None, description="Temperatura para el muestreo")
top_p: Optional[float] = Field(default=None, description="Top-p para el muestreo")
top_k: Optional[int] = Field(default=None, description="Top-k para el muestreo")
repeat_penalty: Optional[float] = Field(default=None, description="Penalización por repetición")
min_p: Optional[float] = Field(default=None, description="Umbral mínimo de probabilidad (min_p sampling)")
stream: Optional[bool] = Field(default=False, description="Si devolver respuesta en streaming")
class MessageContentPart(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[str] = None
# Resolver forward refs
ChatCompletionMessage.model_rebuild()
def normalize_image_input(raw: str) -> str:
"""Normaliza imágenes a data URI; si llega base64 simple, se envuelve como data:image/png;base64."""
if raw.startswith("data:"):
return raw
if "://" in raw:
# En este entorno no se hace fetch remoto
raise HTTPException(status_code=400, detail="Solo se aceptan imágenes en data URI o base64 inline")
# Asumimos base64 puro
return f"data:image/png;base64,{raw}"
def likely_vision_model(model_path: str) -> bool:
"""Heurística simple para detectar modelos multimodales."""
lowered = Path(model_path).name.lower()
return any(key in lowered for key in ("vl", "vision", "llava", "mmproj"))
class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionMessage
@@ -64,9 +107,13 @@ class ModelsResponse(BaseModel):
class LlamaAPI:
"""Clase para manejar la API de Llama"""
def __init__(self, model_path: str, **kwargs):
def __init__(self, model_path: str, mmproj_path: Optional[str] = None, **kwargs):
if not Path(model_path).exists():
raise FileNotFoundError(f"El archivo del modelo no existe: {model_path}")
if mmproj_path:
if not Path(mmproj_path).exists():
raise FileNotFoundError(f"El archivo mmproj no existe: {mmproj_path}")
kwargs["mmproj_path"] = mmproj_path
print(f"Cargando modelo desde: {model_path}")
@@ -79,6 +126,13 @@ class LlamaAPI:
"n_gpu_layers": -1, # Usar todas las capas en GPU (-1 = auto)
"main_gpu": 0, # GPU principal a usar
"split_mode": 1, # Modo de división entre GPUs
"rope_freq_base": 10000.0,
"rope_freq_scale": 1.0,
"offload_kv": True,
"use_mmap": True,
"use_mlock": False,
"seed": 0,
"flash_attn": False,
}
# Actualizar con parámetros personalizados
@@ -90,6 +144,12 @@ class LlamaAPI:
print(f"GPU layers: {default_params.get('n_gpu_layers', 'N/A')}")
print(f"Main GPU: {default_params.get('main_gpu', 'N/A')}")
except Exception as e:
vision_hint = likely_vision_model(model_path)
if vision_hint and not mmproj_path:
raise RuntimeError(
"Fallo al cargar modelo multimodal. Se requiere un archivo mmproj. "
"Define --mmproj-path o coloca un .mmproj en model_choice/."
) from e
print(f"Error al cargar el modelo: {e}")
print("Nota: Si tienes problemas con CUDA, intenta instalar: pip install llama-cpp-python[cublas]")
raise
@@ -97,11 +157,15 @@ class LlamaAPI:
def generate_chat_completion(self, messages: List[ChatCompletionMessage],
max_tokens: int = 2048,
temperature: float = 0.7,
top_p: float = 0.9) -> str:
top_p: float = 0.9,
top_k: int = 40,
repeat_penalty: float = 1.1,
min_p: float = 0.05) -> str:
"""Genera una respuesta de chat usando el modelo"""
# Formatear los mensajes en un prompt
prompt = self._format_messages_to_prompt(messages)
images = self._collect_images(messages)
try:
# Generar respuesta usando llama.cpp
@@ -110,7 +174,11 @@ class LlamaAPI:
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
min_p=min_p,
echo=False,
images=images if images else None,
stop=["</s>", "<|im_end|>", "<|endoftext|>"]
)
@@ -126,19 +194,52 @@ class LlamaAPI:
prompt_parts = []
for message in messages:
text_content = self._extract_text_from_content(message.content)
if message.role == "system":
prompt_parts.append(f"Sistema: {message.content}")
prompt_parts.append(f"Sistema: {text_content}")
elif message.role == "user":
prompt_parts.append(f"Usuario: {message.content}")
prompt_parts.append(f"Usuario: {text_content}")
elif message.role == "assistant":
prompt_parts.append(f"Asistente: {message.content}")
prompt_parts.append(f"Asistente: {text_content}")
prompt_parts.append("Asistente:")
return "\n".join(prompt_parts)
def _extract_text_from_content(self, content: Union[str, List["MessageContentPart"]]) -> str:
"""Extrae el texto de un contenido que puede ser string o lista multimodal."""
if isinstance(content, str):
return content
texts = []
for part in content:
if part.type == "text" and part.text:
texts.append(part.text)
elif part.type == "image_url":
texts.append("[imagen]")
return " ".join(texts).strip()
def _collect_images(self, messages: List[ChatCompletionMessage]) -> List[str]:
"""Recolecta imágenes (data URI) de los mensajes."""
images: List[str] = []
for message in messages:
if isinstance(message.content, str):
continue
for part in message.content:
if part.type == "image_url" and part.image_url:
images.append(normalize_image_input(part.image_url))
return images
# Instancia global de la API
llama_api = None
generation_defaults = {
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"repeat_penalty": 1.1,
"min_p": 0.05,
}
# Crear la aplicación FastAPI
app = FastAPI(
@@ -189,11 +290,21 @@ async def create_chat_completion(request: ChatCompletionRequest):
try:
# Generar respuesta
temperature = request.temperature if request.temperature is not None else generation_defaults["temperature"]
top_p = request.top_p if request.top_p is not None else generation_defaults["top_p"]
top_k = request.top_k if request.top_k is not None else generation_defaults["top_k"]
repeat_penalty = request.repeat_penalty if request.repeat_penalty is not None else generation_defaults["repeat_penalty"]
min_p = request.min_p if request.min_p is not None else generation_defaults["min_p"]
max_tokens = request.max_tokens if request.max_tokens is not None else generation_defaults["max_tokens"]
response_text = llama_api.generate_chat_completion(
messages=request.messages,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature or 0.7,
top_p=request.top_p or 0.9
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
min_p=min_p
)
# Crear respuesta en formato OpenAI
@@ -241,6 +352,8 @@ def main():
help="Tamaño del contexto (default: 4096)")
parser.add_argument("--n-batch", type=int, default=512,
help="Tamaño del batch (default: 512)")
parser.add_argument("--eval-batch-size", type=int, default=None,
help="Tamaño del batch de evaluación (sobrescribe n-batch si se setea)")
parser.add_argument("--n-threads", type=int, default=None,
help="Número de threads (default: auto)")
parser.add_argument("--n-gpu-layers", type=int, default=-1,
@@ -249,6 +362,35 @@ def main():
help="GPU principal a usar (default: 0)")
parser.add_argument("--split-mode", type=int, default=1,
help="Modo de división entre GPUs (default: 1)")
parser.add_argument("--rope-freq-base", type=float, default=10000.0,
help="Base de frecuencia ROPE (default: 10000)")
parser.add_argument("--rope-freq-scale", type=float, default=1.0,
help="Escala de frecuencia ROPE (default: 1.0)")
parser.add_argument("--offload-kv-cache", type=str_to_bool, default=True,
help="Offload del KV cache a GPU (default: true)")
parser.add_argument("--keep-model-in-memory", type=str_to_bool, default=False,
help="Usa mlock para mantener el modelo en RAM (default: false)")
parser.add_argument("--try-mmap", type=str_to_bool, default=True,
help="Usar mmap para mapear el modelo (default: true)")
parser.add_argument("--seed", type=int, default=0,
help="Seed para la generación (0 = aleatorio)")
parser.add_argument("--flash-attn", type=str_to_bool, default=False,
help="Habilitar Flash Attention si está disponible")
# Defaults de generación
parser.add_argument("--default-max-tokens", type=int, default=2048,
help="Límite de tokens de respuesta por defecto")
parser.add_argument("--default-temperature", type=float, default=0.7,
help="Temperatura por defecto")
parser.add_argument("--default-top-k", type=int, default=40,
help="Top-k por defecto")
parser.add_argument("--default-repeat-penalty", type=float, default=1.1,
help="Penalización de repetición por defecto")
parser.add_argument("--default-min-p", type=float, default=0.05,
help="Min-p por defecto")
parser.add_argument("--default-top-p", type=float, default=0.9,
help="Top-p por defecto")
parser.add_argument("--mmproj-path", type=str, default=None,
help="Ruta al proyector multimodal (mmproj) si el modelo lo requiere")
args = parser.parse_args()
@@ -260,25 +402,48 @@ def main():
# Configurar parámetros del modelo
model_params = {
"n_ctx": args.n_ctx,
"n_batch": args.n_batch,
"n_batch": args.eval_batch_size or args.n_batch,
"n_gpu_layers": args.n_gpu_layers,
"main_gpu": args.main_gpu,
"split_mode": args.split_mode,
"rope_freq_base": args.rope_freq_base,
"rope_freq_scale": args.rope_freq_scale,
"offload_kv": args.offload_kv_cache,
"use_mmap": args.try_mmap,
"use_mlock": args.keep_model_in_memory,
"seed": args.seed,
"flash_attn": args.flash_attn,
}
if args.n_threads:
model_params["n_threads"] = args.n_threads
# Defaults de generación para usar si la request no los especifica
global generation_defaults
generation_defaults = {
"max_tokens": args.default_max_tokens,
"temperature": args.default_temperature,
"top_p": args.default_top_p,
"top_k": args.default_top_k,
"repeat_penalty": args.default_repeat_penalty,
"min_p": args.default_min_p,
}
try:
# Inicializar la API de Llama
global llama_api
llama_api = LlamaAPI(args.model_path, **model_params)
llama_api = LlamaAPI(args.model_path, mmproj_path=args.mmproj_path, **model_params)
print(f"\n🚀 Servidor iniciado en http://{args.host}:{args.port}")
print(f"📚 Documentación disponible en http://{args.host}:{args.port}/docs")
print(f"🤖 Modelo cargado desde: {args.model_path}")
print(f"🎮 GPU layers: {args.n_gpu_layers} (usa -1 para todas las capas)")
print(f"🎯 Main GPU: {args.main_gpu}")
print(f"🧠 ROPE base/scale: {args.rope_freq_base} / {args.rope_freq_scale}")
print(f"🧮 Offload KV cache: {args.offload_kv_cache} | mmap: {args.try_mmap} | mlock: {args.keep_model_in_memory}")
print(f"⚡ Flash Attn: {args.flash_attn} | Seed: {args.seed}")
print(f"🖼️ mmproj: {args.mmproj_path or '<none>'}")
print(f"🎛 Defaults -> max_tokens: {generation_defaults['max_tokens']}, temp: {generation_defaults['temperature']}, top_k: {generation_defaults['top_k']}, repeat_penalty: {generation_defaults['repeat_penalty']}, min_p: {generation_defaults['min_p']}, top_p: {generation_defaults['top_p']}")
print("\n💡 Ejemplo de uso con curl:")
print(f"""
curl -X POST http://{args.host}:{args.port}/v1/chat/completions \\