Files
fn_registry/functions/ml/genconfig_test.go
T
egutierrez a802f59f55 chore: auto-commit (95 archivos)
- cmd/fn/doctor.go
- cmd/fn/main.go
- cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt
- cpp/apps/primitives_gallery/playground/tables/data_table.cpp
- cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp
- cpp/apps/primitives_gallery/playground/tables/data_table_logic.h
- cpp/apps/primitives_gallery/playground/tables/self_test.cpp
- cpp/apps/primitives_gallery/playground/tables/tql.cpp
- cpp/apps/primitives_gallery/playground/tables/viz.cpp
- cpp/apps/primitives_gallery/playground/tables/viz.h
- ...

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 00:50:34 +02:00

261 lines
6.8 KiB
Go

package ml
import (
"reflect"
"strings"
"testing"
)
// ---------------------------------------------------------------------------
// TestGenconfigToSdcliArgs
// ---------------------------------------------------------------------------
func TestGenconfigToSdcliArgs(t *testing.T) {
clipSkip := 2
t.Run("config basico sin loras ni clip_skip", func(t *testing.T) {
cfg := GenerationConfig{
Prompt: "a cat",
Seed: 42,
Steps: 20,
CfgScale: 7.5,
Sampler: "euler",
Width: 512,
Height: 512,
Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16"},
}
args := GenconfigToSdcliArgs(cfg)
want := []string{
"--prompt", "a cat",
"--seed", "42",
"--steps", "20",
"--cfg-scale", "7.5",
"--width", "512",
"--height", "512",
"--sampling-method", "euler",
}
if !reflect.DeepEqual(args, want) {
t.Errorf("got %v\nwant %v", args, want)
}
})
t.Run("loras se emiten como pares path:weight", func(t *testing.T) {
cfg := GenerationConfig{
Prompt: "portrait",
Seed: 1,
Steps: 10,
CfgScale: 7.0,
Sampler: "euler",
Width: 512,
Height: 512,
Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16", Path: "/models/v1.safetensors"},
Loras: []LoraRef{
{Path: "/loras/detail.safetensors", Weight: 0.8},
{Path: "/loras/style.safetensors", Weight: 0.5},
},
ClipSkip: &clipSkip,
}
args := GenconfigToSdcliArgs(cfg)
// Verificar que existen los pares --lora para ambas loras
loraIdx := indexAll(args, "--lora")
if len(loraIdx) != 2 {
t.Fatalf("esperaba 2 flags --lora, got %d en %v", len(loraIdx), args)
}
wantLoras := []string{
"/loras/detail.safetensors:0.8",
"/loras/style.safetensors:0.5",
}
for i, idx := range loraIdx {
if idx+1 >= len(args) {
t.Fatalf("--lora[%d] sin valor siguiente", i)
}
if args[idx+1] != wantLoras[i] {
t.Errorf("lora[%d]: got %q, want %q", i, args[idx+1], wantLoras[i])
}
}
// Verificar --model y --clip-skip presentes
if !containsPair(args, "--model", "/models/v1.safetensors") {
t.Errorf("--model no encontrado en %v", args)
}
if !containsPair(args, "--clip-skip", "2") {
t.Errorf("--clip-skip no encontrado en %v", args)
}
})
t.Run("sampler dpm++2m se traduce a dpmpp2m", func(t *testing.T) {
cfg := GenerationConfig{
Prompt: "x",
Seed: 0,
Steps: 1,
CfgScale: 1.0,
Sampler: "dpm++2m",
Width: 64,
Height: 64,
Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"},
}
args := GenconfigToSdcliArgs(cfg)
if !containsPair(args, "--sampling-method", "dpmpp2m") {
t.Errorf("sampler no traducido; args=%v", args)
}
})
t.Run("negative_prompt vacio no genera flag", func(t *testing.T) {
cfg := GenerationConfig{
Prompt: "x",
NegativePrompt: "",
Seed: 0,
Steps: 1,
CfgScale: 1.0,
Sampler: "euler",
Width: 64,
Height: 64,
Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"},
}
args := GenconfigToSdcliArgs(cfg)
for _, a := range args {
if a == "--negative-prompt" {
t.Errorf("flag --negative-prompt presente aunque NegativePrompt es vacio")
}
}
})
}
// ---------------------------------------------------------------------------
// TestGenconfigMarshalRoundtrip
// ---------------------------------------------------------------------------
func TestGenconfigMarshalRoundtrip(t *testing.T) {
t.Run("roundtrip marshal unmarshal produce config igual", func(t *testing.T) {
clip := 2
cfg := GenerationConfig{
Prompt: "sunset over the mountains",
NegativePrompt: "blurry, low quality",
Seed: 99,
Steps: 30,
CfgScale: 7.5,
Sampler: "dpm++2m",
Width: 768,
Height: 512,
Model: ModelRef{
Name: "sdxl-base",
ModelType: "sdxl",
Quantization: "fp16",
Path: "/models/sdxl.safetensors",
},
Loras: []LoraRef{
{Path: "/loras/detail.safetensors", Weight: 0.8},
},
ClipSkip: &clip,
}
b, err := GenconfigMarshal(cfg)
if err != nil {
t.Fatalf("GenconfigMarshal: %v", err)
}
got, err := GenconfigUnmarshal(b)
if err != nil {
t.Fatalf("GenconfigUnmarshal: %v", err)
}
if !reflect.DeepEqual(cfg, got) {
t.Errorf("roundtrip diverge\norig: %+v\ngot: %+v", cfg, got)
}
})
}
// ---------------------------------------------------------------------------
// TestGenconfigCrossLanguageJSON
// ---------------------------------------------------------------------------
func TestGenconfigCrossLanguageJSON(t *testing.T) {
// Fixture escrito a mano replicando lo que generaria Python:
// json.dumps(config.model_dump(), indent=2)
// Keys en snake_case, orden de declaracion del dataclass Python.
fixture := `{
"prompt": "a dragon",
"negative_prompt": "ugly",
"seed": 1234,
"steps": 25,
"cfg_scale": 7.0,
"sampler": "euler_a",
"width": 512,
"height": 512,
"model": {
"name": "v1-5",
"model_type": "sd15",
"quantization": "fp16"
},
"loras": [
{
"path": "/loras/dragon.safetensors",
"weight": 0.9
}
]
}`
t.Run("json cross-language snake_case keys se deserializan correctamente", func(t *testing.T) {
cfg, err := GenconfigUnmarshal([]byte(fixture))
if err != nil {
t.Fatalf("GenconfigUnmarshal fixture: %v", err)
}
// Verificar campos clave
if cfg.Prompt != "a dragon" {
t.Errorf("Prompt: got %q", cfg.Prompt)
}
if cfg.NegativePrompt != "ugly" {
t.Errorf("NegativePrompt: got %q", cfg.NegativePrompt)
}
if cfg.CfgScale != 7.0 {
t.Errorf("CfgScale: got %v", cfg.CfgScale)
}
if cfg.Model.ModelType != "sd15" {
t.Errorf("Model.ModelType: got %q", cfg.Model.ModelType)
}
if len(cfg.Loras) != 1 || cfg.Loras[0].Weight != 0.9 {
t.Errorf("Loras: got %+v", cfg.Loras)
}
// Re-marshal y verificar que las keys snake_case siguen presentes
b, err := GenconfigMarshal(cfg)
if err != nil {
t.Fatalf("GenconfigMarshal: %v", err)
}
s := string(b)
for _, key := range []string{"negative_prompt", "cfg_scale", "model_type", "quantization"} {
if !strings.Contains(s, `"`+key+`"`) {
t.Errorf("key %q ausente en JSON re-serializado:\n%s", key, s)
}
}
})
}
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
// indexAll retorna todos los indices de val en slice.
func indexAll(slice []string, val string) []int {
var out []int
for i, s := range slice {
if s == val {
out = append(out, i)
}
}
return out
}
// containsPair verifica que flag seguido de value aparece en slice.
func containsPair(slice []string, flag, value string) bool {
for i := 0; i+1 < len(slice); i++ {
if slice[i] == flag && slice[i+1] == value {
return true
}
}
return false
}