package ml import ( "reflect" "strings" "testing" ) // --------------------------------------------------------------------------- // TestGenconfigToSdcliArgs // --------------------------------------------------------------------------- func TestGenconfigToSdcliArgs(t *testing.T) { clipSkip := 2 t.Run("config basico sin loras ni clip_skip", func(t *testing.T) { cfg := GenerationConfig{ Prompt: "a cat", Seed: 42, Steps: 20, CfgScale: 7.5, Sampler: "euler", Width: 512, Height: 512, Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16"}, } args := GenconfigToSdcliArgs(cfg) want := []string{ "--prompt", "a cat", "--seed", "42", "--steps", "20", "--cfg-scale", "7.5", "--width", "512", "--height", "512", "--sampling-method", "euler", } if !reflect.DeepEqual(args, want) { t.Errorf("got %v\nwant %v", args, want) } }) t.Run("loras se emiten como pares path:weight", func(t *testing.T) { cfg := GenerationConfig{ Prompt: "portrait", Seed: 1, Steps: 10, CfgScale: 7.0, Sampler: "euler", Width: 512, Height: 512, Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16", Path: "/models/v1.safetensors"}, Loras: []LoraRef{ {Path: "/loras/detail.safetensors", Weight: 0.8}, {Path: "/loras/style.safetensors", Weight: 0.5}, }, ClipSkip: &clipSkip, } args := GenconfigToSdcliArgs(cfg) // Verificar que existen los pares --lora para ambas loras loraIdx := indexAll(args, "--lora") if len(loraIdx) != 2 { t.Fatalf("esperaba 2 flags --lora, got %d en %v", len(loraIdx), args) } wantLoras := []string{ "/loras/detail.safetensors:0.8", "/loras/style.safetensors:0.5", } for i, idx := range loraIdx { if idx+1 >= len(args) { t.Fatalf("--lora[%d] sin valor siguiente", i) } if args[idx+1] != wantLoras[i] { t.Errorf("lora[%d]: got %q, want %q", i, args[idx+1], wantLoras[i]) } } // Verificar --model y --clip-skip presentes if !containsPair(args, "--model", "/models/v1.safetensors") { t.Errorf("--model no encontrado en %v", args) } if !containsPair(args, "--clip-skip", "2") { t.Errorf("--clip-skip no encontrado en %v", args) } }) t.Run("sampler dpm++2m se traduce a dpmpp2m", func(t *testing.T) { cfg := GenerationConfig{ Prompt: "x", Seed: 0, Steps: 1, CfgScale: 1.0, Sampler: "dpm++2m", Width: 64, Height: 64, Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"}, } args := GenconfigToSdcliArgs(cfg) if !containsPair(args, "--sampling-method", "dpmpp2m") { t.Errorf("sampler no traducido; args=%v", args) } }) t.Run("negative_prompt vacio no genera flag", func(t *testing.T) { cfg := GenerationConfig{ Prompt: "x", NegativePrompt: "", Seed: 0, Steps: 1, CfgScale: 1.0, Sampler: "euler", Width: 64, Height: 64, Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"}, } args := GenconfigToSdcliArgs(cfg) for _, a := range args { if a == "--negative-prompt" { t.Errorf("flag --negative-prompt presente aunque NegativePrompt es vacio") } } }) } // --------------------------------------------------------------------------- // TestGenconfigMarshalRoundtrip // --------------------------------------------------------------------------- func TestGenconfigMarshalRoundtrip(t *testing.T) { t.Run("roundtrip marshal unmarshal produce config igual", func(t *testing.T) { clip := 2 cfg := GenerationConfig{ Prompt: "sunset over the mountains", NegativePrompt: "blurry, low quality", Seed: 99, Steps: 30, CfgScale: 7.5, Sampler: "dpm++2m", Width: 768, Height: 512, Model: ModelRef{ Name: "sdxl-base", ModelType: "sdxl", Quantization: "fp16", Path: "/models/sdxl.safetensors", }, Loras: []LoraRef{ {Path: "/loras/detail.safetensors", Weight: 0.8}, }, ClipSkip: &clip, } b, err := GenconfigMarshal(cfg) if err != nil { t.Fatalf("GenconfigMarshal: %v", err) } got, err := GenconfigUnmarshal(b) if err != nil { t.Fatalf("GenconfigUnmarshal: %v", err) } if !reflect.DeepEqual(cfg, got) { t.Errorf("roundtrip diverge\norig: %+v\ngot: %+v", cfg, got) } }) } // --------------------------------------------------------------------------- // TestGenconfigCrossLanguageJSON // --------------------------------------------------------------------------- func TestGenconfigCrossLanguageJSON(t *testing.T) { // Fixture escrito a mano replicando lo que generaria Python: // json.dumps(config.model_dump(), indent=2) // Keys en snake_case, orden de declaracion del dataclass Python. fixture := `{ "prompt": "a dragon", "negative_prompt": "ugly", "seed": 1234, "steps": 25, "cfg_scale": 7.0, "sampler": "euler_a", "width": 512, "height": 512, "model": { "name": "v1-5", "model_type": "sd15", "quantization": "fp16" }, "loras": [ { "path": "/loras/dragon.safetensors", "weight": 0.9 } ] }` t.Run("json cross-language snake_case keys se deserializan correctamente", func(t *testing.T) { cfg, err := GenconfigUnmarshal([]byte(fixture)) if err != nil { t.Fatalf("GenconfigUnmarshal fixture: %v", err) } // Verificar campos clave if cfg.Prompt != "a dragon" { t.Errorf("Prompt: got %q", cfg.Prompt) } if cfg.NegativePrompt != "ugly" { t.Errorf("NegativePrompt: got %q", cfg.NegativePrompt) } if cfg.CfgScale != 7.0 { t.Errorf("CfgScale: got %v", cfg.CfgScale) } if cfg.Model.ModelType != "sd15" { t.Errorf("Model.ModelType: got %q", cfg.Model.ModelType) } if len(cfg.Loras) != 1 || cfg.Loras[0].Weight != 0.9 { t.Errorf("Loras: got %+v", cfg.Loras) } // Re-marshal y verificar que las keys snake_case siguen presentes b, err := GenconfigMarshal(cfg) if err != nil { t.Fatalf("GenconfigMarshal: %v", err) } s := string(b) for _, key := range []string{"negative_prompt", "cfg_scale", "model_type", "quantization"} { if !strings.Contains(s, `"`+key+`"`) { t.Errorf("key %q ausente en JSON re-serializado:\n%s", key, s) } } }) } // --------------------------------------------------------------------------- // helpers // --------------------------------------------------------------------------- // indexAll retorna todos los indices de val en slice. func indexAll(slice []string, val string) []int { var out []int for i, s := range slice { if s == val { out = append(out, i) } } return out } // containsPair verifica que flag seguido de value aparece en slice. func containsPair(slice []string, flag, value string) bool { for i := 0; i+1 < len(slice); i++ { if slice[i] == flag && slice[i+1] == value { return true } } return false }