Files
llama_cpp_local/main.py
T

478 lines
18 KiB
Python

import sys
import time
import uuid
from typing import List, Optional, Union
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
from llama_cpp import Llama
def str_to_bool(value: str) -> bool:
"""Parsea valores booleanos desde la CLI."""
if isinstance(value, bool):
return value
lowered = value.strip().lower()
if lowered in {"true", "1", "yes", "y", "t"}:
return True
if lowered in {"false", "0", "no", "n", "f"}:
return False
raise ValueError(f"Valor booleano invalido: {value}")
# Modelos de datos para la API compatible con OpenAI
class ChatCompletionMessage(BaseModel):
role: str = Field(..., description="El rol del mensaje: 'system', 'user', o 'assistant'")
content: Union[str, List["MessageContentPart"]] = Field(..., description="Contenido del mensaje (texto o partes multimodales)")
class ChatCompletionRequest(BaseModel):
model: str = Field(default="llama", description="Nombre del modelo")
messages: List[ChatCompletionMessage] = Field(..., description="Lista de mensajes")
max_tokens: Optional[int] = Field(default=None, description="Máximo número de tokens a generar")
temperature: Optional[float] = Field(default=None, description="Temperatura para el muestreo")
top_p: Optional[float] = Field(default=None, description="Top-p para el muestreo")
top_k: Optional[int] = Field(default=None, description="Top-k para el muestreo")
repeat_penalty: Optional[float] = Field(default=None, description="Penalización por repetición")
min_p: Optional[float] = Field(default=None, description="Umbral mínimo de probabilidad (min_p sampling)")
stream: Optional[bool] = Field(default=False, description="Si devolver respuesta en streaming")
class MessageContentPart(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[str] = None
# Resolver forward refs
ChatCompletionMessage.model_rebuild()
def normalize_image_input(raw: str) -> str:
"""Normaliza imágenes a data URI; si llega base64 simple, se envuelve como data:image/png;base64."""
if raw.startswith("data:"):
return raw
if "://" in raw:
# En este entorno no se hace fetch remoto
raise HTTPException(status_code=400, detail="Solo se aceptan imágenes en data URI o base64 inline")
# Asumimos base64 puro
return f"data:image/png;base64,{raw}"
def likely_vision_model(model_path: str) -> bool:
"""Heurística simple para detectar modelos multimodales."""
lowered = Path(model_path).name.lower()
return any(key in lowered for key in ("vl", "vision", "llava", "mmproj"))
class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: str
class ChatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
class ModelInfo(BaseModel):
id: str
object: str = "model"
created: int
owned_by: str = "llama-cpp"
class ModelsResponse(BaseModel):
object: str = "list"
data: List[ModelInfo]
@dataclass
class LlamaAPI:
"""Clase para manejar la API de Llama"""
def __init__(self, model_path: str, mmproj_path: Optional[str] = None, **kwargs):
if not Path(model_path).exists():
raise FileNotFoundError(f"El archivo del modelo no existe: {model_path}")
if mmproj_path:
if not Path(mmproj_path).exists():
raise FileNotFoundError(f"El archivo mmproj no existe: {mmproj_path}")
kwargs["mmproj_path"] = mmproj_path
print(f"Cargando modelo desde: {model_path}")
# Parámetros por defecto para llama.cpp
default_params = {
"n_ctx": 4096,
"n_batch": 512,
"n_threads": None, # Usar todos los threads disponibles
"verbose": False,
"n_gpu_layers": -1, # Usar todas las capas en GPU (-1 = auto)
"main_gpu": 0, # GPU principal a usar
"split_mode": 1, # Modo de división entre GPUs
"rope_freq_base": 10000.0,
"rope_freq_scale": 1.0,
"offload_kv": True,
"use_mmap": True,
"use_mlock": False,
"seed": 0,
"flash_attn": False,
}
# Actualizar con parámetros personalizados
default_params.update(kwargs)
try:
self.llama = Llama(model_path=model_path, **default_params)
print("Modelo cargado exitosamente!")
print(f"GPU layers: {default_params.get('n_gpu_layers', 'N/A')}")
print(f"Main GPU: {default_params.get('main_gpu', 'N/A')}")
except Exception as e:
vision_hint = likely_vision_model(model_path)
if vision_hint and not mmproj_path:
raise RuntimeError(
"Fallo al cargar modelo multimodal. Se requiere un archivo mmproj. "
"Define --mmproj-path o coloca un .mmproj en model_choice/."
) from e
print(f"Error al cargar el modelo: {e}")
print("Nota: Si tienes problemas con CUDA, intenta instalar: pip install llama-cpp-python[cublas]")
raise
def generate_chat_completion(self, messages: List[ChatCompletionMessage],
max_tokens: int = 2048,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 40,
repeat_penalty: float = 1.1,
min_p: float = 0.05) -> str:
"""Genera una respuesta de chat usando el modelo"""
# Formatear los mensajes en un prompt
prompt = self._format_messages_to_prompt(messages)
images = self._collect_images(messages)
try:
# Generar respuesta usando llama.cpp
output = self.llama(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
min_p=min_p,
echo=False,
images=images if images else None,
stop=["</s>", "<|im_end|>", "<|endoftext|>"]
)
return output['choices'][0]['text'].strip()
except Exception as e:
print(f"Error durante la generación: {e}")
raise HTTPException(status_code=500, detail=f"Error en la generación: {str(e)}")
def _format_messages_to_prompt(self, messages: List[ChatCompletionMessage]) -> str:
"""Convierte la lista de mensajes en un prompt formateado"""
prompt_parts = []
for message in messages:
text_content = self._extract_text_from_content(message.content)
if message.role == "system":
prompt_parts.append(f"Sistema: {text_content}")
elif message.role == "user":
prompt_parts.append(f"Usuario: {text_content}")
elif message.role == "assistant":
prompt_parts.append(f"Asistente: {text_content}")
prompt_parts.append("Asistente:")
return "\n".join(prompt_parts)
def _extract_text_from_content(self, content: Union[str, List["MessageContentPart"]]) -> str:
"""Extrae el texto de un contenido que puede ser string o lista multimodal."""
if isinstance(content, str):
return content
texts = []
for part in content:
if part.type == "text" and part.text:
texts.append(part.text)
elif part.type == "image_url":
texts.append("[imagen]")
return " ".join(texts).strip()
def _collect_images(self, messages: List[ChatCompletionMessage]) -> List[str]:
"""Recolecta imágenes (data URI) de los mensajes."""
images: List[str] = []
for message in messages:
if isinstance(message.content, str):
continue
for part in message.content:
if part.type == "image_url" and part.image_url:
images.append(normalize_image_input(part.image_url))
return images
# Instancia global de la API
llama_api = None
generation_defaults = {
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"repeat_penalty": 1.1,
"min_p": 0.05,
}
# Crear la aplicación FastAPI
app = FastAPI(
title="Llama.cpp OpenAI Compatible API",
description="API compatible con OpenAI para modelos GGUF usando llama.cpp",
version="1.0.0"
)
# Configurar CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Endpoint raíz"""
return {"message": "Llama.cpp OpenAI Compatible API", "status": "running"}
@app.get("/v1/models", response_model=ModelsResponse)
async def list_models():
"""Lista los modelos disponibles"""
return ModelsResponse(
data=[
ModelInfo(
id="llama",
created=int(time.time()),
owned_by="llama-cpp"
)
]
)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
"""Crea una completion de chat"""
if llama_api is None:
raise HTTPException(status_code=503, detail="Modelo no cargado")
if not request.messages:
raise HTTPException(status_code=400, detail="Se requiere al menos un mensaje")
try:
# Generar respuesta
temperature = request.temperature if request.temperature is not None else generation_defaults["temperature"]
top_p = request.top_p if request.top_p is not None else generation_defaults["top_p"]
top_k = request.top_k if request.top_k is not None else generation_defaults["top_k"]
repeat_penalty = request.repeat_penalty if request.repeat_penalty is not None else generation_defaults["repeat_penalty"]
min_p = request.min_p if request.min_p is not None else generation_defaults["min_p"]
max_tokens = request.max_tokens if request.max_tokens is not None else generation_defaults["max_tokens"]
response_text = llama_api.generate_chat_completion(
messages=request.messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
min_p=min_p
)
# Crear respuesta en formato OpenAI
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
response = ChatCompletionResponse(
id=completion_id,
created=int(time.time()),
model=request.model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=response_text
),
finish_reason="stop"
)
],
usage=ChatCompletionUsage(
prompt_tokens=len(str(request.messages)) // 4, # Estimación aproximada
completion_tokens=len(response_text) // 4, # Estimación aproximada
total_tokens=(len(str(request.messages)) + len(response_text)) // 4
)
)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error procesando la solicitud: {str(e)}")
def main():
"""Función principal para iniciar el servidor"""
import argparse
parser = argparse.ArgumentParser(description="Servidor API compatible con OpenAI usando llama.cpp")
parser.add_argument("--model-path", "-m", type=str, required=True,
help="Ruta al archivo del modelo GGUF")
parser.add_argument("--host", type=str, default="0.0.0.0",
help="Host para el servidor (default: 0.0.0.0)")
parser.add_argument("--port", type=int, default=8000,
help="Puerto para el servidor (default: 8000)")
parser.add_argument("--n-ctx", type=int, default=4096,
help="Tamaño del contexto (default: 4096)")
parser.add_argument("--n-batch", type=int, default=512,
help="Tamaño del batch (default: 512)")
parser.add_argument("--eval-batch-size", type=int, default=None,
help="Tamaño del batch de evaluación (sobrescribe n-batch si se setea)")
parser.add_argument("--n-threads", type=int, default=None,
help="Número de threads (default: auto)")
parser.add_argument("--n-gpu-layers", type=int, default=-1,
help="Número de capas a cargar en GPU (-1 = todas, 0 = ninguna, default: -1)")
parser.add_argument("--main-gpu", type=int, default=0,
help="GPU principal a usar (default: 0)")
parser.add_argument("--split-mode", type=int, default=1,
help="Modo de división entre GPUs (default: 1)")
parser.add_argument("--rope-freq-base", type=float, default=10000.0,
help="Base de frecuencia ROPE (default: 10000)")
parser.add_argument("--rope-freq-scale", type=float, default=1.0,
help="Escala de frecuencia ROPE (default: 1.0)")
parser.add_argument("--offload-kv-cache", type=str_to_bool, default=True,
help="Offload del KV cache a GPU (default: true)")
parser.add_argument("--keep-model-in-memory", type=str_to_bool, default=False,
help="Usa mlock para mantener el modelo en RAM (default: false)")
parser.add_argument("--try-mmap", type=str_to_bool, default=True,
help="Usar mmap para mapear el modelo (default: true)")
parser.add_argument("--seed", type=int, default=0,
help="Seed para la generación (0 = aleatorio)")
parser.add_argument("--flash-attn", type=str_to_bool, default=False,
help="Habilitar Flash Attention si está disponible")
# Defaults de generación
parser.add_argument("--default-max-tokens", type=int, default=2048,
help="Límite de tokens de respuesta por defecto")
parser.add_argument("--default-temperature", type=float, default=0.7,
help="Temperatura por defecto")
parser.add_argument("--default-top-k", type=int, default=40,
help="Top-k por defecto")
parser.add_argument("--default-repeat-penalty", type=float, default=1.1,
help="Penalización de repetición por defecto")
parser.add_argument("--default-min-p", type=float, default=0.05,
help="Min-p por defecto")
parser.add_argument("--default-top-p", type=float, default=0.9,
help="Top-p por defecto")
parser.add_argument("--mmproj-path", type=str, default=None,
help="Ruta al proyector multimodal (mmproj) si el modelo lo requiere")
args = parser.parse_args()
# Validar que el archivo del modelo existe
if not Path(args.model_path).exists():
print(f"Error: El archivo del modelo no existe: {args.model_path}")
sys.exit(1)
# Configurar parámetros del modelo
model_params = {
"n_ctx": args.n_ctx,
"n_batch": args.eval_batch_size or args.n_batch,
"n_gpu_layers": args.n_gpu_layers,
"main_gpu": args.main_gpu,
"split_mode": args.split_mode,
"rope_freq_base": args.rope_freq_base,
"rope_freq_scale": args.rope_freq_scale,
"offload_kv": args.offload_kv_cache,
"use_mmap": args.try_mmap,
"use_mlock": args.keep_model_in_memory,
"seed": args.seed,
"flash_attn": args.flash_attn,
}
if args.n_threads:
model_params["n_threads"] = args.n_threads
# Defaults de generación para usar si la request no los especifica
global generation_defaults
generation_defaults = {
"max_tokens": args.default_max_tokens,
"temperature": args.default_temperature,
"top_p": args.default_top_p,
"top_k": args.default_top_k,
"repeat_penalty": args.default_repeat_penalty,
"min_p": args.default_min_p,
}
try:
# Inicializar la API de Llama
global llama_api
llama_api = LlamaAPI(args.model_path, mmproj_path=args.mmproj_path, **model_params)
print(f"\n🚀 Servidor iniciado en http://{args.host}:{args.port}")
print(f"📚 Documentación disponible en http://{args.host}:{args.port}/docs")
print(f"🤖 Modelo cargado desde: {args.model_path}")
print(f"🎮 GPU layers: {args.n_gpu_layers} (usa -1 para todas las capas)")
print(f"🎯 Main GPU: {args.main_gpu}")
print(f"🧠 ROPE base/scale: {args.rope_freq_base} / {args.rope_freq_scale}")
print(f"🧮 Offload KV cache: {args.offload_kv_cache} | mmap: {args.try_mmap} | mlock: {args.keep_model_in_memory}")
print(f"⚡ Flash Attn: {args.flash_attn} | Seed: {args.seed}")
print(f"🖼️ mmproj: {args.mmproj_path or '<none>'}")
print(f"🎛 Defaults -> max_tokens: {generation_defaults['max_tokens']}, temp: {generation_defaults['temperature']}, top_k: {generation_defaults['top_k']}, repeat_penalty: {generation_defaults['repeat_penalty']}, min_p: {generation_defaults['min_p']}, top_p: {generation_defaults['top_p']}")
print("\n💡 Ejemplo de uso con curl:")
print(f"""
curl -X POST http://{args.host}:{args.port}/v1/chat/completions \\
-H "Content-Type: application/json" \\
-d '{{
"model": "llama",
"messages": [
{{"role": "user", "content": "Hola, ¿cómo estás?"}}
],
"max_tokens": 100,
"temperature": 0.7
}}'
""")
# Iniciar el servidor
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info"
)
except KeyboardInterrupt:
print("\n👋 Servidor detenido por el usuario")
except Exception as e:
print(f"❌ Error al iniciar el servidor: {e}")
sys.exit(1)
if __name__ == "__main__":
main()