package infra import ( "context" "fmt" "os" "os/exec" "path/filepath" "strings" "time" ) // MlEnvCheck holds the result of a single ML environment probe. type MlEnvCheck struct { Name string `json:"name"` // e.g. "cuda_toolkit", "python_venv" Status string `json:"status"` // "ok" | "missing" | "warning" | "unknown" Version string `json:"version,omitempty"` // version string if detected Detail string `json:"detail,omitempty"` // human-readable extra info } // MlEnvReport is the full ML environment audit result. type MlEnvReport struct { Gpus []GpuInfo `json:"gpus"` Checks []MlEnvCheck `json:"checks"` OverallOK bool `json:"overall_ok"` GeneratedAt int64 `json:"generated_at"` } // AuditMlEnv probes the ML environment rooted at registryRoot. // It checks for NVIDIA drivers, CUDA toolkit, Python venv, key Python // packages and optional tools (sd, llama-cli) and a local vault path. // Returns a non-nil MlEnvReport even when individual checks fail; the // function itself only errors if a fundamental system call cannot be // attempted. func AuditMlEnv(registryRoot string) (MlEnvReport, error) { report := MlEnvReport{ GeneratedAt: time.Now().Unix(), } // --- GPU detection (composes GetGpuInfo) --- gpus, err := GetGpuInfo() if err != nil { // Non-fatal: record absence. gpus = []GpuInfo{} } report.Gpus = gpus checks := []MlEnvCheck{} // --- nvidia-smi --- checks = append(checks, probeCommand("nvidia_smi", "nvidia-smi", []string{"--version"}, 5)) // --- nvcc (CUDA toolkit compiler) --- nvcc := probeNvcc() checks = append(checks, nvcc) // --- Python venv --- venvCheck := probeVenv(registryRoot) checks = append(checks, venvCheck) // Python venv path for subsequent checks. venvPy := filepath.Join(registryRoot, "python", ".venv", "bin", "python3") // --- Python packages --- for _, pkg := range []string{"torch", "diffusers", "transformers", "huggingface_hub", "stable_diffusion_cpp_python"} { checks = append(checks, probePythonPackage(venvPy, pkg)) } // --- sd.cpp CLI --- checks = append(checks, probeCommand("sd_cli", "sd", []string{"--version"}, 5)) // --- llama.cpp CLI --- checks = append(checks, probeCommand("llama_cpp", "llama-cli", []string{"--version"}, 5)) // --- imagegen_vault --- checks = append(checks, probeImagegenVault()) report.Checks = checks // OverallOK: no "missing" checks (warning is tolerated) and at least 1 GPU. overallOK := len(gpus) > 0 for _, c := range checks { if c.Status == "missing" { // stable_diffusion_cpp_python and sd_cli are optional — downgrade to warning-only. if c.Name == "stable_diffusion_cpp_python" || c.Name == "sd_cli" || c.Name == "llama_cpp" { continue } overallOK = false } } report.OverallOK = overallOK return report, nil } // probeCommand checks whether a binary is available in PATH by running it with // the given args and recording any version output. func probeCommand(name, binary string, args []string, timeoutSec int) MlEnvCheck { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() path, err := exec.LookPath(binary) if err != nil { return MlEnvCheck{Name: name, Status: "missing", Detail: fmt.Sprintf("%s not found in PATH", binary)} } out, err := exec.CommandContext(ctx, path, args...).CombinedOutput() version := strings.TrimSpace(string(out)) if len(version) > 120 { version = version[:120] } if err != nil { return MlEnvCheck{Name: name, Status: "warning", Version: version, Detail: fmt.Sprintf("exit error: %v", err)} } return MlEnvCheck{Name: name, Status: "ok", Version: version} } // probeNvcc extracts the CUDA toolkit version from nvcc --version output. func probeNvcc() MlEnvCheck { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() path, err := exec.LookPath("nvcc") if err != nil { return MlEnvCheck{Name: "nvcc", Status: "missing", Detail: "nvcc not found in PATH (CUDA toolkit not installed)"} } out, err := exec.CommandContext(ctx, path, "--version").CombinedOutput() if err != nil { return MlEnvCheck{Name: "nvcc", Status: "warning", Detail: fmt.Sprintf("nvcc --version failed: %v", err)} } // Extract version from line like: "Cuda compilation tools, release 12.4, V12.4.99" version := "" for _, line := range strings.Split(string(out), "\n") { if strings.Contains(line, "release") { parts := strings.Split(line, ",") for _, p := range parts { p = strings.TrimSpace(p) if strings.HasPrefix(p, "release") { version = strings.TrimSpace(strings.TrimPrefix(p, "release")) break } } break } } if version == "" { version = strings.TrimSpace(string(out)) if len(version) > 80 { version = version[:80] } } return MlEnvCheck{Name: "nvcc", Status: "ok", Version: version} } // probeVenv checks that the Python venv exists and is functional. func probeVenv(registryRoot string) MlEnvCheck { py := filepath.Join(registryRoot, "python", ".venv", "bin", "python3") if _, err := os.Stat(py); os.IsNotExist(err) { return MlEnvCheck{Name: "python_venv", Status: "missing", Detail: fmt.Sprintf("not found: %s", py)} } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() out, err := exec.CommandContext(ctx, py, "--version").CombinedOutput() version := strings.TrimSpace(string(out)) if err != nil { return MlEnvCheck{Name: "python_venv", Status: "warning", Version: version, Detail: fmt.Sprintf("python3 --version failed: %v", err)} } return MlEnvCheck{Name: "python_venv", Status: "ok", Version: version} } // probePythonPackage imports a package in the venv Python and extracts __version__. func probePythonPackage(venvPy, pkg string) MlEnvCheck { // Map package name → import name (for packages with different import names). importName := pkg switch pkg { case "stable_diffusion_cpp_python": importName = "stable_diffusion_cpp" case "huggingface_hub": importName = "huggingface_hub" } // Check that the venv python binary exists first. if _, err := os.Stat(venvPy); os.IsNotExist(err) { return MlEnvCheck{Name: pkg, Status: "unknown", Detail: "python_venv not available"} } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() script := fmt.Sprintf("import %s; v = getattr(%s, '__version__', None); print(v or 'unknown')", importName, importName) out, err := exec.CommandContext(ctx, venvPy, "-c", script).CombinedOutput() output := strings.TrimSpace(string(out)) if err != nil { // Module not found → missing; other errors → warning. detail := output if len(detail) > 200 { detail = detail[:200] } if strings.Contains(output, "ModuleNotFoundError") || strings.Contains(output, "No module named") { return MlEnvCheck{Name: pkg, Status: "missing", Detail: fmt.Sprintf("%s not installed", importName)} } return MlEnvCheck{Name: pkg, Status: "warning", Detail: detail} } return MlEnvCheck{Name: pkg, Status: "ok", Version: output} } // probeImagegenVault checks that ~/vaults/imagegen_models exists and lists subdirs. func probeImagegenVault() MlEnvCheck { home, err := os.UserHomeDir() if err != nil { return MlEnvCheck{Name: "imagegen_vault", Status: "unknown", Detail: "cannot determine home directory"} } vaultPath := filepath.Join(home, "vaults", "imagegen_models") entries, err := os.ReadDir(vaultPath) if os.IsNotExist(err) { return MlEnvCheck{Name: "imagegen_vault", Status: "missing", Detail: fmt.Sprintf("vault not found: %s", vaultPath)} } if err != nil { return MlEnvCheck{Name: "imagegen_vault", Status: "warning", Detail: fmt.Sprintf("cannot read vault: %v", err)} } subdirs := []string{} for _, e := range entries { if e.IsDir() { subdirs = append(subdirs, e.Name()) } } detail := fmt.Sprintf("subdirs: %s", strings.Join(subdirs, ", ")) if len(subdirs) == 0 { detail = "vault exists but is empty" } return MlEnvCheck{Name: "imagegen_vault", Status: "ok", Detail: detail} }