package ml import ( "context" "fmt" "os" "strconv" "strings" "time" "fn-registry/functions/core" ) // SdcliProgressCallback es una funcion llamada cada vez que se parsea una linea // de progreso del proceso sd. Puede ser nil. type SdcliProgressCallback func(p SdcliProgress) // SdcliGenerate ejecuta el binario sd para generar una imagen y escribe el // resultado en outPath. // // Flujo: // 1. Construye args con GenconfigToSdcliArgs(cfg) + ["-o", outPath]. // 2. Lanza el proceso via SubprocessStream. // 3. Goroutine interna lee eventos: lineas stderr se pasan a SdcliParseProgress; // si onProgress != nil y hay progreso reconocible, llama onProgress(p). // 4. Espera el resultado. ExitCode != 0 => error con las ultimas 10 lineas de stderr. // 5. Lee outPath y retorna ImageGenResult con bytes, meta y duration_ms. // // ctx controla el timeout y cancelacion: se pasa directamente a SubprocessStream, // que maneja SIGTERM -> grace 2s -> SIGKILL. func SdcliGenerate( ctx context.Context, bin SdcliBinary, cfg GenerationConfig, outPath string, onProgress SdcliProgressCallback, ) (ImageGenResult, error) { start := time.Now() args := GenconfigToSdcliArgs(cfg) args = append(args, "-o", outPath) events, results := core.SubprocessStream(ctx, bin.Path, args, nil, nil) // Consumir eventos en goroutine: parsear progreso y acumular stderr para // mensajes de error utiles. type collectedStderr struct { lines []string } stderrCh := make(chan collectedStderr, 1) go func() { var stderrLines []string for ev := range events { if ev.Stream == "stderr" { stderrLines = append(stderrLines, ev.Line) if onProgress != nil { if p, ok := SdcliParseProgress(ev.Line); ok { onProgress(p) } } } } stderrCh <- collectedStderr{lines: stderrLines} }() res := <-results collected := <-stderrCh if res.Err != nil { return ImageGenResult{}, fmt.Errorf("sdcli subprocess: %w", res.Err) } if res.ExitCode != 0 { tail := stderrTail(collected.lines, 10) return ImageGenResult{}, fmt.Errorf( "sdcli exited with code %d:\n%s", res.ExitCode, tail, ) } imageBytes, err := os.ReadFile(outPath) if err != nil { return ImageGenResult{}, fmt.Errorf("sdcli: reading output image %q: %w", outPath, err) } durationMs := time.Since(start).Milliseconds() meta := map[string]any{ "backend": "sdcli", "binary_path": bin.Path, "model": cfg.Model.Name, "seed": cfg.Seed, "steps": cfg.Steps, "cfg_scale": cfg.CfgScale, "sampler": cfg.Sampler, "width": cfg.Width, "height": cfg.Height, } if bin.Version != "" { meta["version"] = bin.Version } if cfg.Model.Path != "" { meta["model_path"] = cfg.Model.Path } if cfg.Model.ModelType != "" { meta["model_type"] = cfg.Model.ModelType } if cfg.Model.Quantization != "" { meta["quantization"] = cfg.Model.Quantization } if len(cfg.Loras) > 0 { loras := make([]string, len(cfg.Loras)) for i, l := range cfg.Loras { loras[i] = l.Path + ":" + strconv.FormatFloat(l.Weight, 'f', -1, 64) } meta["loras"] = strings.Join(loras, ",") } return ImageGenResult{ ImageBytes: imageBytes, Format: "png", Meta: meta, DurationMs: durationMs, }, nil } // stderrTail retorna las ultimas n lineas de lines, unidas con newline. func stderrTail(lines []string, n int) string { if len(lines) <= n { return strings.Join(lines, "\n") } return strings.Join(lines[len(lines)-n:], "\n") }