Files
llama_cpp_local/main.py
T
2025-11-03 00:20:04 +01:00

313 lines
10 KiB
Python

import sys
import time
import uuid
from typing import List, Optional
from dataclasses import dataclass
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
from llama_cpp import Llama
# Modelos de datos para la API compatible con OpenAI
class ChatCompletionMessage(BaseModel):
role: str = Field(..., description="El rol del mensaje: 'system', 'user', o 'assistant'")
content: str = Field(..., description="El contenido del mensaje")
class ChatCompletionRequest(BaseModel):
model: str = Field(default="llama", description="Nombre del modelo")
messages: List[ChatCompletionMessage] = Field(..., description="Lista de mensajes")
max_tokens: Optional[int] = Field(default=2048, description="Máximo número de tokens a generar")
temperature: Optional[float] = Field(default=0.7, description="Temperatura para el muestreo")
top_p: Optional[float] = Field(default=0.9, description="Top-p para el muestreo")
stream: Optional[bool] = Field(default=False, description="Si devolver respuesta en streaming")
class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: str
class ChatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
class ModelInfo(BaseModel):
id: str
object: str = "model"
created: int
owned_by: str = "llama-cpp"
class ModelsResponse(BaseModel):
object: str = "list"
data: List[ModelInfo]
@dataclass
class LlamaAPI:
"""Clase para manejar la API de Llama"""
def __init__(self, model_path: str, **kwargs):
if not Path(model_path).exists():
raise FileNotFoundError(f"El archivo del modelo no existe: {model_path}")
print(f"Cargando modelo desde: {model_path}")
# Parámetros por defecto para llama.cpp
default_params = {
"n_ctx": 4096,
"n_batch": 512,
"n_threads": None, # Usar todos los threads disponibles
"verbose": False,
"n_gpu_layers": -1, # Usar todas las capas en GPU (-1 = auto)
"main_gpu": 0, # GPU principal a usar
"split_mode": 1, # Modo de división entre GPUs
}
# Actualizar con parámetros personalizados
default_params.update(kwargs)
try:
self.llama = Llama(model_path=model_path, **default_params)
print("Modelo cargado exitosamente!")
print(f"GPU layers: {default_params.get('n_gpu_layers', 'N/A')}")
print(f"Main GPU: {default_params.get('main_gpu', 'N/A')}")
except Exception as e:
print(f"Error al cargar el modelo: {e}")
print("Nota: Si tienes problemas con CUDA, intenta instalar: pip install llama-cpp-python[cublas]")
raise
def generate_chat_completion(self, messages: List[ChatCompletionMessage],
max_tokens: int = 2048,
temperature: float = 0.7,
top_p: float = 0.9) -> str:
"""Genera una respuesta de chat usando el modelo"""
# Formatear los mensajes en un prompt
prompt = self._format_messages_to_prompt(messages)
try:
# Generar respuesta usando llama.cpp
output = self.llama(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
echo=False,
stop=["</s>", "<|im_end|>", "<|endoftext|>"]
)
return output['choices'][0]['text'].strip()
except Exception as e:
print(f"Error durante la generación: {e}")
raise HTTPException(status_code=500, detail=f"Error en la generación: {str(e)}")
def _format_messages_to_prompt(self, messages: List[ChatCompletionMessage]) -> str:
"""Convierte la lista de mensajes en un prompt formateado"""
prompt_parts = []
for message in messages:
if message.role == "system":
prompt_parts.append(f"Sistema: {message.content}")
elif message.role == "user":
prompt_parts.append(f"Usuario: {message.content}")
elif message.role == "assistant":
prompt_parts.append(f"Asistente: {message.content}")
prompt_parts.append("Asistente:")
return "\n".join(prompt_parts)
# Instancia global de la API
llama_api = None
# Crear la aplicación FastAPI
app = FastAPI(
title="Llama.cpp OpenAI Compatible API",
description="API compatible con OpenAI para modelos GGUF usando llama.cpp",
version="1.0.0"
)
# Configurar CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Endpoint raíz"""
return {"message": "Llama.cpp OpenAI Compatible API", "status": "running"}
@app.get("/v1/models", response_model=ModelsResponse)
async def list_models():
"""Lista los modelos disponibles"""
return ModelsResponse(
data=[
ModelInfo(
id="llama",
created=int(time.time()),
owned_by="llama-cpp"
)
]
)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
"""Crea una completion de chat"""
if llama_api is None:
raise HTTPException(status_code=503, detail="Modelo no cargado")
if not request.messages:
raise HTTPException(status_code=400, detail="Se requiere al menos un mensaje")
try:
# Generar respuesta
response_text = llama_api.generate_chat_completion(
messages=request.messages,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature or 0.7,
top_p=request.top_p or 0.9
)
# Crear respuesta en formato OpenAI
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
response = ChatCompletionResponse(
id=completion_id,
created=int(time.time()),
model=request.model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=response_text
),
finish_reason="stop"
)
],
usage=ChatCompletionUsage(
prompt_tokens=len(str(request.messages)) // 4, # Estimación aproximada
completion_tokens=len(response_text) // 4, # Estimación aproximada
total_tokens=(len(str(request.messages)) + len(response_text)) // 4
)
)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error procesando la solicitud: {str(e)}")
def main():
"""Función principal para iniciar el servidor"""
import argparse
parser = argparse.ArgumentParser(description="Servidor API compatible con OpenAI usando llama.cpp")
parser.add_argument("--model-path", "-m", type=str, required=True,
help="Ruta al archivo del modelo GGUF")
parser.add_argument("--host", type=str, default="0.0.0.0",
help="Host para el servidor (default: 0.0.0.0)")
parser.add_argument("--port", type=int, default=8000,
help="Puerto para el servidor (default: 8000)")
parser.add_argument("--n-ctx", type=int, default=4096,
help="Tamaño del contexto (default: 4096)")
parser.add_argument("--n-batch", type=int, default=512,
help="Tamaño del batch (default: 512)")
parser.add_argument("--n-threads", type=int, default=None,
help="Número de threads (default: auto)")
parser.add_argument("--n-gpu-layers", type=int, default=-1,
help="Número de capas a cargar en GPU (-1 = todas, 0 = ninguna, default: -1)")
parser.add_argument("--main-gpu", type=int, default=0,
help="GPU principal a usar (default: 0)")
parser.add_argument("--split-mode", type=int, default=1,
help="Modo de división entre GPUs (default: 1)")
args = parser.parse_args()
# Validar que el archivo del modelo existe
if not Path(args.model_path).exists():
print(f"Error: El archivo del modelo no existe: {args.model_path}")
sys.exit(1)
# Configurar parámetros del modelo
model_params = {
"n_ctx": args.n_ctx,
"n_batch": args.n_batch,
"n_gpu_layers": args.n_gpu_layers,
"main_gpu": args.main_gpu,
"split_mode": args.split_mode,
}
if args.n_threads:
model_params["n_threads"] = args.n_threads
try:
# Inicializar la API de Llama
global llama_api
llama_api = LlamaAPI(args.model_path, **model_params)
print(f"\n🚀 Servidor iniciado en http://{args.host}:{args.port}")
print(f"📚 Documentación disponible en http://{args.host}:{args.port}/docs")
print(f"🤖 Modelo cargado desde: {args.model_path}")
print(f"🎮 GPU layers: {args.n_gpu_layers} (usa -1 para todas las capas)")
print(f"🎯 Main GPU: {args.main_gpu}")
print("\n💡 Ejemplo de uso con curl:")
print(f"""
curl -X POST http://{args.host}:{args.port}/v1/chat/completions \\
-H "Content-Type: application/json" \\
-d '{{
"model": "llama",
"messages": [
{{"role": "user", "content": "Hola, ¿cómo estás?"}}
],
"max_tokens": 100,
"temperature": 0.7
}}'
""")
# Iniciar el servidor
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info"
)
except KeyboardInterrupt:
print("\n👋 Servidor detenido por el usuario")
except Exception as e:
print(f"❌ Error al iniciar el servidor: {e}")
sys.exit(1)
if __name__ == "__main__":
main()