e3c8979e8d
- cmd/fn/doctor.go - cmd/fn/main.go - cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt - cpp/apps/primitives_gallery/playground/tables/data_table.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.h - cpp/apps/primitives_gallery/playground/tables/self_test.cpp - cpp/apps/primitives_gallery/playground/tables/tql.cpp - cpp/apps/primitives_gallery/playground/tables/viz.cpp - cpp/apps/primitives_gallery/playground/tables/viz.h - ... Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
261 lines
6.8 KiB
Go
261 lines
6.8 KiB
Go
package ml
|
|
|
|
import (
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestGenconfigToSdcliArgs
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestGenconfigToSdcliArgs(t *testing.T) {
|
|
clipSkip := 2
|
|
|
|
t.Run("config basico sin loras ni clip_skip", func(t *testing.T) {
|
|
cfg := GenerationConfig{
|
|
Prompt: "a cat",
|
|
Seed: 42,
|
|
Steps: 20,
|
|
CfgScale: 7.5,
|
|
Sampler: "euler",
|
|
Width: 512,
|
|
Height: 512,
|
|
Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16"},
|
|
}
|
|
args := GenconfigToSdcliArgs(cfg)
|
|
|
|
want := []string{
|
|
"--prompt", "a cat",
|
|
"--seed", "42",
|
|
"--steps", "20",
|
|
"--cfg-scale", "7.5",
|
|
"--width", "512",
|
|
"--height", "512",
|
|
"--sampling-method", "euler",
|
|
}
|
|
if !reflect.DeepEqual(args, want) {
|
|
t.Errorf("got %v\nwant %v", args, want)
|
|
}
|
|
})
|
|
|
|
t.Run("loras se emiten como pares path:weight", func(t *testing.T) {
|
|
cfg := GenerationConfig{
|
|
Prompt: "portrait",
|
|
Seed: 1,
|
|
Steps: 10,
|
|
CfgScale: 7.0,
|
|
Sampler: "euler",
|
|
Width: 512,
|
|
Height: 512,
|
|
Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16", Path: "/models/v1.safetensors"},
|
|
Loras: []LoraRef{
|
|
{Path: "/loras/detail.safetensors", Weight: 0.8},
|
|
{Path: "/loras/style.safetensors", Weight: 0.5},
|
|
},
|
|
ClipSkip: &clipSkip,
|
|
}
|
|
args := GenconfigToSdcliArgs(cfg)
|
|
|
|
// Verificar que existen los pares --lora para ambas loras
|
|
loraIdx := indexAll(args, "--lora")
|
|
if len(loraIdx) != 2 {
|
|
t.Fatalf("esperaba 2 flags --lora, got %d en %v", len(loraIdx), args)
|
|
}
|
|
wantLoras := []string{
|
|
"/loras/detail.safetensors:0.8",
|
|
"/loras/style.safetensors:0.5",
|
|
}
|
|
for i, idx := range loraIdx {
|
|
if idx+1 >= len(args) {
|
|
t.Fatalf("--lora[%d] sin valor siguiente", i)
|
|
}
|
|
if args[idx+1] != wantLoras[i] {
|
|
t.Errorf("lora[%d]: got %q, want %q", i, args[idx+1], wantLoras[i])
|
|
}
|
|
}
|
|
|
|
// Verificar --model y --clip-skip presentes
|
|
if !containsPair(args, "--model", "/models/v1.safetensors") {
|
|
t.Errorf("--model no encontrado en %v", args)
|
|
}
|
|
if !containsPair(args, "--clip-skip", "2") {
|
|
t.Errorf("--clip-skip no encontrado en %v", args)
|
|
}
|
|
})
|
|
|
|
t.Run("sampler dpm++2m se traduce a dpmpp2m", func(t *testing.T) {
|
|
cfg := GenerationConfig{
|
|
Prompt: "x",
|
|
Seed: 0,
|
|
Steps: 1,
|
|
CfgScale: 1.0,
|
|
Sampler: "dpm++2m",
|
|
Width: 64,
|
|
Height: 64,
|
|
Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"},
|
|
}
|
|
args := GenconfigToSdcliArgs(cfg)
|
|
if !containsPair(args, "--sampling-method", "dpmpp2m") {
|
|
t.Errorf("sampler no traducido; args=%v", args)
|
|
}
|
|
})
|
|
|
|
t.Run("negative_prompt vacio no genera flag", func(t *testing.T) {
|
|
cfg := GenerationConfig{
|
|
Prompt: "x",
|
|
NegativePrompt: "",
|
|
Seed: 0,
|
|
Steps: 1,
|
|
CfgScale: 1.0,
|
|
Sampler: "euler",
|
|
Width: 64,
|
|
Height: 64,
|
|
Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"},
|
|
}
|
|
args := GenconfigToSdcliArgs(cfg)
|
|
for _, a := range args {
|
|
if a == "--negative-prompt" {
|
|
t.Errorf("flag --negative-prompt presente aunque NegativePrompt es vacio")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestGenconfigMarshalRoundtrip
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestGenconfigMarshalRoundtrip(t *testing.T) {
|
|
t.Run("roundtrip marshal unmarshal produce config igual", func(t *testing.T) {
|
|
clip := 2
|
|
cfg := GenerationConfig{
|
|
Prompt: "sunset over the mountains",
|
|
NegativePrompt: "blurry, low quality",
|
|
Seed: 99,
|
|
Steps: 30,
|
|
CfgScale: 7.5,
|
|
Sampler: "dpm++2m",
|
|
Width: 768,
|
|
Height: 512,
|
|
Model: ModelRef{
|
|
Name: "sdxl-base",
|
|
ModelType: "sdxl",
|
|
Quantization: "fp16",
|
|
Path: "/models/sdxl.safetensors",
|
|
},
|
|
Loras: []LoraRef{
|
|
{Path: "/loras/detail.safetensors", Weight: 0.8},
|
|
},
|
|
ClipSkip: &clip,
|
|
}
|
|
|
|
b, err := GenconfigMarshal(cfg)
|
|
if err != nil {
|
|
t.Fatalf("GenconfigMarshal: %v", err)
|
|
}
|
|
|
|
got, err := GenconfigUnmarshal(b)
|
|
if err != nil {
|
|
t.Fatalf("GenconfigUnmarshal: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(cfg, got) {
|
|
t.Errorf("roundtrip diverge\norig: %+v\ngot: %+v", cfg, got)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestGenconfigCrossLanguageJSON
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestGenconfigCrossLanguageJSON(t *testing.T) {
|
|
// Fixture escrito a mano replicando lo que generaria Python:
|
|
// json.dumps(config.model_dump(), indent=2)
|
|
// Keys en snake_case, orden de declaracion del dataclass Python.
|
|
fixture := `{
|
|
"prompt": "a dragon",
|
|
"negative_prompt": "ugly",
|
|
"seed": 1234,
|
|
"steps": 25,
|
|
"cfg_scale": 7.0,
|
|
"sampler": "euler_a",
|
|
"width": 512,
|
|
"height": 512,
|
|
"model": {
|
|
"name": "v1-5",
|
|
"model_type": "sd15",
|
|
"quantization": "fp16"
|
|
},
|
|
"loras": [
|
|
{
|
|
"path": "/loras/dragon.safetensors",
|
|
"weight": 0.9
|
|
}
|
|
]
|
|
}`
|
|
|
|
t.Run("json cross-language snake_case keys se deserializan correctamente", func(t *testing.T) {
|
|
cfg, err := GenconfigUnmarshal([]byte(fixture))
|
|
if err != nil {
|
|
t.Fatalf("GenconfigUnmarshal fixture: %v", err)
|
|
}
|
|
|
|
// Verificar campos clave
|
|
if cfg.Prompt != "a dragon" {
|
|
t.Errorf("Prompt: got %q", cfg.Prompt)
|
|
}
|
|
if cfg.NegativePrompt != "ugly" {
|
|
t.Errorf("NegativePrompt: got %q", cfg.NegativePrompt)
|
|
}
|
|
if cfg.CfgScale != 7.0 {
|
|
t.Errorf("CfgScale: got %v", cfg.CfgScale)
|
|
}
|
|
if cfg.Model.ModelType != "sd15" {
|
|
t.Errorf("Model.ModelType: got %q", cfg.Model.ModelType)
|
|
}
|
|
if len(cfg.Loras) != 1 || cfg.Loras[0].Weight != 0.9 {
|
|
t.Errorf("Loras: got %+v", cfg.Loras)
|
|
}
|
|
|
|
// Re-marshal y verificar que las keys snake_case siguen presentes
|
|
b, err := GenconfigMarshal(cfg)
|
|
if err != nil {
|
|
t.Fatalf("GenconfigMarshal: %v", err)
|
|
}
|
|
s := string(b)
|
|
for _, key := range []string{"negative_prompt", "cfg_scale", "model_type", "quantization"} {
|
|
if !strings.Contains(s, `"`+key+`"`) {
|
|
t.Errorf("key %q ausente en JSON re-serializado:\n%s", key, s)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// indexAll retorna todos los indices de val en slice.
|
|
func indexAll(slice []string, val string) []int {
|
|
var out []int
|
|
for i, s := range slice {
|
|
if s == val {
|
|
out = append(out, i)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
// containsPair verifica que flag seguido de value aparece en slice.
|
|
func containsPair(slice []string, flag, value string) bool {
|
|
for i := 0; i+1 < len(slice); i++ {
|
|
if slice[i] == flag && slice[i+1] == value {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|