478 lines
18 KiB
Python
478 lines
18 KiB
Python
import sys
|
|
import time
|
|
import uuid
|
|
from typing import List, Optional, Union
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
import uvicorn
|
|
from llama_cpp import Llama
|
|
|
|
|
|
def str_to_bool(value: str) -> bool:
|
|
"""Parsea valores booleanos desde la CLI."""
|
|
if isinstance(value, bool):
|
|
return value
|
|
lowered = value.strip().lower()
|
|
if lowered in {"true", "1", "yes", "y", "t"}:
|
|
return True
|
|
if lowered in {"false", "0", "no", "n", "f"}:
|
|
return False
|
|
raise ValueError(f"Valor booleano invalido: {value}")
|
|
|
|
|
|
# Modelos de datos para la API compatible con OpenAI
|
|
class ChatCompletionMessage(BaseModel):
|
|
role: str = Field(..., description="El rol del mensaje: 'system', 'user', o 'assistant'")
|
|
content: Union[str, List["MessageContentPart"]] = Field(..., description="Contenido del mensaje (texto o partes multimodales)")
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str = Field(default="llama", description="Nombre del modelo")
|
|
messages: List[ChatCompletionMessage] = Field(..., description="Lista de mensajes")
|
|
max_tokens: Optional[int] = Field(default=None, description="Máximo número de tokens a generar")
|
|
temperature: Optional[float] = Field(default=None, description="Temperatura para el muestreo")
|
|
top_p: Optional[float] = Field(default=None, description="Top-p para el muestreo")
|
|
top_k: Optional[int] = Field(default=None, description="Top-k para el muestreo")
|
|
repeat_penalty: Optional[float] = Field(default=None, description="Penalización por repetición")
|
|
min_p: Optional[float] = Field(default=None, description="Umbral mínimo de probabilidad (min_p sampling)")
|
|
stream: Optional[bool] = Field(default=False, description="Si devolver respuesta en streaming")
|
|
|
|
|
|
class MessageContentPart(BaseModel):
|
|
type: Literal["text", "image_url"]
|
|
text: Optional[str] = None
|
|
image_url: Optional[str] = None
|
|
|
|
|
|
# Resolver forward refs
|
|
ChatCompletionMessage.model_rebuild()
|
|
|
|
|
|
def normalize_image_input(raw: str) -> str:
|
|
"""Normaliza imágenes a data URI; si llega base64 simple, se envuelve como data:image/png;base64."""
|
|
if raw.startswith("data:"):
|
|
return raw
|
|
if "://" in raw:
|
|
# En este entorno no se hace fetch remoto
|
|
raise HTTPException(status_code=400, detail="Solo se aceptan imágenes en data URI o base64 inline")
|
|
# Asumimos base64 puro
|
|
return f"data:image/png;base64,{raw}"
|
|
|
|
|
|
def likely_vision_model(model_path: str) -> bool:
|
|
"""Heurística simple para detectar modelos multimodales."""
|
|
lowered = Path(model_path).name.lower()
|
|
return any(key in lowered for key in ("vl", "vision", "llava", "mmproj"))
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel):
|
|
index: int
|
|
message: ChatCompletionMessage
|
|
finish_reason: str
|
|
|
|
|
|
class ChatCompletionUsage(BaseModel):
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
total_tokens: int
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
id: str
|
|
object: str = "chat.completion"
|
|
created: int
|
|
model: str
|
|
choices: List[ChatCompletionChoice]
|
|
usage: ChatCompletionUsage
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
id: str
|
|
object: str = "model"
|
|
created: int
|
|
owned_by: str = "llama-cpp"
|
|
|
|
|
|
class ModelsResponse(BaseModel):
|
|
object: str = "list"
|
|
data: List[ModelInfo]
|
|
|
|
|
|
@dataclass
|
|
class LlamaAPI:
|
|
"""Clase para manejar la API de Llama"""
|
|
|
|
def __init__(self, model_path: str, mmproj_path: Optional[str] = None, **kwargs):
|
|
if not Path(model_path).exists():
|
|
raise FileNotFoundError(f"El archivo del modelo no existe: {model_path}")
|
|
if mmproj_path:
|
|
if not Path(mmproj_path).exists():
|
|
raise FileNotFoundError(f"El archivo mmproj no existe: {mmproj_path}")
|
|
kwargs["mmproj_path"] = mmproj_path
|
|
|
|
print(f"Cargando modelo desde: {model_path}")
|
|
|
|
# Parámetros por defecto para llama.cpp
|
|
default_params = {
|
|
"n_ctx": 4096,
|
|
"n_batch": 512,
|
|
"n_threads": None, # Usar todos los threads disponibles
|
|
"verbose": False,
|
|
"n_gpu_layers": -1, # Usar todas las capas en GPU (-1 = auto)
|
|
"main_gpu": 0, # GPU principal a usar
|
|
"split_mode": 1, # Modo de división entre GPUs
|
|
"rope_freq_base": 10000.0,
|
|
"rope_freq_scale": 1.0,
|
|
"offload_kv": True,
|
|
"use_mmap": True,
|
|
"use_mlock": False,
|
|
"seed": 0,
|
|
"flash_attn": False,
|
|
}
|
|
|
|
# Actualizar con parámetros personalizados
|
|
default_params.update(kwargs)
|
|
|
|
try:
|
|
self.llama = Llama(model_path=model_path, **default_params)
|
|
print("Modelo cargado exitosamente!")
|
|
print(f"GPU layers: {default_params.get('n_gpu_layers', 'N/A')}")
|
|
print(f"Main GPU: {default_params.get('main_gpu', 'N/A')}")
|
|
except Exception as e:
|
|
vision_hint = likely_vision_model(model_path)
|
|
if vision_hint and not mmproj_path:
|
|
raise RuntimeError(
|
|
"Fallo al cargar modelo multimodal. Se requiere un archivo mmproj. "
|
|
"Define --mmproj-path o coloca un .mmproj en model_choice/."
|
|
) from e
|
|
print(f"Error al cargar el modelo: {e}")
|
|
print("Nota: Si tienes problemas con CUDA, intenta instalar: pip install llama-cpp-python[cublas]")
|
|
raise
|
|
|
|
def generate_chat_completion(self, messages: List[ChatCompletionMessage],
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.9,
|
|
top_k: int = 40,
|
|
repeat_penalty: float = 1.1,
|
|
min_p: float = 0.05) -> str:
|
|
"""Genera una respuesta de chat usando el modelo"""
|
|
|
|
# Formatear los mensajes en un prompt
|
|
prompt = self._format_messages_to_prompt(messages)
|
|
images = self._collect_images(messages)
|
|
|
|
try:
|
|
# Generar respuesta usando llama.cpp
|
|
output = self.llama(
|
|
prompt,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
repeat_penalty=repeat_penalty,
|
|
min_p=min_p,
|
|
echo=False,
|
|
images=images if images else None,
|
|
stop=["</s>", "<|im_end|>", "<|endoftext|>"]
|
|
)
|
|
|
|
return output['choices'][0]['text'].strip()
|
|
|
|
except Exception as e:
|
|
print(f"Error durante la generación: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Error en la generación: {str(e)}")
|
|
|
|
def _format_messages_to_prompt(self, messages: List[ChatCompletionMessage]) -> str:
|
|
"""Convierte la lista de mensajes en un prompt formateado"""
|
|
|
|
prompt_parts = []
|
|
|
|
for message in messages:
|
|
text_content = self._extract_text_from_content(message.content)
|
|
if message.role == "system":
|
|
prompt_parts.append(f"Sistema: {text_content}")
|
|
elif message.role == "user":
|
|
prompt_parts.append(f"Usuario: {text_content}")
|
|
elif message.role == "assistant":
|
|
prompt_parts.append(f"Asistente: {text_content}")
|
|
|
|
prompt_parts.append("Asistente:")
|
|
return "\n".join(prompt_parts)
|
|
|
|
def _extract_text_from_content(self, content: Union[str, List["MessageContentPart"]]) -> str:
|
|
"""Extrae el texto de un contenido que puede ser string o lista multimodal."""
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
texts = []
|
|
for part in content:
|
|
if part.type == "text" and part.text:
|
|
texts.append(part.text)
|
|
elif part.type == "image_url":
|
|
texts.append("[imagen]")
|
|
return " ".join(texts).strip()
|
|
|
|
def _collect_images(self, messages: List[ChatCompletionMessage]) -> List[str]:
|
|
"""Recolecta imágenes (data URI) de los mensajes."""
|
|
images: List[str] = []
|
|
for message in messages:
|
|
if isinstance(message.content, str):
|
|
continue
|
|
for part in message.content:
|
|
if part.type == "image_url" and part.image_url:
|
|
images.append(normalize_image_input(part.image_url))
|
|
return images
|
|
|
|
|
|
# Instancia global de la API
|
|
llama_api = None
|
|
generation_defaults = {
|
|
"max_tokens": 2048,
|
|
"temperature": 0.7,
|
|
"top_p": 0.9,
|
|
"top_k": 40,
|
|
"repeat_penalty": 1.1,
|
|
"min_p": 0.05,
|
|
}
|
|
|
|
# Crear la aplicación FastAPI
|
|
app = FastAPI(
|
|
title="Llama.cpp OpenAI Compatible API",
|
|
description="API compatible con OpenAI para modelos GGUF usando llama.cpp",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Configurar CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Endpoint raíz"""
|
|
return {"message": "Llama.cpp OpenAI Compatible API", "status": "running"}
|
|
|
|
|
|
@app.get("/v1/models", response_model=ModelsResponse)
|
|
async def list_models():
|
|
"""Lista los modelos disponibles"""
|
|
return ModelsResponse(
|
|
data=[
|
|
ModelInfo(
|
|
id="llama",
|
|
created=int(time.time()),
|
|
owned_by="llama-cpp"
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
|
async def create_chat_completion(request: ChatCompletionRequest):
|
|
"""Crea una completion de chat"""
|
|
|
|
if llama_api is None:
|
|
raise HTTPException(status_code=503, detail="Modelo no cargado")
|
|
|
|
if not request.messages:
|
|
raise HTTPException(status_code=400, detail="Se requiere al menos un mensaje")
|
|
|
|
try:
|
|
# Generar respuesta
|
|
temperature = request.temperature if request.temperature is not None else generation_defaults["temperature"]
|
|
top_p = request.top_p if request.top_p is not None else generation_defaults["top_p"]
|
|
top_k = request.top_k if request.top_k is not None else generation_defaults["top_k"]
|
|
repeat_penalty = request.repeat_penalty if request.repeat_penalty is not None else generation_defaults["repeat_penalty"]
|
|
min_p = request.min_p if request.min_p is not None else generation_defaults["min_p"]
|
|
max_tokens = request.max_tokens if request.max_tokens is not None else generation_defaults["max_tokens"]
|
|
|
|
response_text = llama_api.generate_chat_completion(
|
|
messages=request.messages,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
repeat_penalty=repeat_penalty,
|
|
min_p=min_p
|
|
)
|
|
|
|
# Crear respuesta en formato OpenAI
|
|
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
|
|
|
|
response = ChatCompletionResponse(
|
|
id=completion_id,
|
|
created=int(time.time()),
|
|
model=request.model,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
role="assistant",
|
|
content=response_text
|
|
),
|
|
finish_reason="stop"
|
|
)
|
|
],
|
|
usage=ChatCompletionUsage(
|
|
prompt_tokens=len(str(request.messages)) // 4, # Estimación aproximada
|
|
completion_tokens=len(response_text) // 4, # Estimación aproximada
|
|
total_tokens=(len(str(request.messages)) + len(response_text)) // 4
|
|
)
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error procesando la solicitud: {str(e)}")
|
|
|
|
|
|
def main():
|
|
"""Función principal para iniciar el servidor"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Servidor API compatible con OpenAI usando llama.cpp")
|
|
parser.add_argument("--model-path", "-m", type=str, required=True,
|
|
help="Ruta al archivo del modelo GGUF")
|
|
parser.add_argument("--host", type=str, default="0.0.0.0",
|
|
help="Host para el servidor (default: 0.0.0.0)")
|
|
parser.add_argument("--port", type=int, default=8000,
|
|
help="Puerto para el servidor (default: 8000)")
|
|
parser.add_argument("--n-ctx", type=int, default=4096,
|
|
help="Tamaño del contexto (default: 4096)")
|
|
parser.add_argument("--n-batch", type=int, default=512,
|
|
help="Tamaño del batch (default: 512)")
|
|
parser.add_argument("--eval-batch-size", type=int, default=None,
|
|
help="Tamaño del batch de evaluación (sobrescribe n-batch si se setea)")
|
|
parser.add_argument("--n-threads", type=int, default=None,
|
|
help="Número de threads (default: auto)")
|
|
parser.add_argument("--n-gpu-layers", type=int, default=-1,
|
|
help="Número de capas a cargar en GPU (-1 = todas, 0 = ninguna, default: -1)")
|
|
parser.add_argument("--main-gpu", type=int, default=0,
|
|
help="GPU principal a usar (default: 0)")
|
|
parser.add_argument("--split-mode", type=int, default=1,
|
|
help="Modo de división entre GPUs (default: 1)")
|
|
parser.add_argument("--rope-freq-base", type=float, default=10000.0,
|
|
help="Base de frecuencia ROPE (default: 10000)")
|
|
parser.add_argument("--rope-freq-scale", type=float, default=1.0,
|
|
help="Escala de frecuencia ROPE (default: 1.0)")
|
|
parser.add_argument("--offload-kv-cache", type=str_to_bool, default=True,
|
|
help="Offload del KV cache a GPU (default: true)")
|
|
parser.add_argument("--keep-model-in-memory", type=str_to_bool, default=False,
|
|
help="Usa mlock para mantener el modelo en RAM (default: false)")
|
|
parser.add_argument("--try-mmap", type=str_to_bool, default=True,
|
|
help="Usar mmap para mapear el modelo (default: true)")
|
|
parser.add_argument("--seed", type=int, default=0,
|
|
help="Seed para la generación (0 = aleatorio)")
|
|
parser.add_argument("--flash-attn", type=str_to_bool, default=False,
|
|
help="Habilitar Flash Attention si está disponible")
|
|
# Defaults de generación
|
|
parser.add_argument("--default-max-tokens", type=int, default=2048,
|
|
help="Límite de tokens de respuesta por defecto")
|
|
parser.add_argument("--default-temperature", type=float, default=0.7,
|
|
help="Temperatura por defecto")
|
|
parser.add_argument("--default-top-k", type=int, default=40,
|
|
help="Top-k por defecto")
|
|
parser.add_argument("--default-repeat-penalty", type=float, default=1.1,
|
|
help="Penalización de repetición por defecto")
|
|
parser.add_argument("--default-min-p", type=float, default=0.05,
|
|
help="Min-p por defecto")
|
|
parser.add_argument("--default-top-p", type=float, default=0.9,
|
|
help="Top-p por defecto")
|
|
parser.add_argument("--mmproj-path", type=str, default=None,
|
|
help="Ruta al proyector multimodal (mmproj) si el modelo lo requiere")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validar que el archivo del modelo existe
|
|
if not Path(args.model_path).exists():
|
|
print(f"Error: El archivo del modelo no existe: {args.model_path}")
|
|
sys.exit(1)
|
|
|
|
# Configurar parámetros del modelo
|
|
model_params = {
|
|
"n_ctx": args.n_ctx,
|
|
"n_batch": args.eval_batch_size or args.n_batch,
|
|
"n_gpu_layers": args.n_gpu_layers,
|
|
"main_gpu": args.main_gpu,
|
|
"split_mode": args.split_mode,
|
|
"rope_freq_base": args.rope_freq_base,
|
|
"rope_freq_scale": args.rope_freq_scale,
|
|
"offload_kv": args.offload_kv_cache,
|
|
"use_mmap": args.try_mmap,
|
|
"use_mlock": args.keep_model_in_memory,
|
|
"seed": args.seed,
|
|
"flash_attn": args.flash_attn,
|
|
}
|
|
|
|
if args.n_threads:
|
|
model_params["n_threads"] = args.n_threads
|
|
|
|
# Defaults de generación para usar si la request no los especifica
|
|
global generation_defaults
|
|
generation_defaults = {
|
|
"max_tokens": args.default_max_tokens,
|
|
"temperature": args.default_temperature,
|
|
"top_p": args.default_top_p,
|
|
"top_k": args.default_top_k,
|
|
"repeat_penalty": args.default_repeat_penalty,
|
|
"min_p": args.default_min_p,
|
|
}
|
|
|
|
try:
|
|
# Inicializar la API de Llama
|
|
global llama_api
|
|
llama_api = LlamaAPI(args.model_path, mmproj_path=args.mmproj_path, **model_params)
|
|
|
|
print(f"\n🚀 Servidor iniciado en http://{args.host}:{args.port}")
|
|
print(f"📚 Documentación disponible en http://{args.host}:{args.port}/docs")
|
|
print(f"🤖 Modelo cargado desde: {args.model_path}")
|
|
print(f"🎮 GPU layers: {args.n_gpu_layers} (usa -1 para todas las capas)")
|
|
print(f"🎯 Main GPU: {args.main_gpu}")
|
|
print(f"🧠 ROPE base/scale: {args.rope_freq_base} / {args.rope_freq_scale}")
|
|
print(f"🧮 Offload KV cache: {args.offload_kv_cache} | mmap: {args.try_mmap} | mlock: {args.keep_model_in_memory}")
|
|
print(f"⚡ Flash Attn: {args.flash_attn} | Seed: {args.seed}")
|
|
print(f"🖼️ mmproj: {args.mmproj_path or '<none>'}")
|
|
print(f"🎛 Defaults -> max_tokens: {generation_defaults['max_tokens']}, temp: {generation_defaults['temperature']}, top_k: {generation_defaults['top_k']}, repeat_penalty: {generation_defaults['repeat_penalty']}, min_p: {generation_defaults['min_p']}, top_p: {generation_defaults['top_p']}")
|
|
print("\n💡 Ejemplo de uso con curl:")
|
|
print(f"""
|
|
curl -X POST http://{args.host}:{args.port}/v1/chat/completions \\
|
|
-H "Content-Type: application/json" \\
|
|
-d '{{
|
|
"model": "llama",
|
|
"messages": [
|
|
{{"role": "user", "content": "Hola, ¿cómo estás?"}}
|
|
],
|
|
"max_tokens": 100,
|
|
"temperature": 0.7
|
|
}}'
|
|
""")
|
|
|
|
# Iniciar el servidor
|
|
uvicorn.run(
|
|
app,
|
|
host=args.host,
|
|
port=args.port,
|
|
log_level="info"
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n👋 Servidor detenido por el usuario")
|
|
except Exception as e:
|
|
print(f"❌ Error al iniciar el servidor: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|