import sys import time import uuid from typing import List, Optional from dataclasses import dataclass from pathlib import Path from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import uvicorn from llama_cpp import Llama # Modelos de datos para la API compatible con OpenAI class ChatCompletionMessage(BaseModel): role: str = Field(..., description="El rol del mensaje: 'system', 'user', o 'assistant'") content: str = Field(..., description="El contenido del mensaje") class ChatCompletionRequest(BaseModel): model: str = Field(default="llama", description="Nombre del modelo") messages: List[ChatCompletionMessage] = Field(..., description="Lista de mensajes") max_tokens: Optional[int] = Field(default=2048, description="Máximo número de tokens a generar") temperature: Optional[float] = Field(default=0.7, description="Temperatura para el muestreo") top_p: Optional[float] = Field(default=0.9, description="Top-p para el muestreo") stream: Optional[bool] = Field(default=False, description="Si devolver respuesta en streaming") class ChatCompletionChoice(BaseModel): index: int message: ChatCompletionMessage finish_reason: str class ChatCompletionUsage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionChoice] usage: ChatCompletionUsage class ModelInfo(BaseModel): id: str object: str = "model" created: int owned_by: str = "llama-cpp" class ModelsResponse(BaseModel): object: str = "list" data: List[ModelInfo] @dataclass class LlamaAPI: """Clase para manejar la API de Llama""" def __init__(self, model_path: str, **kwargs): if not Path(model_path).exists(): raise FileNotFoundError(f"El archivo del modelo no existe: {model_path}") print(f"Cargando modelo desde: {model_path}") # Parámetros por defecto para llama.cpp default_params = { "n_ctx": 4096, "n_batch": 512, "n_threads": None, # Usar todos los threads disponibles "verbose": False, "n_gpu_layers": -1, # Usar todas las capas en GPU (-1 = auto) "main_gpu": 0, # GPU principal a usar "split_mode": 1, # Modo de división entre GPUs } # Actualizar con parámetros personalizados default_params.update(kwargs) try: self.llama = Llama(model_path=model_path, **default_params) print("Modelo cargado exitosamente!") print(f"GPU layers: {default_params.get('n_gpu_layers', 'N/A')}") print(f"Main GPU: {default_params.get('main_gpu', 'N/A')}") except Exception as e: print(f"Error al cargar el modelo: {e}") print("Nota: Si tienes problemas con CUDA, intenta instalar: pip install llama-cpp-python[cublas]") raise def generate_chat_completion(self, messages: List[ChatCompletionMessage], max_tokens: int = 2048, temperature: float = 0.7, top_p: float = 0.9) -> str: """Genera una respuesta de chat usando el modelo""" # Formatear los mensajes en un prompt prompt = self._format_messages_to_prompt(messages) try: # Generar respuesta usando llama.cpp output = self.llama( prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, echo=False, stop=["", "<|im_end|>", "<|endoftext|>"] ) return output['choices'][0]['text'].strip() except Exception as e: print(f"Error durante la generación: {e}") raise HTTPException(status_code=500, detail=f"Error en la generación: {str(e)}") def _format_messages_to_prompt(self, messages: List[ChatCompletionMessage]) -> str: """Convierte la lista de mensajes en un prompt formateado""" prompt_parts = [] for message in messages: if message.role == "system": prompt_parts.append(f"Sistema: {message.content}") elif message.role == "user": prompt_parts.append(f"Usuario: {message.content}") elif message.role == "assistant": prompt_parts.append(f"Asistente: {message.content}") prompt_parts.append("Asistente:") return "\n".join(prompt_parts) # Instancia global de la API llama_api = None # Crear la aplicación FastAPI app = FastAPI( title="Llama.cpp OpenAI Compatible API", description="API compatible con OpenAI para modelos GGUF usando llama.cpp", version="1.0.0" ) # Configurar CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """Endpoint raíz""" return {"message": "Llama.cpp OpenAI Compatible API", "status": "running"} @app.get("/v1/models", response_model=ModelsResponse) async def list_models(): """Lista los modelos disponibles""" return ModelsResponse( data=[ ModelInfo( id="llama", created=int(time.time()), owned_by="llama-cpp" ) ] ) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): """Crea una completion de chat""" if llama_api is None: raise HTTPException(status_code=503, detail="Modelo no cargado") if not request.messages: raise HTTPException(status_code=400, detail="Se requiere al menos un mensaje") try: # Generar respuesta response_text = llama_api.generate_chat_completion( messages=request.messages, max_tokens=request.max_tokens or 2048, temperature=request.temperature or 0.7, top_p=request.top_p or 0.9 ) # Crear respuesta en formato OpenAI completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" response = ChatCompletionResponse( id=completion_id, created=int(time.time()), model=request.model, choices=[ ChatCompletionChoice( index=0, message=ChatCompletionMessage( role="assistant", content=response_text ), finish_reason="stop" ) ], usage=ChatCompletionUsage( prompt_tokens=len(str(request.messages)) // 4, # Estimación aproximada completion_tokens=len(response_text) // 4, # Estimación aproximada total_tokens=(len(str(request.messages)) + len(response_text)) // 4 ) ) return response except Exception as e: raise HTTPException(status_code=500, detail=f"Error procesando la solicitud: {str(e)}") def main(): """Función principal para iniciar el servidor""" import argparse parser = argparse.ArgumentParser(description="Servidor API compatible con OpenAI usando llama.cpp") parser.add_argument("--model-path", "-m", type=str, required=True, help="Ruta al archivo del modelo GGUF") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host para el servidor (default: 0.0.0.0)") parser.add_argument("--port", type=int, default=8000, help="Puerto para el servidor (default: 8000)") parser.add_argument("--n-ctx", type=int, default=4096, help="Tamaño del contexto (default: 4096)") parser.add_argument("--n-batch", type=int, default=512, help="Tamaño del batch (default: 512)") parser.add_argument("--n-threads", type=int, default=None, help="Número de threads (default: auto)") parser.add_argument("--n-gpu-layers", type=int, default=-1, help="Número de capas a cargar en GPU (-1 = todas, 0 = ninguna, default: -1)") parser.add_argument("--main-gpu", type=int, default=0, help="GPU principal a usar (default: 0)") parser.add_argument("--split-mode", type=int, default=1, help="Modo de división entre GPUs (default: 1)") args = parser.parse_args() # Validar que el archivo del modelo existe if not Path(args.model_path).exists(): print(f"Error: El archivo del modelo no existe: {args.model_path}") sys.exit(1) # Configurar parámetros del modelo model_params = { "n_ctx": args.n_ctx, "n_batch": args.n_batch, "n_gpu_layers": args.n_gpu_layers, "main_gpu": args.main_gpu, "split_mode": args.split_mode, } if args.n_threads: model_params["n_threads"] = args.n_threads try: # Inicializar la API de Llama global llama_api llama_api = LlamaAPI(args.model_path, **model_params) print(f"\n🚀 Servidor iniciado en http://{args.host}:{args.port}") print(f"📚 Documentación disponible en http://{args.host}:{args.port}/docs") print(f"🤖 Modelo cargado desde: {args.model_path}") print(f"🎮 GPU layers: {args.n_gpu_layers} (usa -1 para todas las capas)") print(f"🎯 Main GPU: {args.main_gpu}") print("\n💡 Ejemplo de uso con curl:") print(f""" curl -X POST http://{args.host}:{args.port}/v1/chat/completions \\ -H "Content-Type: application/json" \\ -d '{{ "model": "llama", "messages": [ {{"role": "user", "content": "Hola, ¿cómo estás?"}} ], "max_tokens": 100, "temperature": 0.7 }}' """) # Iniciar el servidor uvicorn.run( app, host=args.host, port=args.port, log_level="info" ) except KeyboardInterrupt: print("\n👋 Servidor detenido por el usuario") except Exception as e: print(f"❌ Error al iniciar el servidor: {e}") sys.exit(1) if __name__ == "__main__": main()