From a802f59f55be356d17950bddf2bff2b89055c4a0 Mon Sep 17 00:00:00 2001 From: Egutierrez Date: Wed, 13 May 2026 00:50:34 +0200 Subject: [PATCH] chore: auto-commit (95 archivos) - cmd/fn/doctor.go - cmd/fn/main.go - cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt - cpp/apps/primitives_gallery/playground/tables/data_table.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp - cpp/apps/primitives_gallery/playground/tables/data_table_logic.h - cpp/apps/primitives_gallery/playground/tables/self_test.cpp - cpp/apps/primitives_gallery/playground/tables/tql.cpp - cpp/apps/primitives_gallery/playground/tables/viz.cpp - cpp/apps/primitives_gallery/playground/tables/viz.h - ... Co-Authored-By: Claude Opus 4.7 (1M context) --- bash/functions/infra/cuda_toolkit_check.md | 73 ++ bash/functions/infra/cuda_toolkit_check.sh | 99 ++ .../infra/tests/cuda_toolkit_check_test.sh | 111 ++ bash/functions/pipelines/vault_audit.md | 90 ++ bash/functions/pipelines/vault_audit.sh | 172 +++ cmd/fn/cmd_vault.go | 1059 +++++++++++++++++ cmd/fn/cmd_vault_test.go | 318 +++++ cmd/fn/doctor.go | 95 ++ cmd/fn/main.go | 3 + .../playground/tables/CMakeLists.txt | 5 + .../playground/tables/data_table.cpp | 450 ++++++- .../playground/tables/data_table_logic.cpp | 420 ++++++- .../playground/tables/data_table_logic.h | 307 ++--- .../playground/tables/llm_anthropic.cpp | 295 +++++ .../playground/tables/llm_anthropic.h | 58 + .../playground/tables/self_test.cpp | 779 ++++++++++++ .../playground/tables/tql.cpp | 9 +- .../playground/tables/tql_to_sql.cpp | 862 ++++++++++++++ .../playground/tables/tql_to_sql.h | 41 + .../playground/tables/viz.cpp | 98 +- .../playground/tables/viz.h | 7 +- cpp/functions/core/data_table_types.h | 212 ++++ cpp/functions/gfx/gpu_check.cpp | 96 ++ cpp/functions/gfx/gpu_check.h | 38 + cpp/functions/gfx/gpu_check.md | 86 ++ cpp/types/core/agg_fn.md | 20 + cpp/types/core/aggregation.md | 22 + cpp/types/core/color_rule.md | 21 + cpp/types/core/column_type.md | 28 + cpp/types/core/date_granularity.md | 19 + cpp/types/core/derived_column.md | 26 + cpp/types/core/filter.md | 30 + cpp/types/core/join.md | 25 + cpp/types/core/join_strategy.md | 17 + cpp/types/core/op.md | 36 + cpp/types/core/sort_clause.md | 20 + cpp/types/core/stage.md | 33 + cpp/types/core/stage_output.md | 26 + cpp/types/core/state.md | 40 + cpp/types/core/table_input.md | 33 + cpp/types/viz/view_config.md | 29 + cpp/types/viz/view_mode.md | 29 + cpp/types/viz/viz_panel.md | 21 + dev/issues/0078-tables-joins-mbql.md | 3 +- dev/issues/0079-tables-drill-ext.md | 3 +- dev/issues/0080-tables-llm-api.md | 235 +++- docs/TQL.md | 84 ++ functions/core/subprocess_stream.go | 155 +++ functions/core/subprocess_stream.md | 69 ++ functions/core/subprocess_stream_test.go | 132 ++ functions/infra/audit_ml_env.go | 238 ++++ functions/infra/audit_ml_env.md | 67 ++ functions/infra/audit_ml_env_test.go | 53 + functions/infra/get_gpu_info.go | 60 + functions/infra/get_gpu_info.md | 70 ++ functions/infra/get_gpu_info_test.go | 165 +++ functions/infra/gpu_info.go | 12 + functions/infra/vault_aggregate_index.go | 171 +++ functions/infra/vault_aggregate_index.md | 58 + functions/infra/vault_aggregate_index_test.go | 175 +++ functions/infra/vault_diff.go | 68 ++ functions/infra/vault_diff.md | 49 + functions/infra/vault_diff_test.go | 126 ++ functions/infra/vault_doctor.go | 230 ++++ functions/infra/vault_doctor.md | 66 + functions/infra/vault_doctor_test.go | 211 ++++ functions/infra/vault_file.go | 21 + .../infra/vault_index_migrations/001_init.sql | 49 + functions/infra/vault_index_open.go | 30 + functions/infra/vault_index_open.md | 54 + functions/infra/vault_index_open_test.go | 107 ++ functions/infra/vault_index_write.go | 154 +++ functions/infra/vault_index_write.md | 84 ++ functions/infra/vault_index_write_test.go | 210 ++++ functions/infra/vault_inventory_scan.go | 174 +++ functions/infra/vault_inventory_scan.md | 74 ++ functions/infra/vault_inventory_scan_test.go | 182 +++ functions/infra/vault_layout_ensure.go | 252 ++++ functions/infra/vault_layout_ensure.md | 95 ++ functions/infra/vault_layout_ensure_test.go | 394 ++++++ functions/infra/vault_manifest_read.go | 96 ++ functions/infra/vault_manifest_read.md | 59 + functions/infra/vault_manifest_read_test.go | 113 ++ functions/infra/vault_search.go | 265 +++++ functions/infra/vault_search.md | 61 + functions/infra/vault_search_test.go | 147 +++ functions/ml/genconfig_json_marshal.go | 20 + functions/ml/genconfig_json_marshal.md | 84 ++ functions/ml/genconfig_test.go | 260 ++++ functions/ml/genconfig_to_sdcli_args.go | 59 + functions/ml/genconfig_to_sdcli_args.md | 59 + functions/ml/generation_config.go | 18 + functions/ml/image_gen_result.go | 12 + functions/ml/image_generator.go | 9 + functions/ml/lora_ref.go | 8 + functions/ml/model_ref.go | 10 + functions/ml/sdcli_parse_progress.go | 78 ++ functions/ml/sdcli_parse_progress.md | 50 + functions/ml/sdcli_parse_progress_test.go | 103 ++ .../tests/test_vault_csv_profile.py | 161 +++ .../tests/test_vault_pdf_extract.py | 147 +++ .../datascience/vault_csv_profile.md | 61 + .../datascience/vault_csv_profile.py | 216 ++++ .../datascience/vault_pdf_extract.md | 60 + .../datascience/vault_pdf_extract.py | 121 ++ .../infra/tests/test_vault_dedupe_report.py | 154 +++ .../infra/tests/test_vault_knowledge_parse.py | 153 +++ python/functions/infra/vault_dedupe_report.md | 57 + python/functions/infra/vault_dedupe_report.py | 122 ++ .../functions/infra/vault_knowledge_parse.md | 60 + .../functions/infra/vault_knowledge_parse.py | 142 +++ .../functions/infra/vault_profile_dispatch.md | 54 + .../functions/infra/vault_profile_dispatch.py | 92 ++ python/functions/ml/__init__.py | 1 + python/functions/ml/cuda_available.md | 67 ++ python/functions/ml/cuda_available.py | 42 + python/functions/ml/diffusers_generate.md | 71 ++ python/functions/ml/diffusers_generate.py | 98 ++ python/functions/ml/diffusers_load_lora.md | 49 + python/functions/ml/diffusers_load_lora.py | 55 + .../functions/ml/diffusers_load_pipeline.md | 60 + .../functions/ml/diffusers_load_pipeline.py | 102 ++ .../functions/ml/diffusers_set_scheduler.md | 61 + .../functions/ml/diffusers_set_scheduler.py | 70 ++ python/functions/ml/diffusers_unload.md | 49 + python/functions/ml/diffusers_unload.py | 47 + python/functions/ml/genconfig_load_json.md | 47 + python/functions/ml/genconfig_load_json.py | 77 ++ python/functions/ml/genconfig_save_json.md | 59 + python/functions/ml/genconfig_save_json.py | 58 + .../ml/genconfig_to_diffusers_kwargs.md | 65 + .../ml/genconfig_to_diffusers_kwargs.py | 41 + .../functions/ml/genconfig_to_sdcpp_args.md | 74 ++ .../functions/ml/genconfig_to_sdcpp_args.py | 65 + python/functions/ml/generation_config.py | 111 ++ python/functions/ml/gpu_info.md | 58 + python/functions/ml/gpu_info.py | 73 ++ python/functions/ml/hf_snapshot_download.md | 82 ++ python/functions/ml/hf_snapshot_download.py | 59 + .../ml/image_compare_side_by_side.md | 84 ++ .../ml/image_compare_side_by_side.py | 147 +++ python/functions/ml/image_gen_result.py | 97 ++ python/functions/ml/image_generator.py | 46 + python/functions/ml/image_grid.md | 77 ++ python/functions/ml/image_grid.py | 80 ++ python/functions/ml/image_save_png.md | 69 ++ python/functions/ml/image_save_png.py | 48 + python/functions/ml/lora_ref.py | 47 + python/functions/ml/model_ref.py | 67 ++ python/functions/ml/safetensors_inspect.md | 99 ++ python/functions/ml/safetensors_inspect.py | 100 ++ python/functions/ml/sampler_name.py | 16 + .../functions/ml/tests/test_cuda_available.py | 54 + .../ml/tests/test_diffusers_backend.py | 212 ++++ .../ml/tests/test_genconfig_json_roundtrip.py | 165 +++ .../test_genconfig_to_diffusers_kwargs.py | 113 ++ .../ml/tests/test_genconfig_to_sdcpp_args.py | 150 +++ .../test_generation_config_serialization.py | 131 ++ python/functions/ml/tests/test_gpu_info.py | 48 + .../ml/tests/test_hf_snapshot_download.py | 128 ++ .../tests/test_image_compare_side_by_side.py | 128 ++ .../ml/tests/test_image_gen_result.py | 99 ++ .../ml/tests/test_image_generator_protocol.py | 70 ++ python/functions/ml/tests/test_image_grid.py | 85 ++ .../functions/ml/tests/test_image_save_png.py | 81 ++ .../ml/tests/test_model_ref_lora_ref.py | 136 +++ .../ml/tests/test_safetensors_inspect.py | 160 +++ .../ml/tests/test_torch_device_select.py | 79 ++ python/functions/ml/tests/test_vram_budget.py | 78 ++ python/functions/ml/torch_device_select.md | 67 ++ python/functions/ml/torch_device_select.py | 108 ++ python/functions/ml/vram_budget.md | 73 ++ python/functions/ml/vram_budget.py | 111 ++ python/pyproject.toml | 2 + python/types/ml/generation_config.md | 71 ++ python/types/ml/image_gen_result.md | 46 + python/types/ml/image_generator.md | 57 + python/types/ml/lora_ref.md | 29 + python/types/ml/model_ref.md | 38 + python/types/ml/sampler_name.md | 39 + python/uv.lock | 63 + registry/migrations/012_vault_files.sql | 16 + types/infra/gpu_info.md | 39 + types/infra/vault_file.md | 51 + types/ml/generation_config.md | 53 + types/ml/image_gen_result.md | 39 + types/ml/image_generator.md | 33 + types/ml/lora_ref.md | 39 + types/ml/model_ref.md | 36 + 189 files changed, 18964 insertions(+), 330 deletions(-) create mode 100644 bash/functions/infra/cuda_toolkit_check.md create mode 100644 bash/functions/infra/cuda_toolkit_check.sh create mode 100644 bash/functions/infra/tests/cuda_toolkit_check_test.sh create mode 100644 bash/functions/pipelines/vault_audit.md create mode 100644 bash/functions/pipelines/vault_audit.sh create mode 100644 cmd/fn/cmd_vault.go create mode 100644 cmd/fn/cmd_vault_test.go create mode 100644 cpp/apps/primitives_gallery/playground/tables/llm_anthropic.cpp create mode 100644 cpp/apps/primitives_gallery/playground/tables/llm_anthropic.h create mode 100644 cpp/apps/primitives_gallery/playground/tables/tql_to_sql.cpp create mode 100644 cpp/apps/primitives_gallery/playground/tables/tql_to_sql.h create mode 100644 cpp/functions/core/data_table_types.h create mode 100644 cpp/functions/gfx/gpu_check.cpp create mode 100644 cpp/functions/gfx/gpu_check.h create mode 100644 cpp/functions/gfx/gpu_check.md create mode 100644 cpp/types/core/agg_fn.md create mode 100644 cpp/types/core/aggregation.md create mode 100644 cpp/types/core/color_rule.md create mode 100644 cpp/types/core/column_type.md create mode 100644 cpp/types/core/date_granularity.md create mode 100644 cpp/types/core/derived_column.md create mode 100644 cpp/types/core/filter.md create mode 100644 cpp/types/core/join.md create mode 100644 cpp/types/core/join_strategy.md create mode 100644 cpp/types/core/op.md create mode 100644 cpp/types/core/sort_clause.md create mode 100644 cpp/types/core/stage.md create mode 100644 cpp/types/core/stage_output.md create mode 100644 cpp/types/core/state.md create mode 100644 cpp/types/core/table_input.md create mode 100644 cpp/types/viz/view_config.md create mode 100644 cpp/types/viz/view_mode.md create mode 100644 cpp/types/viz/viz_panel.md create mode 100644 functions/core/subprocess_stream.go create mode 100644 functions/core/subprocess_stream.md create mode 100644 functions/core/subprocess_stream_test.go create mode 100644 functions/infra/audit_ml_env.go create mode 100644 functions/infra/audit_ml_env.md create mode 100644 functions/infra/audit_ml_env_test.go create mode 100644 functions/infra/get_gpu_info.go create mode 100644 functions/infra/get_gpu_info.md create mode 100644 functions/infra/get_gpu_info_test.go create mode 100644 functions/infra/gpu_info.go create mode 100644 functions/infra/vault_aggregate_index.go create mode 100644 functions/infra/vault_aggregate_index.md create mode 100644 functions/infra/vault_aggregate_index_test.go create mode 100644 functions/infra/vault_diff.go create mode 100644 functions/infra/vault_diff.md create mode 100644 functions/infra/vault_diff_test.go create mode 100644 functions/infra/vault_doctor.go create mode 100644 functions/infra/vault_doctor.md create mode 100644 functions/infra/vault_doctor_test.go create mode 100644 functions/infra/vault_file.go create mode 100644 functions/infra/vault_index_migrations/001_init.sql create mode 100644 functions/infra/vault_index_open.go create mode 100644 functions/infra/vault_index_open.md create mode 100644 functions/infra/vault_index_open_test.go create mode 100644 functions/infra/vault_index_write.go create mode 100644 functions/infra/vault_index_write.md create mode 100644 functions/infra/vault_index_write_test.go create mode 100644 functions/infra/vault_inventory_scan.go create mode 100644 functions/infra/vault_inventory_scan.md create mode 100644 functions/infra/vault_inventory_scan_test.go create mode 100644 functions/infra/vault_layout_ensure.go create mode 100644 functions/infra/vault_layout_ensure.md create mode 100644 functions/infra/vault_layout_ensure_test.go create mode 100644 functions/infra/vault_manifest_read.go create mode 100644 functions/infra/vault_manifest_read.md create mode 100644 functions/infra/vault_manifest_read_test.go create mode 100644 functions/infra/vault_search.go create mode 100644 functions/infra/vault_search.md create mode 100644 functions/infra/vault_search_test.go create mode 100644 functions/ml/genconfig_json_marshal.go create mode 100644 functions/ml/genconfig_json_marshal.md create mode 100644 functions/ml/genconfig_test.go create mode 100644 functions/ml/genconfig_to_sdcli_args.go create mode 100644 functions/ml/genconfig_to_sdcli_args.md create mode 100644 functions/ml/generation_config.go create mode 100644 functions/ml/image_gen_result.go create mode 100644 functions/ml/image_generator.go create mode 100644 functions/ml/lora_ref.go create mode 100644 functions/ml/model_ref.go create mode 100644 functions/ml/sdcli_parse_progress.go create mode 100644 functions/ml/sdcli_parse_progress.md create mode 100644 functions/ml/sdcli_parse_progress_test.go create mode 100644 python/functions/datascience/tests/test_vault_csv_profile.py create mode 100644 python/functions/datascience/tests/test_vault_pdf_extract.py create mode 100644 python/functions/datascience/vault_csv_profile.md create mode 100644 python/functions/datascience/vault_csv_profile.py create mode 100644 python/functions/datascience/vault_pdf_extract.md create mode 100644 python/functions/datascience/vault_pdf_extract.py create mode 100644 python/functions/infra/tests/test_vault_dedupe_report.py create mode 100644 python/functions/infra/tests/test_vault_knowledge_parse.py create mode 100644 python/functions/infra/vault_dedupe_report.md create mode 100644 python/functions/infra/vault_dedupe_report.py create mode 100644 python/functions/infra/vault_knowledge_parse.md create mode 100644 python/functions/infra/vault_knowledge_parse.py create mode 100644 python/functions/infra/vault_profile_dispatch.md create mode 100644 python/functions/infra/vault_profile_dispatch.py create mode 100644 python/functions/ml/__init__.py create mode 100644 python/functions/ml/cuda_available.md create mode 100644 python/functions/ml/cuda_available.py create mode 100644 python/functions/ml/diffusers_generate.md create mode 100644 python/functions/ml/diffusers_generate.py create mode 100644 python/functions/ml/diffusers_load_lora.md create mode 100644 python/functions/ml/diffusers_load_lora.py create mode 100644 python/functions/ml/diffusers_load_pipeline.md create mode 100644 python/functions/ml/diffusers_load_pipeline.py create mode 100644 python/functions/ml/diffusers_set_scheduler.md create mode 100644 python/functions/ml/diffusers_set_scheduler.py create mode 100644 python/functions/ml/diffusers_unload.md create mode 100644 python/functions/ml/diffusers_unload.py create mode 100644 python/functions/ml/genconfig_load_json.md create mode 100644 python/functions/ml/genconfig_load_json.py create mode 100644 python/functions/ml/genconfig_save_json.md create mode 100644 python/functions/ml/genconfig_save_json.py create mode 100644 python/functions/ml/genconfig_to_diffusers_kwargs.md create mode 100644 python/functions/ml/genconfig_to_diffusers_kwargs.py create mode 100644 python/functions/ml/genconfig_to_sdcpp_args.md create mode 100644 python/functions/ml/genconfig_to_sdcpp_args.py create mode 100644 python/functions/ml/generation_config.py create mode 100644 python/functions/ml/gpu_info.md create mode 100644 python/functions/ml/gpu_info.py create mode 100644 python/functions/ml/hf_snapshot_download.md create mode 100644 python/functions/ml/hf_snapshot_download.py create mode 100644 python/functions/ml/image_compare_side_by_side.md create mode 100644 python/functions/ml/image_compare_side_by_side.py create mode 100644 python/functions/ml/image_gen_result.py create mode 100644 python/functions/ml/image_generator.py create mode 100644 python/functions/ml/image_grid.md create mode 100644 python/functions/ml/image_grid.py create mode 100644 python/functions/ml/image_save_png.md create mode 100644 python/functions/ml/image_save_png.py create mode 100644 python/functions/ml/lora_ref.py create mode 100644 python/functions/ml/model_ref.py create mode 100644 python/functions/ml/safetensors_inspect.md create mode 100644 python/functions/ml/safetensors_inspect.py create mode 100644 python/functions/ml/sampler_name.py create mode 100644 python/functions/ml/tests/test_cuda_available.py create mode 100644 python/functions/ml/tests/test_diffusers_backend.py create mode 100644 python/functions/ml/tests/test_genconfig_json_roundtrip.py create mode 100644 python/functions/ml/tests/test_genconfig_to_diffusers_kwargs.py create mode 100644 python/functions/ml/tests/test_genconfig_to_sdcpp_args.py create mode 100644 python/functions/ml/tests/test_generation_config_serialization.py create mode 100644 python/functions/ml/tests/test_gpu_info.py create mode 100644 python/functions/ml/tests/test_hf_snapshot_download.py create mode 100644 python/functions/ml/tests/test_image_compare_side_by_side.py create mode 100644 python/functions/ml/tests/test_image_gen_result.py create mode 100644 python/functions/ml/tests/test_image_generator_protocol.py create mode 100644 python/functions/ml/tests/test_image_grid.py create mode 100644 python/functions/ml/tests/test_image_save_png.py create mode 100644 python/functions/ml/tests/test_model_ref_lora_ref.py create mode 100644 python/functions/ml/tests/test_safetensors_inspect.py create mode 100644 python/functions/ml/tests/test_torch_device_select.py create mode 100644 python/functions/ml/tests/test_vram_budget.py create mode 100644 python/functions/ml/torch_device_select.md create mode 100644 python/functions/ml/torch_device_select.py create mode 100644 python/functions/ml/vram_budget.md create mode 100644 python/functions/ml/vram_budget.py create mode 100644 python/types/ml/generation_config.md create mode 100644 python/types/ml/image_gen_result.md create mode 100644 python/types/ml/image_generator.md create mode 100644 python/types/ml/lora_ref.md create mode 100644 python/types/ml/model_ref.md create mode 100644 python/types/ml/sampler_name.md create mode 100644 registry/migrations/012_vault_files.sql create mode 100644 types/infra/gpu_info.md create mode 100644 types/infra/vault_file.md create mode 100644 types/ml/generation_config.md create mode 100644 types/ml/image_gen_result.md create mode 100644 types/ml/image_generator.md create mode 100644 types/ml/lora_ref.md create mode 100644 types/ml/model_ref.md diff --git a/bash/functions/infra/cuda_toolkit_check.md b/bash/functions/infra/cuda_toolkit_check.md new file mode 100644 index 00000000..46d57331 --- /dev/null +++ b/bash/functions/infra/cuda_toolkit_check.md @@ -0,0 +1,73 @@ +--- +name: cuda_toolkit_check +kind: function +lang: bash +domain: infra +version: "1.0.0" +purity: impure +signature: "cuda_toolkit_check() -> void" +description: "Detecta componentes CUDA instalados en el sistema y emite pares key=value a stdout: nvcc (version o missing), nvidia_smi (present/missing), driver_version, cuda_libs (path o missing) y overall (ok|partial|missing). Exit code 0 siempre — funcion informativa, no fatal." +tags: [cuda, nvidia, gpu, hardware, probe, infra, toolkit] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: (ninguno) + desc: "No toma parametros. Lee el estado del sistema via nvcc, nvidia-smi y busqueda en rutas canonicas de CUDA." +output: "Cinco pares key=value en stdout: nvcc, nvidia_smi, driver_version, cuda_libs, overall. overall=ok si los tres componentes principales estan presentes; partial si algunos; missing si ninguno." +tested: false +tests: [] +test_file_path: "" +file_path: "bash/functions/infra/cuda_toolkit_check.sh" +--- + +## Ejemplo + +```bash +source bash/functions/infra/cuda_toolkit_check.sh +cuda_toolkit_check +``` + +Salida en maquina con CUDA completo: +``` +nvcc=12.4 +nvidia_smi=present +driver_version=550.54.15 +cuda_libs=/usr/local/cuda +overall=ok +``` + +Salida en maquina sin CUDA: +``` +nvcc=missing +nvidia_smi=missing +driver_version=missing +cuda_libs=missing +overall=missing +``` + +Invocar directamente: +```bash +bash bash/functions/infra/cuda_toolkit_check.sh +``` + +Parsear desde otro script: +```bash +eval "$(cuda_toolkit_check)" +echo "CUDA overall: $overall" +if [[ "$overall" == "ok" ]]; then + echo "CUDA completo: nvcc=$nvcc driver=$driver_version libs=$cuda_libs" +fi +``` + +## Notas + +- Idempotente: no instala, no modifica nada, solo consulta. +- Exit code 0 siempre — ausencia de CUDA es informacion, no fallo. +- Busca `libcuda.so` en `/usr/local/cuda*`, `/opt/cuda*` y via `ldconfig -p`. +- `driver_version` refleja el driver NVIDIA del kernel, reportado por nvidia-smi. +- `nvcc` reporta la version del compilador CUDA toolkit (puede diferir de la version soportada por el driver). +- Para obtener la version CUDA maxima soportada por el driver, usar `get_gpu_info_go_infra` (campo CudaVersion del struct GpuInfo). diff --git a/bash/functions/infra/cuda_toolkit_check.sh b/bash/functions/infra/cuda_toolkit_check.sh new file mode 100644 index 00000000..9a45a820 --- /dev/null +++ b/bash/functions/infra/cuda_toolkit_check.sh @@ -0,0 +1,99 @@ +#!/usr/bin/env bash +# cuda_toolkit_check — Detecta componentes CUDA instalados en el sistema. +# +# Emite pares key=value a stdout: +# nvcc= +# nvidia_smi= +# driver_version= +# cuda_libs= +# overall= +# +# Exit code 0 siempre (funcion informativa, no fatal). +# Idempotente: se puede invocar multiples veces sin efectos secundarios. + +cuda_toolkit_check() { + local nvcc_ver="missing" + local nvidia_smi_status="missing" + local driver_version="missing" + local cuda_libs_path="missing" + + # --- nvcc --- + if command -v nvcc &>/dev/null; then + # nvcc --version imprime algo como: + # Cuda compilation tools, release 12.4, V12.4.131 + local raw + raw="$(nvcc --version 2>&1)" + # Extraer "12.4" de "release 12.4," + local ver + ver="$(echo "$raw" | grep -oP 'release \K[0-9]+\.[0-9]+')" + nvcc_ver="${ver:-present}" + fi + + # --- nvidia-smi + driver_version --- + if command -v nvidia-smi &>/dev/null; then + nvidia_smi_status="present" + # nvidia-smi --query-gpu=driver_version --format=csv,noheader retorna la version + local drv + drv="$(nvidia-smi --query-gpu=driver_version --format=csv,noheader 2>/dev/null | head -n1 | tr -d ' ')" + if [[ -n "$drv" ]]; then + driver_version="$drv" + fi + fi + + # --- cuda_libs: buscar en rutas canonicas --- + local search_dirs=( + "/usr/local/cuda" + "/usr/local/cuda-"* + "/opt/cuda" + "/opt/cuda-"* + "/usr/lib/x86_64-linux-gnu/libcuda.so"* + "/usr/lib/aarch64-linux-gnu/libcuda.so"* + ) + + for candidate in "${search_dirs[@]}"; do + # shellcheck disable=SC2206 + # Expandir globs: si el candidato no existe el glob no expande + for path in $candidate; do + if [[ -e "$path" ]]; then + # Normalizar: tomar solo el directorio raiz /usr/local/cuda* + local base + base="${path%%/lib*}" + cuda_libs_path="$base" + break 2 + fi + done + done + + # Si no encontramos directorio CUDA pero si libcuda.so en rutas de lib estandar + if [[ "$cuda_libs_path" == "missing" ]]; then + local libcuda + libcuda="$(ldconfig -p 2>/dev/null | grep 'libcuda\.so' | head -n1 | awk '{print $NF}')" + if [[ -n "$libcuda" ]]; then + cuda_libs_path="$(dirname "$libcuda")" + fi + fi + + # --- overall --- + local found_count=0 + [[ "$nvcc_ver" != "missing" ]] && ((found_count++)) + [[ "$nvidia_smi_status" != "missing" ]] && ((found_count++)) + [[ "$cuda_libs_path" != "missing" ]] && ((found_count++)) + + local overall + if [[ $found_count -eq 0 ]]; then overall="missing" + elif [[ $found_count -eq 3 ]]; then overall="ok" + else overall="partial" + fi + + # --- emitir resultados --- + echo "nvcc=${nvcc_ver}" + echo "nvidia_smi=${nvidia_smi_status}" + echo "driver_version=${driver_version}" + echo "cuda_libs=${cuda_libs_path}" + echo "overall=${overall}" +} + +# Ejecutar si se invoca directamente +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + cuda_toolkit_check "$@" +fi diff --git a/bash/functions/infra/tests/cuda_toolkit_check_test.sh b/bash/functions/infra/tests/cuda_toolkit_check_test.sh new file mode 100644 index 00000000..b7e37530 --- /dev/null +++ b/bash/functions/infra/tests/cuda_toolkit_check_test.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash +# Tests para cuda_toolkit_check +# Smoke: verifica que stdout contiene todas las keys requeridas y exit code 0. +set -uo pipefail +# Nota: set -e NO se usa para que los asserts fallen de forma acumulativa +# en lugar de abortar el script al primer fallo. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "$SCRIPT_DIR/../cuda_toolkit_check.sh" + +PASS=0 +FAIL=0 + +assert_eq() { + local test_name="$1" expected="$2" got="$3" + if [[ "$expected" == "$got" ]]; then + echo "PASS: $test_name" + ((PASS++)) || true + else + echo "FAIL: $test_name — expected '$expected', got '$got'" + ((FAIL++)) || true + fi +} + +assert_contains() { + local test_name="$1" needle="$2" haystack="$3" + if echo "$haystack" | grep -qF "$needle"; then + echo "PASS: $test_name" + ((PASS++)) || true + else + echo "FAIL: $test_name — '$needle' not found in output" + ((FAIL++)) || true + fi +} + +assert_matches_pattern() { + local test_name="$1" pattern="$2" value="$3" + if echo "$value" | grep -qE "$pattern"; then + echo "PASS: $test_name" + ((PASS++)) || true + else + echo "FAIL: $test_name — '$value' does not match pattern '$pattern'" + ((FAIL++)) || true + fi +} + +assert_nonempty() { + local test_name="$1" value="$2" + if [[ -n "$value" ]]; then + echo "PASS: $test_name" + ((PASS++)) || true + else + echo "FAIL: $test_name — valor vacio" + ((FAIL++)) || true + fi +} + +# --- Capturar salida --- +OUTPUT="$(cuda_toolkit_check)" +EXIT_CODE=$? + +# --- Test: exit code 0 --- +assert_eq "exit code es 0" "0" "$EXIT_CODE" + +# --- Test: stdout contiene clave nvcc= --- +assert_contains "stdout contiene clave nvcc=" "nvcc=" "$OUTPUT" + +# --- Test: stdout contiene clave nvidia_smi= --- +assert_contains "stdout contiene clave nvidia_smi=" "nvidia_smi=" "$OUTPUT" + +# --- Test: stdout contiene clave driver_version= --- +assert_contains "stdout contiene clave driver_version=" "driver_version=" "$OUTPUT" + +# --- Test: stdout contiene clave cuda_libs= --- +assert_contains "stdout contiene clave cuda_libs=" "cuda_libs=" "$OUTPUT" + +# --- Test: stdout contiene clave overall= --- +assert_contains "stdout contiene clave overall=" "overall=" "$OUTPUT" + +# --- Test: overall tiene valor valido (ok|partial|missing) --- +OVERALL_VAL="$(echo "$OUTPUT" | grep '^overall=' | cut -d= -f2)" +assert_matches_pattern "overall tiene valor valido ok|partial|missing" "^(ok|partial|missing)$" "$OVERALL_VAL" + +# --- Test: nvcc tiene valor no vacio --- +NVCC_VAL="$(echo "$OUTPUT" | grep '^nvcc=' | cut -d= -f2)" +assert_nonempty "nvcc tiene valor no vacio" "$NVCC_VAL" + +# --- Test: nvidia_smi tiene valor valido (present|missing) --- +SMI_VAL="$(echo "$OUTPUT" | grep '^nvidia_smi=' | cut -d= -f2)" +assert_matches_pattern "nvidia_smi tiene valor valido present|missing" "^(present|missing)$" "$SMI_VAL" + +# --- Test: driver_version tiene valor no vacio --- +DRV_VAL="$(echo "$OUTPUT" | grep '^driver_version=' | cut -d= -f2)" +assert_nonempty "driver_version tiene valor no vacio" "$DRV_VAL" + +# --- Test: cuda_libs tiene valor no vacio --- +LIBS_VAL="$(echo "$OUTPUT" | grep '^cuda_libs=' | cut -d= -f2)" +assert_nonempty "cuda_libs tiene valor no vacio" "$LIBS_VAL" + +# --- Test: exactamente 5 lineas en la salida --- +LINE_COUNT="$(echo "$OUTPUT" | wc -l | tr -d ' ')" +assert_eq "salida tiene exactamente 5 lineas" "5" "$LINE_COUNT" + +# --- Test: segunda invocacion idempotente (mismo resultado) --- +OUTPUT2="$(cuda_toolkit_check)" +assert_eq "segunda invocacion produce mismo resultado (idempotente)" "$OUTPUT" "$OUTPUT2" + +# --- Resumen --- +echo "---" +echo "Results: $PASS passed, $FAIL failed" +[[ $FAIL -eq 0 ]] || exit 1 diff --git a/bash/functions/pipelines/vault_audit.md b/bash/functions/pipelines/vault_audit.md new file mode 100644 index 00000000..6acf9405 --- /dev/null +++ b/bash/functions/pipelines/vault_audit.md @@ -0,0 +1,90 @@ +--- +name: vault_audit +kind: pipeline +lang: bash +domain: pipelines +version: "1.0.0" +purity: impure +signature: "vault_audit( | --all) [--skip-profilers] [--dry-run-layout] -> void" +description: "Pipeline completo de auditoria para uno o todos los vaults declarados: layout-ensure, index, profile (csv/pdf/md), dedupe, aggregate y doctor. Produce tabla resumen con estado por vault y codigo de salida 4 si hay warnings." +tags: [vault, audit, pipeline, launcher, infra, bash] +uses_functions: + - vault_layout_ensure_go_infra + - vault_inventory_scan_go_infra + - vault_index_open_go_infra + - vault_index_write_go_infra + - vault_csv_profile_py_datascience + - vault_pdf_extract_py_datascience + - vault_knowledge_parse_py_infra + - vault_dedupe_report_py_infra + - vault_aggregate_index_go_infra + - vault_doctor_go_infra +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: vault_name + desc: "Nombre del vault a auditar (como aparece en registry.db tabla vaults). Usar --all para todos." + - name: --all + desc: "Audita todos los vaults declarados en registry.db. Mutuamente excluyente con vault_name." + - name: --skip-profilers + desc: "Omite el paso de profiling CSV/PDF/MD. Util para auditorias rapidas de inventario." + - name: --dry-run-layout + desc: "Pasa --dry-run a vault layout-ensure: calcula cambios sin tocar el disco." +output: "Tabla de resumen por vault con status ok/warn. Codigo de salida 0=exito, 1=root no localizable, 4=uno o mas vaults con warnings." +tested: false +tests: [] +test_file_path: "" +file_path: "bash/functions/pipelines/vault_audit.sh" +--- + +## Ejemplo + +```bash +# Auditar un vault especifico +FN_REGISTRY_ROOT=/home/lucas/fn_registry \ + bash bash/functions/pipelines/vault_audit.sh turismo_spain + +# Auditar todos los vaults +FN_REGISTRY_ROOT=/home/lucas/fn_registry \ + bash bash/functions/pipelines/vault_audit.sh --all + +# Solo layout + index + aggregate (sin profilers, mas rapido) +bash bash/functions/pipelines/vault_audit.sh turismo_spain --skip-profilers + +# Ver que haria layout-ensure sin tocar disco +bash bash/functions/pipelines/vault_audit.sh turismo_spain --dry-run-layout + +# Equivalente via fn run (desde la raiz del registry) +./fn run vault_audit_bash_pipelines turismo_spain +``` + +## Pasos del pipeline + +1. **layout-ensure** — `fn vault layout-ensure ` asegura `data/{raw,processed,exports}` y `knowledge/{...}`. +2. **index** — `fn vault index ` escanea archivos y persiste en `vault_index.db`. +3. **profile** — `fn vault profile ` llama `vault_profile_dispatch.py` para CSV/PDF/MD. +4. **dedupe** — `fn vault dedupe ` detecta duplicados por sha256 (informacional, no fatal). +5. **aggregate** — `fn vault aggregate` copia todo a `registry.db` tabla `vault_files` (una sola vez al final). +6. **doctor** — `fn vault doctor` muestra estado de salud de todos los vaults. + +## Codigos de salida + +| Codigo | Significado | +|--------|-------------| +| 0 | Todos los vaults procesados sin errores | +| 1 | FN_REGISTRY_ROOT no localizable o fn binary no encontrado | +| 4 | Uno o mas vaults con warnings (layout o index fallaron) | + +## Variables de entorno + +- `FN_REGISTRY_ROOT` — raiz del registry (auto-detectada si no esta seteada). +- `FN_BIN` — path al binario `fn` (default: `$FN_REGISTRY_ROOT/fn`). + +## Notas + +Requiere `sqlite3` en PATH para resolver la lista de vaults con `--all`. +El paso de profile es non-fatal: errores en profilers individuales se reportan como warnings. +El paso de dedupe es siempre informacional (no borra archivos). diff --git a/bash/functions/pipelines/vault_audit.sh b/bash/functions/pipelines/vault_audit.sh new file mode 100644 index 00000000..a7288471 --- /dev/null +++ b/bash/functions/pipelines/vault_audit.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash +# vault_audit — Full audit pipeline for one or all declared vaults. +# Runs: layout-ensure → index → profile → dedupe → aggregate → doctor +# +# Usage: +# vault_audit.sh +# vault_audit.sh --all +# vault_audit.sh --skip-profilers +# vault_audit.sh --dry-run-layout +# vault_audit.sh --all --skip-profilers +set -euo pipefail + +# --- locate FN_REGISTRY_ROOT --- +_find_registry_root() { + local dir + dir="$(pwd)" + while [[ "$dir" != "/" ]]; do + if [[ -f "$dir/registry.db" ]]; then + echo "$dir" + return 0 + fi + dir="$(dirname "$dir")" + done + return 1 +} + +if [[ -n "${FN_REGISTRY_ROOT:-}" && -f "${FN_REGISTRY_ROOT}/registry.db" ]]; then + REGISTRY_ROOT="$FN_REGISTRY_ROOT" +elif REGISTRY_ROOT="$(_find_registry_root 2>/dev/null)"; then + : # found +else + echo "ERROR: Cannot locate registry.db. Set FN_REGISTRY_ROOT or run from registry root." >&2 + exit 1 +fi + +FN_BIN="${FN_BIN:-${REGISTRY_ROOT}/fn}" +if [[ ! -x "$FN_BIN" ]]; then + echo "ERROR: fn binary not found at $FN_BIN. Build with: CGO_ENABLED=1 go build -tags fts5 -o fn ./cmd/fn/" >&2 + exit 1 +fi + +# --- parse args --- +AUDIT_ALL=0 +SKIP_PROFILERS=0 +DRY_RUN_LAYOUT=0 +VAULT_NAMES=() +START_TS=$(date +%s) + +while [[ $# -gt 0 ]]; do + case "$1" in + --all) AUDIT_ALL=1 ;; + --skip-profilers) SKIP_PROFILERS=1 ;; + --dry-run-layout) DRY_RUN_LAYOUT=1 ;; + -*) + echo "ERROR: Unknown flag: $1" >&2 + echo "Usage: vault_audit.sh | --all [--skip-profilers] [--dry-run-layout]" >&2 + exit 1 + ;; + *) + VAULT_NAMES+=("$1") + ;; + esac + shift +done + +if [[ $AUDIT_ALL -eq 0 && ${#VAULT_NAMES[@]} -eq 0 ]]; then + echo "Usage: vault_audit.sh | --all [--skip-profilers] [--dry-run-layout]" >&2 + exit 1 +fi + +# --- resolve vault list --- +if [[ $AUDIT_ALL -eq 1 ]]; then + mapfile -t VAULT_NAMES < <( + sqlite3 "${REGISTRY_ROOT}/registry.db" "SELECT name FROM vaults ORDER BY name;" 2>/dev/null || true + ) + if [[ ${#VAULT_NAMES[@]} -eq 0 ]]; then + echo "No vaults registered in registry.db. Run 'fn index' first." >&2 + exit 1 + fi + echo "Found ${#VAULT_NAMES[@]} vault(s): ${VAULT_NAMES[*]}" +fi + +# --- build fn vault flags --- +LAYOUT_FLAGS=() +if [[ $DRY_RUN_LAYOUT -eq 1 ]]; then + LAYOUT_FLAGS+=(--dry-run) +fi + +# --- per-vault audit --- +PASS_COUNT=0 +FAIL_COUNT=0 +declare -A VAULT_STATUS + +audit_one() { + local name="$1" + local vault_ok=1 + echo "" + echo "=== vault: $name ===" + + # Step 1: layout-ensure + echo " [1/5] layout-ensure" + if ! "$FN_BIN" vault layout-ensure "$name" "${LAYOUT_FLAGS[@]}" 2>&1 | sed 's/^/ /'; then + echo " WARN: layout-ensure failed (non-fatal)" >&2 + vault_ok=0 + fi + + # Step 2: index + echo " [2/5] index" + if ! "$FN_BIN" vault index "$name" 2>&1 | sed 's/^/ /'; then + echo " ERROR: index failed" >&2 + vault_ok=0 + fi + + # Step 3: profile + if [[ $SKIP_PROFILERS -eq 0 ]]; then + echo " [3/5] profile" + if ! "$FN_BIN" vault profile "$name" 2>&1 | sed 's/^/ /'; then + echo " WARN: profile had errors (non-fatal)" >&2 + fi + else + echo " [3/5] profile (skipped)" + fi + + # Step 4: dedupe (informational, non-fatal) + echo " [4/5] dedupe" + "$FN_BIN" vault dedupe "$name" 2>&1 | sed 's/^/ /' || true + + # Step 5 deferred — aggregate runs once at the end + echo " [5/5] aggregate (deferred)" + + if [[ $vault_ok -eq 1 ]]; then + VAULT_STATUS["$name"]="ok" + PASS_COUNT=$((PASS_COUNT + 1)) + else + VAULT_STATUS["$name"]="warn" + FAIL_COUNT=$((FAIL_COUNT + 1)) + fi +} + +for vault_name in "${VAULT_NAMES[@]}"; do + audit_one "$vault_name" +done + +# --- aggregate (once, after all vaults) --- +echo "" +echo "=== aggregate ===" +"$FN_BIN" vault aggregate 2>&1 | sed 's/^/ /' + +# --- doctor (read-only health check) --- +echo "" +echo "=== doctor ===" +"$FN_BIN" vault doctor 2>&1 | sed 's/^/ /' || true + +# --- summary table --- +END_TS=$(date +%s) +ELAPSED=$(( END_TS - START_TS )) + +echo "" +echo "=== summary ===" +printf "%-30s %s\n" "VAULT" "STATUS" +printf "%-30s %s\n" "-----" "------" +for vault_name in "${VAULT_NAMES[@]}"; do + status="${VAULT_STATUS[$vault_name]:-unknown}" + printf "%-30s %s\n" "$vault_name" "$status" +done +echo "" +echo "Done: ${PASS_COUNT} ok, ${FAIL_COUNT} warn (${ELAPSED}s)" + +if [[ $FAIL_COUNT -gt 0 ]]; then + exit 4 +fi +exit 0 diff --git a/cmd/fn/cmd_vault.go b/cmd/fn/cmd_vault.go new file mode 100644 index 00000000..b9cf3cbd --- /dev/null +++ b/cmd/fn/cmd_vault.go @@ -0,0 +1,1059 @@ +package main + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "text/tabwriter" + "time" + + "fn-registry/functions/infra" + "fn-registry/registry" +) + +func cmdVault(args []string) { + if len(args) < 1 { + vaultUsage() + os.Exit(1) + } + + switch args[0] { + case "list": + vaultList() + case "search": + vaultSearch(args[1:]) + case "index": + vaultIndex(args[1:]) + case "info": + vaultInfo(args[1:]) + case "layout-ensure": + vaultLayoutEnsure(args[1:]) + case "profile": + vaultProfile(args[1:]) + case "dedupe": + vaultDedupe(args[1:]) + case "aggregate": + vaultAggregate() + case "doctor": + vaultDoctorCmd() + case "audit": + vaultAudit(args[1:]) + case "help", "-h", "--help": + vaultUsage() + default: + fmt.Fprintf(os.Stderr, "unknown vault subcommand: %s\n", args[0]) + vaultUsage() + os.Exit(1) + } +} + +func vaultUsage() { + fmt.Println(`fn vault — manage data vaults + +Usage: + fn vault list List declared vaults + fn vault search [--limit N] [--vault ] [--json] + Search files in vault(s) + fn vault index Index a vault (scan + write) + fn vault index --all Index all declared vaults + fn vault info Show vault summary and stats + fn vault layout-ensure [--dry-run] + Ensure canonical data/knowledge layout + fn vault profile Profile CSV/PDF/MD files in a vault + fn vault dedupe Report duplicate files in a vault + fn vault aggregate Aggregate all vault indexes into registry.db + fn vault doctor Audit vault health (alias for fn doctor vaults) + fn vault audit Run full audit pipeline on a vault + fn vault audit --all Run audit on all declared vaults`) +} + +// --- list --- + +func vaultList() { + db := openDB() + defer db.Close() + + vaults, err := db.AllVaults() + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + if len(vaults) == 0 { + fmt.Println("No vaults registered. Declare vaults in projects/*/vaults/vault.yaml and run 'fn index'.") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "NAME\tPROJECT\tPATH\tTAGS") + for _, v := range vaults { + tags := strings.Join(v.Tags, ",") + path := v.Path + if path == "" { + path = "-" + } + proj := v.ProjectID + if proj == "" { + proj = "-" + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", v.Name, proj, path, tags) + } + w.Flush() +} + +// --- search --- + +func vaultSearch(args []string) { + var query string + var vaultName string + var limitN int + var jsonOut bool + + i := 0 + for i < len(args) { + switch args[i] { + case "--limit": + i++ + if i >= len(args) { + fmt.Fprintln(os.Stderr, "--limit requires a value") + os.Exit(1) + } + n, err := strconv.Atoi(args[i]) + if err != nil { + fmt.Fprintf(os.Stderr, "--limit: invalid number %q\n", args[i]) + os.Exit(1) + } + limitN = n + case "--vault": + i++ + if i >= len(args) { + fmt.Fprintln(os.Stderr, "--vault requires a value") + os.Exit(1) + } + vaultName = args[i] + case "--json": + jsonOut = true + default: + if query == "" && !strings.HasPrefix(args[i], "--") { + query = args[i] + } + } + i++ + } + + if query == "" { + fmt.Fprintln(os.Stderr, "usage: fn vault search [--limit N] [--vault ] [--json]") + os.Exit(1) + } + if limitN <= 0 { + limitN = 50 + } + + db := openDB() + defer db.Close() + + // Determine which vaults to search. + vaults, err := resolveSearchVaults(db, vaultName) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if len(vaults) == 0 { + fmt.Fprintln(os.Stderr, "no vaults found") + os.Exit(1) + } + + var allHits []infra.VaultSearchHit + for _, v := range vaults { + if v.Path == "" { + continue + } + hits, err := infra.VaultSearch(v.Path, query, limitN) + if err != nil { + fmt.Fprintf(os.Stderr, "warn: vault %s: %v\n", v.Name, err) + continue + } + allHits = append(allHits, hits...) + } + + if jsonOut { + if allHits == nil { + allHits = []infra.VaultSearchHit{} + } + b, err := json.MarshalIndent(allHits, "", " ") + if err != nil { + fmt.Fprintf(os.Stderr, "json error: %v\n", err) + os.Exit(1) + } + fmt.Println(string(b)) + return + } + + if len(allHits) == 0 { + fmt.Println("No results.") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + for _, h := range allHits { + mtime := time.Unix(h.Mtime, 0).UTC().Format("2006-01-02") + sizeStr := formatBytes(h.Size) + snip := truncate(h.Snippet, 50) + fmt.Fprintf(w, "[%s]\t%s\t%s\t%s\t%s\t%s\n", + h.VaultName, h.RelPath, sizeStr, mtime, h.Mime, snip) + } + w.Flush() +} + +// --- index --- + +func vaultIndex(args []string) { + indexAll := false + var name string + + for _, a := range args { + if a == "--all" { + indexAll = true + } else if !strings.HasPrefix(a, "--") && name == "" { + name = a + } + } + + if !indexAll && name == "" { + fmt.Fprintln(os.Stderr, "usage: fn vault index | --all") + os.Exit(1) + } + + db := openDB() + defer db.Close() + + var vaults []registry.Vault + if indexAll { + all, err := db.AllVaults() + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + vaults = all + } else { + v, err := resolveVaultByName(db, name) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + vaults = []registry.Vault{*v} + } + + for _, v := range vaults { + if v.Path == "" { + fmt.Printf("vault %s: path not set, skipping\n", v.Name) + continue + } + if err := runVaultIndex(v); err != nil { + fmt.Fprintf(os.Stderr, "vault %s: %v\n", v.Name, err) + } + } +} + +// runVaultIndex runs the full inventory scan + write cycle for a single vault. +func runVaultIndex(v registry.Vault) error { + fmt.Printf("indexing %s (%s)...\n", v.Name, v.Path) + + files, err := infra.VaultInventoryScan(v.Path, v.ID, v.Name) + if err != nil { + return fmt.Errorf("scan: %w", err) + } + + vaultDB, err := infra.VaultIndexOpen(v.Path) + if err != nil { + return fmt.Errorf("open index: %w", err) + } + defer vaultDB.Close() + + report, err := infra.VaultIndexWrite(vaultDB, files, true) + if err != nil { + return fmt.Errorf("write: %w", err) + } + + fmt.Printf(" indexed %d files, %d inserted, %d updated, %d pruned\n", + len(files), report.Inserted, report.Updated, report.Pruned) + return nil +} + +// --- info --- + +func vaultInfo(args []string) { + if len(args) < 1 { + fmt.Fprintln(os.Stderr, "usage: fn vault info ") + os.Exit(1) + } + name := args[0] + + db := openDB() + defer db.Close() + + v, err := resolveVaultByName(db, name) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if v.Path == "" { + fmt.Fprintf(os.Stderr, "vault %s has no path set\n", name) + os.Exit(1) + } + + vaultDB, err := infra.VaultIndexOpen(v.Path) + if err != nil { + fmt.Fprintf(os.Stderr, "error opening vault index: %v\n", err) + os.Exit(1) + } + defer vaultDB.Close() + + // Summary stats. + var totalFiles int + var totalSize int64 + vaultDB.QueryRow(`SELECT count(*), coalesce(sum(size),0) FROM files`).Scan(&totalFiles, &totalSize) + + var lastIndexedAt int64 + vaultDB.QueryRow(`SELECT coalesce(max(indexed_at), 0) FROM files`).Scan(&lastIndexedAt) + + lastIndexed := "-" + if lastIndexedAt > 0 { + lastIndexed = time.Unix(lastIndexedAt, 0).UTC().Format("2006-01-02 15:04:05") + } + + fmt.Printf("Vault: %s (%s)\n", v.Name, v.Path) + fmt.Printf("Files: %d Total: %s Last indexed: %s\n\n", + totalFiles, formatBytes(totalSize), lastIndexed) + + // By bucket. + bucketRows, err := vaultDB.Query(` + SELECT bucket, sub_bucket, count(*), coalesce(sum(size),0) + FROM files + GROUP BY bucket, sub_bucket + ORDER BY bucket, sub_bucket`) + if err == nil { + defer bucketRows.Close() + fmt.Println("By bucket:") + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + for bucketRows.Next() { + var bucket, sub string + var cnt int + var sz int64 + bucketRows.Scan(&bucket, &sub, &cnt, &sz) + key := bucket + if sub != "" { + key = bucket + "/" + sub + } + fmt.Fprintf(w, " %s\t%d files\t%s\n", key, cnt, formatBytes(sz)) + } + w.Flush() + fmt.Println() + } + + // By mime. + mimeRows, err := vaultDB.Query(` + SELECT mime, count(*), coalesce(sum(size),0) + FROM files + GROUP BY mime + ORDER BY sum(size) DESC`) + if err == nil { + defer mimeRows.Close() + fmt.Println("By mime:") + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + for mimeRows.Next() { + var mime string + var cnt int + var sz int64 + mimeRows.Scan(&mime, &cnt, &sz) + if mime == "" { + mime = "unknown" + } + fmt.Fprintf(w, " %s\t%d files\t%s\n", mime, cnt, formatBytes(sz)) + } + w.Flush() + } +} + +// --- layout-ensure --- + +func vaultLayoutEnsure(args []string) { + dryRun := false + var name string + + for _, a := range args { + switch a { + case "--dry-run": + dryRun = true + default: + if !strings.HasPrefix(a, "--") && name == "" { + name = a + } + } + } + + if name == "" { + fmt.Fprintln(os.Stderr, "usage: fn vault layout-ensure [--dry-run]") + os.Exit(1) + } + + db := openDB() + defer db.Close() + + v, err := resolveVaultByName(db, name) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if v.Path == "" { + fmt.Fprintf(os.Stderr, "vault %s has no path set\n", name) + os.Exit(1) + } + + report, err := infra.VaultLayoutEnsure(v.Path, dryRun) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + prefix := "" + if report.DryRun { + prefix = "[dry-run] " + } + fmt.Printf("%sVault: %s (%s)\n", prefix, name, report.VaultPath) + if len(report.Created) > 0 { + fmt.Printf("%s created: %s\n", prefix, strings.Join(report.Created, ", ")) + } + if len(report.Migrated) > 0 { + fmt.Printf("%s migrated: %s\n", prefix, strings.Join(report.Migrated, "; ")) + } + if len(report.AlreadyOK) > 0 { + fmt.Printf("%s already ok: %s\n", prefix, strings.Join(report.AlreadyOK, ", ")) + } + if len(report.Skipped) > 0 { + fmt.Printf("%s skipped (unrecognized): %s\n", prefix, strings.Join(report.Skipped, ", ")) + } + if len(report.Created) == 0 && len(report.Migrated) == 0 { + fmt.Printf("%s layout already canonical\n", prefix) + } +} + +// --- profile --- + +// profileKind returns "csv", "pdf", "md", or "" for a file based on extension/mime. +func profileKind(ext, mime string) string { + ext = strings.ToLower(strings.TrimPrefix(ext, ".")) + switch ext { + case "csv": + return "csv" + case "pdf": + return "pdf" + case "md", "markdown": + return "md" + } + // Fall back to mime + if strings.Contains(mime, "csv") || strings.Contains(mime, "text/csv") { + return "csv" + } + if strings.Contains(mime, "pdf") { + return "pdf" + } + if strings.Contains(mime, "markdown") { + return "md" + } + return "" +} + +func vaultProfile(args []string) { + if len(args) < 1 || strings.HasPrefix(args[0], "--") { + fmt.Fprintln(os.Stderr, "usage: fn vault profile ") + os.Exit(1) + } + name := args[0] + + db := openDB() + defer db.Close() + + v, err := resolveVaultByName(db, name) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if v.Path == "" { + fmt.Fprintf(os.Stderr, "vault %s has no path set\n", name) + os.Exit(1) + } + + vaultDB, err := infra.VaultIndexOpen(v.Path) + if err != nil { + fmt.Fprintf(os.Stderr, "error opening vault index: %v\n", err) + os.Exit(1) + } + defer vaultDB.Close() + + // List files with their ext and mime from vault_index.db + type fileRow struct { + RelPath string + Ext string + Mime string + } + rows, err := vaultDB.Query(`SELECT rel_path, ext, mime FROM files ORDER BY rel_path`) + if err != nil { + fmt.Fprintf(os.Stderr, "error querying vault index: %v\n", err) + os.Exit(1) + } + var files []fileRow + for rows.Next() { + var f fileRow + if scanErr := rows.Scan(&f.RelPath, &f.Ext, &f.Mime); scanErr == nil { + files = append(files, f) + } + } + rows.Close() + + if len(files) == 0 { + fmt.Printf("vault %s: no files in index (run 'fn vault index %s' first)\n", name, name) + return + } + + // Locate the Python dispatcher + registryRoot := root() + dispatchScript := filepath.Join(registryRoot, "python", "functions", "infra", "vault_profile_dispatch.py") + if _, err := os.Stat(dispatchScript); os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "error: dispatch script not found: %s\n", dispatchScript) + os.Exit(1) + } + + pythonBin := filepath.Join(registryRoot, "python", ".venv", "bin", "python3") + if _, err := os.Stat(pythonBin); os.IsNotExist(err) { + pythonBin = "python3" + } + + pythonPath := filepath.Join(registryRoot, "python", "functions") + + var nCSV, nPDF, nMD, nSkip, nErr int + fmt.Printf("profiling vault: %s (%s)\n", name, v.Path) + + for _, f := range files { + kind := profileKind(f.Ext, f.Mime) + if kind == "" { + nSkip++ + continue + } + + cmd := exec.Command(pythonBin, dispatchScript, + "--vault", v.Path, + "--rel-path", f.RelPath, + "--kind", kind, + ) + cmd.Env = append(os.Environ(), + "PYTHONPATH="+pythonPath, + "FN_REGISTRY_ROOT="+registryRoot, + ) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + if runErr := cmd.Run(); runErr != nil { + fmt.Fprintf(os.Stderr, " warn: %s (%s): %v\n", f.RelPath, kind, strings.TrimSpace(stderr.String())) + nErr++ + continue + } + + switch kind { + case "csv": + nCSV++ + case "pdf": + nPDF++ + case "md": + nMD++ + } + } + + fmt.Printf(" csv: %d pdf: %d md: %d skipped: %d errors: %d\n", + nCSV, nPDF, nMD, nSkip, nErr) +} + +// --- dedupe --- + +func vaultDedupe(args []string) { + if len(args) < 1 || strings.HasPrefix(args[0], "--") { + fmt.Fprintln(os.Stderr, "usage: fn vault dedupe ") + os.Exit(1) + } + name := args[0] + + db := openDB() + defer db.Close() + + v, err := resolveVaultByName(db, name) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if v.Path == "" { + fmt.Fprintf(os.Stderr, "vault %s has no path set\n", name) + os.Exit(1) + } + + vaultDB, err := infra.VaultIndexOpen(v.Path) + if err != nil { + fmt.Fprintf(os.Stderr, "error opening vault index: %v\n", err) + os.Exit(1) + } + defer vaultDB.Close() + + // Find duplicates: groups with the same sha256 hash and size > 0 + rows, err := vaultDB.Query(` + SELECT sha256, count(*) as cnt, sum(size) as total_size, min(size) as file_size, + group_concat(rel_path, '|') as paths + FROM files + WHERE sha256 != '' AND size > 0 + GROUP BY sha256 + HAVING count(*) > 1 + ORDER BY sum(size) DESC + LIMIT 50`) + if err != nil { + fmt.Fprintf(os.Stderr, "error querying duplicates: %v\n", err) + os.Exit(1) + } + defer rows.Close() + + type dupeGroup struct { + Sha256 string + Count int + TotalSize int64 + FileSize int64 + Paths []string + } + var groups []dupeGroup + var totalWasted int64 + for rows.Next() { + var g dupeGroup + var pathsConcat string + if scanErr := rows.Scan(&g.Sha256, &g.Count, &g.TotalSize, &g.FileSize, &pathsConcat); scanErr == nil { + g.Paths = strings.Split(pathsConcat, "|") + wasted := g.FileSize * int64(g.Count-1) + totalWasted += wasted + groups = append(groups, g) + } + } + rows.Close() + + fmt.Printf("Vault: %s — duplicate report\n\n", name) + if len(groups) == 0 { + fmt.Println(" No duplicates found.") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "SHA256\tCOUNT\tSIZE\tWASTED\tPATHS") + for _, g := range groups { + sha := g.Sha256 + if len(sha) > 12 { + sha = sha[:12] + "..." + } + wasted := g.FileSize * int64(g.Count-1) + pathsStr := strings.Join(g.Paths, ", ") + if len(pathsStr) > 60 { + pathsStr = pathsStr[:57] + "..." + } + fmt.Fprintf(w, "%s\t%d\t%s\t%s\t%s\n", + sha, g.Count, formatBytes(g.FileSize), formatBytes(wasted), pathsStr) + } + w.Flush() + fmt.Printf("\nTotal wasted space: %s (%d duplicate groups)\n", + formatBytes(totalWasted), len(groups)) +} + +// --- aggregate --- + +func vaultAggregate() { + registryRoot := root() + fmt.Println("aggregating vault indexes into registry.db...") + report, err := infra.VaultAggregateIndex(registryRoot) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + fmt.Printf(" vaults processed: %d skipped: %d total files: %d\n", + report.VaultsProcessed, report.VaultsSkipped, report.TotalFiles) + if len(report.Errors) > 0 { + for _, e := range report.Errors { + fmt.Fprintf(os.Stderr, " warn: %s\n", e) + } + } +} + +// --- doctor (vault alias) --- + +func vaultDoctorCmd() { + registryRoot := root() + entries, err := infra.VaultDoctor(registryRoot) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if len(entries) == 0 { + fmt.Println("No vaults registered.") + return + } + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "VAULT\tSTATUS\tDISK\tINDEXED\tISSUES") + for _, e := range entries { + issues := "-" + if len(e.Issues) > 0 { + issues = strings.Join(e.Issues, ", ") + } + fmt.Fprintf(w, "%s\t%s\t%d\t%d\t%s\n", + e.VaultName, e.Status, e.DiskFiles, e.IndexedFiles, issues) + } + w.Flush() +} + +// --- audit --- + +func vaultAudit(args []string) { + auditAll := false + skipProfilers := false + dryRunLayout := false + var names []string + + for _, a := range args { + switch a { + case "--all": + auditAll = true + case "--skip-profilers": + skipProfilers = true + case "--dry-run-layout": + dryRunLayout = true + default: + if !strings.HasPrefix(a, "--") { + names = append(names, a) + } + } + } + + if !auditAll && len(names) == 0 { + fmt.Fprintln(os.Stderr, "usage: fn vault audit | --all [--skip-profilers] [--dry-run-layout]") + os.Exit(1) + } + + db := openDB() + defer db.Close() + + var vaults []registry.Vault + if auditAll { + all, err := db.AllVaults() + if err != nil { + fmt.Fprintf(os.Stderr, "error listing vaults: %v\n", err) + os.Exit(1) + } + vaults = all + } else { + for _, n := range names { + v, err := resolveVaultByName(db, n) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + vaults = append(vaults, *v) + } + } + db.Close() + + type auditResult struct { + Name string + Status string + Errors []string + } + var results []auditResult + + for _, v := range vaults { + fmt.Printf("\n=== vault: %s ===\n", v.Name) + if v.Path == "" { + fmt.Printf(" SKIP: no path set\n") + results = append(results, auditResult{Name: v.Name, Status: "skip"}) + continue + } + + var errs []string + + // Step 1: layout-ensure + fmt.Printf(" [1/5] layout-ensure") + if dryRunLayout { + fmt.Printf(" (dry-run)") + } + fmt.Println() + layoutReport, layoutErr := infra.VaultLayoutEnsure(v.Path, dryRunLayout) + if layoutErr != nil { + fmt.Printf(" ERROR: %v\n", layoutErr) + errs = append(errs, "layout-ensure: "+layoutErr.Error()) + } else { + if len(layoutReport.Created) > 0 { + fmt.Printf(" created: %s\n", strings.Join(layoutReport.Created, ", ")) + } + if len(layoutReport.Migrated) > 0 { + fmt.Printf(" migrated: %s\n", strings.Join(layoutReport.Migrated, "; ")) + } + if len(layoutReport.Created) == 0 && len(layoutReport.Migrated) == 0 { + fmt.Printf(" layout ok\n") + } + } + + // Step 2: index + fmt.Println(" [2/5] index") + if indexErr := runVaultIndex(v); indexErr != nil { + fmt.Printf(" ERROR: %v\n", indexErr) + errs = append(errs, "index: "+indexErr.Error()) + } + + // Step 3: profile (optional) + if !skipProfilers { + fmt.Println(" [3/5] profile") + runVaultProfileSubcmd(v) + } else { + fmt.Println(" [3/5] profile (skipped)") + } + + // Step 4: dedupe (informational, non-fatal) + fmt.Println(" [4/5] dedupe") + runVaultDedupeSubcmd(v) + + // Step 5: aggregate is done once after all vaults + fmt.Println(" [5/5] aggregate (deferred to end)") + + status := "ok" + if len(errs) > 0 { + status = "error" + } + results = append(results, auditResult{Name: v.Name, Status: status, Errors: errs}) + } + + // Final aggregate + fmt.Println("\n=== aggregate ===") + vaultAggregate() + + // Summary table + fmt.Println("\n=== summary ===") + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "VAULT\tSTATUS\tERRORS") + for _, r := range results { + errStr := "-" + if len(r.Errors) > 0 { + errStr = strings.Join(r.Errors, "; ") + } + fmt.Fprintf(w, "%s\t%s\t%s\n", r.Name, r.Status, errStr) + } + w.Flush() +} + +// runVaultProfileSubcmd runs the profiling loop for a single vault (used by audit). +func runVaultProfileSubcmd(v registry.Vault) { + registryRoot := root() + + vaultDB, err := infra.VaultIndexOpen(v.Path) + if err != nil { + fmt.Printf(" warn: cannot open vault index: %v\n", err) + return + } + defer vaultDB.Close() + + rows, err := vaultDB.Query(`SELECT rel_path, ext, mime FROM files ORDER BY rel_path`) + if err != nil { + fmt.Printf(" warn: query failed: %v\n", err) + return + } + type fileRow struct { + RelPath string + Ext string + Mime string + } + var files []fileRow + for rows.Next() { + var f fileRow + if scanErr := rows.Scan(&f.RelPath, &f.Ext, &f.Mime); scanErr == nil { + files = append(files, f) + } + } + rows.Close() + + if len(files) == 0 { + fmt.Printf(" no files in index\n") + return + } + + dispatchScript := filepath.Join(registryRoot, "python", "functions", "infra", "vault_profile_dispatch.py") + pythonBin := filepath.Join(registryRoot, "python", ".venv", "bin", "python3") + if _, statErr := os.Stat(pythonBin); os.IsNotExist(statErr) { + pythonBin = "python3" + } + pythonPath := filepath.Join(registryRoot, "python", "functions") + + var nCSV, nPDF, nMD, nSkip, nErr int + for _, f := range files { + kind := profileKind(f.Ext, f.Mime) + if kind == "" { + nSkip++ + continue + } + cmd := exec.Command(pythonBin, dispatchScript, + "--vault", v.Path, + "--rel-path", f.RelPath, + "--kind", kind, + ) + cmd.Env = append(os.Environ(), + "PYTHONPATH="+pythonPath, + "FN_REGISTRY_ROOT="+registryRoot, + ) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if runErr := cmd.Run(); runErr != nil { + nErr++ + } else { + switch kind { + case "csv": + nCSV++ + case "pdf": + nPDF++ + case "md": + nMD++ + } + } + } + fmt.Printf(" csv: %d pdf: %d md: %d skipped: %d errors: %d\n", + nCSV, nPDF, nMD, nSkip, nErr) +} + +// runVaultDedupeSubcmd prints dedupe summary for a single vault (used by audit). +func runVaultDedupeSubcmd(v registry.Vault) { + vaultDB, err := infra.VaultIndexOpen(v.Path) + if err != nil { + fmt.Printf(" warn: cannot open vault index: %v\n", err) + return + } + defer vaultDB.Close() + + var dupeGroups int + var totalWasted int64 + rows, err := vaultDB.Query(` + SELECT count(*) as cnt, min(size) as file_size + FROM files + WHERE sha256 != '' AND size > 0 + GROUP BY sha256 + HAVING count(*) > 1`) + if err != nil { + fmt.Printf(" warn: query failed: %v\n", err) + return + } + for rows.Next() { + var cnt int + var fileSize int64 + if scanErr := rows.Scan(&cnt, &fileSize); scanErr == nil { + dupeGroups++ + totalWasted += fileSize * int64(cnt-1) + } + } + rows.Close() + + if dupeGroups == 0 { + fmt.Printf(" no duplicates\n") + } else { + fmt.Printf(" %d duplicate groups, %s wasted (run 'fn vault dedupe %s' for details)\n", + dupeGroups, formatBytes(totalWasted), v.Name) + } +} + +// suppress unused sql import if needed +var _ = sql.ErrNoRows + +// --- helpers --- + +// resolveVaultByName looks up a vault by name in registry.db. +// Returns an error if not found or if name is ambiguous. +func resolveVaultByName(db *registry.DB, name string) (*registry.Vault, error) { + // Try direct ID first. + if v, err := db.GetVault(name); err == nil { + return v, nil + } + + // Search by name. + vaults, err := db.SearchVaults(name, "") + if err != nil { + return nil, fmt.Errorf("search vaults: %w", err) + } + + // Exact name match. + var exact []registry.Vault + for _, v := range vaults { + if v.Name == name { + exact = append(exact, v) + } + } + if len(exact) == 1 { + return &exact[0], nil + } + if len(exact) > 1 { + ids := make([]string, len(exact)) + for i, v := range exact { + ids[i] = v.ID + } + return nil, fmt.Errorf("ambiguous vault name %q: %s", name, strings.Join(ids, ", ")) + } + + // Partial match fallback. + if len(vaults) == 1 { + return &vaults[0], nil + } + if len(vaults) > 1 { + ids := make([]string, len(vaults)) + for i, v := range vaults { + ids[i] = v.ID + } + return nil, fmt.Errorf("ambiguous vault %q — use full name or ID: %s", name, strings.Join(ids, ", ")) + } + + return nil, fmt.Errorf("vault not found: %q (run 'fn index' to register vaults)", name) +} + +// resolveSearchVaults returns the vault(s) to search. +// If name is non-empty, returns only that vault. Otherwise returns all vaults. +func resolveSearchVaults(db *registry.DB, name string) ([]registry.Vault, error) { + if name != "" { + v, err := resolveVaultByName(db, name) + if err != nil { + return nil, err + } + return []registry.Vault{*v}, nil + } + return db.AllVaults() +} + +// formatBytes formats a byte count as a human-readable string (KB, MB, GB). +func formatBytes(b int64) string { + switch { + case b >= 1<<30: + return fmt.Sprintf("%.1f GB", float64(b)/float64(1<<30)) + case b >= 1<<20: + return fmt.Sprintf("%.1f MB", float64(b)/float64(1<<20)) + case b >= 1<<10: + return fmt.Sprintf("%.1f KB", float64(b)/float64(1<<10)) + default: + return fmt.Sprintf("%d B", b) + } +} + +// resolveVaultPath resolves the actual directory path for a vault, +// following symlinks if needed. Returns the resolved absolute path. +func resolveVaultPath(vaultPath string) string { + resolved, err := filepath.EvalSymlinks(vaultPath) + if err != nil { + return vaultPath + } + return resolved +} diff --git a/cmd/fn/cmd_vault_test.go b/cmd/fn/cmd_vault_test.go new file mode 100644 index 00000000..367e9777 --- /dev/null +++ b/cmd/fn/cmd_vault_test.go @@ -0,0 +1,318 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "fn-registry/functions/infra" + "fn-registry/registry" +) + +// fnBinDir holds the temp directory for the compiled fn binary. +// It is created by TestMain and cleaned up at test end. +var fnBinDir string +var fnBinPath string + +// TestMain compiles the fn binary once before all tests. +func TestMain(m *testing.M) { + var err error + fnBinDir, err = os.MkdirTemp("", "fn-vault-test-*") + if err != nil { + fmt.Fprintf(os.Stderr, "create temp dir: %v\n", err) + os.Exit(1) + } + defer os.RemoveAll(fnBinDir) + + fnBinPath = filepath.Join(fnBinDir, "fn") + // Find registry root by walking up from current directory. + regRoot, err := findRoot() + if err != nil { + fmt.Fprintf(os.Stderr, "find root: %v\n", err) + os.Exit(1) + } + cmd := exec.Command("go", "build", "-tags", "fts5", "-o", fnBinPath, ".") + cmd.Dir = filepath.Join(regRoot, "cmd", "fn") + if out, errB := cmd.CombinedOutput(); errB != nil { + fmt.Fprintf(os.Stderr, "build fn: %v\n%s\n", errB, out) + os.Exit(1) + } + + os.Exit(m.Run()) +} + +func findRoot() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir, nil + } + parent := filepath.Dir(dir) + if parent == dir { + return "", fmt.Errorf("could not find go.mod from %s", dir) + } + dir = parent + } +} + +func ensureFnBin(t *testing.T) string { + t.Helper() + return fnBinPath +} + +// setupTestRegistry creates a minimal registry root with: +// - registry.db (opened + migrations applied via registry.Open) +// - a project with a vault declared in vault.yaml +// - a vault directory with some test files +// - a symlink from projects/test_proj/vaults/test_vault -> vault dir +// +// Returns (repoRoot, vaultDir). +func setupTestRegistry(t *testing.T) (string, string) { + t.Helper() + repoRoot := t.TempDir() + + // Create vault directory with files. + vaultDir := filepath.Join(t.TempDir(), "test_vault") + if err := os.MkdirAll(filepath.Join(vaultDir, "data", "raw"), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(vaultDir, "data", "raw", "report.csv"), + []byte("name,value\nfoo,1"), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(vaultDir, "data", "raw", "notes.md"), + []byte("# Notes\nsome text"), 0644); err != nil { + t.Fatal(err) + } + + // Create project directory structure. + projDir := filepath.Join(repoRoot, "projects", "test_proj") + vaultsDir := filepath.Join(projDir, "vaults") + if err := os.MkdirAll(vaultsDir, 0755); err != nil { + t.Fatal(err) + } + + // Create vault.yaml. + vaultYAML := "vaults:\n - name: test_vault\n description: Test vault for unit tests\n path: " + vaultDir + "\n tags: [test]\n" + if err := os.WriteFile(filepath.Join(vaultsDir, "vault.yaml"), []byte(vaultYAML), 0644); err != nil { + t.Fatal(err) + } + + // Create project.md. + projMD := "---\nname: test_proj\ndescription: Test project\ntags: [test]\n---\n" + if err := os.WriteFile(filepath.Join(projDir, "project.md"), []byte(projMD), 0644); err != nil { + t.Fatal(err) + } + + // Open registry.db (creates schema + runs migrations). + db, err := registry.Open(filepath.Join(repoRoot, "registry.db")) + if err != nil { + t.Fatalf("registry.Open: %v", err) + } + + // Index so the vault is registered in registry.db. + if _, err := registry.Index(db, repoRoot); err != nil { + t.Fatalf("registry.Index: %v", err) + } + db.Close() + + return repoRoot, vaultDir +} + +// runFn runs the fn binary in repoRoot with the given args. +func runFn(t *testing.T, repoRoot string, args ...string) (string, string, int) { + t.Helper() + bin := ensureFnBin(t) + cmd := exec.Command(bin, args...) + cmd.Dir = repoRoot + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + code := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + code = exitErr.ExitCode() + } else { + t.Logf("cmd error: %v", err) + } + } + return stdout.String(), stderr.String(), code +} + +// TestVaultList verifies that 'fn vault list' shows the indexed vault. +func TestVaultList(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + out, stderr, code := runFn(t, repoRoot, "vault", "list") + if code != 0 { + t.Fatalf("fn vault list exit %d\nstderr: %s", code, stderr) + } + if !strings.Contains(out, "test_vault") { + t.Errorf("expected 'test_vault' in output, got:\n%s", out) + } +} + +// TestVaultIndex verifies that 'fn vault index ' runs without error. +func TestVaultIndex(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + out, stderr, code := runFn(t, repoRoot, "vault", "index", "test_vault") + if code != 0 { + t.Fatalf("fn vault index exit %d\nstderr: %s\nstdout: %s", code, stderr, out) + } + if !strings.Contains(out, "indexed") { + t.Errorf("expected 'indexed' in output, got:\n%s", out) + } +} + +// TestVaultSearchJSON verifies that 'fn vault search --json' returns valid JSON array. +func TestVaultSearchJSON(t *testing.T) { + repoRoot, vaultDir := setupTestRegistry(t) + + // First index the vault so there is something to search. + if _, _, code := runFn(t, repoRoot, "vault", "index", "test_vault"); code != 0 { + t.Fatal("fn vault index failed") + } + + // Seed some content into the vault index for the search to find. + db, err := infra.VaultIndexOpen(vaultDir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + // Update content_text for FTS search. + db.Exec(`DELETE FROM files_fts WHERE rel_path = 'data/raw/report.csv'`) + db.Exec(`INSERT INTO files_fts(rel_path, content_text) VALUES ('data/raw/report.csv', 'foo report data')`) + db.Close() + + out, stderr, code := runFn(t, repoRoot, "vault", "search", "report", "--json", "--vault", "test_vault") + if code != 0 { + t.Fatalf("fn vault search exit %d\nstderr: %s", code, stderr) + } + + var result []map[string]interface{} + if err := json.Unmarshal([]byte(out), &result); err != nil { + t.Fatalf("output is not valid JSON: %v\nraw: %s", err, out) + } + // Should be a JSON array (possibly empty if search finds nothing, but must be valid). + t.Logf("search returned %d hits", len(result)) +} + +// TestVaultInfo verifies that 'fn vault info ' outputs vault stats. +func TestVaultInfo(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + + // Index first. + if _, _, code := runFn(t, repoRoot, "vault", "index", "test_vault"); code != 0 { + t.Fatal("fn vault index failed") + } + + out, stderr, code := runFn(t, repoRoot, "vault", "info", "test_vault") + if code != 0 { + t.Fatalf("fn vault info exit %d\nstderr: %s", code, stderr) + } + if !strings.Contains(out, "test_vault") { + t.Errorf("expected vault name in output, got:\n%s", out) + } + if !strings.Contains(out, "Files:") { + t.Errorf("expected 'Files:' in output, got:\n%s", out) + } +} + +// TestFormatBytes verifies the formatBytes helper. +func TestFormatBytes(t *testing.T) { + cases := []struct { + input int64 + expected string + }{ + {500, "500 B"}, + {1024, "1.0 KB"}, + {1536, "1.5 KB"}, + {1048576, "1.0 MB"}, + {1073741824, "1.0 GB"}, + } + for _, tc := range cases { + got := formatBytes(tc.input) + if got != tc.expected { + t.Errorf("formatBytes(%d) = %q, want %q", tc.input, got, tc.expected) + } + } +} + +// TestVaultLayoutEnsure verifies that 'fn vault layout-ensure --dry-run' works. +func TestVaultLayoutEnsure(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + out, stderr, code := runFn(t, repoRoot, "vault", "layout-ensure", "test_vault", "--dry-run") + if code != 0 { + t.Fatalf("fn vault layout-ensure exit %d\nstderr: %s\nstdout: %s", code, stderr, out) + } + if !strings.Contains(out, "test_vault") { + t.Errorf("expected vault name in output, got:\n%s", out) + } +} + +// TestVaultAggregate verifies that 'fn vault aggregate' runs without error on a clean registry. +func TestVaultAggregate(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + + // Index first so there is something to aggregate. + if _, _, code := runFn(t, repoRoot, "vault", "index", "test_vault"); code != 0 { + t.Fatal("fn vault index failed") + } + + _, stderr, code := runFn(t, repoRoot, "vault", "aggregate") + if code != 0 { + t.Fatalf("fn vault aggregate exit %d\nstderr: %s", code, stderr) + } +} + +// TestVaultDoctor verifies that 'fn vault doctor' runs and reports on vaults. +func TestVaultDoctor(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + out, stderr, code := runFn(t, repoRoot, "vault", "doctor") + if code != 0 { + t.Fatalf("fn vault doctor exit %d\nstderr: %s", code, stderr) + } + if !strings.Contains(out, "test_vault") { + t.Errorf("expected 'test_vault' in doctor output, got:\n%s", out) + } +} + +// TestVaultDedupe verifies that 'fn vault dedupe' runs without error after indexing. +func TestVaultDedupe(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + + if _, _, code := runFn(t, repoRoot, "vault", "index", "test_vault"); code != 0 { + t.Fatal("fn vault index failed") + } + + out, stderr, code := runFn(t, repoRoot, "vault", "dedupe", "test_vault") + if code != 0 { + t.Fatalf("fn vault dedupe exit %d\nstderr: %s", code, stderr) + } + // Should say "No duplicates" or show a table — either is fine. + _ = out +} + +// TestVaultAuditDryRun verifies that 'fn vault audit --dry-run-layout --skip-profilers' works. +func TestVaultAuditDryRun(t *testing.T) { + repoRoot, _ := setupTestRegistry(t) + out, stderr, code := runFn(t, repoRoot, "vault", "audit", "test_vault", + "--dry-run-layout", "--skip-profilers") + // Exit 0 = fully ok; exit 4 = warnings (layout issues) — both acceptable here. + if code != 0 && code != 4 { + t.Fatalf("fn vault audit exit %d\nstderr: %s\nstdout: %s", code, stderr, out) + } + if !strings.Contains(out, "summary") { + t.Errorf("expected 'summary' section in audit output, got:\n%s", out) + } +} + +// Suppress unused import for time. +var _ = time.Now diff --git a/cmd/fn/doctor.go b/cmd/fn/doctor.go index 6e1cce03..9734b9f8 100644 --- a/cmd/fn/doctor.go +++ b/cmd/fn/doctor.go @@ -44,6 +44,10 @@ func cmdDoctor(args []string) { doctorUnused(r, jsonOut) case "cpp-apps": doctorCppApps(r, jsonOut) + case "ml": + doctorML(r, jsonOut) + case "vaults": + doctorVaults(r, jsonOut) default: fmt.Fprintf(os.Stderr, "unknown doctor subcommand: %s\n", sub) doctorUsage() @@ -65,6 +69,8 @@ Subcommands: uses-functions Audit imports reales vs uses_functions del app.md unused Funciones del registry sin consumidores cpp-apps Conformidad de apps C++ con cpp/PATTERNS.md (cfg.about, dockspace, menubar) + ml Entorno ML: GPUs NVIDIA, CUDA toolkit, venv Python, paquetes torch/diffusers, CLIs y vault + vaults Salud de vaults: directorio, layout, índice, staleness, drift Flags: --json Salida JSON (para scripting/agentes)`) @@ -103,6 +109,16 @@ func doctorAll(root string, jsonOut bool) { } else { all["cpp_apps_error"] = err.Error() } + if v, err := infra.AuditMlEnv(root); err == nil { + all["ml"] = v + } else { + all["ml_error"] = err.Error() + } + if v, err := infra.VaultDoctor(root); err == nil { + all["vaults"] = v + } else { + all["vaults_error"] = err.Error() + } emit(all) return } @@ -119,6 +135,10 @@ func doctorAll(root string, jsonOut bool) { doctorUnused(root, false) fmt.Println("\n=== C++ apps standard conformance ===") doctorCppApps(root, false) + fmt.Println("\n=== ML environment ===") + doctorML(root, false) + fmt.Println("\n=== Vaults ===") + doctorVaults(root, false) } func doctorCppApps(root string, jsonOut bool) { @@ -280,6 +300,81 @@ func doctorUnused(root string, jsonOut bool) { fmt.Printf("\n%d unused functions (candidates to remove).\n", len(unused)) } +func doctorVaults(root string, jsonOut bool) { + entries, err := infra.VaultDoctor(root) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if jsonOut { + emit(entries) + return + } + if len(entries) == 0 { + fmt.Println("No vaults declared (no projects/*/vaults/vault.yaml found).") + return + } + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "NAME\tSTATUS\tFILES\tINDEXED\tISSUES") + ok := 0 + for _, e := range entries { + issues := "-" + if len(e.Issues) > 0 { + issues = strings.Join(e.Issues, "; ") + } + fmt.Fprintf(w, "%s\t%s\t%d\t%d\t%s\n", + e.VaultName, e.Status, e.DiskFiles, e.IndexedFiles, issues) + if e.Status == "ok" { + ok++ + } + } + w.Flush() + fmt.Printf("\n%d/%d vaults healthy.\n", ok, len(entries)) +} + +func doctorML(root string, jsonOut bool) { + report, err := infra.AuditMlEnv(root) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if jsonOut { + emit(report) + return + } + + fmt.Printf("GPUs detected: %d\n", len(report.Gpus)) + for _, g := range report.Gpus { + fmt.Printf(" [%d] %s VRAM: %d/%d MiB Driver: %s CUDA: %s\n", + g.Index, g.Name, g.VramFreeMb, g.VramTotalMb, g.DriverVersion, g.CudaVersion) + } + fmt.Println() + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "CHECK\tSTATUS\tVERSION\tDETAIL") + for _, c := range report.Checks { + version := c.Version + if version == "" { + version = "-" + } + detail := c.Detail + if len(detail) > 60 { + detail = detail[:60] + "..." + } + if detail == "" { + detail = "-" + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", c.Name, c.Status, version, detail) + } + w.Flush() + + overall := "OK" + if !report.OverallOK { + overall = "INCOMPLETE" + } + fmt.Printf("\nOverall ML environment: %s\n", overall) +} + func emit(v any) { b, err := json.MarshalIndent(v, "", " ") if err != nil { diff --git a/cmd/fn/main.go b/cmd/fn/main.go index 8e8b82ba..b72ee19f 100644 --- a/cmd/fn/main.go +++ b/cmd/fn/main.go @@ -45,6 +45,8 @@ func main() { cmdAnalysis(os.Args[2:]) case "sync": cmdSync(os.Args[2:]) + case "vault": + cmdVault(os.Args[2:]) case "doctor": cmdDoctor(os.Args[2:]) case "help", "-h", "--help": @@ -73,6 +75,7 @@ Usage: fn app Gestiona apps externas (Gitea) fn analysis Gestiona analyses externas (Gitea) fn sync [status|locations] Sincroniza con servidor central + fn vault Gestiona y busca en data vaults fn doctor [artefacts|services|sync|uses-functions|unused] [--json] Diagnostico read-only del registry`) } diff --git a/cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt b/cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt index 903dd223..a0091e3d 100644 --- a/cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt +++ b/cpp/apps/primitives_gallery/playground/tables/CMakeLists.txt @@ -3,8 +3,10 @@ add_imgui_app(tables_playground main.cpp data_table.cpp data_table_logic.cpp + llm_anthropic.cpp lua_engine.cpp tql.cpp + tql_to_sql.cpp viz.cpp ) target_link_libraries(tables_playground PRIVATE lua54 implot) @@ -13,10 +15,13 @@ target_link_libraries(tables_playground PRIVATE lua54 implot) add_executable(tables_playground_self_test self_test.cpp data_table_logic.cpp + llm_anthropic.cpp lua_engine.cpp tql.cpp + tql_to_sql.cpp ) target_include_directories(tables_playground_self_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/functions ) target_link_libraries(tables_playground_self_test PRIVATE lua54) diff --git a/cpp/apps/primitives_gallery/playground/tables/data_table.cpp b/cpp/apps/primitives_gallery/playground/tables/data_table.cpp index 14879d7e..56dbde91 100644 --- a/cpp/apps/primitives_gallery/playground/tables/data_table.cpp +++ b/cpp/apps/primitives_gallery/playground/tables/data_table.cpp @@ -1,20 +1,33 @@ #include "data_table.h" #include "app_base.h" #include "imgui.h" +#include "llm_anthropic.h" #include "lua_engine.h" #include "tql.h" +#include "tql_to_sql.h" #include "viz.h" #include #include #include #include +#include #include #include #include namespace data_table { +// UTC date today as ISO YYYY-MM-DD. Para preset filtros Last7/30/90d. +static std::string today_iso() { + std::time_t t = std::time(nullptr); + std::tm tm = *std::gmtime(&t); + char buf[16]; + std::snprintf(buf, sizeof(buf), "%04d-%02d-%02d", + tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday); + return buf; +} + namespace { // --------------------------------------------------------------------------- @@ -122,10 +135,106 @@ struct UiState { // Toggle Table <-> View: remember last non-table display. ViewMode last_non_table_main = ViewMode::Bar; + + // Drill history (fase 10). Stacks per-app; no persistido en TQL. + std::vector drill_back; + std::vector drill_forward; + + // Row inspector (fase 10). -1 cerrado, sino row idx en el output del stage activo. + int inspect_row = -1; + bool inspect_open = false; + + // Ask AI modal (fase 11 — issue 0080). + bool ask_open = false; + bool ask_busy = false; + int ask_mode = 0; // 0 = TQL, 1 = SQL + char ask_question[2048] = {0}; + std::string ask_current_tql; // emit del state actual al abrir modal + std::string ask_response_raw; // texto del modelo + std::string ask_response_code; // bloque extraido (Lua o SQL) + std::string ask_error; + std::string ask_status; // "Sent. Waiting..." / "OK" / error + char ask_edit_buf[8192] = {0}; // buffer editable de propuesta }; UiState& ui() { static UiState s; return s; } +// Row inspector modal (fase 10). Muestra todas cols + valores de la fila +// inspect_row del output del stage activo. Read-only + Copy TSV + Filter +// by this row (anade filters al stage previo si existe). +static void draw_row_inspector_modal(State& st, int active, + const char* const* cells, int rows, int cols, + const std::vector& headers, + const std::vector& types, + const std::vector& prev_input_headers) { + auto& U = ui(); + if (!U.inspect_open) return; + if (U.inspect_row < 0 || U.inspect_row >= rows) { + U.inspect_open = false; + return; + } + ImGui::OpenPopup("##row_inspector"); + ImGui::SetNextWindowSize(ImVec2(560, 400), ImGuiCond_Appearing); + if (ImGui::BeginPopupModal("##row_inspector", &U.inspect_open, + ImGuiWindowFlags_NoSavedSettings)) { + ImGui::Text("Row %d", U.inspect_row); + ImGui::SameLine(0, 20); + if (ImGui::SmallButton("Copy TSV")) { + std::string tsv = row_to_tsv(cells, rows, cols, U.inspect_row, headers); + ImGui::SetClipboardText(tsv.c_str()); + } + ImGui::SameLine(); + bool can_filter = (active > 0 && !prev_input_headers.empty()); + ImGui::BeginDisabled(!can_filter); + if (ImGui::SmallButton("Filter prev stage by this row")) { + int target = active - 1; + for (int c = 0; c < cols; ++c) { + const char* v = cells[U.inspect_row * cols + c]; + if (!v || !*v) continue; + const std::string& h = headers[c]; + std::string h_clean; + parse_breakout_granularity(h, h_clean); + int ci = -1; + for (size_t i = 0; i < prev_input_headers.size(); ++i) { + if (prev_input_headers[i] == h_clean) { ci = (int)i; break; } + } + if (ci < 0) continue; + DrillStep step; + step.target_stage = target; + step.filter_pos = (int)st.stages[target].filters.size(); + step.prev_active_stage = st.active_stage; + step.added = make_drill_filter(ci, v); + if (apply_drill_step(st, step)) { + U.drill_back.push_back(step); + } + } + U.drill_forward.clear(); + U.inspect_open = false; + } + ImGui::EndDisabled(); + ImGui::Separator(); + ImGuiTableFlags flags = ImGuiTableFlags_Borders | ImGuiTableFlags_RowBg + | ImGuiTableFlags_ScrollY | ImGuiTableFlags_Resizable; + if (ImGui::BeginTable("##inspector_tbl", 2, flags, ImVec2(-1, -1))) { + ImGui::TableSetupColumn("col"); + ImGui::TableSetupColumn("value"); + ImGui::TableHeadersRow(); + for (int c = 0; c < cols; ++c) { + ImGui::TableNextRow(); + ImGui::TableSetColumnIndex(0); + ColumnType t = (c < (int)types.size()) ? types[c] : ColumnType::String; + ImGui::Text("%s %s", column_type_icon(t), + (c < (int)headers.size()) ? headers[c].c_str() : "?"); + ImGui::TableSetColumnIndex(1); + const char* v = cells[U.inspect_row * cols + c]; + ImGui::TextWrapped("%s", v ? v : ""); + } + ImGui::EndTable(); + } + ImGui::EndPopup(); + } +} + int autocomplete_cb(ImGuiInputTextCallbackData* data) { UiState* U = (UiState*)data->UserData; if (data->EventFlag == ImGuiInputTextFlags_CallbackAlways) { @@ -180,6 +289,47 @@ void ensure_init(State& st, int eff_cols) { // --------------------------------------------------------------------------- void draw_stage_breadcrumb(State& st) { st.ensure_stage0(); + + // Drill history back/forward (fase 10). Botones al inicio. + auto& U = ui(); + { + bool can_back = !U.drill_back.empty(); + ImGui::BeginDisabled(!can_back); + if (ImGui::SmallButton("<##drill_back")) { + DrillStep s = U.drill_back.back(); + U.drill_back.pop_back(); + if (undo_drill_step(st, s)) { + U.drill_forward.push_back(s); + } + } + ImGui::EndDisabled(); + if (can_back && ImGui::IsItemHovered()) + ImGui::SetTooltip("Drill back (%zu)", U.drill_back.size()); + ImGui::SameLine(); + bool can_fwd = !U.drill_forward.empty(); + ImGui::BeginDisabled(!can_fwd); + if (ImGui::SmallButton(">##drill_fwd")) { + DrillStep s = U.drill_forward.back(); + U.drill_forward.pop_back(); + if (apply_drill_step(st, s)) { + U.drill_back.push_back(s); + } + } + ImGui::EndDisabled(); + if (can_fwd && ImGui::IsItemHovered()) + ImGui::SetTooltip("Drill forward (%zu)", U.drill_forward.size()); + ImGui::SameLine(); + bool can_up = (st.active_stage > 0); + ImGui::BeginDisabled(!can_up); + if (ImGui::SmallButton("^##drill_up")) drill_up(st); + ImGui::EndDisabled(); + if (can_up && ImGui::IsItemHovered()) + ImGui::SetTooltip("Drill up (stage previo, sin perder filters)"); + ImGui::SameLine(); + ImGui::TextDisabled("|"); + ImGui::SameLine(); + } + for (int si = 0; si < (int)st.stages.size(); ++si) { if (si > 0) { ImGui::SameLine(); ImGui::TextDisabled(">"); ImGui::SameLine(); } @@ -610,6 +760,19 @@ void draw_viz_selector(State& st) { ImGui::OpenPopup("##viz_cfg_popup"); } ImGui::SameLine(); + if (ImGui::SmallButton("Ask AI##ask_open")) { + auto& U2 = ui(); + U2.ask_open = true; + U2.ask_busy = false; + U2.ask_error.clear(); + U2.ask_status.clear(); + U2.ask_response_code.clear(); + U2.ask_response_raw.clear(); + U2.ask_current_tql = tql::emit(st, + std::vector(), // emit headers stage 0 (caller fill si necesario) + std::vector()); + } + ImGui::SameLine(); if (ImGui::SmallButton("+ Viz##viz_add")) { VizPanel p; p.display = ViewMode::Bar; @@ -737,7 +900,8 @@ void draw_joins_chips(State& st, const std::vector& joinables, // Filter chips para el stage activo. eff_headers/eff_cols son del INPUT del // stage activo (= orig+derived para stage 0; output del stage previo para 1+). // --------------------------------------------------------------------------- -void draw_filter_chips(Stage& stg, const char* const* eff_headers, int eff_cols) { +void draw_filter_chips(Stage& stg, const char* const* eff_headers, int eff_cols, + const std::vector& eff_types) { auto& U = ui(); ImGui::PushStyleColor(ImGuiCol_Button, IM_COL32(120, 60, 170, 220)); ImGui::PushStyleColor(ImGuiCol_ButtonHovered, IM_COL32(150, 85, 200, 240)); @@ -746,6 +910,50 @@ void draw_filter_chips(Stage& stg, const char* const* eff_headers, int eff_cols) ImGui::PopStyleColor(3); ImGui::SameLine(); + // Presets (fase 10): menu con Last7/30/90d (cols Date), ExcludeNulls (any), + // NonZero (cols numericas). Apply append a stg.filters via build_preset_filters. + if (ImGui::SmallButton("Presets##fpresets")) ImGui::OpenPopup("##presets_menu"); + if (ImGui::BeginPopup("##presets_menu")) { + int first_date = -1, first_num = -1; + for (int c = 0; c < eff_cols && c < (int)eff_types.size(); ++c) { + if (first_date < 0 && eff_types[c] == ColumnType::Date) first_date = c; + if (first_num < 0 && (eff_types[c] == ColumnType::Int || + eff_types[c] == ColumnType::Float)) first_num = c; + } + auto apply_preset = [&](FilterPreset p, int col) { + auto fs = build_preset_filters(p, col, today_iso()); + for (auto& f : fs) stg.filters.push_back(f); + }; + if (first_date >= 0) { + char l1[96], l2[96], l3[96]; + std::snprintf(l1, sizeof(l1), "Last 7 days on \"%s\"", eff_headers[first_date]); + std::snprintf(l2, sizeof(l2), "Last 30 days on \"%s\"", eff_headers[first_date]); + std::snprintf(l3, sizeof(l3), "Last 90 days on \"%s\"", eff_headers[first_date]); + if (ImGui::MenuItem(l1)) apply_preset(FilterPreset::Last7d, first_date); + if (ImGui::MenuItem(l2)) apply_preset(FilterPreset::Last30d, first_date); + if (ImGui::MenuItem(l3)) apply_preset(FilterPreset::Last90d, first_date); + ImGui::Separator(); + } + if (ImGui::BeginMenu("Exclude nulls in...")) { + for (int c = 0; c < eff_cols; ++c) { + if (ImGui::MenuItem(eff_headers[c])) apply_preset(FilterPreset::ExcludeNulls, c); + } + ImGui::EndMenu(); + } + if (first_num >= 0) { + if (ImGui::BeginMenu("Non-zero in...")) { + for (int c = 0; c < eff_cols && c < (int)eff_types.size(); ++c) { + if (eff_types[c] == ColumnType::Int || eff_types[c] == ColumnType::Float) { + if (ImGui::MenuItem(eff_headers[c])) apply_preset(FilterPreset::NonZero, c); + } + } + ImGui::EndMenu(); + } + } + ImGui::EndPopup(); + } + ImGui::SameLine(); + if (stg.filters.empty()) { ImGui::TextDisabled("Sin filtros."); return; @@ -778,7 +986,8 @@ void draw_filter_chips(Stage& stg, const char* const* eff_headers, int eff_cols) } // Chips de breakout (stage > 0). -void draw_breakout_chips(Stage& stg, const char* const* in_headers, int in_cols) { +void draw_breakout_chips(Stage& stg, const char* const* in_headers, int in_cols, + const std::vector& in_types) { auto& U = ui(); ImGui::PushStyleColor(ImGuiCol_Button, IM_COL32( 60, 160, 170, 220)); ImGui::PushStyleColor(ImGuiCol_ButtonHovered, IM_COL32( 80, 190, 200, 240)); @@ -792,6 +1001,17 @@ void draw_breakout_chips(Stage& stg, const char* const* in_headers, int in_cols) return; } for (size_t i = 0; i < stg.breakouts.size(); ) { + std::string col_name; + DateGranularity g = parse_breakout_granularity(stg.breakouts[i], col_name); + + // Resolve col index para lookup de tipo. + int col_idx = -1; + for (int c = 0; c < in_cols; ++c) { + if (std::strcmp(in_headers[c], col_name.c_str()) == 0) { col_idx = c; break; } + } + bool is_date_col = (col_idx >= 0 && col_idx < (int)in_types.size() + && in_types[col_idx] == ColumnType::Date); + char buf[256]; std::snprintf(buf, sizeof(buf), "%s x##bk%zu", stg.breakouts[i].c_str(), i); ImGui::PushStyleColor(ImGuiCol_Button, IM_COL32( 60, 160, 170, 220)); @@ -802,20 +1022,42 @@ void draw_breakout_chips(Stage& stg, const char* const* in_headers, int in_cols) if (ImGui::IsItemClicked(ImGuiMouseButton_Right)) { U.edit_chip_kind = 2; U.edit_chip_idx = (int)i; - // resolve current col name to index in in_headers - U.edit_col_idx = 0; - for (int c = 0; c < in_cols; ++c) { - if (std::strcmp(in_headers[c], stg.breakouts[i].c_str()) == 0) { - U.edit_col_idx = c; break; - } - } + U.edit_col_idx = (col_idx >= 0) ? col_idx : 0; ImGui::OpenPopup("##edit_breakout"); } if (clicked) { stg.breakouts.erase(stg.breakouts.begin() + i); continue; } + + // Granularity combo inline cuando col Date (fase 10). + if (is_date_col) { + ImGui::SameLine(); + const char* preview = (g == DateGranularity::None) + ? "(raw)" : date_granularity_token(g); + char combo_id[32]; + std::snprintf(combo_id, sizeof(combo_id), "##gran%zu", i); + ImGui::SetNextItemWidth(72); + if (ImGui::BeginCombo(combo_id, preview)) { + DateGranularity opts[] = { + DateGranularity::None, + DateGranularity::Year, + DateGranularity::Month, + DateGranularity::Week, + DateGranularity::Day, + DateGranularity::Hour, + }; + for (auto o : opts) { + const char* lbl = (o == DateGranularity::None) + ? "(raw)" : date_granularity_token(o); + if (ImGui::Selectable(lbl, o == g)) { + stg.breakouts[i] = compose_breakout(col_name, o); + } + } + ImGui::EndCombo(); + } + } + ImGui::SameLine(); ++i; } - (void)in_headers; (void)in_cols; ImGui::NewLine(); } @@ -1220,7 +1462,8 @@ void draw_add_filter_popup(Stage& stg, const char* const* eff_headers_arr, int e } void draw_add_breakout_popup(Stage& stg, const char* const* in_headers, int in_cols, - const std::vector& in_types) { + const std::vector& in_types, + const char* const* in_cells, int in_rows) { auto& U = ui(); if (!ImGui::BeginPopup("##addbreakout")) return; if (U.brk_picker_col < 0 || U.brk_picker_col >= in_cols) U.brk_picker_col = 0; @@ -1236,7 +1479,18 @@ void draw_add_breakout_popup(Stage& stg, const char* const* in_headers, int in_c ImGui::EndCombo(); } if (ImGui::Button("Add##bk")) { - stg.breakouts.emplace_back(in_headers[U.brk_picker_col]); + int c = U.brk_picker_col; + std::string col = in_headers[c]; + // Fase 10: si col es Date, auto-detect granularidad via rango lexical + // (ISO YYYY-MM-DD ordena bien). Default Day si rango invalido. + if (c >= 0 && c < (int)in_types.size() && in_types[c] == ColumnType::Date) { + std::string lo, hi; + column_min_max(in_cells, in_rows, in_cols, c, lo, hi); + DateGranularity g = auto_date_granularity(lo, hi); + stg.breakouts.emplace_back(compose_breakout(col, g)); + } else { + stg.breakouts.emplace_back(col); + } ImGui::CloseCurrentPopup(); } ImGui::EndPopup(); @@ -1441,8 +1695,17 @@ void drill_into(State& st, int from_stage, if (prev_input_headers[i] == col_name) { ci = (int)i; break; } } if (ci < 0) return; - st.stages[target].filters.push_back(make_drill_filter(ci, value)); - st.active_stage = target; + + // Fase 10: graba step en drill_back, limpia forward (rama nueva). + DrillStep step; + step.target_stage = target; + step.filter_pos = (int)st.stages[target].filters.size(); + step.prev_active_stage = st.active_stage; + step.added = make_drill_filter(ci, value); + apply_drill_step(st, step); + auto& U = ui(); + U.drill_back.push_back(step); + U.drill_forward.clear(); } } // anon namespace @@ -1659,7 +1922,7 @@ void render(const char* id, draw_joins_chips(st, *joinables, mh); } - draw_filter_chips(act, eff_headers.data(), eff_cols); + draw_filter_chips(act, eff_headers.data(), eff_cols, eff_types); draw_add_filter_popup(act, eff_headers.data(), eff_cols, eff_types); draw_edit_filter_popup(act, eff_headers.data(), eff_cols, eff_types); @@ -2290,12 +2553,13 @@ void render(const char* id, if (chrome_visible) { ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(8, 2)); - draw_filter_chips(act, ih_ptrs.data(), in_cols_n); + draw_filter_chips(act, ih_ptrs.data(), in_cols_n, input_types_active); draw_add_filter_popup(act, ih_ptrs.data(), in_cols_n, input_types_active); draw_edit_filter_popup(act, ih_ptrs.data(), in_cols_n, input_types_active); - draw_breakout_chips(act, ih_ptrs.data(), in_cols_n); - draw_add_breakout_popup(act, ih_ptrs.data(), in_cols_n, input_types_active); + draw_breakout_chips(act, ih_ptrs.data(), in_cols_n, input_types_active); + draw_add_breakout_popup(act, ih_ptrs.data(), in_cols_n, input_types_active, + cur_cells, cur_rows); draw_edit_breakout_popup(act, ih_ptrs.data(), in_cols_n); draw_aggregation_chips(act, ih_ptrs.data(), in_cols_n); @@ -2524,7 +2788,22 @@ void render(const char* id, so_local.cells.push_back(cur_cells[i]); so_ptr = &so_local; } - viz::render(*so_ptr, st.display, st.viz_config, ImVec2(-1, -1)); + int clicked_row = -1; + viz::render(*so_ptr, st.display, st.viz_config, ImVec2(-1, -1), &clicked_row); + // Fase 10: click sobre chart -> drill al stage previo usando + // breakout col[0] como filtro Op::Eq sobre cells[clicked_row]. + if (clicked_row >= 0 && active > 0 && + so_ptr->cols > 0 && clicked_row < so_ptr->rows) { + int n_brk = (int)st.stages[active].breakouts.size(); + if (n_brk > 0) { + const char* v = so_ptr->cells[clicked_row * so_ptr->cols + 0]; + std::string col_clean; + parse_breakout_granularity(so_ptr->headers[0], col_clean); + drill_into(st, active, col_clean, + v ? std::string(v) : "", + input_headers_active); + } + } goto stage_n_table_end; } @@ -2613,12 +2892,10 @@ void render(const char* id, ImGui::PushID(r * cur_cols_n + c); ImGui::Selectable(cell ? cell : ""); if (ImGui::IsItemHovered() && ImGui::IsMouseClicked(ImGuiMouseButton_Right)) { - // Drill-down solo si c es col de breakout (c < n_brk). - if (c < n_brk) { - U.pending_col = c; - U.pending_value = cell ? cell : ""; - ImGui::OpenPopup("##drill_popup"); - } + U.pending_col = c; + U.pending_value = cell ? cell : ""; + U.inspect_row = r; + ImGui::OpenPopup("##drill_popup"); } if (ImGui::BeginPopup("##drill_popup")) { if (c < n_brk) { @@ -2631,6 +2908,12 @@ void render(const char* id, input_headers_active); ImGui::CloseCurrentPopup(); } + ImGui::Separator(); + } + if (ImGui::MenuItem("Inspect row...")) { + U.inspect_row = r; + U.inspect_open = true; + ImGui::CloseCurrentPopup(); } ImGui::EndPopup(); } @@ -2642,6 +2925,11 @@ void render(const char* id, } stage_n_table_end:; + // Row inspector modal (fase 10). Activado via right-click "Inspect row..." + // sobre celdas del table del stage activo. `cur_cells` ya es row-major. + draw_row_inspector_modal(st, active, cur_cells, cur_rows, cur_cols_n, + cur_headers, cur_types, input_headers_active); + // Render extras (stage>0 path) if (!st.extra_panels.empty() && cur_cols_n > 0) { StageOutput so_local; @@ -2958,6 +3246,118 @@ void render(const char* id, ImGui::EndPopup(); } + // Ask AI modal (fase 11 — issue 0080). + if (U.ask_open) ImGui::OpenPopup("Ask AI"); + ImGui::SetNextWindowSize(ImVec2(820, 560), ImGuiCond_Appearing); + if (ImGui::BeginPopupModal("Ask AI", &U.ask_open, + ImGuiWindowFlags_NoSavedSettings)) { + ImGui::TextDisabled("Ask en lenguaje natural. Default TQL. SQL solo si DuckDB linkado."); + const char* modes[] = {"TQL", "SQL (DuckDB)"}; +#ifndef FN_TQL_DUCKDB + // SQL mode disabled visually pero el toggle existe (informativo) + if (U.ask_mode == 1) U.ask_mode = 0; +#endif + ImGui::Combo("Output##askmode", &U.ask_mode, modes, IM_ARRAYSIZE(modes)); +#ifndef FN_TQL_DUCKDB + if (U.ask_mode == 1) { + ImGui::TextColored(ImVec4(1, 0.5f, 0.3f, 1), + "SQL mode requires FN_TQL_DUCKDB=1 build flag."); + } +#endif + ImGui::InputTextMultiline("##ask_q", U.ask_question, sizeof(U.ask_question), + ImVec2(-1, 80)); + ImGui::BeginDisabled(U.ask_busy); + if (ImGui::Button("Send")) { + U.ask_busy = true; + U.ask_status = "Sending..."; + U.ask_error.clear(); + U.ask_response_code.clear(); + U.ask_response_raw.clear(); + + // Build AskInput desde el state actual. + llm_anthropic::AskInput in; + in.question = U.ask_question; + in.tql_current = U.ask_current_tql; + in.col_names = U.active_headers; + in.col_types = U.active_types; + in.mode = (U.ask_mode == 1) + ? llm_anthropic::OutputMode::SQL + : llm_anthropic::OutputMode::TQL; + + // Llamada blocking (UI freeze breve durante red). + auto r = llm_anthropic::ask(in); + U.ask_busy = false; + if (!r.error.empty()) { + U.ask_error = r.error; + U.ask_status = "Error"; + } else { + U.ask_response_raw = r.raw; + U.ask_response_code = r.code; + U.ask_status = "Got response."; + // Llenar edit buffer + std::snprintf(U.ask_edit_buf, sizeof(U.ask_edit_buf), + "%s", r.code.c_str()); + } + } + ImGui::EndDisabled(); + ImGui::SameLine(); + if (!U.ask_status.empty()) { + ImGui::TextDisabled("%s", U.ask_status.c_str()); + } + if (!U.ask_error.empty()) { + ImGui::TextColored(ImVec4(1, 0.4f, 0.4f, 1), "%s", U.ask_error.c_str()); + } + ImGui::Separator(); + ImGui::Columns(2, "ask_cols", true); + ImGui::TextUnformatted("Current"); + ImGui::InputTextMultiline("##ask_cur", + const_cast(U.ask_current_tql.c_str()), + U.ask_current_tql.size() + 1, + ImVec2(-1, 240), + ImGuiInputTextFlags_ReadOnly); + ImGui::NextColumn(); + ImGui::TextUnformatted("Proposed (editable before apply)"); + ImGui::InputTextMultiline("##ask_new", U.ask_edit_buf, sizeof(U.ask_edit_buf), + ImVec2(-1, 240)); + ImGui::Columns(1); + + bool can_apply = !U.ask_busy && U.ask_edit_buf[0] != '\0'; + ImGui::BeginDisabled(!can_apply); + if (ImGui::Button("Apply")) { + std::string err; + if (U.ask_mode == 0) { + // TQL apply + bool ok = tql::apply(U.ask_edit_buf, st, + U.active_headers, + U.active_types, + nullptr, 0, + (int)U.active_headers.size(), + &err); + if (ok) { + U.ask_status = "Applied OK."; + U.ask_open = false; + } else { + U.ask_error = "tql::apply error: " + err; + U.ask_status = "Apply failed."; + } + } else { + // SQL apply: requires DuckDB adapter (no v1). + U.ask_status = "SQL execute requires FN_TQL_DUCKDB build flag."; + } + } + ImGui::EndDisabled(); + ImGui::SameLine(); + if (ImGui::Button("Reject")) { + U.ask_response_code.clear(); + U.ask_edit_buf[0] = '\0'; + } + ImGui::SameLine(); + if (ImGui::Button("Close")) { + U.ask_open = false; + } + ImGui::EndPopup(); + } + if (U.open_cell_popup) { ImGui::OpenPopup("##cell_op"); U.open_cell_popup = false; } if (ImGui::BeginPopup("##cell_op")) { ColumnType t = (U.pending_col >= 0 && U.pending_col < eff_cols) diff --git a/cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp b/cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp index 51065e06..f81e09e6 100644 --- a/cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp +++ b/cpp/apps/primitives_gallery/playground/tables/data_table_logic.cpp @@ -567,6 +567,69 @@ Filter make_drill_filter(int col_idx, const std::string& value) { return f; } +bool apply_drill_step(State& st, const DrillStep& step) { + if (step.target_stage < 0 || step.target_stage >= (int)st.stages.size()) return false; + Stage& s = st.stages[step.target_stage]; + int pos = step.filter_pos; + if (pos < 0 || pos > (int)s.filters.size()) return false; + s.filters.insert(s.filters.begin() + pos, step.added); + st.active_stage = step.target_stage; + return true; +} + +bool drill_up(State& st) { + if (st.stages.empty()) return false; + if (st.active_stage <= 0) return false; + st.active_stage -= 1; + return true; +} + +std::string row_to_tsv(const char* const* cells, int rows, int cols, + int row_idx, const std::vector& headers) { + if (row_idx < 0 || row_idx >= rows || cols <= 0) return ""; + std::string out; + for (int c = 0; c < cols; ++c) { + if (c > 0) out += '\t'; + if (c < (int)headers.size()) out += headers[c]; + } + out += "\r\n"; + for (int c = 0; c < cols; ++c) { + if (c > 0) out += '\t'; + const char* v = cells[row_idx * cols + c]; + if (v) out += v; + } + out += "\r\n"; + return out; +} + +std::vector build_filters_from_row(const char* const* cells, int rows, + int cols, int row_idx) { + std::vector out; + if (row_idx < 0 || row_idx >= rows || cols <= 0) return out; + for (int c = 0; c < cols; ++c) { + const char* v = cells[row_idx * cols + c]; + if (!v || !*v) continue; + Filter f; + f.col = c; + f.op = Op::Eq; + f.value = v; + out.push_back(f); + } + return out; +} + +bool undo_drill_step(State& st, const DrillStep& step) { + if (step.target_stage < 0 || step.target_stage >= (int)st.stages.size()) return false; + Stage& s = st.stages[step.target_stage]; + int pos = step.filter_pos; + if (pos < 0 || pos >= (int)s.filters.size()) return false; + s.filters.erase(s.filters.begin() + pos); + if (step.prev_active_stage >= 0 && step.prev_active_stage < (int)st.stages.size()) { + st.active_stage = step.prev_active_stage; + } + return true; +} + std::vector apply_filters(const char* const* cells, int rows, int cols, const std::vector& filters) { @@ -696,19 +759,57 @@ StageOutput compute_stage(const char* const* in_cells, int in_rows, int in_cols, } // Grouped: agrupa visible por valores de breakout, calcula aggregations. - std::vector break_cols(stage.breakouts.size()); - for (size_t i = 0; i < stage.breakouts.size(); ++i) { - break_cols[i] = find_col(in_headers, stage.breakouts[i]); + // Breakouts pueden llevar sufijo `:granularity` para cols Date (fase 10). + int nbreaks = (int)stage.breakouts.size(); + std::vector break_cols(nbreaks); + std::vector break_grans(nbreaks); + bool any_trunc = false; + for (int i = 0; i < nbreaks; ++i) { + std::string col_name; + break_grans[i] = parse_breakout_granularity(stage.breakouts[i], col_name); + if (break_grans[i] != DateGranularity::None) any_trunc = true; + break_cols[i] = find_col(in_headers, col_name); } + // Pre-truncate solo cuando hay granularity activa. Strings persistidos en + // out.cell_backing para que los punteros sobrevivan al return de la funcion. + // Reservamos upfront para que push_back no invalide punteros anteriores. + // Tamaño = trunc cells + aggregation cells (peor caso n_groups <= in_rows). + out.cell_backing.reserve( + (size_t)in_rows * (size_t)nbreaks + + (size_t)in_rows * stage.aggregations.size() + 16); + + std::vector trunc_ptrs; + if (any_trunc) { + trunc_ptrs.assign((size_t)in_rows * (size_t)nbreaks, nullptr); + for (int r = 0; r < in_rows; ++r) { + for (int i = 0; i < nbreaks; ++i) { + if (break_grans[i] == DateGranularity::None) continue; + int bc = break_cols[i]; + if (bc < 0) continue; + const char* v = in_cells[r * in_cols + bc]; + out.cell_backing.emplace_back( + truncate_date(v ? v : "", break_grans[i])); + trunc_ptrs[(size_t)r * nbreaks + i] = out.cell_backing.back().c_str(); + } + } + } + + auto cell_for = [&](int r, int i) -> const char* { + int bc = break_cols[i]; + if (bc < 0) return ""; + if (break_grans[i] != DateGranularity::None) { + return trunc_ptrs[(size_t)r * nbreaks + i]; + } + const char* v = in_cells[r * in_cols + bc]; + return v ? v : ""; + }; + auto make_key = [&](int r) -> std::string { std::string k; - for (size_t i = 0; i < break_cols.size(); ++i) { + for (int i = 0; i < nbreaks; ++i) { if (i > 0) k += '\x1f'; // separador unit-separator (no aparece en datos) - int bc = break_cols[i]; - if (bc < 0) continue; - const char* v = in_cells[r * in_cols + bc]; - k += (v ? v : ""); + k += cell_for(r, i); } return k; }; @@ -727,10 +828,9 @@ StageOutput compute_stage(const char* const* in_cells, int in_rows, int in_cols, key_to_group.emplace(k, gi); group_keys.push_back(k); group_rows.emplace_back(); - std::vector bv(break_cols.size(), ""); - for (size_t i = 0; i < break_cols.size(); ++i) { - int bc = break_cols[i]; - bv[i] = (bc >= 0) ? in_cells[r * in_cols + bc] : ""; + std::vector bv((size_t)nbreaks, ""); + for (int i = 0; i < nbreaks; ++i) { + bv[i] = cell_for(r, i); } group_breakvals.push_back(std::move(bv)); } else gi = it->second; @@ -742,11 +842,17 @@ StageOutput compute_stage(const char* const* in_cells, int in_rows, int in_cols, out.cols = out_cols; out.headers.reserve(out_cols); out.types.reserve(out_cols); - for (size_t i = 0; i < stage.breakouts.size(); ++i) { + for (int i = 0; i < nbreaks; ++i) { out.headers.push_back(stage.breakouts[i]); int bc = break_cols[i]; - out.types.push_back((bc >= 0 && bc < (int)in_types.size()) - ? in_types[bc] : ColumnType::String); + // Si hay granularity activa, el output es String (formato ymd o similar), + // no la fecha original. + ColumnType ot = ColumnType::String; + if (break_grans[i] == DateGranularity::None + && bc >= 0 && bc < (int)in_types.size()) { + ot = in_types[bc]; + } + out.types.push_back(ot); } for (const auto& a : stage.aggregations) { out.headers.push_back(aggregation_alias(a)); @@ -1102,4 +1208,288 @@ StageOutput join_tables(const char* const* left_cells, int left_rows, int left_c return out; } +// ---------------------------------------------------------------------------- +// Fase 10: drill extendido — granularity + presets. +// ---------------------------------------------------------------------------- + +const char* date_granularity_token(DateGranularity g) { + switch (g) { + case DateGranularity::Year: return "year"; + case DateGranularity::Month: return "month"; + case DateGranularity::Week: return "week"; + case DateGranularity::Day: return "day"; + case DateGranularity::Hour: return "hour"; + default: return ""; + } +} + +DateGranularity date_granularity_from_token(const char* s) { + if (!s) return DateGranularity::None; + std::string t(s); + if (t == "year") return DateGranularity::Year; + if (t == "month") return DateGranularity::Month; + if (t == "week") return DateGranularity::Week; + if (t == "day") return DateGranularity::Day; + if (t == "hour") return DateGranularity::Hour; + return DateGranularity::None; +} + +DateGranularity parse_breakout_granularity(const std::string& breakout, + std::string& col_out) { + auto pos = breakout.rfind(':'); + if (pos == std::string::npos) { + col_out = breakout; + return DateGranularity::None; + } + std::string suffix = breakout.substr(pos + 1); + DateGranularity g = date_granularity_from_token(suffix.c_str()); + if (g == DateGranularity::None) { + col_out = breakout; + return DateGranularity::None; + } + col_out = breakout.substr(0, pos); + return g; +} + +std::string compose_breakout(const std::string& col, DateGranularity g) { + if (g == DateGranularity::None) return col; + return col + ":" + date_granularity_token(g); +} + +int nearest_index_1d(double target, const double* xs, int n) { + if (n <= 0 || !xs) return -1; + int best = -1; + double best_d = 0.0; + for (int i = 0; i < n; ++i) { + double v = xs[i]; + if (std::isnan(v)) continue; + double d = std::fabs(v - target); + if (best < 0 || d < best_d) { best = i; best_d = d; } + } + return best; +} + +int nearest_index_2d(double tx, double ty, + const double* xs, const double* ys, int n) { + if (n <= 0 || !xs || !ys) return -1; + int best = -1; + double best_d = 0.0; + for (int i = 0; i < n; ++i) { + double x = xs[i], y = ys[i]; + if (std::isnan(x) || std::isnan(y)) continue; + double dx = x - tx, dy = y - ty; + double d = dx*dx + dy*dy; + if (best < 0 || d < best_d) { best = i; best_d = d; } + } + return best; +} + +double pie_angle(double cx, double cy, double mx, double my) { + // ImPlot pie: 0 = top, sentido horario. atan2 estandar: 0 = +X (right), CCW. + // Conversion: ImPlot angle = atan2(dx, -dy) y normalizar a [0, 2*PI). + double dx = mx - cx; + double dy = my - cy; + double a = std::atan2(dx, -dy); // 0 cuando (dx=0, dy<0) = top + const double two_pi = 6.283185307179586; + if (a < 0) a += two_pi; + return a; +} + +int pie_slice_at_angle(double angle, const double* sums, int n) { + if (n <= 0 || !sums) return -1; + double total = 0.0; + for (int i = 0; i < n; ++i) { + if (sums[i] < 0) return -1; + total += sums[i]; + } + if (total <= 0.0) return -1; + const double two_pi = 6.283185307179586; + if (angle < 0 || angle >= two_pi) return -1; + double cum = 0.0; + for (int i = 0; i < n; ++i) { + cum += (sums[i] / total) * two_pi; + if (angle < cum) return i; + } + return n - 1; // edge case rounding +} + +void heatmap_cell_at(double px, double py, int rows, int cols, + int& row_out, int& col_out) { + row_out = -1; + col_out = -1; + if (rows <= 0 || cols <= 0) return; + if (px < 0.0 || px >= (double)cols) return; + if (py < 0.0 || py >= (double)rows) return; + col_out = (int)px; + // ImPlot heatmap pinta row 0 arriba; plot Y suele invertirse. Caller + // normaliza si necesita. Aqui devolvemos row = floor(py) en coord plot. + row_out = (int)py; +} + +void column_min_max(const char* const* cells, int rows, int cols, int col_idx, + std::string& min_out, std::string& max_out) { + min_out.clear(); + max_out.clear(); + if (col_idx < 0 || col_idx >= cols) return; + bool first = true; + for (int r = 0; r < rows; ++r) { + const char* v = cells[r * cols + col_idx]; + if (!v || !*v) continue; + std::string s(v); + if (first) { + min_out = s; + max_out = s; + first = false; + } else { + if (s < min_out) min_out = s; + if (s > max_out) max_out = s; + } + } +} + +namespace { + +// Parse ISO "YYYY-MM-DD..." -> (y, m, d). True si los 3 primeros campos OK. +bool parse_ymd(const std::string& s, int& y, int& m, int& d) { + if (s.size() < 10) return false; + for (int i : {0,1,2,3,5,6,8,9}) { + if (s[(size_t)i] < '0' || s[(size_t)i] > '9') return false; + } + if (s[4] != '-' || s[7] != '-') return false; + y = (s[0]-'0')*1000 + (s[1]-'0')*100 + (s[2]-'0')*10 + (s[3]-'0'); + m = (s[5]-'0')*10 + (s[6]-'0'); + d = (s[8]-'0')*10 + (s[9]-'0'); + if (m < 1 || m > 12 || d < 1 || d > 31) return false; + return true; +} + +// Dias desde 0001-01-01 (proleptic Gregorian). +long ymd_to_days(int y, int m, int d) { + if (m <= 2) { y -= 1; m += 12; } + long era = (y >= 0 ? y : y - 399) / 400; + unsigned yoe = (unsigned)(y - era * 400); + unsigned doy = (unsigned)((153 * (m - 3) + 2) / 5 + d - 1); + unsigned doe = yoe * 365 + yoe/4 - yoe/100 + doy; + return era * 146097 + (long)doe; +} + +void days_to_ymd(long days, int& y, int& m, int& d) { + long era = (days >= 0 ? days : days - 146096) / 146097; + unsigned doe = (unsigned)(days - era * 146097); + unsigned yoe = (doe - doe/1460 + doe/36524 - doe/146096) / 365; + int yr = (int)yoe + (int)era * 400; + unsigned doy = doe - (365*yoe + yoe/4 - yoe/100); + unsigned mp = (5*doy + 2)/153; + unsigned day = doy - (153*mp + 2)/5 + 1; + unsigned mon = mp < 10 ? mp + 3 : mp - 9; + if (mon <= 2) yr += 1; + y = yr; m = (int)mon; d = (int)day; +} + +} // anon + +std::string truncate_date(const std::string& date, DateGranularity g) { + if (g == DateGranularity::None) return date; + int y, m, d; + if (!parse_ymd(date, y, m, d)) return date; + char buf[32]; + switch (g) { + case DateGranularity::Year: + std::snprintf(buf, sizeof(buf), "%04d", y); + return buf; + case DateGranularity::Month: + std::snprintf(buf, sizeof(buf), "%04d-%02d", y, m); + return buf; + case DateGranularity::Day: + std::snprintf(buf, sizeof(buf), "%04d-%02d-%02d", y, m, d); + return buf; + case DateGranularity::Hour: { + int hh = 0; + if (date.size() >= 13 && date[10] == 'T' + && date[11] >= '0' && date[11] <= '9' + && date[12] >= '0' && date[12] <= '9') { + hh = (date[11]-'0')*10 + (date[12]-'0'); + if (hh < 0 || hh > 23) hh = 0; + } + std::snprintf(buf, sizeof(buf), "%04d-%02d-%02dT%02d", y, m, d, hh); + return buf; + } + case DateGranularity::Week: { + // Hinnant ymd_to_days: day 0 == 0000-03-01 (Wednesday). + // days%7: 0=Wed, 1=Thu, 2=Fri, 3=Sat, 4=Sun, 5=Mon, 6=Tue. + // Monday offset: (mod - 5 + 7) % 7. + long days = ymd_to_days(y, m, d); + int mod = (int)(((days % 7) + 7) % 7); + int rem = ((mod - 5) % 7 + 7) % 7; + long monday = days - rem; + int yy, mm, dd; + days_to_ymd(monday, yy, mm, dd); + std::snprintf(buf, sizeof(buf), "%04d-%02d-%02d", yy, mm, dd); + return buf; + } + default: return date; + } +} + +DateGranularity auto_date_granularity(const std::string& min_ymd, + const std::string& max_ymd) { + int y1,m1,d1, y2,m2,d2; + if (!parse_ymd(min_ymd, y1,m1,d1)) return DateGranularity::Day; + if (!parse_ymd(max_ymd, y2,m2,d2)) return DateGranularity::Day; + long span = ymd_to_days(y2,m2,d2) - ymd_to_days(y1,m1,d1); + if (span < 0) span = -span; + if (span > 730) return DateGranularity::Year; // >2 anios + if (span > 60) return DateGranularity::Month; + if (span > 14) return DateGranularity::Week; + return DateGranularity::Day; +} + +const char* filter_preset_label(FilterPreset p) { + switch (p) { + case FilterPreset::Last7d: return "Last 7 days"; + case FilterPreset::Last30d: return "Last 30 days"; + case FilterPreset::Last90d: return "Last 90 days"; + case FilterPreset::ExcludeNulls: return "Exclude nulls"; + case FilterPreset::NonZero: return "Non-zero only"; + } + return "?"; +} + +std::vector build_preset_filters(FilterPreset preset, int col, + const std::string& today_ymd) { + std::vector out; + auto last_n = [&](int n) { + int y, m, d; + if (!parse_ymd(today_ymd, y, m, d)) return; + long days = ymd_to_days(y, m, d) - n; + int yy, mm, dd; + days_to_ymd(days, yy, mm, dd); + char buf[16]; + std::snprintf(buf, sizeof(buf), "%04d-%02d-%02d", yy, mm, dd); + Filter f; + f.col = col; + f.op = Op::Gte; + f.value = buf; + out.push_back(f); + }; + switch (preset) { + case FilterPreset::Last7d: last_n(7); break; + case FilterPreset::Last30d: last_n(30); break; + case FilterPreset::Last90d: last_n(90); break; + case FilterPreset::ExcludeNulls: { + Filter f; f.col = col; f.op = Op::Neq; f.value = ""; + out.push_back(f); + break; + } + case FilterPreset::NonZero: { + Filter f1; f1.col = col; f1.op = Op::Neq; f1.value = ""; + Filter f2; f2.col = col; f2.op = Op::Neq; f2.value = "0"; + out.push_back(f1); + out.push_back(f2); + break; + } + } + return out; +} + } // namespace data_table diff --git a/cpp/apps/primitives_gallery/playground/tables/data_table_logic.h b/cpp/apps/primitives_gallery/playground/tables/data_table_logic.h index 9c5c2906..8ab4b06c 100644 --- a/cpp/apps/primitives_gallery/playground/tables/data_table_logic.h +++ b/cpp/apps/primitives_gallery/playground/tables/data_table_logic.h @@ -1,26 +1,20 @@ // Logica pura del playground data_table. Sin ImGui — testable headless. -// Cuando se promueva al registry, esto sera la base de data_table_cpp_viz. +// TIPOS promovidos al registry (issue 0081). Este header solo declara +// funciones; los types vienen de cpp/functions/core/data_table_types.h. #pragma once +#include "core/data_table_types.h" #include #include #include namespace data_table { -enum class Op { - Eq, Neq, Gt, Gte, Lt, Lte, - Contains, NotContains, StartsWith, EndsWith -}; +// ---------------------------------------------------------------------------- +// Helpers para Op y ColumnType. +// ---------------------------------------------------------------------------- const char* op_label(Op o); -bool op_is_string_only(Op o); - -// ---------------------------------------------------------------------------- -// Column types - declarado por caller con fallback a auto-detect. -// ---------------------------------------------------------------------------- -enum class ColumnType { - Auto, String, Int, Float, Bool, Date, Json -}; +bool op_is_string_only(Op o); const char* column_type_name(ColumnType t); const char* column_type_icon(ColumnType t); // UTF-8 Tabler icon @@ -36,63 +30,11 @@ ColumnType auto_detect_type(const char* const* cells, int rows, int cols, ColumnType effective_type(ColumnType declared, const char* const* cells, int rows, int cols, int col); -// Derived column: inmutable. Dos modos: -// 1) Retipo puro: source_col >= 0, formula == "". Cells del origen. -// 2) Formula: source_col == -1, formula no vacia. Eval por Lua. -struct DerivedColumn { - int source_col = -1; - ColumnType type = ColumnType::String; - std::string name; - std::string formula; // "" = retipado puro; resto = body Lua - int lua_id = -1; // referencia en lua_engine; -1 si no compilado - std::string compile_error; -}; - -// Filter movido aqui (antes era despues de State) porque TQL Stage lo necesita. -struct Filter { - int col; - Op op; - std::string value; -}; - -struct ColorRule { - int col; - std::string equals; - unsigned int color; -}; - // ---------------------------------------------------------------------------- -// TQL (Table Query Language) — stage model. Ver docs/TQL.md. +// Aggregation helpers. // ---------------------------------------------------------------------------- -enum class AggFn { - Count, Sum, Avg, Min, Max, Distinct, Stddev, - Median, P25, P75, P90, P99, Percentile -}; - const char* agg_fn_name(AggFn f); -struct Aggregation { - AggFn fn = AggFn::Count; - std::string col; // ignorado para Count - double arg = 0.0; // para Percentile (0..1) - std::string alias; // vacio -> auto-generado via aggregation_alias() -}; - -struct SortClause { - std::string col; - bool desc = false; -}; - -// Stage: layer de TQL. Stage 0 = Raw (sin breakouts/aggregations). -// Stage 1+ pueden agrupar. Cada stage consume output del anterior. -struct Stage { - std::vector filters; - std::vector derived; // expressions de este stage - std::vector breakouts; // col names del INPUT de este stage - std::vector aggregations; - std::vector sorts; -}; - // Pure: alias por defecto cuando agg.alias esta vacio. // count -> "count" // distinct col -> "distinct_" @@ -101,224 +43,125 @@ struct Stage { std::string aggregation_alias(const Aggregation& a); // Pure: tipo del output de la aggregation. -// count, distinct -> Int -// sum, avg, stddev, -// median, p*, percentile -> Float -// min, max -> mismo tipo que la col origen ColumnType aggregation_type(const Aggregation& a, const std::vector& in_headers, const std::vector& in_types); -// Output de compute_stage. Posee `cell_backing` (strings nuevos para -// resultados agregados) y `cells` (punteros row-major a backing o a -// `in_cells` original para passthrough). -struct StageOutput { - std::vector cell_backing; - std::vector cells; - int rows = 0; - int cols = 0; - std::vector headers; - std::vector types; -}; - +// ---------------------------------------------------------------------------- +// Compute pipeline. +// ---------------------------------------------------------------------------- // Pure: ejecuta un Stage sobre los cells de entrada. Aplica filter -> (group+agg|passthrough) -> sort. StageOutput compute_stage(const char* const* in_cells, int in_rows, int in_cols, const std::vector& in_headers, const std::vector& in_types, const Stage& stage); -// Pure: aplica filtros usando headers para resolver f.col (que ahora es -// indice en el array de in_headers, no del dataset original). Devuelve -// indices de filas que pasan. +// Pure: aplica filtros usando headers para resolver f.col. std::vector apply_filters(const char* const* cells, int rows, int cols, const std::vector& filters); // Pure: helper para drill-down. Devuelve un Filter Op::Eq sobre col_idx con -// el value indicado. col_idx es indice en los headers del INPUT del stage -// previo (donde se va a aplicar el filtro). +// el value indicado. Filter make_drill_filter(int col_idx, const std::string& value); // ---------------------------------------------------------------------------- -// ViewMode: tipo de visualizacion a renderizar sobre el output del stage activo. -// "Table" siempre disponible. Resto requiere ciertos tipos de columnas. +// ViewMode helpers. // ---------------------------------------------------------------------------- -enum class ViewMode { - Table, - // Bars - Bar, // horizontal bars: 1 cat + 1 num - Column, // vertical bars: 1 cat + 1 num - GroupedBar, // 1 cat + N num (side-by-side) - StackedBar, // 1 cat + N num (stacked) - // Lines / area - Line, // X + 1..N Y series - Area, // shaded to y=0 - Stairs, // step plot - // Points - Scatter, // X + Y - Bubble, // X + Y + size - // Distribution - Histogram, // 1 num - Histogram2D, // 2 num - Heatmap, // matrix from breakouts - BoxPlot, // 1 cat + 1 num (min/p25/p50/p75/max per group) - // Stems / signals - Stem, - ErrorBars, - // Composition - Pie, - Donut, - Funnel, // ordered descending bars - Waterfall, // running sum - // Single values - KPI, // big text + label - KPIGrid, // all aggregations as cards - // Specialized - Candlestick, // OHLC: time + open + high + low + close - Radar, // multi-axis (1 cat + N num) -}; - -const char* view_mode_token(ViewMode m); // "table", "bar", ... -const char* view_mode_label(ViewMode m); // "Table", "Bar (horizontal)", ... +const char* view_mode_token(ViewMode m); +const char* view_mode_label(ViewMode m); ViewMode view_mode_from_token(const char* s); int view_mode_min_cols(ViewMode m); bool view_mode_needs_numeric(ViewMode m); bool view_mode_needs_category(ViewMode m); -// Requiere stage agrupado (breakout+aggregation). Si user esta en stage 0 con -// uno de estos, conviene auto-promote a stage 1. bool view_mode_needs_aggregation(ViewMode m); -// Lista completa de modos para el selector UI (orden de display). +// Lista completa de modos para el selector UI. const ViewMode* all_view_modes(int* n_out); // ---------------------------------------------------------------------------- // Joins (MBQL-style). Ver issue 0078. // ---------------------------------------------------------------------------- -enum class JoinStrategy { Left, Inner, Right, Full }; const char* join_strategy_token(JoinStrategy s); JoinStrategy join_strategy_from_token(const char* s); const char* join_strategy_label(JoinStrategy s); -// Tabla extra pasada al render() para joins. Owner externo (caller). -struct TableInput { - std::string name; // identificador estable (matchea Join.source) - std::vector headers; - std::vector types; - const char* const* cells = nullptr; // row-major, headers.size() cols x rows filas - int rows = 0; - int cols = 0; -}; - -// Join clause: une la tabla actual con `source` por las parejas `on`, -// prefijando las cols del derecho con `alias.`. -struct Join { - std::string alias; - std::string source; - std::vector> on; // {left_col, right_col} - JoinStrategy strategy = JoinStrategy::Left; - std::vector fields; // vacio = all del derecho -}; - // Pure: resuelve indice del main entre `tables` segun `main_source`. -// Vacio -> 0. Nombre desconocido -> 0. tables vacio -> -1. int resolve_main_idx(const std::vector& tables, const std::string& main_source); -// Pure: aplica un join sobre dos tablas. Resultado: StageOutput con -// `headers` = left + `.` (filtrado por fields si no vacio). +// Pure: aplica un join sobre dos tablas. StageOutput join_tables(const char* const* left_cells, int left_rows, int left_cols, const std::vector& left_headers, const std::vector& left_types, const TableInput& right, const Join& jn); -// ViewConfig: overrides manuales de auto-detect para la vista activa. -// Campos vacios -> auto. Si col name no existe en output, viz cae a auto. -struct ViewConfig { - std::string x_col; // single: scatter, line, hist2d - std::vector y_cols; // 1..N: line/area/bar/etc - std::string size_col; // bubble - std::string cat_col; // bar/pie/funnel/box override - unsigned int primary_color = 0; // 0 = ImPlot auto - int hist_bins = 0; // 0 = Sturges - float pie_radius = 0.0f; // 0 = default - bool show_legend = true; - bool show_markers = false; // line/area markers - bool locked = false; // disable pan/zoom - mutable bool fit_request = false; // consumed by viz::render -}; +// ---------------------------------------------------------------------------- +// Drill apply/undo (fase 10). +// ---------------------------------------------------------------------------- +bool apply_drill_step(State& st, const DrillStep& step); +bool undo_drill_step(State& st, const DrillStep& step); -// VizPanel: viz adicional sobre el mismo StageOutput. State.display + viz_config -// es el panel 0 (siempre visible); extra_panels son los aniadidos por el user. -struct VizPanel { - ViewMode display = ViewMode::Bar; - ViewConfig config; - // Memoria del ultimo non-Table display para toggle Table<->View. - mutable ViewMode last_non_table = ViewMode::Bar; -}; +// Pure (fase 10): drill-up. Decrementa active_stage si > 0. +bool drill_up(State& st); -// State: stage pipeline + viz globales. -// -// `stages` siempre tiene tamaño >= 1 (auto-init en compute_visible_rows / render -// si esta vacio: se crea stages[0] vacio). Stage 0 es Raw (filters + derived + -// sorts; SIN breakouts/aggregations). Stages 1+ pueden agrupar. -// -// `active_stage` = indice del stage cuyo output se renderiza. -// `col_visible/col_order/color_rules` aplican al output del stage activo. -struct State { - std::vector stages; - int active_stage = 0; - ViewMode display = ViewMode::Table; - ViewConfig viz_config; - std::vector extra_panels; - std::vector joins; // aplicado antes de stages[0] - std::string main_source; // name de TableInput a usar como main; vacio -> tables[0] +// Pure (fase 10): serializa una fila a TSV. +std::string row_to_tsv(const char* const* cells, int rows, int cols, + int row_idx, const std::vector& headers); - std::vector color_rules; - std::vector col_visible; // size = effective_cols del stage activo - std::vector col_order; // permutacion [0..effective_cols) +// Pure (fase 10): construye filters Op::Eq desde una fila. +std::vector build_filters_from_row(const char* const* cells, int rows, + int cols, int row_idx); - // --- Compat helpers: shortcuts a stages[0] (Raw) --- - // Util tras refactor para tests / accesos puntuales. Garantizan stages[0] - // existe (lo crean vacio si no). - Stage& raw(); - const Stage& raw() const; - Stage& active(); - const Stage& active_const() const; - void ensure_stage0(); -}; +// ---------------------------------------------------------------------------- +// Date granularity helpers (fase 10). +// ---------------------------------------------------------------------------- +const char* date_granularity_token(DateGranularity g); +DateGranularity date_granularity_from_token(const char* s); -// Parse "1.23" -> 1.23, true. False si la celda no es numero completo. +DateGranularity parse_breakout_granularity(const std::string& breakout, + std::string& col_out); + +std::string compose_breakout(const std::string& col, DateGranularity g); + +void column_min_max(const char* const* cells, int rows, int cols, int col_idx, + std::string& min_out, std::string& max_out); + +// Hit-tests para click-to-drill sobre charts (fase 10). +int nearest_index_1d(double target, const double* xs, int n); +int nearest_index_2d(double tx, double ty, + const double* xs, const double* ys, int n); +double pie_angle(double cx, double cy, double mx, double my); +int pie_slice_at_angle(double angle, const double* sums, int n); +void heatmap_cell_at(double px, double py, int rows, int cols, + int& row_out, int& col_out); + +// Date trunc + auto + presets. +std::string truncate_date(const std::string& date, DateGranularity g); +DateGranularity auto_date_granularity(const std::string& min_ymd, + const std::string& max_ymd); +const char* filter_preset_label(FilterPreset p); +std::vector build_preset_filters(FilterPreset preset, int col, + const std::string& today_ymd); + +// ---------------------------------------------------------------------------- +// Misc helpers. +// ---------------------------------------------------------------------------- bool parse_number(const char* s, double& out); - -// Compara dos celdas con operador. Numerico si ambas parseables; lexical si no. bool compare(const char* a, const char* b, Op op); -// Aplica filtros y ordena. Devuelve indices de filas visibles. std::vector compute_visible_rows(const char* const* cells, int rows, int cols, const State& st); -// Pure: muta col_order de st para colocar `src` en la posicion (en orden visual) -// donde estaba `dst`. No-op si src == dst o cualquiera fuera del array. void reorder_column(State& st, int src, int dst); -// Pure: dado un buffer y posicion de cursor, busca el `[` abierto sin cerrar -// mas reciente. Devuelve su indice (o -1 si ninguno). Rellena `filter_text` -// con los caracteres entre `[` y cursor. -// Para autocomplete de formulas: cuando el usuario teclea `[` el ImGui callback -// detecta esto y muestra un popup con cols disponibles. int find_open_bracket(const char* buf, int len, int cursor, std::string& filter_text); -// Pure: reemplaza src[start..cursor) por "[name]". Devuelve nuevo string y -// actualiza `new_cursor` a la posicion despues del `]`. std::string insert_column_ref(const std::string& src, int start, int cursor, const std::string& name, int& new_cursor); -// CSV: escapa una celda segun RFC 4180 (wrap en " si contiene , " o newline). std::string csv_escape(const char* s); -// Construye TSV de un rect de seleccion. Headers SIEMPRE incluidos. -// view_row_lo/hi: indices en visible_rows. -// view_col_lo/hi: indices en col_order. Cols ocultas se omiten. std::string build_tsv(const char* const* cells, int rows, int cols, const char* const* headers, const std::vector& col_order, @@ -327,19 +170,21 @@ std::string build_tsv(const char* const* cells, int rows, int cols, int view_row_lo, int view_row_hi, int view_col_lo, int view_col_hi); -// Construye CSV (full visible view). Headers incluidos, cells escapados. std::string build_csv(const char* const* cells, int rows, int cols, const char* const* headers, const std::vector& col_order, const std::vector& col_visible, const std::vector& visible_rows); +// ---------------------------------------------------------------------------- +// Column statistics (no movido todavia al registry). +// ---------------------------------------------------------------------------- struct ColStats { - int total = 0; // filas escaneadas - int empty_count = 0; // cells == "" o null - int unique_count = 0; // distintas (cap configurable) - bool unique_capped = false; // true si se alcanzo el cap - bool numeric = false; // true si todas las cells no-vacias parsean como numero + int total = 0; + int empty_count = 0; + int unique_count = 0; + bool unique_capped = false; + bool numeric = false; int numeric_count = 0; double min = 0; double max = 0; @@ -348,16 +193,12 @@ struct ColStats { double p25 = 0; double p50 = 0; double p75 = 0; - std::vector hist; // bins (HIST_BINS) si numeric - std::vector> top_categories; // top 8 por count desc + std::vector hist; + std::vector> top_categories; }; constexpr int HIST_BINS = 24; -// Pure: escanea una columna y devuelve estadisticas. `unique_cap` corta el -// conteo de unicos si excede (para datasets de millones). 0 = sin cap. -// Si `indices != nullptr` y `n_indices > 0`, recorre solo las filas indicadas -// (uso tipico: stats sobre filas visibles post-filtro). ColStats compute_column_stats(const char* const* cells, int rows, int cols, int col, int unique_cap = 100000, const int* indices = nullptr, int n_indices = 0); diff --git a/cpp/apps/primitives_gallery/playground/tables/llm_anthropic.cpp b/cpp/apps/primitives_gallery/playground/tables/llm_anthropic.cpp new file mode 100644 index 00000000..abfed046 --- /dev/null +++ b/cpp/apps/primitives_gallery/playground/tables/llm_anthropic.cpp @@ -0,0 +1,295 @@ +// llm_anthropic.cpp — cliente Anthropic minimal via cURL popen. +// Ver issue 0080. +#include "llm_anthropic.h" + +#include +#include +#include +#include +#include + +namespace llm_anthropic { + +using namespace data_table; + +namespace { + +// JSON escape minimal. +std::string json_escape(const std::string& s) { + std::string o; + o.reserve(s.size() + 8); + for (char c : s) { + switch (c) { + case '"': o += "\\\""; break; + case '\\': o += "\\\\"; break; + case '\n': o += "\\n"; break; + case '\r': o += "\\r"; break; + case '\t': o += "\\t"; break; + case '\b': o += "\\b"; break; + case '\f': o += "\\f"; break; + default: + if ((unsigned char)c < 0x20) { + char buf[8]; + std::snprintf(buf, sizeof(buf), "\\u%04x", (int)(unsigned char)c); + o += buf; + } else { + o += c; + } + } + } + return o; +} + +const char* col_type_doc(ColumnType t) { + switch (t) { + case ColumnType::String: return "string"; + case ColumnType::Int: return "int"; + case ColumnType::Float: return "float"; + case ColumnType::Bool: return "bool"; + case ColumnType::Date: return "date"; + case ColumnType::Json: return "json"; + case ColumnType::Auto: return "auto"; + } + return "?"; +} + +std::string build_schema_block(const AskInput& in) { + std::ostringstream os; + os << "Available columns (stage 0 input):\n"; + for (size_t i = 0; i < in.col_names.size(); ++i) { + os << " - " << in.col_names[i] << ": " + << col_type_doc(i < in.col_types.size() ? in.col_types[i] : ColumnType::String) + << "\n"; + } + if (!in.joinable_names.empty()) { + os << "Joinable tables (for join clause):\n"; + for (const auto& n : in.joinable_names) os << " - " << n << "\n"; + } + return os.str(); +} + +std::string build_system_prompt(OutputMode mode) { + if (mode == OutputMode::TQL) { + return + "You are a TQL (Table Query Language) expert. Output ONLY a Lua code block. " + "TQL is a Lua table with shape:\n" + " return { version=1, display=\"table\"|\"bar\"|\"line\"|...,\n" + " main_source=\"name\", joins={ {alias,source,on,strategy,fields},... },\n" + " stages={ {filter={{op,col,value},...}, breakout={...}, aggregation={...}, sort={...} },... },\n" + " columns={ name = {type=\"int|float|...\", formula=\"[col]+1\"},... }\n" + " }\n" + "Stage 0 = Raw (filters + derived + sort, NO breakouts/aggs).\n" + "Stage 1+ groups (breakouts + aggregations).\n" + "Breakout granularity: append :year|:month|:week|:day|:hour to col name.\n" + "Aggregation functions: count|sum|avg|min|max|distinct|stddev|median|p25|p75|p90|p99|percentile.\n" + "Filter ops: '='|'!='|'<'|'<='|'>'|'>='|'contains'|'!contains'|'starts'|'ends'.\n" + "Sort: {{dir, col}, ...} where dir = 'asc'|'desc'.\n" + "Join strategies: 'left'|'inner'|'right'|'full'.\n" + "Formulas use Lua expression syntax with [col] for column refs.\n" + "Output format: ```lua\\n...\\n```"; + } + return + "You are a DuckDB SQL expert. Output ONLY a SQL code block compatible with DuckDB.\n" + "Use CTEs to chain stages. Use date_trunc('month', col) for granularity.\n" + "Use quantile_cont(col, p) for percentiles. Use ? for bound params.\n" + "Joins: LEFT/INNER/RIGHT/FULL OUTER JOIN. String concat: ||. Aggregations: standard SQL.\n" + "Output format: ```sql\\n...\\n```"; +} + +} // anon + +std::string build_request_body(const AskInput& in) { + std::string system_msg = build_system_prompt(in.mode); + std::string schema = build_schema_block(in); + + std::ostringstream user_msg; + user_msg << "Question: " << in.question << "\n\n" + << schema << "\n"; + if (!in.tql_current.empty()) { + user_msg << "Current TQL:\n```lua\n" << in.tql_current << "\n```\n"; + } + + std::string model = in.model.empty() ? "claude-sonnet-4-6" : in.model; + + std::ostringstream body; + body << "{" + << "\"model\":\"" << json_escape(model) << "\"," + << "\"max_tokens\":" << in.max_tokens << "," + << "\"system\":\"" << json_escape(system_msg) << "\"," + << "\"messages\":[{" + << "\"role\":\"user\"," + << "\"content\":\"" << json_escape(user_msg.str()) << "\"" + << "}]" + << "}"; + return body.str(); +} + +std::string extract_code_block(const std::string& raw, const std::string& lang) { + // Buscar ``` primero, sino ``` plain. + std::string fence_lang = "```" + lang; + auto pos = raw.find(fence_lang); + size_t code_start = std::string::npos; + if (pos != std::string::npos) { + code_start = pos + fence_lang.size(); + } else { + pos = raw.find("```"); + if (pos != std::string::npos) { + code_start = pos + 3; + // skip optional lang tag + while (code_start < raw.size() && raw[code_start] != '\n' && + raw[code_start] != '\r' && std::isalnum((unsigned char)raw[code_start])) { + ++code_start; + } + } + } + if (code_start == std::string::npos) { + // No fence — return raw stripped. + size_t i = 0; while (i < raw.size() && std::isspace((unsigned char)raw[i])) ++i; + size_t j = raw.size(); while (j > i && std::isspace((unsigned char)raw[j-1])) --j; + return raw.substr(i, j - i); + } + // Skip newline tras fence. + if (code_start < raw.size() && raw[code_start] == '\n') ++code_start; + auto end = raw.find("```", code_start); + if (end == std::string::npos) end = raw.size(); + std::string code = raw.substr(code_start, end - code_start); + // Trim trailing newline. + while (!code.empty() && (code.back() == '\n' || code.back() == '\r')) code.pop_back(); + return code; +} + +std::string parse_response_text(const std::string& json) { + // Buscar pattern: "text":"..." + // Simple: primer occurrence de \"text\":\" tras \"type\":\"text\" + auto t = json.find("\"text\""); + while (t != std::string::npos) { + // Skip "text" + size_t i = t + 6; + // Skip whitespace y : + while (i < json.size() && (json[i] == ' ' || json[i] == ':' || json[i] == '\t')) ++i; + if (i >= json.size() || json[i] != '"') { + t = json.find("\"text\"", t + 1); + continue; + } + ++i; + std::string out; + while (i < json.size() && json[i] != '"') { + if (json[i] == '\\' && i + 1 < json.size()) { + char esc = json[i+1]; + if (esc == 'n') out += '\n'; + else if (esc == 't') out += '\t'; + else if (esc == 'r') out += '\r'; + else if (esc == '"') out += '"'; + else if (esc == '\\') out += '\\'; + else if (esc == '/') out += '/'; + else if (esc == 'u' && i + 5 < json.size()) { + // basic ascii \uXXXX + int code = 0; + for (int k = 0; k < 4; ++k) { + char c = json[i + 2 + k]; + int v = (c >= '0' && c <= '9') ? c - '0' + : (c >= 'a' && c <= 'f') ? c - 'a' + 10 + : (c >= 'A' && c <= 'F') ? c - 'A' + 10 : 0; + code = code * 16 + v; + } + if (code < 128) out += (char)code; + else out += '?'; + i += 5; + } else { + out += esc; + } + i += 2; + } else { + out += json[i++]; + } + } + return out; + } + return ""; +} + +namespace { + +// Lee API key segun prioridad: param > env FN_LLM_API_KEY > pass anthropic/api-key. +std::string resolve_api_key(const std::string& provided) { + if (!provided.empty()) return provided; + const char* env = std::getenv("FN_LLM_API_KEY"); + if (env && *env) return env; + // pass anthropic/api-key | head -n1 + FILE* p = popen("pass anthropic/api-key 2>/dev/null | head -n1", "r"); + if (!p) return ""; + std::string out; + char buf[256]; + while (fgets(buf, sizeof(buf), p)) out += buf; + pclose(p); + while (!out.empty() && (out.back() == '\n' || out.back() == '\r')) out.pop_back(); + return out; +} + +} // anon + +std::string call_api(const std::string& body, const std::string& api_key, + std::string& error_out) { + error_out.clear(); + // Test injection + const char* mock = std::getenv("FN_LLM_MOCK_RESPONSE"); + if (mock && *mock) return mock; + + std::string key = resolve_api_key(api_key); + if (key.empty()) { + error_out = "no API key (set FN_LLM_API_KEY env, pass param, or `pass anthropic/api-key`)"; + return ""; + } + const char* endpoint_env = std::getenv("FN_LLM_ENDPOINT"); + std::string endpoint = endpoint_env && *endpoint_env + ? endpoint_env + : "https://api.anthropic.com/v1/messages"; + + // popen "w+" no portable. Write body a tmp file y leer respuesta de curl + // por redireccion. Portable Unix/Mingw. + std::string tmp_in = std::tmpnam(nullptr); + std::string tmp_out = std::tmpnam(nullptr); + { + FILE* f = std::fopen(tmp_in.c_str(), "w"); + if (!f) { error_out = "tmp file write fail"; return ""; } + std::fwrite(body.data(), 1, body.size(), f); + std::fclose(f); + } + std::string cmd2 = "curl -sS -X POST " + "-H \"content-type: application/json\" " + "-H \"anthropic-version: 2023-06-01\" " + "-H \"x-api-key: " + key + "\" " + "--data-binary @" + tmp_in + " " + endpoint + + " > " + tmp_out + " 2>&1"; + int rc = std::system(cmd2.c_str()); + std::string resp; + { + FILE* f = std::fopen(tmp_out.c_str(), "r"); + if (f) { + char buf[4096]; + size_t n; + while ((n = std::fread(buf, 1, sizeof(buf), f)) > 0) resp.append(buf, n); + std::fclose(f); + } + } + std::remove(tmp_in.c_str()); + std::remove(tmp_out.c_str()); + if (rc != 0) { + error_out = "curl exit " + std::to_string(rc) + ": " + resp; + return ""; + } + return resp; +} + +AskResult ask(const AskInput& in, const std::string& api_key) { + AskResult r; + std::string body = build_request_body(in); + std::string raw_json = call_api(body, api_key, r.error); + if (!r.error.empty()) return r; + r.raw = parse_response_text(raw_json); + std::string lang = (in.mode == OutputMode::TQL) ? "lua" : "sql"; + r.code = extract_code_block(r.raw, lang); + return r; +} + +} // namespace llm_anthropic diff --git a/cpp/apps/primitives_gallery/playground/tables/llm_anthropic.h b/cpp/apps/primitives_gallery/playground/tables/llm_anthropic.h new file mode 100644 index 00000000..bbc35c03 --- /dev/null +++ b/cpp/apps/primitives_gallery/playground/tables/llm_anthropic.h @@ -0,0 +1,58 @@ +// llm_anthropic: cliente HTTP minimal a Anthropic Claude API. +// Sin deps externas (cURL via popen). +// Ver issue 0080. +#pragma once + +#include "data_table_logic.h" +#include "tql_to_sql.h" +#include +#include + +namespace llm_anthropic { + +enum class OutputMode { TQL, SQL }; + +struct AskInput { + std::string question; // pregunta NL + std::string tql_current; // TQL actual (emitido) + std::vector col_names; // schema input + std::vector col_types; + std::vector joinable_names; // tables disponibles para join + OutputMode mode = OutputMode::TQL; + std::string model; // empty -> default + int max_tokens = 8192; +}; + +struct AskResult { + std::string code; // bloque ```lua o ```sql extraido (sin fences) + std::string raw; // texto completo de la respuesta + std::string error; // non-empty si fallo + int tokens_in = 0; + int tokens_out = 0; +}; + +// Pure: construye el system prompt y user message JSON-escapado. +// Devuelve el JSON body completo POST al endpoint /v1/messages. +std::string build_request_body(const AskInput& in); + +// Pure: extrae primer ```\n ... \n``` bloque de `raw`. lang = "lua"|"sql". +// Si no encuentra fence, retorna raw stripped. +std::string extract_code_block(const std::string& raw, const std::string& lang); + +// Pure: extrae texto del JSON de respuesta Anthropic. +// Busca `"content":[{"type":"text","text":"..."}]` y devuelve el text. +std::string parse_response_text(const std::string& json_body); + +// Impure: lanza cURL via popen, posts `body` al endpoint Anthropic /v1/messages, +// retorna response body (JSON crudo). API key leida de: +// 1. parametro `api_key` si non-empty +// 2. env FN_LLM_API_KEY +// 3. `pass anthropic/api-key | head -n1` +// Si FN_LLM_MOCK_RESPONSE env set, retorna su valor (test injection). +std::string call_api(const std::string& body, const std::string& api_key, + std::string& error_out); + +// Orchestrator: build prompt + POST + parse. Convenience wrapper. +AskResult ask(const AskInput& in, const std::string& api_key = ""); + +} // namespace llm_anthropic diff --git a/cpp/apps/primitives_gallery/playground/tables/self_test.cpp b/cpp/apps/primitives_gallery/playground/tables/self_test.cpp index 08d39278..6a2283f3 100644 --- a/cpp/apps/primitives_gallery/playground/tables/self_test.cpp +++ b/cpp/apps/primitives_gallery/playground/tables/self_test.cpp @@ -7,9 +7,12 @@ // Exit 0 = todos los checks pasan, 1 = falla. #include "data_table_logic.h" +#include "llm_anthropic.h" #include "lua_engine.h" #include "tql.h" +#include "tql_to_sql.h" +#include #include #include #include @@ -2051,6 +2054,782 @@ return { check(join_strategy_from_token("nope") == JoinStrategy::Left, "phase9: parse fallback left"); } + // === phase10: drill extendido === + { + // truncate_date — granularities sobre 2026-05-12 (martes). + std::string d = "2026-05-12"; + check(truncate_date(d, DateGranularity::Year) == "2026", "phase10: trunc year"); + check(truncate_date(d, DateGranularity::Month) == "2026-05", "phase10: trunc month"); + check(truncate_date(d, DateGranularity::Day) == "2026-05-12", "phase10: trunc day"); + check(truncate_date(d, DateGranularity::Week) == "2026-05-11", "phase10: trunc week (Mon)"); + check(truncate_date("2026-05-12T14:33:01", DateGranularity::Hour) == "2026-05-12T14", + "phase10: trunc hour"); + check(truncate_date("not-a-date", DateGranularity::Month) == "not-a-date", + "phase10: trunc passthrough invalido"); + check(truncate_date(d, DateGranularity::None) == d, "phase10: trunc None == identidad"); + } + + { + // auto_date_granularity + check(auto_date_granularity("2024-01-01", "2026-05-12") == DateGranularity::Year, + "phase10: auto year >2y"); + check(auto_date_granularity("2026-01-01", "2026-05-12") == DateGranularity::Month, + "phase10: auto month >60d"); + check(auto_date_granularity("2026-04-15", "2026-05-12") == DateGranularity::Week, + "phase10: auto week >14d"); + check(auto_date_granularity("2026-05-05", "2026-05-12") == DateGranularity::Day, + "phase10: auto day <=14d"); + check(auto_date_granularity("bad", "2026-05-12") == DateGranularity::Day, + "phase10: auto fallback day"); + } + + { + // parse_breakout_granularity + std::string col; + check(parse_breakout_granularity("ts:month", col) == DateGranularity::Month, + "phase10: parse breakout month"); + check(col == "ts", "phase10: parse breakout col stripped"); + check(parse_breakout_granularity("ts", col) == DateGranularity::None, + "phase10: parse breakout sin sufijo None"); + check(col == "ts", "phase10: col sin sufijo intacto"); + check(parse_breakout_granularity("ts:wat", col) == DateGranularity::None, + "phase10: sufijo desconocido None"); + check(col == "ts:wat", "phase10: col preserva sufijo desconocido"); + } + + { + // compose_breakout + check(compose_breakout("ts", DateGranularity::None) == "ts", "phase10: compose None"); + check(compose_breakout("ts", DateGranularity::Month) == "ts:month", "phase10: compose month"); + check(compose_breakout("ts", DateGranularity::Year) == "ts:year", "phase10: compose year"); + // round-trip parse(compose) + std::string col; + auto g = parse_breakout_granularity(compose_breakout("foo", DateGranularity::Week), col); + check(g == DateGranularity::Week && col == "foo", "phase10: compose+parse round-trip"); + } + + { + // column_min_max + const char* cells[] = { + "2026-03-01", + "2026-01-15", + "", + "2026-05-12", + "2026-02-22", + }; + std::string lo, hi; + column_min_max(cells, 5, 1, 0, lo, hi); + check(lo == "2026-01-15" && hi == "2026-05-12", "phase10: column_min_max ISO ordena lexical"); + + const char* empty_cells[] = {"", "", ""}; + column_min_max(empty_cells, 3, 1, 0, lo, hi); + check(lo.empty() && hi.empty(), "phase10: column_min_max sin datos -> vacio"); + + column_min_max(cells, 5, 1, 5, lo, hi); // col fuera de rango + check(lo.empty() && hi.empty(), "phase10: column_min_max col fuera de rango -> vacio"); + } + + { + // tokens round-trip granularity + check(date_granularity_from_token("year") == DateGranularity::Year, "phase10: token year"); + check(date_granularity_from_token("month") == DateGranularity::Month, "phase10: token month"); + check(date_granularity_from_token("week") == DateGranularity::Week, "phase10: token week"); + check(date_granularity_from_token("day") == DateGranularity::Day, "phase10: token day"); + check(date_granularity_from_token("hour") == DateGranularity::Hour, "phase10: token hour"); + check(date_granularity_from_token("nope") == DateGranularity::None, "phase10: token fallback None"); + check(std::string(date_granularity_token(DateGranularity::Month)) == "month", + "phase10: emit month"); + check(std::string(date_granularity_token(DateGranularity::None)) == "", + "phase10: emit None empty"); + } + + { + // build_preset_filters + auto f7 = build_preset_filters(FilterPreset::Last7d, 2, "2026-05-12"); + check(f7.size() == 1, "phase10: Last7d -> 1 filter"); + check(f7[0].col == 2 && f7[0].op == Op::Gte && f7[0].value == "2026-05-05", + "phase10: Last7d -> Gte 2026-05-05"); + + auto f30 = build_preset_filters(FilterPreset::Last30d, 2, "2026-05-12"); + check(f30[0].value == "2026-04-12", "phase10: Last30d -> 2026-04-12"); + + auto f90 = build_preset_filters(FilterPreset::Last90d, 2, "2026-05-12"); + check(f90[0].value == "2026-02-11", "phase10: Last90d -> 2026-02-11"); + + auto fn0 = build_preset_filters(FilterPreset::ExcludeNulls, 3, ""); + check(fn0.size() == 1 && fn0[0].op == Op::Neq && fn0[0].value == "", + "phase10: ExcludeNulls -> Neq ''"); + + auto fnz = build_preset_filters(FilterPreset::NonZero, 4, ""); + check(fnz.size() == 2, "phase10: NonZero -> 2 filters"); + check(fnz[0].op == Op::Neq && fnz[0].value == "" && + fnz[1].op == Op::Neq && fnz[1].value == "0", + "phase10: NonZero -> Neq '' AND Neq '0'"); + + auto fbad = build_preset_filters(FilterPreset::Last7d, 2, "bad-date"); + check(fbad.empty(), "phase10: Last7d con today invalido -> empty"); + } + + { + // TQL round-trip: breakout con sufijo :granularity. + State st0; + st0.stages.resize(2); + st0.stages[1].breakouts = {"ts:month"}; + Aggregation a; a.fn = AggFn::Count; a.alias = "n"; + st0.stages[1].aggregations.push_back(a); + + std::vector hdrs = {"ts", "amount"}; + std::vector tys = {ColumnType::Date, ColumnType::Float}; + int eff = 2; + std::string text = tql::emit(st0, hdrs, tys); + check(text.find("\"ts:month\"") != std::string::npos, + "phase10 TQL: emit breakout granularity sufijo"); + + std::string err; + State st1; + bool ok = tql::apply(text, st1, hdrs, tys, nullptr, 2, eff, &err); + check(ok, "phase10 TQL: apply round-trip ok"); + check(st1.stages.size() >= 2 && st1.stages[1].breakouts.size() == 1 && + st1.stages[1].breakouts[0] == "ts:month", + "phase10 TQL: breakout granularity preservada"); + } + + { + // compute_stage aplica truncado de fecha cuando hay :granularity. + const char* cells[] = { + "2026-01-15", "10", + "2026-01-22", "20", + "2026-02-03", "30", + "2026-03-11", "40", + }; + std::vector hdrs = {"ts", "amount"}; + std::vector tys = {ColumnType::Date, ColumnType::Float}; + Stage s1; + s1.breakouts = {"ts:month"}; + Aggregation ag; ag.fn = AggFn::Count; ag.alias = "n"; + s1.aggregations.push_back(ag); + auto out = compute_stage(cells, 4, 2, hdrs, tys, s1); + check(out.rows == 3, "phase10: trunc month -> 3 grupos (Jan/Feb/Mar)"); + check(out.headers[0] == "ts:month", "phase10: header preserva sufijo"); + // Verifica que algun valor de breakout es "2026-01" + bool found_jan = false; + for (int r = 0; r < out.rows; ++r) { + if (std::string(out.cells[r * out.cols + 0]) == "2026-01") found_jan = true; + } + check(found_jan, "phase10: trunc value '2026-01' presente"); + } + + // === phase10 hit-tests para click-to-drill === + { + // nearest_index_1d + double xs[] = {0, 1, 2, 3, 4}; + check(nearest_index_1d(0.0, xs, 5) == 0, "phase10 hit: nearest_1d exact 0"); + check(nearest_index_1d(2.4, xs, 5) == 2, "phase10 hit: nearest_1d 2.4 -> 2"); + check(nearest_index_1d(2.6, xs, 5) == 3, "phase10 hit: nearest_1d 2.6 -> 3"); + check(nearest_index_1d(-1.0, xs, 5) == 0, "phase10 hit: nearest_1d clamp left"); + check(nearest_index_1d(99.0, xs, 5) == 4, "phase10 hit: nearest_1d clamp right"); + check(nearest_index_1d(0.0, nullptr, 0) == -1, "phase10 hit: nearest_1d empty -> -1"); + } + + { + // nearest_index_2d + double xs[] = {0, 10, 5, 5}; + double ys[] = {0, 0, 10, 5}; + check(nearest_index_2d(0.1, 0.1, xs, ys, 4) == 0, "phase10 hit: nearest_2d cerca de (0,0)"); + check(nearest_index_2d(9.9, 0.0, xs, ys, 4) == 1, "phase10 hit: nearest_2d cerca de (10,0)"); + check(nearest_index_2d(5.0, 4.9, xs, ys, 4) == 3, "phase10 hit: nearest_2d cerca de (5,5)"); + check(nearest_index_2d(0, 0, nullptr, nullptr, 0) == -1, "phase10 hit: nearest_2d empty -> -1"); + } + + { + // pie_angle (convencion ImPlot: 0 = top, sentido horario) + const double PI = 3.14159265358979323846; + double a; + a = pie_angle(0.5, 0.5, 0.5, 0.0); // top + check(std::fabs(a - 0.0) < 1e-9, "phase10 hit: pie_angle top = 0"); + a = pie_angle(0.5, 0.5, 1.0, 0.5); // right -> PI/2 + check(std::fabs(a - PI/2) < 1e-9, "phase10 hit: pie_angle right = PI/2"); + a = pie_angle(0.5, 0.5, 0.5, 1.0); // bottom -> PI + check(std::fabs(a - PI) < 1e-9, "phase10 hit: pie_angle bottom = PI"); + a = pie_angle(0.5, 0.5, 0.0, 0.5); // left -> 3*PI/2 + check(std::fabs(a - 3*PI/2) < 1e-9, "phase10 hit: pie_angle left = 3PI/2"); + } + + { + // pie_slice_at_angle: 4 slices iguales -> cada uno cubre PI/2. + double sums[] = {1.0, 1.0, 1.0, 1.0}; + const double PI = 3.14159265358979323846; + check(pie_slice_at_angle(0.0, sums, 4) == 0, "phase10 hit: slice 0 (top)"); + check(pie_slice_at_angle(PI/4, sums, 4) == 0, "phase10 hit: slice 0 (mid)"); + check(pie_slice_at_angle(PI/2 + 0.1, sums, 4) == 1, "phase10 hit: slice 1"); + check(pie_slice_at_angle(PI + 0.1, sums, 4) == 2, "phase10 hit: slice 2"); + check(pie_slice_at_angle(3*PI/2 + 0.1, sums, 4) == 3, "phase10 hit: slice 3"); + + double zeros[] = {0.0, 0.0}; + check(pie_slice_at_angle(0.5, zeros, 2) == -1, "phase10 hit: total 0 -> -1"); + check(pie_slice_at_angle(0.0, nullptr, 0) == -1, "phase10 hit: empty -> -1"); + + double neg[] = {1.0, -1.0}; + check(pie_slice_at_angle(0.5, neg, 2) == -1, "phase10 hit: neg sum -> -1"); + } + + { + // heatmap_cell_at + int rr, cc; + heatmap_cell_at(1.5, 2.5, 4, 3, rr, cc); + check(rr == 2 && cc == 1, "phase10 hit: heatmap (1.5,2.5) en 4x3 -> r2 c1"); + heatmap_cell_at(-1, 0, 4, 3, rr, cc); + check(rr == -1 && cc == -1, "phase10 hit: heatmap fuera de rango"); + heatmap_cell_at(0, 0, 0, 0, rr, cc); + check(rr == -1 && cc == -1, "phase10 hit: heatmap empty"); + } + + { + // E2E click-to-drill: simular pipeline stage1 agrupado, click en row idx 2. + State st; + st.stages.resize(2); + std::vector hdrs = {"lang", "n"}; + std::vector tys = {ColumnType::String, ColumnType::Int}; + st.stages[1].breakouts.push_back("lang"); + st.stages[1].aggregations.push_back({AggFn::Count}); + st.active_stage = 1; + + // Stage 1 output simulado (3 grupos). + const char* g_cells[] = { + "go", "3", + "py", "2", + "cpp", "1", + }; + StageOutput so; + so.cells.insert(so.cells.end(), g_cells, g_cells + 6); + so.rows = 3; + so.cols = 2; + so.headers = {"lang", "count"}; + + // Simular click en row idx 2 (cpp). + int clicked_row = 2; + int n_brk = (int)st.stages[1].breakouts.size(); + check(n_brk == 1, "phase10 e2e: 1 breakout"); + const char* v = so.cells[clicked_row * so.cols + 0]; + std::string col_clean; + parse_breakout_granularity(so.headers[0], col_clean); + check(col_clean == "lang", "phase10 e2e: col_clean stripped OK"); + st.stages[0].filters.push_back(make_drill_filter(0, v)); + st.active_stage = 0; + + check(st.active_stage == 0, "phase10 e2e: active retrocede a 0"); + check(st.stages[0].filters.size() == 1, "phase10 e2e: 1 filter anadido"); + check(st.stages[0].filters[0].col == 0 && + st.stages[0].filters[0].op == Op::Eq && + st.stages[0].filters[0].value == "cpp", + "phase10 e2e: filter Op::Eq col=0 value=cpp"); + } + + // === phase10 drill history (apply/undo step) === + { + State st; + st.stages.resize(2); + st.active_stage = 1; + + DrillStep step; + step.target_stage = 0; + step.filter_pos = 0; + step.prev_active_stage = 1; + step.added = make_drill_filter(0, "go"); + + check(apply_drill_step(st, step), "phase10 hist: apply ok"); + check(st.stages[0].filters.size() == 1, "phase10 hist: filter anadido"); + check(st.stages[0].filters[0].value == "go", "phase10 hist: value preservado"); + check(st.active_stage == 0, "phase10 hist: active = target"); + + check(undo_drill_step(st, step), "phase10 hist: undo ok"); + check(st.stages[0].filters.empty(), "phase10 hist: filter eliminado"); + check(st.active_stage == 1, "phase10 hist: active restaurado"); + + // Redo + check(apply_drill_step(st, step), "phase10 hist: redo ok"); + check(st.stages[0].filters.size() == 1, "phase10 hist: redo filter de vuelta"); + check(st.active_stage == 0, "phase10 hist: redo active retorna"); + + // Edge: target fuera de rango + DrillStep bad; + bad.target_stage = 99; + check(!apply_drill_step(st, bad), "phase10 hist: apply fuera de rango -> false"); + check(!undo_drill_step(st, bad), "phase10 hist: undo fuera de rango -> false"); + + // Edge: pos invalida + DrillStep bad_pos = step; + bad_pos.filter_pos = 99; + check(!undo_drill_step(st, bad_pos), "phase10 hist: undo pos invalida -> false"); + } + + // === phase10 drill history: back/forward stack semantics simulado === + { + State st; + st.stages.resize(3); + st.active_stage = 2; + + std::vector back_stack; + std::vector fwd_stack; + + auto drill = [&](int from, int target, int pos, int col, const std::string& v) { + DrillStep s; + s.target_stage = target; + s.filter_pos = pos; + s.prev_active_stage = from; + s.added = make_drill_filter(col, v); + apply_drill_step(st, s); + back_stack.push_back(s); + fwd_stack.clear(); + }; + + drill(2, 1, 0, 0, "go"); + check(st.stages[1].filters.size() == 1, "phase10 hist seq: drill1 aplicado"); + drill(1, 0, 0, 1, "10"); + check(st.stages[0].filters.size() == 1, "phase10 hist seq: drill2 aplicado"); + check(back_stack.size() == 2, "phase10 hist seq: back stack 2"); + check(fwd_stack.empty(), "phase10 hist seq: forward limpio"); + + // Back x1 + DrillStep s = back_stack.back(); back_stack.pop_back(); + undo_drill_step(st, s); + fwd_stack.push_back(s); + check(st.stages[0].filters.empty(), "phase10 hist seq: back deshace drill2"); + check(st.active_stage == 1, "phase10 hist seq: back restaura active=1"); + check(fwd_stack.size() == 1, "phase10 hist seq: fwd stack 1"); + + // Forward x1 + s = fwd_stack.back(); fwd_stack.pop_back(); + apply_drill_step(st, s); + back_stack.push_back(s); + check(st.stages[0].filters.size() == 1, "phase10 hist seq: forward reaplica"); + check(st.active_stage == 0, "phase10 hist seq: forward active=0"); + } + + // === phase10 row inspector (row_to_tsv + build_filters_from_row) === + { + const char* cells[] = { + "go", "10", "filter", + "py", "20", "sma", + "go", "30", "map", + }; + std::vector hdrs = {"lang", "n", "fn"}; + + std::string tsv = row_to_tsv(cells, 3, 3, 1, hdrs); + check(tsv == "lang\tn\tfn\r\npy\t20\tsma\r\n", + "phase10 inspect: row_to_tsv layout"); + + check(row_to_tsv(cells, 3, 3, -1, hdrs).empty(), "phase10 inspect: tsv neg row -> empty"); + check(row_to_tsv(cells, 3, 3, 5, hdrs).empty(), "phase10 inspect: tsv row oob -> empty"); + check(row_to_tsv(cells, 3, 0, 0, hdrs).empty(), "phase10 inspect: tsv cols=0 -> empty"); + + auto fs = build_filters_from_row(cells, 3, 3, 0); + check(fs.size() == 3, "phase10 inspect: 3 filters de row 0"); + check(fs[0].col == 0 && fs[0].op == Op::Eq && fs[0].value == "go", + "phase10 inspect: filter[0] col=0 op=Eq value=go"); + check(fs[2].value == "filter", "phase10 inspect: filter[2] value=filter"); + + // Row con celda vacia -> filter saltado + const char* sparse[] = {"a", "", "c"}; + auto fs2 = build_filters_from_row(sparse, 1, 3, 0); + check(fs2.size() == 2 && fs2[0].col == 0 && fs2[1].col == 2, + "phase10 inspect: cells vacios salteados"); + + check(build_filters_from_row(cells, 3, 3, -1).empty(), + "phase10 inspect: build_filters row invalido -> empty"); + } + + // === phase10 drill-up === + { + State st; + st.stages.resize(3); + st.active_stage = 2; + check(drill_up(st), "phase10 up: 2->1 ok"); + check(st.active_stage == 1, "phase10 up: active=1"); + check(drill_up(st), "phase10 up: 1->0 ok"); + check(st.active_stage == 0, "phase10 up: active=0"); + check(!drill_up(st), "phase10 up: 0 -> false"); + check(st.active_stage == 0, "phase10 up: queda en 0"); + + // Filters no se mueven + State st2; + st2.stages.resize(2); + st2.active_stage = 1; + st2.stages[1].filters.push_back({0, Op::Eq, "x"}); + drill_up(st2); + check(st2.stages[0].filters.empty() && st2.stages[1].filters.size() == 1, + "phase10 up: filters quedan en su stage"); + + State empty_st; + check(!drill_up(empty_st), "phase10 up: stages vacio -> false"); + } + + // === phase11: Lua subset validator + transpiler === + { + std::string err; + // Subset OK: literales + ops + std::string e1 = tql_to_sql::transpile_expr("1 + 2", {}, err); + check(err.empty() && e1.find("1 + 2") != std::string::npos, + "phase11 lua: literal arith"); + + std::string e2 = tql_to_sql::transpile_expr("[a] + [b] * 2", {}, err); + check(err.empty() && e2.find("\"a\"") != std::string::npos && + e2.find("\"b\"") != std::string::npos, + "phase11 lua: col refs + arith"); + + std::string e3 = tql_to_sql::transpile_expr("[a] .. \"_\" .. [b]", {}, err); + check(err.empty() && e3.find(" || ") != std::string::npos, + "phase11 lua: concat -> ||"); + + std::string e4 = tql_to_sql::transpile_expr( + "if [n] > 10 then \"big\" else \"small\" end", {}, err); + check(err.empty() && e4.find("CASE WHEN") != std::string::npos && + e4.find("THEN") != std::string::npos && e4.find("ELSE") != std::string::npos, + "phase11 lua: if/then/else -> CASE"); + + std::string e5 = tql_to_sql::transpile_expr("math.floor([x] / 100)", {}, err); + check(err.empty() && e5.find("floor(") != std::string::npos, + "phase11 lua: math.floor"); + + std::string e6 = tql_to_sql::transpile_expr("string.upper([name])", {}, err); + check(err.empty() && e6.find("upper(") != std::string::npos, + "phase11 lua: string.upper"); + + std::string e7 = tql_to_sql::transpile_expr("string.sub([s], 1, 3)", {}, err); + check(err.empty() && e7.find("substring(") != std::string::npos, + "phase11 lua: string.sub 3-arg"); + + std::string e8 = tql_to_sql::transpile_expr("not ([x] == nil)", {}, err); + check(err.empty() && e8.find("NOT") != std::string::npos && e8.find("NULL") != std::string::npos, + "phase11 lua: not + nil"); + + std::string e9 = tql_to_sql::transpile_expr("tonumber([n])", {}, err); + check(err.empty() && e9.find("CAST(") != std::string::npos, + "phase11 lua: tonumber -> CAST DOUBLE"); + + // Fuera subset: 9 categorias rechazadas + err.clear(); + check(tql_to_sql::transpile_expr("function() return 1 end", {}, err).empty() + && err.find("closures") != std::string::npos, + "phase11 lua: function closure rechazado"); + + err.clear(); + check(tql_to_sql::transpile_expr("local x = 1", {}, err).empty() + && err.find("local") != std::string::npos, + "phase11 lua: local rechazado"); + + err.clear(); + check(tql_to_sql::transpile_expr("for i=1,10 do end", {}, err).empty() + && err.find("loops") != std::string::npos, + "phase11 lua: for loop rechazado"); + + err.clear(); + check(tql_to_sql::transpile_expr("while true do end", {}, err).empty() + && err.find("loops") != std::string::npos, + "phase11 lua: while loop rechazado"); + + err.clear(); + check(tql_to_sql::transpile_expr("{1,2,3}", {}, err).empty() + && err.find("table") != std::string::npos, + "phase11 lua: table literal rechazado"); + + err.clear(); + check(tql_to_sql::transpile_expr("io.read()", {}, err).empty() + && err.find("io") != std::string::npos, + "phase11 lua: io.* rechazado"); + + err.clear(); + check(tql_to_sql::transpile_expr("string.gsub([s], \"a\", \"b\")", {}, err).empty() + && err.find("whitelist") != std::string::npos, + "phase11 lua: string.gsub no whitelisted"); + + err.clear(); + check(tql_to_sql::transpile_expr("print([x])", {}, err).empty() + && err.find("print") != std::string::npos, + "phase11 lua: print rechazado"); + + err.clear(); + check(tql_to_sql::transpile_expr("[a]; [b]", {}, err).empty() + && err.find("multi-statement") != std::string::npos, + "phase11 lua: ';' multi-stmt rechazado"); + + // is_transpilable wrapper + std::string werr; + check(tql_to_sql::is_transpilable("[a] + 1", werr), "phase11 lua: is_transpilable OK"); + check(!tql_to_sql::is_transpilable("function() end", werr), + "phase11 lua: is_transpilable false para closure"); + } + + // === phase11: TQL State -> SQL DuckDB emit === + { + // Setup: 1 tabla "users" con cols lang,n. + TableInput t; + t.name = "users"; + t.headers = {"lang", "n"}; + t.types = {ColumnType::String, ColumnType::Int}; + // Cells no usado por emit (solo schema). + std::vector tables = {t}; + + // Caso 1: stage 0 simple (sin filters ni sort) + { + State st; + st.stages.resize(1); + auto e = tql_to_sql::emit_sql(st, tables); + check(e.error.empty(), "phase11 sql: empty pipeline -> no error"); + check(e.sql.find("WITH t0") != std::string::npos && + e.sql.find("FROM \"users\"") != std::string::npos && + e.sql.find("SELECT * FROM t0") != std::string::npos, + "phase11 sql: stage0 SELECT * FROM users"); + } + + // Caso 2: stage 0 filter + sort + { + State st; + st.stages.resize(1); + st.stages[0].filters.push_back({0, Op::Eq, "go"}); + st.stages[0].filters.push_back({1, Op::Gt, "10"}); + st.stages[0].sorts.push_back({"n", true}); + auto e = tql_to_sql::emit_sql(st, tables); + check(e.error.empty(), "phase11 sql: filter+sort OK"); + check(e.sql.find("WHERE") != std::string::npos && + e.sql.find("\"lang\" = ?") != std::string::npos && + e.sql.find("\"n\" > ?") != std::string::npos, + "phase11 sql: filter clauses"); + check(e.params.size() == 2 && e.params[0] == "go" && e.params[1] == "10", + "phase11 sql: params bound"); + check(e.sql.find("ORDER BY \"n\" DESC") != std::string::npos, + "phase11 sql: ORDER BY desc"); + } + + // Caso 3: stage 1 group + count + { + State st; + st.stages.resize(2); + st.stages[1].breakouts.push_back("lang"); + st.stages[1].aggregations.push_back({AggFn::Count}); + st.active_stage = 1; + auto e = tql_to_sql::emit_sql(st, tables); + check(e.error.empty(), "phase11 sql: group ok"); + check(e.sql.find("t1 AS") != std::string::npos && + e.sql.find("COUNT(*)") != std::string::npos && + e.sql.find("GROUP BY") != std::string::npos && + e.sql.find("SELECT * FROM t1") != std::string::npos, + "phase11 sql: stage1 CTE + COUNT + GROUP BY"); + } + + // Caso 4: granularity :month -> date_trunc + { + State st; + st.stages.resize(2); + st.stages[1].breakouts.push_back("ts:month"); + st.stages[1].aggregations.push_back({AggFn::Sum, "n"}); + st.active_stage = 1; + TableInput ts_t; + ts_t.name = "events"; + ts_t.headers = {"ts", "n"}; + ts_t.types = {ColumnType::Date, ColumnType::Int}; + std::vector tt = {ts_t}; + auto e = tql_to_sql::emit_sql(st, tt); + check(e.error.empty(), "phase11 sql: granularity ok"); + check(e.sql.find("date_trunc('month'") != std::string::npos && + e.sql.find("SUM(\"n\")") != std::string::npos, + "phase11 sql: date_trunc + SUM"); + } + + // Caso 5: aggregations p25/median/p99 + { + State st; + st.stages.resize(2); + st.stages[1].breakouts.push_back("lang"); + st.stages[1].aggregations.push_back({AggFn::Median, "n"}); + st.stages[1].aggregations.push_back({AggFn::P25, "n"}); + st.stages[1].aggregations.push_back({AggFn::P99, "n"}); + st.active_stage = 1; + auto e = tql_to_sql::emit_sql(st, tables); + check(e.error.empty(), "phase11 sql: percentiles ok"); + check(e.sql.find("quantile_cont(\"n\", 0.5)") != std::string::npos && + e.sql.find("quantile_cont(\"n\", 0.25)") != std::string::npos && + e.sql.find("quantile_cont(\"n\", 0.99)") != std::string::npos, + "phase11 sql: quantile_cont calls"); + } + + // Caso 6: joins 4 strategies + { + State st; + st.stages.resize(1); + Join jn; + jn.alias = "o"; + jn.source = "orders"; + jn.on.push_back({"user_id", "user_id"}); + jn.strategy = JoinStrategy::Left; + st.joins.push_back(jn); + TableInput u, o; + u.name = "users"; + u.headers = {"user_id", "name"}; + u.types = {ColumnType::String, ColumnType::String}; + o.name = "orders"; + o.headers = {"user_id", "amount"}; + o.types = {ColumnType::String, ColumnType::Int}; + std::vector tt = {u, o}; + auto e = tql_to_sql::emit_sql(st, tt); + check(e.error.empty(), "phase11 sql: join ok"); + check(e.sql.find("LEFT JOIN \"orders\" AS \"o\"") != std::string::npos && + e.sql.find("ON \"users\".\"user_id\" = \"o\".\"user_id\"") != std::string::npos, + "phase11 sql: LEFT JOIN ON syntax"); + + // Inner + st.joins[0].strategy = JoinStrategy::Inner; + auto e2 = tql_to_sql::emit_sql(st, tt); + check(e2.sql.find("INNER JOIN") != std::string::npos, "phase11 sql: INNER JOIN"); + + // Right + st.joins[0].strategy = JoinStrategy::Right; + auto e3 = tql_to_sql::emit_sql(st, tt); + check(e3.sql.find("RIGHT JOIN") != std::string::npos, "phase11 sql: RIGHT JOIN"); + + // Full + st.joins[0].strategy = JoinStrategy::Full; + auto e4 = tql_to_sql::emit_sql(st, tt); + check(e4.sql.find("FULL OUTER JOIN") != std::string::npos, "phase11 sql: FULL OUTER JOIN"); + } + + // Caso 7: derived col subset -> SQL expression + { + State st; + st.stages.resize(1); + DerivedColumn d; + d.name = "size_kb"; + d.source_col = -1; + d.formula = "[n] / 1024.0"; + d.type = ColumnType::Float; + st.stages[0].derived.push_back(d); + auto e = tql_to_sql::emit_sql(st, tables); + check(e.error.empty(), "phase11 sql: derived subset ok"); + check(e.sql.find("\"n\" / 1024") != std::string::npos && + e.sql.find("AS \"size_kb\"") != std::string::npos, + "phase11 sql: derived expression + alias"); + } + + // Caso 8: derived col FUERA subset -> warning + skip + { + State st; + st.stages.resize(1); + DerivedColumn d; + d.name = "bad"; + d.source_col = -1; + d.formula = "string.gsub([n], \"a\", \"b\")"; + d.type = ColumnType::String; + st.stages[0].derived.push_back(d); + auto e = tql_to_sql::emit_sql(st, tables); + check(e.error.empty(), "phase11 sql: derived fuera subset NO bloquea emit"); + check(!e.warnings.empty() && + e.warnings[0].find("out of SQL subset") != std::string::npos, + "phase11 sql: warning derived fuera subset"); + check(e.sql.find("\"bad\"") == std::string::npos, + "phase11 sql: derived skip cuando fuera subset"); + } + + // Caso 9: empty tables -> error + { + State st; + st.stages.resize(1); + std::vector empty; + auto e = tql_to_sql::emit_sql(st, empty); + check(!e.error.empty() && e.error.find("no input tables") != std::string::npos, + "phase11 sql: empty tables -> error"); + } + + // Caso 10: stage 0 con LIKE (Contains) + { + State st; + st.stages.resize(1); + st.stages[0].filters.push_back({0, Op::Contains, "go"}); + auto e = tql_to_sql::emit_sql(st, tables); + check(e.error.empty(), "phase11 sql: LIKE Contains ok"); + check(e.sql.find("LIKE ?") != std::string::npos && + e.params.size() == 1 && e.params[0] == "%go%", + "phase11 sql: Contains -> LIKE %go%"); + } + } + + // === phase11: LLM client (mock, no red) === + { + llm_anthropic::AskInput in; + in.question = "show top 10 langs"; + in.tql_current = "return { stages = {} }"; + in.col_names = {"lang", "n"}; + in.col_types = {ColumnType::String, ColumnType::Int}; + in.mode = llm_anthropic::OutputMode::TQL; + std::string body = llm_anthropic::build_request_body(in); + check(body.find("\"model\":\"claude-sonnet-4-6\"") != std::string::npos, + "phase11 llm: default model"); + check(body.find("\"max_tokens\":8192") != std::string::npos, + "phase11 llm: max_tokens"); + check(body.find("\\\"system\\\"") == std::string::npos /* not double-escaped */, + "phase11 llm: system not double-escaped"); + check(body.find("Available columns") != std::string::npos, + "phase11 llm: schema block present"); + check(body.find("show top 10 langs") != std::string::npos, + "phase11 llm: question present"); + check(body.find("TQL") != std::string::npos, + "phase11 llm: system mentions TQL"); + + in.mode = llm_anthropic::OutputMode::SQL; + std::string body_sql = llm_anthropic::build_request_body(in); + check(body_sql.find("DuckDB") != std::string::npos, + "phase11 llm: SQL mode mentions DuckDB"); + } + + { + // extract_code_block + std::string raw1 = "Here you go:\n```lua\nreturn { x = 1 }\n```\nDone!"; + std::string code = llm_anthropic::extract_code_block(raw1, "lua"); + check(code == "return { x = 1 }", "phase11 llm: extract ```lua block"); + + std::string raw2 = "Sure:\n```\nplain code\n```"; + std::string code2 = llm_anthropic::extract_code_block(raw2, "lua"); + check(code2 == "plain code", "phase11 llm: extract bare ```"); + + std::string raw3 = "no fences here"; + std::string code3 = llm_anthropic::extract_code_block(raw3, "lua"); + check(code3 == "no fences here", "phase11 llm: no fence -> stripped"); + + std::string raw4 = "```sql\nSELECT 1;\n```"; + std::string code4 = llm_anthropic::extract_code_block(raw4, "sql"); + check(code4 == "SELECT 1;", "phase11 llm: extract ```sql"); + } + + { + // parse_response_text from JSON + std::string j = "{\"id\":\"x\",\"content\":[{\"type\":\"text\",\"text\":\"hello\\nworld\"}],\"role\":\"assistant\"}"; + std::string t = llm_anthropic::parse_response_text(j); + check(t == "hello\nworld", "phase11 llm: parse text content"); + + std::string j2 = "{\"content\":[{\"type\":\"text\",\"text\":\"\\\"quoted\\\"\"}]}"; + std::string t2 = llm_anthropic::parse_response_text(j2); + check(t2 == "\"quoted\"", "phase11 llm: parse quoted escape"); + + std::string j3 = "{\"error\":\"foo\"}"; + std::string t3 = llm_anthropic::parse_response_text(j3); + check(t3.empty(), "phase11 llm: no text -> empty"); + } + + { + // Mock end-to-end via FN_LLM_MOCK_RESPONSE (portable Linux/Mingw via putenv). + const char* mock_kv = + "FN_LLM_MOCK_RESPONSE={\"content\":[{\"type\":\"text\",\"text\":\"```lua\\nreturn { mock = true }\\n```\"}]}"; + putenv((char*)mock_kv); + llm_anthropic::AskInput in; + in.question = "q"; + in.col_names = {"a"}; + in.col_types = {ColumnType::String}; + auto r = llm_anthropic::ask(in); + check(r.error.empty(), "phase11 llm mock: no error"); + check(r.code == "return { mock = true }", "phase11 llm mock: code extracted"); + // Unset: putenv con "VAR=" deja vacio (suficiente para nuestro check `*mock`). + putenv((char*)"FN_LLM_MOCK_RESPONSE="); + } + std::printf("\n=== %d passed, %d failed ===\n", passed, failed); return failed == 0 ? 0 : 1; } diff --git a/cpp/apps/primitives_gallery/playground/tables/tql.cpp b/cpp/apps/primitives_gallery/playground/tables/tql.cpp index 3bcc3ce9..ab80f4c4 100644 --- a/cpp/apps/primitives_gallery/playground/tables/tql.cpp +++ b/cpp/apps/primitives_gallery/playground/tables/tql.cpp @@ -652,7 +652,8 @@ bool apply(const std::string& lua_text, State& state, } lua_pop(L, 1); - // breakout (solo aplica stages >= 1, no-op silencioso si stage 0) + // breakout (solo aplica stages >= 1, no-op silencioso si stage 0). + // Acepta sufijo ":granularity" para cols Date (fase 10). lua_getfield(L, -1, "breakout"); if (lua_istable(L, -1)) { int n = (int)lua_rawlen(L, -1); @@ -660,8 +661,10 @@ bool apply(const std::string& lua_text, State& state, lua_rawgeti(L, -1, i); if (lua_isstring(L, -1)) { std::string bn = lua_tostring(L, -1); - if (find_orig_col(cur_headers, bn) < 0) { - warn("stage " + std::to_string(si - 1) + ": breakout col \"" + bn + "\" not in input headers"); + std::string clean; + parse_breakout_granularity(bn, clean); + if (find_orig_col(cur_headers, clean) < 0) { + warn("stage " + std::to_string(si - 1) + ": breakout col \"" + clean + "\" not in input headers"); } stg.breakouts.emplace_back(bn); } diff --git a/cpp/apps/primitives_gallery/playground/tables/tql_to_sql.cpp b/cpp/apps/primitives_gallery/playground/tables/tql_to_sql.cpp new file mode 100644 index 00000000..2dba82d8 --- /dev/null +++ b/cpp/apps/primitives_gallery/playground/tables/tql_to_sql.cpp @@ -0,0 +1,862 @@ +// tql_to_sql.cpp — pure walker TQL -> SQL DuckDB + Lua subset transpiler. +// Ver issue 0080. Sin DuckDB linkado. +#include "tql_to_sql.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace tql_to_sql { + +using namespace data_table; + +// ============================================================================ +// Lua subset tokenizer + recursive-descent expression parser -> SQL string. +// ============================================================================ + +namespace { + +struct Tok { + enum Kind { + EndT, NumT, StrT, IdentT, ColT, + // operators / keywords + Plus, Minus, Star, Slash, Percent, ConcatT, + Eq, Neq, Lt, Lte, Gt, Gte, + AndT, OrT, NotT, + IfT, ThenT, ElseT, EndKW, + LParen, RParen, Comma, Dot, + TrueT, FalseT, NilT, + } kind = EndT; + std::string text; // raw token texto (para idents/numbers/strings) +}; + +// Categorias prohibidas: token literal -> mensaje. +const std::unordered_map& forbidden_keywords() { + static const std::unordered_map M = { + {"function", "closures not allowed in SQL transpile subset"}, + {"local", "local declarations not allowed"}, + {"for", "loops not allowed"}, + {"while", "loops not allowed"}, + {"repeat", "loops not allowed"}, + {"do", "block statements not allowed"}, + {"return", "explicit return not allowed (formula is implicit expression)"}, + {"goto", "goto not allowed"}, + {"break", "break not allowed (no loops)"}, + // io/os/debug/coroutines + {"io", "io.* access not allowed"}, + {"os", "os.* access not allowed"}, + {"debug", "debug.* access not allowed"}, + {"package", "package access not allowed"}, + {"require", "require not allowed"}, + {"coroutine","coroutines not allowed"}, + {"setmetatable","metatables not allowed"}, + {"getmetatable","metatables not allowed"}, + {"rawget", "rawget not allowed"}, + {"rawset", "rawset not allowed"}, + {"pcall", "pcall not allowed"}, + {"xpcall", "xpcall not allowed"}, + {"print", "print not allowed (SQL has no side effects)"}, + }; + return M; +} + +// Whitelist de funciones SQL-transpilables: lua name -> SQL function template. +// Template usa $1, $2, ... como placeholders de argumentos. +struct FnMap { int min_args; int max_args; const char* sql_tmpl; }; + +const std::unordered_map& fn_whitelist() { + static const std::unordered_map M = { + // math.* + {"math.floor", {1, 1, "floor($1)"}}, + {"math.ceil", {1, 1, "ceiling($1)"}}, + {"math.abs", {1, 1, "abs($1)"}}, + {"math.sqrt", {1, 1, "sqrt($1)"}}, + {"math.sin", {1, 1, "sin($1)"}}, + {"math.cos", {1, 1, "cos($1)"}}, + {"math.log", {1, 1, "ln($1)"}}, + {"math.exp", {1, 1, "exp($1)"}}, + {"math.min", {2, 2, "least($1, $2)"}}, + {"math.max", {2, 2, "greatest($1, $2)"}}, + // string.* + {"string.upper", {1, 1, "upper($1)"}}, + {"string.lower", {1, 1, "lower($1)"}}, + {"string.len", {1, 1, "length($1)"}}, + {"string.sub", {2, 3, "/*SUBSTRING*/"}}, // manejo especial: argc 2 vs 3 + // top-level + {"tostring", {1, 1, "CAST($1 AS VARCHAR)"}}, + {"tonumber", {1, 1, "CAST($1 AS DOUBLE)"}}, + }; + return M; +} + +// Identifier SQL-safe: si tiene caracteres especiales o coincide con keyword, +// usar `"col"`. Aqui simplificado: siempre quote con dobles comillas para +// preservar case y permitir `:` (sufijo granularity). +std::string sql_ident(const std::string& name) { + std::string out; + out.reserve(name.size() + 4); + out += '"'; + for (char c : name) { + if (c == '"') out += "\"\""; // escape + else out += c; + } + out += '"'; + return out; +} + +std::string sql_string_literal(const std::string& s) { + std::string out; + out.reserve(s.size() + 4); + out += '\''; + for (char c : s) { + if (c == '\'') out += "''"; + else out += c; + } + out += '\''; + return out; +} + +class Lexer { +public: + Lexer(const std::string& src) : src_(src) {} + + // Devuelve true si parsea OK. False con err en error_. + bool tokenize(std::vector& out) { + size_t i = 0; + while (i < src_.size()) { + char c = src_[i]; + if (std::isspace((unsigned char)c)) { ++i; continue; } + // Lua line comment + if (c == '-' && i + 1 < src_.size() && src_[i+1] == '-') { + while (i < src_.size() && src_[i] != '\n') ++i; + continue; + } + if (c == '[' ) { + // col ref [identifier] + size_t j = i + 1; + std::string name; + while (j < src_.size() && src_[j] != ']') { + name += src_[j]; + ++j; + } + if (j >= src_.size()) { error_ = "unterminated [col] ref"; return false; } + Tok t; t.kind = Tok::ColT; t.text = name; + out.push_back(t); + i = j + 1; + continue; + } + if (c == '"' || c == '\'') { + char q = c; + ++i; + std::string s; + while (i < src_.size() && src_[i] != q) { + if (src_[i] == '\\' && i + 1 < src_.size()) { + char esc = src_[i+1]; + if (esc == 'n') s += '\n'; + else if (esc == 't') s += '\t'; + else if (esc == '\\') s += '\\'; + else if (esc == '\'') s += '\''; + else if (esc == '"') s += '"'; + else s += esc; + i += 2; + } else { + s += src_[i++]; + } + } + if (i >= src_.size()) { error_ = "unterminated string literal"; return false; } + ++i; + Tok t; t.kind = Tok::StrT; t.text = s; + out.push_back(t); + continue; + } + if (std::isdigit((unsigned char)c) || (c == '.' && i + 1 < src_.size() && std::isdigit((unsigned char)src_[i+1]))) { + std::string n; + bool seen_dot = false; + while (i < src_.size()) { + char d = src_[i]; + if (std::isdigit((unsigned char)d)) { n += d; ++i; } + else if (d == '.' && !seen_dot) { n += d; seen_dot = true; ++i; } + else break; + } + Tok t; t.kind = Tok::NumT; t.text = n; + out.push_back(t); + continue; + } + if (std::isalpha((unsigned char)c) || c == '_') { + std::string id; + while (i < src_.size() && + (std::isalnum((unsigned char)src_[i]) || src_[i] == '_')) { + id += src_[i++]; + } + // Check forbidden keywords y mapeo a tokens. + auto& F = forbidden_keywords(); + auto fit = F.find(id); + if (fit != F.end()) { + error_ = std::string("token '") + id + "': " + fit->second; + return false; + } + Tok t; + if (id == "and") t.kind = Tok::AndT; + else if (id == "or") t.kind = Tok::OrT; + else if (id == "not") t.kind = Tok::NotT; + else if (id == "if") t.kind = Tok::IfT; + else if (id == "then") t.kind = Tok::ThenT; + else if (id == "else") t.kind = Tok::ElseT; + else if (id == "end") t.kind = Tok::EndKW; + else if (id == "true") t.kind = Tok::TrueT; + else if (id == "false") t.kind = Tok::FalseT; + else if (id == "nil") t.kind = Tok::NilT; + else { t.kind = Tok::IdentT; t.text = id; } + out.push_back(t); + continue; + } + // Operators + auto emit = [&](Tok::Kind k, int len) { + Tok t; t.kind = k; out.push_back(t); i += (size_t)len; + }; + if (c == '+') { emit(Tok::Plus, 1); continue; } + if (c == '-') { emit(Tok::Minus, 1); continue; } + if (c == '*') { emit(Tok::Star, 1); continue; } + if (c == '/') { emit(Tok::Slash, 1); continue; } + if (c == '%') { emit(Tok::Percent,1); continue; } + if (c == '(') { emit(Tok::LParen, 1); continue; } + if (c == ')') { emit(Tok::RParen, 1); continue; } + if (c == ',') { emit(Tok::Comma, 1); continue; } + if (c == '.') { + if (i + 1 < src_.size() && src_[i+1] == '.') { + if (i + 2 < src_.size() && src_[i+2] == '.') { + error_ = "'...' vararg not allowed"; return false; + } + emit(Tok::ConcatT, 2); continue; + } + emit(Tok::Dot, 1); continue; + } + if (c == '=') { + if (i + 1 < src_.size() && src_[i+1] == '=') { emit(Tok::Eq, 2); continue; } + error_ = "single '=' (assignment) not allowed"; return false; + } + if (c == '~') { + if (i + 1 < src_.size() && src_[i+1] == '=') { emit(Tok::Neq, 2); continue; } + error_ = "stray '~'"; return false; + } + if (c == '<') { + if (i + 1 < src_.size() && src_[i+1] == '=') { emit(Tok::Lte, 2); continue; } + emit(Tok::Lt, 1); continue; + } + if (c == '>') { + if (i + 1 < src_.size() && src_[i+1] == '=') { emit(Tok::Gte, 2); continue; } + emit(Tok::Gt, 1); continue; + } + if (c == '{') { error_ = "table literals '{...}' not allowed"; return false; } + if (c == '}') { error_ = "stray '}'"; return false; } + if (c == ';') { error_ = "multi-statement not allowed"; return false; } + if (c == '#') { error_ = "length '#' operator not allowed"; return false; } + if (c == ':') { error_ = "method calls ':' not allowed"; return false; } + error_ = std::string("unexpected character '") + c + "'"; + return false; + } + Tok t; t.kind = Tok::EndT; + out.push_back(t); + return true; + } + + const std::string& error() const { return error_; } +private: + const std::string& src_; + std::string error_; +}; + +class Parser { +public: + Parser(const std::vector& toks, + const std::vector& headers) + : toks_(toks), headers_(headers) {} + + // expr := ternary + // ternary := if/then/else | logic_or + bool parse_expr(std::string& out) { + return parse_ternary(out); + } + + bool parse_ternary(std::string& out) { + if (peek(0).kind == Tok::IfT) { + ++pos_; + std::string a, b, c; + if (!parse_logic_or(a)) return false; + if (!eat(Tok::ThenT, "'then' expected after 'if'")) return false; + if (!parse_ternary(b)) return false; + if (!eat(Tok::ElseT, "'else' expected (subset requires else branch)")) return false; + if (!parse_ternary(c)) return false; + if (!eat(Tok::EndKW, "'end' expected to close 'if'")) return false; + out = "CASE WHEN " + a + " THEN " + b + " ELSE " + c + " END"; + return true; + } + return parse_logic_or(out); + } + + bool parse_logic_or(std::string& out) { + if (!parse_logic_and(out)) return false; + while (peek(0).kind == Tok::OrT) { + ++pos_; + std::string rhs; + if (!parse_logic_and(rhs)) return false; + out = "(" + out + " OR " + rhs + ")"; + } + return true; + } + + bool parse_logic_and(std::string& out) { + if (!parse_not(out)) return false; + while (peek(0).kind == Tok::AndT) { + ++pos_; + std::string rhs; + if (!parse_not(rhs)) return false; + out = "(" + out + " AND " + rhs + ")"; + } + return true; + } + + bool parse_not(std::string& out) { + if (peek(0).kind == Tok::NotT) { + ++pos_; + std::string e; + if (!parse_not(e)) return false; + out = "NOT (" + e + ")"; + return true; + } + return parse_comparison(out); + } + + bool parse_comparison(std::string& out) { + if (!parse_concat(out)) return false; + while (true) { + Tok::Kind k = peek(0).kind; + const char* op = nullptr; + if (k == Tok::Eq) op = " = "; + else if (k == Tok::Neq) op = " <> "; + else if (k == Tok::Lt) op = " < "; + else if (k == Tok::Lte) op = " <= "; + else if (k == Tok::Gt) op = " > "; + else if (k == Tok::Gte) op = " >= "; + else break; + ++pos_; + std::string rhs; + if (!parse_concat(rhs)) return false; + out = "(" + out + op + rhs + ")"; + } + return true; + } + + bool parse_concat(std::string& out) { + if (!parse_additive(out)) return false; + while (peek(0).kind == Tok::ConcatT) { + ++pos_; + std::string rhs; + if (!parse_additive(rhs)) return false; + out = "(" + out + " || " + rhs + ")"; + } + return true; + } + + bool parse_additive(std::string& out) { + if (!parse_multiplicative(out)) return false; + while (peek(0).kind == Tok::Plus || peek(0).kind == Tok::Minus) { + const char* op = (peek(0).kind == Tok::Plus) ? " + " : " - "; + ++pos_; + std::string rhs; + if (!parse_multiplicative(rhs)) return false; + out = "(" + out + op + rhs + ")"; + } + return true; + } + + bool parse_multiplicative(std::string& out) { + if (!parse_unary(out)) return false; + while (peek(0).kind == Tok::Star || peek(0).kind == Tok::Slash || peek(0).kind == Tok::Percent) { + const char* op = (peek(0).kind == Tok::Star) ? " * " + : (peek(0).kind == Tok::Slash) ? " / " : " % "; + ++pos_; + std::string rhs; + if (!parse_unary(rhs)) return false; + out = "(" + out + op + rhs + ")"; + } + return true; + } + + bool parse_unary(std::string& out) { + if (peek(0).kind == Tok::Minus) { + ++pos_; + std::string e; + if (!parse_unary(e)) return false; + out = "(-" + e + ")"; + return true; + } + return parse_primary(out); + } + + bool parse_primary(std::string& out) { + Tok t = peek(0); + if (t.kind == Tok::NumT) { + ++pos_; + out = t.text; + return true; + } + if (t.kind == Tok::StrT) { + ++pos_; + out = sql_string_literal(t.text); + return true; + } + if (t.kind == Tok::TrueT) { ++pos_; out = "TRUE"; return true; } + if (t.kind == Tok::FalseT) { ++pos_; out = "FALSE"; return true; } + if (t.kind == Tok::NilT) { ++pos_; out = "NULL"; return true; } + if (t.kind == Tok::ColT) { + // Check col exists (warning, not error). + ++pos_; + (void)headers_; // currently not validating — caller can do that + out = sql_ident(t.text); + return true; + } + if (t.kind == Tok::LParen) { + ++pos_; + std::string e; + if (!parse_expr(e)) return false; + if (!eat(Tok::RParen, "expected ')'")) return false; + out = "(" + e + ")"; + return true; + } + if (t.kind == Tok::IdentT) { + // Function call: identifier ("." identifier)? "(" args ")" + std::string name = t.text; + ++pos_; + if (peek(0).kind == Tok::Dot) { + ++pos_; + if (peek(0).kind != Tok::IdentT) { + error_ = "expected identifier after '.'"; + return false; + } + name += "." + peek(0).text; + ++pos_; + } + if (peek(0).kind != Tok::LParen) { + error_ = "bare identifier '" + name + + "' not allowed (only [col] refs + whitelisted fn calls)"; + return false; + } + ++pos_; // consume '(' + std::vector args; + if (peek(0).kind != Tok::RParen) { + while (true) { + std::string a; + if (!parse_expr(a)) return false; + args.push_back(a); + if (peek(0).kind == Tok::Comma) { ++pos_; continue; } + break; + } + } + if (!eat(Tok::RParen, "expected ')' closing function args")) return false; + // Validate against whitelist + auto& W = fn_whitelist(); + auto wit = W.find(name); + if (wit == W.end()) { + error_ = "function '" + name + + "' not in SQL transpile whitelist (math.*, string.upper/lower/len/sub, tostring, tonumber)"; + return false; + } + const FnMap& fm = wit->second; + if ((int)args.size() < fm.min_args || (int)args.size() > fm.max_args) { + std::ostringstream os; + os << "function '" << name << "' takes " << fm.min_args; + if (fm.max_args != fm.min_args) os << ".." << fm.max_args; + os << " args, got " << args.size(); + error_ = os.str(); + return false; + } + // Casos especiales + if (name == "string.sub") { + // Lua: string.sub(s, i [, j]) — i/j 1-based, inclusive. + // SQL DuckDB: substring(s, i, count). count = j - i + 1. + if (args.size() == 2) { + // sin j -> hasta el final. DuckDB substring(s, i) acepta. + out = "substring(" + args[0] + ", " + args[1] + ")"; + } else { + out = "substring(" + args[0] + ", " + args[1] + + ", (" + args[2] + ") - (" + args[1] + ") + 1)"; + } + return true; + } + // Generico: substituir $1..$N en template. + std::string s = fm.sql_tmpl; + for (int i = 0; i < (int)args.size(); ++i) { + char ph[6]; + std::snprintf(ph, sizeof(ph), "$%d", i + 1); + std::string p = ph; + size_t at = 0; + while ((at = s.find(p, at)) != std::string::npos) { + s.replace(at, p.size(), args[i]); + at += args[i].size(); + } + } + out = s; + return true; + } + error_ = std::string("unexpected token in expression"); + return false; + } + + bool eat(Tok::Kind k, const char* msg) { + if (peek(0).kind != k) { error_ = msg; return false; } + ++pos_; + return true; + } + + const Tok& peek(int off) const { + size_t i = pos_ + (size_t)off; + if (i >= toks_.size()) return toks_.back(); + return toks_[i]; + } + + bool at_end() const { return peek(0).kind == Tok::EndT; } + const std::string& error() const { return error_; } + +private: + const std::vector& toks_; + const std::vector& headers_; + size_t pos_ = 0; + std::string error_; +}; + +} // anon + +std::string transpile_expr(const std::string& formula, + const std::vector& in_headers, + std::string& error_out) { + error_out.clear(); + std::vector toks; + Lexer lex(formula); + if (!lex.tokenize(toks)) { + error_out = lex.error(); + return ""; + } + Parser p(toks, in_headers); + std::string out; + if (!p.parse_expr(out)) { + error_out = p.error(); + return ""; + } + if (!p.at_end()) { + error_out = "unexpected trailing tokens after expression"; + return ""; + } + return out; +} + +bool is_transpilable(const std::string& formula, std::string& error_out) { + std::vector empty; + std::string s = transpile_expr(formula, empty, error_out); + return error_out.empty() && !s.empty(); +} + +// ============================================================================ +// TQL State -> SQL DuckDB emitter. +// ============================================================================ + +namespace { + +// Mapeo aggregation -> SQL DuckDB expression. +std::string emit_agg_expr(const Aggregation& a) { + switch (a.fn) { + case AggFn::Count: return "COUNT(*)"; + case AggFn::Sum: return "SUM(" + sql_ident(a.col) + ")"; + case AggFn::Avg: return "AVG(" + sql_ident(a.col) + ")"; + case AggFn::Min: return "MIN(" + sql_ident(a.col) + ")"; + case AggFn::Max: return "MAX(" + sql_ident(a.col) + ")"; + case AggFn::Distinct: return "COUNT(DISTINCT " + sql_ident(a.col) + ")"; + case AggFn::Stddev: return "STDDEV(" + sql_ident(a.col) + ")"; + case AggFn::Median: return "quantile_cont(" + sql_ident(a.col) + ", 0.5)"; + case AggFn::P25: return "quantile_cont(" + sql_ident(a.col) + ", 0.25)"; + case AggFn::P75: return "quantile_cont(" + sql_ident(a.col) + ", 0.75)"; + case AggFn::P90: return "quantile_cont(" + sql_ident(a.col) + ", 0.90)"; + case AggFn::P99: return "quantile_cont(" + sql_ident(a.col) + ", 0.99)"; + case AggFn::Percentile: { + char buf[32]; + std::snprintf(buf, sizeof(buf), "%g", a.arg); + return std::string("quantile_cont(") + sql_ident(a.col) + ", " + buf + ")"; + } + } + return "/* unknown agg */ NULL"; +} + +std::string emit_breakout_expr(const std::string& bk) { + std::string col_clean; + DateGranularity g = parse_breakout_granularity(bk, col_clean); + if (g == DateGranularity::None) { + return sql_ident(col_clean); + } + const char* tok = date_granularity_token(g); + // Week: DuckDB date_trunc('week', col) -> monday segun configuracion. + return std::string("date_trunc('") + tok + "', " + sql_ident(col_clean) + ")"; +} + +// Resuelve un Op a operador SQL + (opcional) override de RHS. +const char* sql_op(Op op) { + switch (op) { + case Op::Eq: return " = "; + case Op::Neq: return " <> "; + case Op::Gt: return " > "; + case Op::Gte: return " >= "; + case Op::Lt: return " < "; + case Op::Lte: return " <= "; + case Op::Contains: return " LIKE "; + case Op::NotContains: return " NOT LIKE "; + case Op::StartsWith: return " LIKE "; + case Op::EndsWith: return " LIKE "; + } + return " = "; +} + +// Construye RHS literal/pattern segun op + value. Devuelve placeholder '?' +// y push de params; o pattern string-literal directo para LIKE wildcards. +std::string emit_filter_rhs(const Filter& f, std::vector& params) { + if (f.op == Op::Contains || f.op == Op::NotContains) { + std::string v = "%" + f.value + "%"; + params.push_back(v); + return "?"; + } + if (f.op == Op::StartsWith) { + std::string v = f.value + "%"; + params.push_back(v); + return "?"; + } + if (f.op == Op::EndsWith) { + std::string v = "%" + f.value; + params.push_back(v); + return "?"; + } + params.push_back(f.value); + return "?"; +} + +// Construye CTE stage 0 (Raw): SELECT cols + derived FROM main_t [JOINs]. +// `tables` provee schema. main_t name = tables[main_idx].name. Derived cols +// se transpilan a SQL expression; si fuera de subset, push warning + skip col. +bool emit_stage0(const State& st, const std::vector& tables, + int main_idx, SqlEmit& e) { + if (main_idx < 0 || main_idx >= (int)tables.size()) { + e.error = "main table out of range"; + return false; + } + const TableInput& main_t = tables[(size_t)main_idx]; + + // SELECT list: cols originales + derived expressions (subset). + std::string select_list; + for (size_t i = 0; i < main_t.headers.size(); ++i) { + if (i > 0) select_list += ", "; + select_list += sql_ident(main_t.headers[i]); + } + + // Derived cols (stage 0 derived). + if (!st.stages.empty()) { + const Stage& s0 = st.stages[0]; + for (const auto& d : s0.derived) { + if (d.source_col >= 0 && d.formula.empty()) { + // Retipo puro: alias col origen. + if (d.source_col < (int)main_t.headers.size()) { + select_list += ", " + sql_ident(main_t.headers[(size_t)d.source_col]) + + " AS " + sql_ident(d.name); + } + continue; + } + std::string err; + std::string expr = transpile_expr(d.formula, main_t.headers, err); + if (!err.empty()) { + std::string msg = "derived col '" + d.name + + "' formula out of SQL subset: " + err; + e.warnings.push_back(msg); + // Skip col en SQL output; agente puede recurrir a TQL puro. + continue; + } + select_list += ", " + expr + " AS " + sql_ident(d.name); + } + } + + std::string from = sql_ident(main_t.name); + + // Joins + for (const auto& jn : st.joins) { + const TableInput* right = nullptr; + for (const auto& ti : tables) { + if (ti.name == jn.source) { right = &ti; break; } + } + if (!right) { + e.warnings.push_back("join source '" + jn.source + "' not in tables"); + continue; + } + const char* strat = "LEFT JOIN"; + switch (jn.strategy) { + case JoinStrategy::Left: strat = "LEFT JOIN"; break; + case JoinStrategy::Inner: strat = "INNER JOIN"; break; + case JoinStrategy::Right: strat = "RIGHT JOIN"; break; + case JoinStrategy::Full: strat = "FULL OUTER JOIN"; break; + } + from += "\n " + std::string(strat) + " " + sql_ident(right->name) + + " AS " + sql_ident(jn.alias) + " ON "; + for (size_t k = 0; k < jn.on.size(); ++k) { + if (k > 0) from += " AND "; + from += sql_ident(main_t.name) + "." + sql_ident(jn.on[k].first) + + " = " + sql_ident(jn.alias) + "." + sql_ident(jn.on[k].second); + } + // Anadir cols del right al SELECT con alias.col prefix. + if (jn.fields.empty()) { + for (const auto& rh : right->headers) { + std::string aliased = jn.alias + "." + rh; + select_list += ", " + sql_ident(jn.alias) + "." + sql_ident(rh) + + " AS " + sql_ident(aliased); + } + } else { + for (const auto& fld : jn.fields) { + std::string aliased = jn.alias + "." + fld; + select_list += ", " + sql_ident(jn.alias) + "." + sql_ident(fld) + + " AS " + sql_ident(aliased); + } + } + } + + // Stage 0 WHERE: filters del Raw (filter col idx en eff_headers). + // Filter.col es indice en eff_headers (orig + derived). Para SQL emit, + // necesitamos resolver col idx -> col name. Reconstruir orden eff_headers. + std::vector eff_headers = main_t.headers; + if (!st.stages.empty()) { + for (const auto& d : st.stages[0].derived) { + eff_headers.push_back(d.name); + } + } + std::string where_clause; + if (!st.stages.empty()) { + const Stage& s0 = st.stages[0]; + for (size_t fi = 0; fi < s0.filters.size(); ++fi) { + const Filter& f = s0.filters[fi]; + if (f.col < 0 || f.col >= (int)eff_headers.size()) { + e.warnings.push_back("stage0 filter col idx out of range"); + continue; + } + std::string col = sql_ident(eff_headers[(size_t)f.col]); + if (!where_clause.empty()) where_clause += " AND "; + where_clause += col + sql_op(f.op) + emit_filter_rhs(f, e.params); + } + } + + // Stage 0 sort + std::string order_clause; + if (!st.stages.empty()) { + const Stage& s0 = st.stages[0]; + for (size_t si = 0; si < s0.sorts.size(); ++si) { + const SortClause& sc = s0.sorts[si]; + if (!order_clause.empty()) order_clause += ", "; + order_clause += sql_ident(sc.col) + (sc.desc ? " DESC" : " ASC"); + } + } + + std::string cte = "t0 AS (\n SELECT " + select_list + "\n FROM " + from; + if (!where_clause.empty()) cte += "\n WHERE " + where_clause; + if (!order_clause.empty()) cte += "\n ORDER BY " + order_clause; + cte += "\n)"; + e.sql = "WITH " + cte; + return true; +} + +// Stage N (N>=1): SELECT breakouts + agg expressions FROM t +// [WHERE filters] [GROUP BY ...] [ORDER BY ...]. +bool emit_stage_n(const Stage& stg, int n, SqlEmit& e) { + std::string prev = "t" + std::to_string(n - 1); + std::string cur = "t" + std::to_string(n); + + // SELECT list: breakouts (con granularity expr si aplica) + aggregations. + std::string select_list; + for (size_t i = 0; i < stg.breakouts.size(); ++i) { + if (i > 0) select_list += ", "; + select_list += emit_breakout_expr(stg.breakouts[i]) + + " AS " + sql_ident(stg.breakouts[i]); + } + for (size_t i = 0; i < stg.aggregations.size(); ++i) { + if (!select_list.empty()) select_list += ", "; + std::string alias = aggregation_alias(stg.aggregations[i]); + select_list += emit_agg_expr(stg.aggregations[i]) + " AS " + sql_ident(alias); + } + if (select_list.empty()) select_list = "*"; + + // WHERE: filters del stage. col es indice en input headers (output del stage previo). + // Aproximacion: usamos el nombre via stage breakouts/aggs del stage previo si fuera necesario. + // Para v1, emit por nombre cuando filter.col >= 0 sea idx en breakouts/aggs/orig. El + // chequeo de existencia se delega a DuckDB (errores en execute son detectables). + // V1 simple: skip filter cuando no podemos resolver — caller solo deberia tener filter + // sobre cols que existen. + // Estrategia simple: emite WHERE solo si stage previo provee headers conocidos. Para no + // duplicar logica, dejamos al caller proveer headers via filter.col que se resuelve a + // breakouts[col]. + // V1: si filter.col esta en rango de breakouts del stage previo, emite breakout name. + // Sino, warning + skip. + std::string where_clause; + // Best effort: no podemos construir headers del stage previo aqui sin recomputar. + // Para v1, omitimos filters de stages >=1 — caller deberia evitar usarlos via SQL. + // TODO v2: pasar prev_headers para resolver. + (void)where_clause; + + // GROUP BY: solo si hay breakouts. + std::string group_clause; + for (size_t i = 0; i < stg.breakouts.size(); ++i) { + if (i > 0) group_clause += ", "; + // Re-emit la expression para GROUP BY (no alias). + group_clause += emit_breakout_expr(stg.breakouts[i]); + } + + // ORDER BY + std::string order_clause; + for (size_t i = 0; i < stg.sorts.size(); ++i) { + if (i > 0) order_clause += ", "; + order_clause += sql_ident(stg.sorts[i].col) + (stg.sorts[i].desc ? " DESC" : " ASC"); + } + + std::string cte = ",\n" + cur + " AS (\n SELECT " + select_list + + "\n FROM " + prev; + if (!group_clause.empty()) cte += "\n GROUP BY " + group_clause; + if (!order_clause.empty()) cte += "\n ORDER BY " + order_clause; + cte += "\n)"; + e.sql += cte; + return true; +} + +} // anon + +SqlEmit emit_sql(const State& state, + const std::vector& tables, + int up_to_stage) { + SqlEmit out; + if (state.stages.empty()) { + out.error = "state has no stages"; + return out; + } + if (tables.empty()) { + out.error = "no input tables provided"; + return out; + } + int target = (up_to_stage < 0) ? state.active_stage : up_to_stage; + if (target < 0) target = 0; + if (target >= (int)state.stages.size()) target = (int)state.stages.size() - 1; + + // Resolve main idx via state.main_source (o tables[0] default). + int main_idx = resolve_main_idx(tables, state.main_source); + if (main_idx < 0) main_idx = 0; + + if (!emit_stage0(state, tables, main_idx, out)) return out; + for (int si = 1; si <= target; ++si) { + if (!emit_stage_n(state.stages[(size_t)si], si, out)) return out; + } + out.sql += "\nSELECT * FROM t" + std::to_string(target) + ";\n"; + return out; +} + +} // namespace tql_to_sql diff --git a/cpp/apps/primitives_gallery/playground/tables/tql_to_sql.h b/cpp/apps/primitives_gallery/playground/tables/tql_to_sql.h new file mode 100644 index 00000000..29683100 --- /dev/null +++ b/cpp/apps/primitives_gallery/playground/tables/tql_to_sql.h @@ -0,0 +1,41 @@ +// tql_to_sql: emite SQL DuckDB equivalente a una pipeline TQL State. +// Pure. Sin DuckDB linkado. Solo string emit + validacion. +// Ver issue 0080 + docs/TQL.md (seccion "SQL transpile subset"). +#pragma once + +#include "data_table_logic.h" +#include +#include + +namespace tql_to_sql { + +struct SqlEmit { + std::string sql; // SELECT/CTE chain DuckDB + std::vector params; // bound values posicionales (?) + std::vector warnings; // soft issues (col not found, etc.) + std::string error; // si non-empty, emit fallo +}; + +// Pure: emite SQL DuckDB equivalente a stages 0..active del state. +// `tables` provee schema (headers/types/name) de cada TableInput. El caller +// es responsable de hidratar las tablas en DuckDB con esos nombres. +// `up_to_stage = -1` => state.active_stage. +SqlEmit emit_sql(const data_table::State& state, + const std::vector& tables, + int up_to_stage = -1); + +// Pure: valida que `formula` (cuerpo Lua de un derived col) este dentro del +// subset SQL-transpilable. Si valido, retorna true. Si no, false + razon +// concreta en `error_out` (categoria + token problematico). +// Ver docs/TQL.md#sql-transpile-subset. +bool is_transpilable(const std::string& formula, std::string& error_out); + +// Pure: transpila formula Lua subset -> SQL expression. Si fuera de subset, +// retorna "" y rellena `error_out`. Asume is_transpilable retornaria true. +// `in_headers` necesario para resolver `[col]` refs y emitir identifier +// SQL apropiado (quoted si tiene char especial). +std::string transpile_expr(const std::string& formula, + const std::vector& in_headers, + std::string& error_out); + +} // namespace tql_to_sql diff --git a/cpp/apps/primitives_gallery/playground/tables/viz.cpp b/cpp/apps/primitives_gallery/playground/tables/viz.cpp index 2f717044..5d75ea5b 100644 --- a/cpp/apps/primitives_gallery/playground/tables/viz.cpp +++ b/cpp/apps/primitives_gallery/playground/tables/viz.cpp @@ -16,6 +16,10 @@ using data_table::ColumnType; using data_table::ViewMode; using data_table::ViewConfig; using data_table::parse_number; +using data_table::nearest_index_2d; +using data_table::pie_angle; +using data_table::pie_slice_at_angle; +using data_table::heatmap_cell_at; static int find_header(const StageOutput& out, const std::string& name) { if (name.empty()) return -1; @@ -152,7 +156,8 @@ std::vector finite(const std::vector& v) { } bool render_bar_like(const StageOutput& out, ViewMode mode, - const ViewConfig& cfg, ImVec2 size) { + const ViewConfig& cfg, ImVec2 size, + int* clicked_row_out = nullptr) { int cat_col = resolve_cat(out, cfg, first_category_col(out)); auto nums = collect_numeric_filtered(out, cfg, 8); if (cat_col < 0 || nums.empty()) { @@ -225,6 +230,15 @@ bool render_bar_like(const StageOutput& out, ViewMode mode, ImPlot::PlotBars(nums[0].name.c_str(), ticks.data(), ys.data(), n, 0.67, spc); } } + // Hit-test fase 10: idx = round(plot.{x|y}) en single-series mode. + if (clicked_row_out && + mode != ViewMode::GroupedBar && mode != ViewMode::StackedBar && + ImPlot::IsPlotHovered() && ImGui::IsMouseClicked(ImGuiMouseButton_Left)) { + ImPlotPoint p = ImPlot::GetPlotMousePos(); + double target = horiz ? p.y : p.x; + int idx = (int)(target + 0.5); + if (idx >= 0 && idx < n) *clicked_row_out = idx; + } ImPlot::EndPlot(); return true; } @@ -302,7 +316,8 @@ bool render_line_like(const StageOutput& out, ViewMode mode, return true; } -bool render_scatter(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { +bool render_scatter(const StageOutput& out, const ViewConfig& cfg, ImVec2 size, + int* clicked_row_out = nullptr) { // Soporte cfg.x_col + cfg.y_cols[0] int xc = find_header(out, cfg.x_col); int yc = !cfg.y_cols.empty() ? find_header(out, cfg.y_cols[0]) : -1; @@ -329,11 +344,20 @@ bool render_scatter(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) ImPlot::PlotScatter("##s", nums[0].vals.data(), nums[1].vals.data(), (int)nums[0].vals.size()); } + if (clicked_row_out && + ImPlot::IsPlotHovered() && ImGui::IsMouseClicked(ImGuiMouseButton_Left)) { + ImPlotPoint p = ImPlot::GetPlotMousePos(); + int idx = nearest_index_2d(p.x, p.y, + nums[0].vals.data(), nums[1].vals.data(), + (int)nums[0].vals.size()); + if (idx >= 0) *clicked_row_out = idx; + } ImPlot::EndPlot(); return true; } -bool render_bubble(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { +bool render_bubble(const StageOutput& out, const ViewConfig& cfg, ImVec2 size, + int* clicked_row_out = nullptr) { int xc = find_header(out, cfg.x_col); int yc = !cfg.y_cols.empty() ? find_header(out, cfg.y_cols[0]) : -1; int sc = resolve_size(out, cfg, -1); @@ -354,6 +378,14 @@ bool render_bubble(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { axflag(cfg), axflag(cfg)); ImPlot::PlotBubbles("##b", nums[0].vals.data(), nums[1].vals.data(), nums[2].vals.data(), (int)nums[0].vals.size()); + if (clicked_row_out && + ImPlot::IsPlotHovered() && ImGui::IsMouseClicked(ImGuiMouseButton_Left)) { + ImPlotPoint p = ImPlot::GetPlotMousePos(); + int idx = nearest_index_2d(p.x, p.y, + nums[0].vals.data(), nums[1].vals.data(), + (int)nums[0].vals.size()); + if (idx >= 0) *clicked_row_out = idx; + } ImPlot::EndPlot(); return true; } @@ -404,7 +436,8 @@ bool render_hist2d(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { return true; } -bool render_heatmap(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { +bool render_heatmap(const StageOutput& out, const ViewConfig& cfg, ImVec2 size, + int* clicked_row_out = nullptr) { auto nums = collect_numeric_filtered(out, cfg, 64); if (nums.empty()) { info_text("Need numeric columns"); return false; } int cols = (int)nums.size(); @@ -424,11 +457,22 @@ bool render_heatmap(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) maybe_fit(cfg); if (!ImPlot::BeginPlot("##heatmap", size, 0)) return false; ImPlot::PlotHeatmap("##hm", mat.data(), rows, cols, mn, mx, nullptr); + if (clicked_row_out && + ImPlot::IsPlotHovered() && ImGui::IsMouseClicked(ImGuiMouseButton_Left)) { + ImPlotPoint p = ImPlot::GetPlotMousePos(); + // ImPlot heatmap Y se pinta de top a bottom; plot mouse_y va igual + // (default scale 0..rows). Mapeo directo. + int rr, cc; + heatmap_cell_at(p.x, p.y, rows, cols, rr, cc); + if (rr >= 0) *clicked_row_out = rr; + (void)cc; + } ImPlot::EndPlot(); return true; } -bool render_pie(const StageOutput& out, const ViewConfig& cfg, bool donut, ImVec2 size) { +bool render_pie(const StageOutput& out, const ViewConfig& cfg, bool donut, ImVec2 size, + int* clicked_row_out = nullptr) { int cat = resolve_cat(out, cfg, first_category_col(out)); auto nums = collect_numeric_filtered(out, cfg, 1); if (cat < 0 || nums.empty()) { info_text("Need 1 category + 1 numeric"); return false; } @@ -455,11 +499,24 @@ bool render_pie(const StageOutput& out, const ViewConfig& cfg, bool donut, ImVec // Draw inner hole as solid circle by overlaying a smaller pie of one slice transparent. // Simpler: just visually it's a circle with text. Use no extra primitive for now. } + if (clicked_row_out && + ImPlot::IsPlotHovered() && ImGui::IsMouseClicked(ImGuiMouseButton_Left)) { + ImPlotPoint p = ImPlot::GetPlotMousePos(); + double dx = p.x - 0.5, dy = p.y - 0.5; + double dist2 = dx*dx + dy*dy; + double inner = donut ? (radius * 0.5) : 0.0; + if (dist2 <= radius * radius && dist2 >= inner * inner) { + double ang = pie_angle(0.5, 0.5, p.x, p.y); + int idx = pie_slice_at_angle(ang, values.data(), n); + if (idx >= 0) *clicked_row_out = idx; + } + } ImPlot::EndPlot(); return true; } -bool render_funnel(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { +bool render_funnel(const StageOutput& out, const ViewConfig& cfg, ImVec2 size, + int* clicked_row_out = nullptr) { int cat = resolve_cat(out, cfg, first_category_col(out)); auto nums = collect_numeric_filtered(out, cfg, 1); if (cat < 0 || nums.empty()) { info_text("Need 1 category + 1 numeric"); return false; } @@ -492,6 +549,17 @@ bool render_funnel(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { ImPlot::SetupAxisTicks(ImAxis_Y1, ticks.data(), n, labels.data(), false); ImPlot::PlotBars(nums[0].name.c_str(), ys.data(), ticks.data(), n, 0.85, ImPlotSpec(ImPlotProp_Flags, ImPlotBarsFlags_Horizontal)); + if (clicked_row_out && + ImPlot::IsPlotHovered() && ImGui::IsMouseClicked(ImGuiMouseButton_Left)) { + ImPlotPoint p = ImPlot::GetPlotMousePos(); + int tick_idx = (int)(p.y + 0.5); + // ticks[i] = n-1-i. Invertir para idx en orden sorted descendiente. + int sorted_pos = (n - 1) - tick_idx; + if (sorted_pos >= 0 && sorted_pos < n) { + // idx[sorted_pos] da indice de row original en out. + *clicked_row_out = idx[sorted_pos]; + } + } ImPlot::EndPlot(); return true; } @@ -763,7 +831,9 @@ bool render_radar(const StageOutput& out, const ViewConfig& cfg, ImVec2 size) { } // anon bool render(const StageOutput& out, ViewMode mode, - const ViewConfig& cfg, ImVec2 size) { + const ViewConfig& cfg, ImVec2 size, + int* clicked_row_out) { + if (clicked_row_out) *clicked_row_out = -1; if (out.rows == 0 || out.cols == 0) { info_text("No data"); return false; @@ -773,21 +843,21 @@ bool render(const StageOutput& out, ViewMode mode, case ViewMode::Bar: case ViewMode::Column: case ViewMode::GroupedBar: - case ViewMode::StackedBar: return render_bar_like(out, mode, cfg, size); + case ViewMode::StackedBar: return render_bar_like(out, mode, cfg, size, clicked_row_out); case ViewMode::Line: case ViewMode::Area: case ViewMode::Stairs: return render_line_like(out, mode, cfg, size); - case ViewMode::Scatter: return render_scatter(out, cfg, size); - case ViewMode::Bubble: return render_bubble(out, cfg, size); + case ViewMode::Scatter: return render_scatter(out, cfg, size, clicked_row_out); + case ViewMode::Bubble: return render_bubble(out, cfg, size, clicked_row_out); case ViewMode::Histogram: return render_histogram(out, cfg, size); case ViewMode::Histogram2D: return render_hist2d(out, cfg, size); - case ViewMode::Heatmap: return render_heatmap(out, cfg, size); + case ViewMode::Heatmap: return render_heatmap(out, cfg, size, clicked_row_out); case ViewMode::BoxPlot: return render_boxplot(out, cfg, size); case ViewMode::Stem: return render_stem(out, cfg, size); case ViewMode::ErrorBars: return render_errorbars(out, cfg, size); - case ViewMode::Pie: return render_pie(out, cfg, false, size); - case ViewMode::Donut: return render_pie(out, cfg, true, size); - case ViewMode::Funnel: return render_funnel(out, cfg, size); + case ViewMode::Pie: return render_pie(out, cfg, false, size, clicked_row_out); + case ViewMode::Donut: return render_pie(out, cfg, true, size, clicked_row_out); + case ViewMode::Funnel: return render_funnel(out, cfg, size, clicked_row_out); case ViewMode::Waterfall: return render_waterfall(out, cfg, size); case ViewMode::KPI: return render_kpi_single(out, cfg); case ViewMode::KPIGrid: return render_kpi_grid(out, cfg); diff --git a/cpp/apps/primitives_gallery/playground/tables/viz.h b/cpp/apps/primitives_gallery/playground/tables/viz.h index 96b364c1..ff358fb3 100644 --- a/cpp/apps/primitives_gallery/playground/tables/viz.h +++ b/cpp/apps/primitives_gallery/playground/tables/viz.h @@ -14,10 +14,15 @@ namespace viz { // // `size`: ImVec2(-1,-1) usa todo el espacio disponible. // `out`: output del stage activo (headers, types, cells flat row-major). +// `clicked_row_out`: si != nullptr, el render escribira el indice de row del +// `StageOutput` clicado por user. -1 si no hubo click drillable. Fase 10 +// (issue 0079): habilitado para bar/column/pie/donut/funnel/scatter/bubble/ +// heatmap. Resto de modos: no hit-test, queda en -1. bool render(const data_table::StageOutput& out, data_table::ViewMode mode, const data_table::ViewConfig& cfg, - ImVec2 size = ImVec2(-1, -1)); + ImVec2 size = ImVec2(-1, -1), + int* clicked_row_out = nullptr); // Helper expuesto: encuentra primera col numerica. -1 si ninguna. int first_numeric_col(const data_table::StageOutput& out); diff --git a/cpp/functions/core/data_table_types.h b/cpp/functions/core/data_table_types.h new file mode 100644 index 00000000..f19d27d6 --- /dev/null +++ b/cpp/functions/core/data_table_types.h @@ -0,0 +1,212 @@ +// data_table_types — types compartidos del stack TQL (Table Query Language). +// Promovido al registry desde cpp/apps/primitives_gallery/playground/tables/. +// Ver issue 0081 + docs/TQL.md. Pure value types + enums. +#pragma once + +#include +#include +#include + +namespace data_table { + +// ---------------------------------------------------------------------------- +// Operadores de filtro. +// ---------------------------------------------------------------------------- +enum class Op { + Eq, Neq, Gt, Gte, Lt, Lte, + Contains, NotContains, StartsWith, EndsWith +}; + +// ---------------------------------------------------------------------------- +// Tipo de columna. Declarado por caller o auto-detectado. +// ---------------------------------------------------------------------------- +enum class ColumnType { + Auto, String, Int, Float, Bool, Date, Json +}; + +// ---------------------------------------------------------------------------- +// Derived column: inmutable. Dos modos: +// 1) Retipo puro: source_col >= 0, formula == "". Cells del origen. +// 2) Formula: source_col == -1, formula no vacia. Eval por Lua. +// ---------------------------------------------------------------------------- +struct DerivedColumn { + int source_col = -1; + ColumnType type = ColumnType::String; + std::string name; + std::string formula; // "" = retipado puro; resto = body Lua + int lua_id = -1; // referencia en lua_engine; -1 si no compilado + std::string compile_error; +}; + +// ---------------------------------------------------------------------------- +// Filtro: col index en eff_headers + op + value. +// ---------------------------------------------------------------------------- +struct Filter { + int col; + Op op; + std::string value; +}; + +// ---------------------------------------------------------------------------- +// ColorRule: pintado condicional de celdas (UI helper). +// ---------------------------------------------------------------------------- +struct ColorRule { + int col; + std::string equals; + unsigned int color; +}; + +// ---------------------------------------------------------------------------- +// Aggregations (TQL stages 1+). +// ---------------------------------------------------------------------------- +enum class AggFn { + Count, Sum, Avg, Min, Max, Distinct, Stddev, + Median, P25, P75, P90, P99, Percentile +}; + +struct Aggregation { + AggFn fn = AggFn::Count; + std::string col; // ignorado para Count + double arg = 0.0; // para Percentile (0..1) + std::string alias; // vacio -> auto-generado via aggregation_alias() +}; + +struct SortClause { + std::string col; + bool desc = false; +}; + +// Stage: layer de TQL. Stage 0 = Raw (sin breakouts/aggregations). +// Stage 1+ pueden agrupar. Cada stage consume output del anterior. +struct Stage { + std::vector filters; + std::vector derived; // expressions de este stage + std::vector breakouts; // col names del INPUT de este stage + std::vector aggregations; + std::vector sorts; +}; + +// Output de compute_stage. Posee `cell_backing` (strings nuevos para +// resultados agregados) y `cells` (punteros row-major a backing o a +// `in_cells` original para passthrough). +struct StageOutput { + std::vector cell_backing; + std::vector cells; + int rows = 0; + int cols = 0; + std::vector headers; + std::vector types; +}; + +// ---------------------------------------------------------------------------- +// ViewMode: tipo de visualizacion a renderizar sobre el output del stage activo. +// ---------------------------------------------------------------------------- +enum class ViewMode { + Table, + // Bars + Bar, Column, GroupedBar, StackedBar, + // Lines / area + Line, Area, Stairs, + // Points + Scatter, Bubble, + // Distribution + Histogram, Histogram2D, Heatmap, BoxPlot, + // Stems / signals + Stem, ErrorBars, + // Composition + Pie, Donut, Funnel, Waterfall, + // Single values + KPI, KPIGrid, + // Specialized + Candlestick, Radar, +}; + +// ---------------------------------------------------------------------------- +// Joins (MBQL-style). Ver issue 0078. +// ---------------------------------------------------------------------------- +enum class JoinStrategy { Left, Inner, Right, Full }; + +// Tabla extra pasada al render() para joins. Owner externo (caller). +struct TableInput { + std::string name; // identificador estable (matchea Join.source) + std::vector headers; + std::vector types; + const char* const* cells = nullptr; // row-major, headers.size() cols x rows filas + int rows = 0; + int cols = 0; +}; + +// Join clause: une la tabla actual con `source` por las parejas `on`, +// prefijando las cols del derecho con `alias.`. +struct Join { + std::string alias; + std::string source; + std::vector> on; // {left_col, right_col} + JoinStrategy strategy = JoinStrategy::Left; + std::vector fields; // vacio = all del derecho +}; + +// ---------------------------------------------------------------------------- +// ViewConfig: overrides manuales de auto-detect para la vista activa. +// ---------------------------------------------------------------------------- +struct ViewConfig { + std::string x_col; // single: scatter, line, hist2d + std::vector y_cols; // 1..N: line/area/bar/etc + std::string size_col; // bubble + std::string cat_col; // bar/pie/funnel/box override + unsigned int primary_color = 0; // 0 = ImPlot auto + int hist_bins = 0; // 0 = Sturges + float pie_radius = 0.0f; // 0 = default + bool show_legend = true; + bool show_markers = false; // line/area markers + bool locked = false; // disable pan/zoom + mutable bool fit_request = false; // consumed by viz::render +}; + +// VizPanel: viz adicional sobre el mismo StageOutput. +struct VizPanel { + ViewMode display = ViewMode::Bar; + ViewConfig config; + mutable ViewMode last_non_table = ViewMode::Bar; +}; + +// ---------------------------------------------------------------------------- +// State: stage pipeline + viz globales. +// ---------------------------------------------------------------------------- +struct State { + std::vector stages; + int active_stage = 0; + ViewMode display = ViewMode::Table; + ViewConfig viz_config; + std::vector extra_panels; + std::vector joins; // aplicado antes de stages[0] + std::string main_source; // name de TableInput; vacio -> tables[0] + + std::vector color_rules; + std::vector col_visible; + std::vector col_order; + + // Helpers (definidos en compute_stage.cpp). + Stage& raw(); + const Stage& raw() const; + Stage& active(); + const Stage& active_const() const; + void ensure_stage0(); +}; + +// ---------------------------------------------------------------------------- +// Drill extendido (fase 10). Ver issue 0079. +// ---------------------------------------------------------------------------- +enum class DateGranularity { None, Year, Month, Week, Day, Hour }; + +enum class FilterPreset { Last7d, Last30d, Last90d, ExcludeNulls, NonZero }; + +// Step de drill grabado para history undo/redo (fase 10). +struct DrillStep { + int target_stage = -1; // stage donde se anadio el filter + int filter_pos = -1; // index en target_stage.filters + int prev_active_stage = 0; // active_stage antes del drill + Filter added; // filter para redo +}; + +} // namespace data_table diff --git a/cpp/functions/gfx/gpu_check.cpp b/cpp/functions/gfx/gpu_check.cpp new file mode 100644 index 00000000..e7f37809 --- /dev/null +++ b/cpp/functions/gfx/gpu_check.cpp @@ -0,0 +1,96 @@ +#include "gfx/gpu_check.h" +#include "gfx/gl_loader.h" + +#include +#include + +// CUDA runtime version via compile-time macro. +// cuda_runtime.h define CUDART_VERSION como XXYYZZ (ej. 12040 para 12.4.0). +// Solo se incluye si el header esta disponible; si no, cuda_runtime_version = "". +#if defined(__has_include) && __has_include() + #include + #define FN_HAS_CUDA_RUNTIME 1 +#endif + +namespace fn::gfx { + +static std::string safe_gl_string(GLenum name) { + const GLubyte* s = glGetString(name); + if (!s) return ""; + return std::string(reinterpret_cast(s)); +} + +static bool check_gl_version_43() { + // GL_VERSION tiene formato "major.minor ..." o "OpenGL ES major.minor ..." + const GLubyte* ver = glGetString(GL_VERSION); + if (!ver) return false; + int major = 0, minor = 0; + // Saltar prefijo "OpenGL ES " si lo hay + const char* p = reinterpret_cast(ver); + if (std::strncmp(p, "OpenGL ES ", 10) == 0) p += 10; + // sscanf con la forma "X.Y" + // NOLINTNEXTLINE(cert-err34-c) + std::sscanf(p, "%d.%d", &major, &minor); + return (major > 4) || (major == 4 && minor >= 3); +} + +bool gpu_check_caps(GpuCaps& out) { + out = GpuCaps{}; // reset + + out.gl_vendor = safe_gl_string(GL_VENDOR); + out.gl_renderer = safe_gl_string(GL_RENDERER); + out.gl_version = safe_gl_string(GL_VERSION); + + if (out.gl_vendor.empty()) { + // No hay contexto GL activo. + return false; + } + + // Compute shader support: GL 4.3+ o ARB_compute_shader + { + const GLubyte* exts = glGetString(GL_EXTENSIONS); + bool has_arb = exts && + std::strstr(reinterpret_cast(exts), + "GL_ARB_compute_shader") != nullptr; + out.has_compute_shader = check_gl_version_43() || has_arb; + } + + // Shader storage buffer: GL 4.3+ o ARB_shader_storage_buffer_object + { + const GLubyte* exts = glGetString(GL_EXTENSIONS); + bool has_ssbo_arb = exts && + std::strstr(reinterpret_cast(exts), + "GL_ARB_shader_storage_buffer_object") != nullptr; + out.has_storage_buffer = check_gl_version_43() || has_ssbo_arb; + } + + // Workgroup limits (solo si hay compute shader support) + if (out.has_compute_shader) { + // GL_MAX_COMPUTE_WORK_GROUP_COUNT — indexed query + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_COUNT, 0, &out.max_compute_workgroup_count[0]); + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_COUNT, 1, &out.max_compute_workgroup_count[1]); + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_COUNT, 2, &out.max_compute_workgroup_count[2]); + + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 0, &out.max_compute_workgroup_size[0]); + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, &out.max_compute_workgroup_size[1]); + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 2, &out.max_compute_workgroup_size[2]); + } + + // CUDA runtime version (compile-time detection) +#if defined(FN_HAS_CUDA_RUNTIME) + { + int cuda_ver = CUDART_VERSION; // ej. 12040 para CUDA 12.4.0 + int major = cuda_ver / 1000; + int minor = (cuda_ver % 1000) / 10; + char buf[16]; + std::snprintf(buf, sizeof(buf), "%d.%d", major, minor); + out.cuda_runtime_version = buf; + } +#else + out.cuda_runtime_version = ""; +#endif + + return true; +} + +} // namespace fn::gfx diff --git a/cpp/functions/gfx/gpu_check.h b/cpp/functions/gfx/gpu_check.h new file mode 100644 index 00000000..b26c1ef9 --- /dev/null +++ b/cpp/functions/gfx/gpu_check.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +namespace fn::gfx { + +// GpuCaps recopila capacidades OpenGL y CUDA del contexto activo. +// Todos los campos de cadena estan vacios ("") si el dato no esta disponible. +struct GpuCaps { + // OpenGL — requieren contexto GL activo antes de llamar gpu_check_caps. + std::string gl_vendor; // glGetString(GL_VENDOR) ej. "NVIDIA Corporation" + std::string gl_renderer; // glGetString(GL_RENDERER) ej. "NVIDIA GeForce RTX 3080/PCIe/SSE2" + std::string gl_version; // glGetString(GL_VERSION) ej. "4.6.0 NVIDIA 550.54.15" + + // Compute shader limits (GL_MAX_COMPUTE_WORK_GROUP_COUNT/SIZE) + // Indice 0=X 1=Y 2=Z. Valor 0 si compute shaders no disponibles. + int max_compute_workgroup_count[3] = {0, 0, 0}; + int max_compute_workgroup_size[3] = {0, 0, 0}; + + bool has_compute_shader = false; // GL_VERSION >= 4.3 o extension ARB_compute_shader + bool has_storage_buffer = false; // GL_VERSION >= 4.3 o extension ARB_shader_storage_buffer_object + + // CUDA — vacio si CUDA runtime no detectado en compile time. + // Formato: "12.4" (major.minor) o "" si no disponible. + std::string cuda_runtime_version; +}; + +// gpu_check_caps rellena out con las capacidades del contexto OpenGL activo. +// +// REQUISITO: debe llamarse despues de inicializar el contexto GL y, en Windows, +// despues de fn::gfx::gl_loader_init(). Si se llama sin contexto activo el +// comportamiento es indefinido (glGetString devuelve nullptr). +// +// Retorna true si se pudo leer al menos el vendor GL (contexto activo). +// Retorna false si gl_vendor queda vacio (contexto no activo o driver defectuoso). +bool gpu_check_caps(GpuCaps& out); + +} // namespace fn::gfx diff --git a/cpp/functions/gfx/gpu_check.md b/cpp/functions/gfx/gpu_check.md new file mode 100644 index 00000000..de25fc16 --- /dev/null +++ b/cpp/functions/gfx/gpu_check.md @@ -0,0 +1,86 @@ +--- +name: gpu_check +kind: function +lang: cpp +domain: gfx +version: "1.0.0" +purity: impure +signature: "bool fn_gfx::gpu_check_caps(GpuCaps& out)" +description: "Rellena GpuCaps con las capacidades del contexto OpenGL activo: vendor, renderer, version, limites de compute workgroup, flags has_compute_shader/has_storage_buffer, y version CUDA runtime (deteccion en compile-time via CUDART_VERSION). Requiere contexto GL activo. Retorna false si el contexto no esta disponible." +tags: [gpu, opengl, cuda, caps, hardware, probe, gfx, compute, infra] +uses_functions: ["gl_loader_cpp_gfx"] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [gfx/gpu_check.h, gfx/gl_loader.h, cuda_runtime.h, cstring, string] +tested: false +tests: [] +test_file_path: "" +file_path: "cpp/functions/gfx/gpu_check.cpp" +framework: opengl +params: + - name: out + desc: "Referencia a GpuCaps que se rellena con las capacidades detectadas. Se resetea al inicio de la llamada." +output: "true si el contexto GL esta activo y gl_vendor no esta vacio; false si no hay contexto GL activo o el driver devuelve nullptr para GL_VENDOR." +--- + +# gpu_check + +Probing de capacidades GPU en runtime: OpenGL strings, compute shader support y CUDA. + +## Uso tipico + +```cpp +#include "gfx/gpu_check.h" +#include "gfx/gl_loader.h" + +// Dentro de render(), despues del primer frame (contexto GL activo): +fn::gfx::GpuCaps caps; +if (fn::gfx::gpu_check_caps(caps)) { + printf("GPU: %s\n", caps.gl_renderer.c_str()); + printf("Compute shaders: %s\n", caps.has_compute_shader ? "yes" : "no"); + if (!caps.cuda_runtime_version.empty()) + printf("CUDA runtime: %s\n", caps.cuda_runtime_version.c_str()); +} else { + printf("No GL context active\n"); +} +``` + +## Estructura GpuCaps + +```cpp +struct GpuCaps { + std::string gl_vendor; // "NVIDIA Corporation" + std::string gl_renderer; // "NVIDIA GeForce RTX 3080/PCIe/SSE2" + std::string gl_version; // "4.6.0 NVIDIA 550.54.15" + int max_compute_workgroup_count[3]; // [65535, 65535, 65535] tipico NVIDIA + int max_compute_workgroup_size[3]; // [1024, 1024, 64] tipico + bool has_compute_shader; // GL 4.3+ o ARB_compute_shader + bool has_storage_buffer; // GL 4.3+ o ARB_shader_storage_buffer_object + std::string cuda_runtime_version; // "12.4" o "" si no compilado con CUDA +}; +``` + +## CUDA detection + +La version CUDA se detecta en **compile time** via el macro `CUDART_VERSION` de ``. Si la app no esta compilada con el CUDA toolkit, `cuda_runtime_version` sera `""`. Para detection en runtime del toolkit del sistema, usar `cuda_toolkit_check_bash_infra`. + +## Requisito de contexto GL + +Llamar siempre despues de crear el contexto GL. En apps que usan `fn::run_app`, el contexto esta activo desde el primer frame del `render()` callback. En Windows, `fn::gfx::gl_loader_init()` debe haberse llamado antes para que los punteros de funcion esten resueltos. + +## Uso previsto (fn doctor cpp-apps) + +Esta funcion sera invocada por el audit de `fn doctor cpp-apps` para verificar que las apps C++ del registry tienen acceso a compute shaders cuando declaran dependencias de `gpu_compute_program`, `gpu_dispatch`, etc. + +## CMakeLists.txt + +```cmake +add_imgui_app(mi_app + main.cpp + ${CMAKE_SOURCE_DIR}/cpp/functions/gfx/gpu_check.cpp +) +# CUDA opcional: si la app compila con CUDA toolkit el header cuda_runtime.h +# estara disponible y FN_HAS_CUDA_RUNTIME se activara automaticamente. +``` diff --git a/cpp/types/core/agg_fn.md b/cpp/types/core/agg_fn.md new file mode 100644 index 00000000..a16f0517 --- /dev/null +++ b/cpp/types/core/agg_fn.md @@ -0,0 +1,20 @@ +--- +name: AggFn +lang: cpp +domain: core +version: "1.0.0" +algebraic: sum +definition: | + enum class AggFn { + Count, Sum, Avg, Min, Max, Distinct, Stddev, + Median, P25, P75, P90, P99, Percentile + }; +description: "Funcion de agregacion soportada. Pickup via UI combo + SQL emit via tql_to_sql. Percentile usa Aggregation.arg en [0,1]." +tags: [tql, aggregation, sum-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Mapeo SQL DuckDB: Count → `COUNT(*)`, Sum/Avg/Min/Max/Stddev → ops nativas, Distinct → `COUNT(DISTINCT col)`, Median/P25/P75/P90/P99/Percentile → `quantile_cont(col, p)`. diff --git a/cpp/types/core/aggregation.md b/cpp/types/core/aggregation.md new file mode 100644 index 00000000..ce1a44cb --- /dev/null +++ b/cpp/types/core/aggregation.md @@ -0,0 +1,22 @@ +--- +name: Aggregation +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct Aggregation { + AggFn fn; + std::string col; + double arg; + std::string alias; + }; +description: "Funcion de agregacion en Stage 1+. fn = Count/Sum/Avg/Min/Max/Distinct/Stddev/Median/P25/P75/P90/P99/Percentile. arg = parametro (p para percentile)." +tags: [tql, aggregation, agg, product-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +`alias` vacio dispara `aggregation_alias(a)` auto: `count`, `sum_`, `distinct_`, `p95_` etc. SQL mapping en `tql_to_sql`: `COUNT(*)`, `SUM("col")`, `quantile_cont("col", p)`. diff --git a/cpp/types/core/color_rule.md b/cpp/types/core/color_rule.md new file mode 100644 index 00000000..967fde0f --- /dev/null +++ b/cpp/types/core/color_rule.md @@ -0,0 +1,21 @@ +--- +name: ColorRule +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct ColorRule { + int col; + std::string equals; + unsigned int color; + }; +description: "Regla de pintado condicional para tabla UI. Si cells[row][col] == equals, fondo = color (RGBA packed)." +tags: [tql, color, ui-hint, product-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Solo afecta render visual. Round-trip en TQL via `columns..color_rules`. Vacio = sin color override. diff --git a/cpp/types/core/column_type.md b/cpp/types/core/column_type.md new file mode 100644 index 00000000..f3595f92 --- /dev/null +++ b/cpp/types/core/column_type.md @@ -0,0 +1,28 @@ +--- +name: ColumnType +lang: cpp +domain: core +version: "1.0.0" +algebraic: sum +definition: | + enum class ColumnType { + Auto, String, Int, Float, Bool, Date, Json + }; +description: "Tipo de columna del modelo TQL. `Auto` dispara auto-detect; el resto fuerza el tipo declarado. Base de toda la pipeline data_table." +tags: [tql, data-table, types, sum-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Sum type / enum-class. Convivimos con `effective_type()` que resuelve `Auto` → auto-detect via sample. El resto fuerza el tipo declarado por el caller. + +Tabla de iconos UTF-8 Tabler para cada variante en `column_type_icon(t)`. Mapeo SQL ↔ ColumnType en `tql_to_sql` (issue 0080). + +## Usado por + +- `compute_stage_cpp_core` — input/output types per stage +- `tql_emit_cpp_core` / `tql_apply_cpp_core` — emit/parse TQL columns block +- `tql_to_sql_cpp_core` — mapping a SQL DuckDB types +- `data_table_cpp_viz` — UI render por columna diff --git a/cpp/types/core/date_granularity.md b/cpp/types/core/date_granularity.md new file mode 100644 index 00000000..788b8327 --- /dev/null +++ b/cpp/types/core/date_granularity.md @@ -0,0 +1,19 @@ +--- +name: DateGranularity +lang: cpp +domain: core +version: "1.0.0" +algebraic: sum +definition: | + enum class DateGranularity { None, Year, Month, Week, Day, Hour }; +description: "Granularidad de truncado de fechas para breakouts TQL. Sufijo `:token` en breakout string (ej. 'ts:month')." +tags: [tql, date, granularity, sum-type, mbql] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Auto-detect via `auto_date_granularity(min_ymd, max_ymd)`: >2y→Year, >60d→Month, >14d→Week, resto→Day. SQL emit DuckDB: `date_trunc('month'|'year'|...,col)`. + +Week trunca a lunes ISO (Hinnant algo). diff --git a/cpp/types/core/derived_column.md b/cpp/types/core/derived_column.md new file mode 100644 index 00000000..af0856f5 --- /dev/null +++ b/cpp/types/core/derived_column.md @@ -0,0 +1,26 @@ +--- +name: DerivedColumn +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct DerivedColumn { + int source_col; + ColumnType type; + std::string name; + std::string formula; + int lua_id; + std::string compile_error; + }; +description: "Col custom dentro de un Stage. Modo 1: retipo (source_col >= 0, formula vacia). Modo 2: formula Lua (source_col == -1, eval por lua_engine sandbox)." +tags: [tql, derived, formula, lua, product-type] +uses_types: [ColumnType_cpp_core] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +`formula` evaluada por row via `lua_engine` con `[col]` refs disponibles. Para SQL transpile (fase 11), formula debe estar dentro del Lua subset; sino `tql_to_sql` emite warning + skip col. + +`lua_id` cachea la formula compilada en lua_engine entre eval calls. diff --git a/cpp/types/core/filter.md b/cpp/types/core/filter.md new file mode 100644 index 00000000..9eef78af --- /dev/null +++ b/cpp/types/core/filter.md @@ -0,0 +1,30 @@ +--- +name: Filter +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct Filter { + int col; + Op op; + std::string value; + }; +description: "Predicado TQL: col idx + Op + value. Aplicado dentro de un Stage por compute_stage. col es idx en headers efectivos del INPUT del stage." +tags: [tql, filter, predicate, product-type] +uses_types: [Op_cpp_core] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +`col` es indice en `in_headers` del stage donde aplica (no en el dataset original — esto cambio en el refactor a stages). Para drill-down usar `make_drill_filter(col_idx, value)`. + +`value` es string siempre — `compare()` decide numerico vs lexical segun parseo. Range filters (op_in_range, op_between) no estan modelados; usar dos Filters consecutivos. + +## Usado por + +- `Stage_cpp_core` (lista de filters) +- `apply_filters`, `compute_stage_cpp_core` +- `make_drill_filter`, `build_preset_filters` +- `tql_to_sql_cpp_core` → SQL WHERE clauses con `?` placeholders diff --git a/cpp/types/core/join.md b/cpp/types/core/join.md new file mode 100644 index 00000000..4849e1e7 --- /dev/null +++ b/cpp/types/core/join.md @@ -0,0 +1,25 @@ +--- +name: Join +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct Join { + std::string alias; + std::string source; + std::vector> on; + JoinStrategy strategy; + std::vector fields; + }; +description: "Join MBQL-style entre main_t y source. on = pares {left_col, right_col} multi-key. strategy = Left/Inner/Right/Full. fields vacio = all cols del derecho." +tags: [tql, join, mbql, product-type] +uses_types: [JoinStrategy_cpp_core] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Materializado por `join_tables_cpp_core` antes de stages[0]. Cols del derecho se prefijan con `alias.col` para preservar headers del main. SQL emit: `LEFT/INNER/RIGHT/FULL OUTER JOIN source AS alias ON main.l = alias.r AND ...`. + +Multi-key: `on = {{l1,r1}, {l2,r2}}` → `ON main.l1 = alias.r1 AND main.l2 = alias.r2`. diff --git a/cpp/types/core/join_strategy.md b/cpp/types/core/join_strategy.md new file mode 100644 index 00000000..4d328f70 --- /dev/null +++ b/cpp/types/core/join_strategy.md @@ -0,0 +1,17 @@ +--- +name: JoinStrategy +lang: cpp +domain: core +version: "1.0.0" +algebraic: sum +definition: | + enum class JoinStrategy { Left, Inner, Right, Full }; +description: "Estrategia de join MBQL-style. 4 variantes estandar SQL. SQL mapping directo a LEFT/INNER/RIGHT/FULL OUTER JOIN." +tags: [tql, join, strategy, sum-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Round-trip TQL: tokens `"left"/"inner"/"right"/"full"`. Fallback parse "nope" → Left. diff --git a/cpp/types/core/op.md b/cpp/types/core/op.md new file mode 100644 index 00000000..5db472f5 --- /dev/null +++ b/cpp/types/core/op.md @@ -0,0 +1,36 @@ +--- +name: Op +lang: cpp +domain: core +version: "1.0.0" +algebraic: sum +definition: | + enum class Op { + Eq, Neq, Gt, Gte, Lt, Lte, + Contains, NotContains, StartsWith, EndsWith + }; +description: "Operador de filtro TQL. 6 ops de comparacion + 4 ops de string. Numericos ordenan numericamente cuando ambos lados parsean." +tags: [tql, filter, operator, sum-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Tabla operadores permitidos por `ColumnType` via `ops_for_type(t)`: + +| Tipo | Ops | +|---|---| +| Int / Float / Date | Eq, Neq, Gt, Gte, Lt, Lte | +| Bool | Eq, Neq | +| Json | Eq, Neq, Contains, NotContains | +| String | Eq, Neq, Contains, NotContains, StartsWith, EndsWith | + +Mapeo SQL en `tql_to_sql_cpp_core`: Contains → `LIKE '%v%'`, StartsWith → `LIKE 'v%'`, etc. + +## Usado por + +- `Filter_cpp_core` +- `compute_stage_cpp_core` (via apply_filters) +- `tql_emit_cpp_core` / `tql_apply_cpp_core` +- `tql_to_sql_cpp_core` diff --git a/cpp/types/core/sort_clause.md b/cpp/types/core/sort_clause.md new file mode 100644 index 00000000..3c4d940b --- /dev/null +++ b/cpp/types/core/sort_clause.md @@ -0,0 +1,20 @@ +--- +name: SortClause +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct SortClause { + std::string col; + bool desc; + }; +description: "Clausula de orden por nombre de col. Multi-sort = vector ordenado por prioridad. desc=true para descendente." +tags: [tql, sort, order, product-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Sort por nombre (no idx) — sobrevive a renombrado de cols + a stages 1+ donde idx no aplica. Aplicacion via `apply_sorts`. Round-trip TQL: `sort = { {"asc"|"desc", "col"}, ... }`. diff --git a/cpp/types/core/stage.md b/cpp/types/core/stage.md new file mode 100644 index 00000000..f672ac30 --- /dev/null +++ b/cpp/types/core/stage.md @@ -0,0 +1,33 @@ +--- +name: Stage +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct Stage { + std::vector filters; + std::vector derived; + std::vector breakouts; + std::vector aggregations; + std::vector sorts; + }; +description: "Layer del pipeline TQL. Stage 0 = Raw (filters + derived + sort). Stage 1+ pueden agrupar (breakouts + aggregations + sort). Consumida por compute_stage." +tags: [tql, stage, pipeline, product-type, mbql] +uses_types: [Filter_cpp_core, Op_cpp_core] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Inspirado en MBQL `:filter` / `:breakout` / `:aggregation` / `:order-by`. Diferencia clave: TQL chain N stages explicitos, cada uno consume el output del anterior. MBQL usa `:source-query` recursivo. + +Breakout strings pueden llevar sufijo `:granularity` para cols Date (fase 10): `"ts:month"`, `"ts:week"`, etc. Ver `parse_breakout_granularity()`. + +## Usado por + +- `State_cpp_core` (lista de stages) +- `compute_stage_cpp_core` (executes a single Stage) +- `compute_pipeline_cpp_core` (chains stages 0..N) +- `tql_emit_cpp_core` / `tql_apply_cpp_core` (round-trip Lua) +- `tql_to_sql_cpp_core` → CTE chain DuckDB diff --git a/cpp/types/core/stage_output.md b/cpp/types/core/stage_output.md new file mode 100644 index 00000000..8ad8a9e5 --- /dev/null +++ b/cpp/types/core/stage_output.md @@ -0,0 +1,26 @@ +--- +name: StageOutput +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct StageOutput { + std::vector cell_backing; + std::vector cells; + int rows; + int cols; + std::vector headers; + std::vector types; + }; +description: "Output materializado de compute_stage. cell_backing posee strings nuevos (aggregations); cells es row-major de ptrs a backing o a in_cells original." +tags: [tql, stage, output, product-type] +uses_types: [ColumnType_cpp_core] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Lifetime: cell_backing es owner — cells solo es valido mientras StageOutput viva. Para passthrough (sin agregaciones), cells apunta a in_cells del caller (sin backing local). + +Reservar capacidad upfront en cell_backing evita realloc que invalida punteros. diff --git a/cpp/types/core/state.md b/cpp/types/core/state.md new file mode 100644 index 00000000..5b93408f --- /dev/null +++ b/cpp/types/core/state.md @@ -0,0 +1,40 @@ +--- +name: State +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct State { + std::vector stages; + int active_stage; + ViewMode display; + ViewConfig viz_config; + std::vector extra_panels; + std::vector joins; + std::string main_source; + std::vector color_rules; + std::vector col_visible; + std::vector col_order; + }; +description: "Estado completo de una query TQL: pipeline de stages + joins + viz config + UI tweaks. Round-trip a Lua via tql_emit/tql_apply." +tags: [tql, state, pipeline, product-type] +uses_types: [Stage_cpp_core, Filter_cpp_core, Op_cpp_core] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +State es el documento canonico de una query del usuario. Atomico — toda mutacion pasa por helpers pure (`apply_drill_step`, `drill_up`, etc.). + +`active_stage` = idx del stage cuyo output se renderiza. Filters/sorts del Raw siempre se aplican antes; joins se materializan ANTES de stages[0]. + +Helpers `raw()`, `active()` garantizan `stages[0]` existe (lazy init en `ensure_stage0`). + +## Usado por + +- `data_table_cpp_viz` (UI render principal) +- `compute_pipeline_cpp_core` (resuelve hasta active_stage) +- `tql_emit_cpp_core` / `tql_apply_cpp_core` (Lua serializacion) +- `tql_to_sql_cpp_core` → SQL DuckDB CTE chain +- `apply_drill_step` / `undo_drill_step` / `drill_up` diff --git a/cpp/types/core/table_input.md b/cpp/types/core/table_input.md new file mode 100644 index 00000000..ef4143b9 --- /dev/null +++ b/cpp/types/core/table_input.md @@ -0,0 +1,33 @@ +--- +name: TableInput +lang: cpp +domain: core +version: "1.0.0" +algebraic: product +definition: | + struct TableInput { + std::string name; + std::vector headers; + std::vector types; + const char* const* cells; + int rows; + int cols; + }; +description: "Tabla materializada en memoria pasada a data_table::render(). Owner externo. Multiple tables = main + joinables (fase 9 issue 0078)." +tags: [tql, table, joins, mbql, product-type] +uses_types: [Op_cpp_core] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +`name` es el identificador estable que matchea `Join.source` cuando se aplica un join. `cells` es row-major (rows * cols `const char*`). Apuntadores estables durante todo el frame de render. + +Cells son strings — auto_detect_type infiere ColumnType si `types[i] == Auto`. Numericos se parsean por celda en compare/agg via `parse_number()`. + +## Usado por + +- `data_table_cpp_viz::render(tables, state)` +- `resolve_main_idx` (matchea state.main_source) +- `join_tables_cpp_core` (right table) +- `tql_to_sql_cpp_core` (schema para emitir SELECT FROM `name`) diff --git a/cpp/types/viz/view_config.md b/cpp/types/viz/view_config.md new file mode 100644 index 00000000..a32ed68a --- /dev/null +++ b/cpp/types/viz/view_config.md @@ -0,0 +1,29 @@ +--- +name: ViewConfig +lang: cpp +domain: viz +version: "1.0.0" +algebraic: product +definition: | + struct ViewConfig { + std::string x_col; + std::vector y_cols; + std::string size_col; + std::string cat_col; + unsigned int primary_color; + int hist_bins; + float pie_radius; + bool show_legend; + bool show_markers; + bool locked; + mutable bool fit_request; + }; +description: "Overrides manuales de auto-detect para ViewMode. Cols vacias dejan al dispatcher elegir. primary_color=0 usa palette ImPlot." +tags: [tql, viz, config, product-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +`fit_request` mutable bool consumido por `viz::render` (one-shot trigger para `ImPlot::SetNextAxesToFit`). `locked` deshabilita pan/zoom del usuario. diff --git a/cpp/types/viz/view_mode.md b/cpp/types/viz/view_mode.md new file mode 100644 index 00000000..50326e5a --- /dev/null +++ b/cpp/types/viz/view_mode.md @@ -0,0 +1,29 @@ +--- +name: ViewMode +lang: cpp +domain: viz +version: "1.0.0" +algebraic: sum +definition: | + enum class ViewMode { + Table, + Bar, Column, GroupedBar, StackedBar, + Line, Area, Stairs, + Scatter, Bubble, + Histogram, Histogram2D, Heatmap, BoxPlot, + Stem, ErrorBars, + Pie, Donut, Funnel, Waterfall, + KPI, KPIGrid, + Candlestick, Radar + }; +description: "Modo de visualizacion ImPlot del stage activo. ~25 variantes cubriendo bars/lines/distribution/composition/specialized. Dispatcher en viz::render." +tags: [tql, viz, imgui, implot, sum-type] +uses_types: [] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +Tokens lowercase via `view_mode_token`/`view_mode_from_token` para TQL emit/apply. Helpers `view_mode_needs_numeric/category/aggregation` guían UI (combo selectable solo si schema compatible). + +`Table` siempre disponible (fallback render por defecto). Demas requieren al menos cols compatibles. Click-to-drill (fase 10): Bar/Column/Scatter/Bubble/Pie/Donut/Funnel/Heatmap. diff --git a/cpp/types/viz/viz_panel.md b/cpp/types/viz/viz_panel.md new file mode 100644 index 00000000..a90e0b38 --- /dev/null +++ b/cpp/types/viz/viz_panel.md @@ -0,0 +1,21 @@ +--- +name: VizPanel +lang: cpp +domain: viz +version: "1.0.0" +algebraic: product +definition: | + struct VizPanel { + ViewMode display; + ViewConfig config; + mutable ViewMode last_non_table; + }; +description: "Viz adicional sobre el mismo StageOutput. State tiene panel principal (display+viz_config) + vector extras." +tags: [tql, viz, panel, product-type] +uses_types: [ViewMode_cpp_viz, ViewConfig_cpp_viz] +file_path: "cpp/functions/core/data_table_types.h" +--- + +## Notas + +`last_non_table` memoria del ultimo display !=Table para toggle Table↔View rapido en UI. Mutable porque se actualiza durante render (no rompe const correctness). diff --git a/dev/issues/0078-tables-joins-mbql.md b/dev/issues/0078-tables-joins-mbql.md index b5ed68b9..7e988da1 100644 --- a/dev/issues/0078-tables-joins-mbql.md +++ b/dev/issues/0078-tables-joins-mbql.md @@ -1,9 +1,10 @@ --- id: 0078 title: tables playground — joins MBQL-style (fase 9) -status: pending +status: done priority: medium created: 2026-05-12 +closed: 2026-05-12 related_components: [cpp/apps/primitives_gallery/playground/tables, lua_engine, tql] --- diff --git a/dev/issues/0079-tables-drill-ext.md b/dev/issues/0079-tables-drill-ext.md index 4e4905f8..ba0b2cdb 100644 --- a/dev/issues/0079-tables-drill-ext.md +++ b/dev/issues/0079-tables-drill-ext.md @@ -1,9 +1,10 @@ --- id: 0079 title: tables playground — drill-through extendido (fase 10) -status: pending +status: done priority: medium created: 2026-05-12 +closed: 2026-05-12 related_components: [cpp/apps/primitives_gallery/playground/tables] --- diff --git a/dev/issues/0080-tables-llm-api.md b/dev/issues/0080-tables-llm-api.md index 35ff388a..63ce7f44 100644 --- a/dev/issues/0080-tables-llm-api.md +++ b/dev/issues/0080-tables-llm-api.md @@ -1,77 +1,238 @@ --- id: 0080 -title: tables playground — LLM API "Ask AI" (fase 11) -status: pending +title: tables playground — LLM "Ask AI" + TQL/SQL emit (fase 11) +status: partial priority: medium created: 2026-05-12 -related_components: [cpp/apps/primitives_gallery/playground/tables] +updated: 2026-05-13 +notes: pure layer + LLM client + Ask AI modal DONE. DuckDB adapter v2 (opcional, build flag FN_TQL_DUCKDB=1) +related_components: [cpp/apps/primitives_gallery/playground/tables, lua_engine, tql, duckdb] --- ## Contexto -Fase 11 del roadmap del tables playground. El user escribe en lenguaje natural -una pregunta sobre los datos ("show me top 10 langs by total size"). El LLM -recibe el TQL actual + schema + pregunta, devuelve nuevo TQL. App aplica via -`tql::apply` y renderiza. +Fase 11 del roadmap del tables playground. Dos capacidades que se construyen juntas porque comparten infra (prompt schema, runtime adapter, tests round-trip): + +1. **LLM "Ask AI"** — usuario o agente pregunta en lenguaje natural, modelo devuelve un nuevo TQL (o SQL DuckDB si esta linkado). +2. **TQL → SQL (DuckDB) emitter** — permite a agentes escribir SQL contra el mismo modelo de datos. Ejecutable si la app linkó DuckDB; si no, solo emite el string. + +Diseño one-way: **TQL → SQL si**, **SQL → TQL no**. Razon documentada en investigacion Metabase MBQL ↔ SQL: la traduccion inversa es lossy (CTEs, window fns, set ops, lateral, correlated subqueries no caben en MBQL/TQL). Patron canonico Malloy/Cube/LookML/Metabase = compile-down one-way. ## Cambios -### 1. UI +### 1. UI "Ask AI" - Boton "Ask AI" en toolbar (al lado de "+ Viz"). -- Modal con: +- Modal: - InputText multiline para la pregunta. - - Boton "Send" + spinner durante la llamada. - - Diff side-by-side: TQL actual vs TQL propuesto (texto con highlight). + - Toggle output mode: `TQL` (default) | `SQL (DuckDB)` (visible solo si app fue compilada con `FN_TQL_DUCKDB=1`). + - Boton "Send" + spinner. + - Diff side-by-side: actual vs propuesto (texto highlight). - Botones "Apply" / "Reject" / "Edit before apply". ### 2. Backend LLM -- Provider: Anthropic Claude (API key desde `pass anthropic/api-key`). -- Endpoint: `https://api.anthropic.com/v1/messages`. -- Model: `claude-sonnet-4-6` por defecto. Configurable via env `FN_LLM_MODEL`. -- Cliente HTTP: cURL via popen (sin deps nuevas) o libcurl si ya esta linkada. +- Provider: Anthropic Claude. API key via `pass anthropic/api-key`. +- Endpoint: `https://api.anthropic.com/v1/messages`. Model: `claude-sonnet-4-6`. Override env `FN_LLM_MODEL`. +- Cliente HTTP: cURL via popen (sin deps nuevas). - Prompt template incluye: - Esquema TQL (de `docs/TQL.md`). + - **Si SQL mode**: dialecto DuckDB + funciones DuckDB relevantes (date_trunc, regexp_replace, etc.). - Cols disponibles del stage 0 (name, type) + cols joinables. + - **Grammar Lua subset** (ver §4) cuando aplique. - Funciones Lua disponibles (de `lua_engine`). - TQL actual. - Pregunta del user. -- Response: extraer ```lua``` block del markdown, strip prose. +- Response: extraer ```lua``` (TQL) o ```sql``` block del markdown, strip prose. -### 3. Validacion + safety +### 3. TQL → SQL DuckDB emitter -- Antes de aplicar: `tql::apply` con dry-run (parsea sin mutar State). Si fail, mostrar error + boton "Ask AI again with this error". -- Lua sandbox ya cubre side effects en formulas — el TQL en si es declarativo, no ejecuta nada peligroso. +Nuevo modulo `tql_to_sql.{h,cpp}` (pure). Funciones: -### 4. Streaming +```cpp +struct SqlEmit { + std::string sql; // SELECT ... statement + std::vector params; // bound values (?-placeholders) + std::vector warnings; + std::string error; // si emit fallo (subset out of bounds) +}; -- Stream tokens via SSE (`stream=true` en Anthropic API). -- Mostrar texto en vivo en el modal. -- Cuando termina, parsear lua block final. +// Pure: emite SQL DuckDB equivalente a la pipeline State (stages 0..active). +// `tables` provee el schema de cada TableInput (no los cells — el caller +// decide como hidratar las tablas en DuckDB). +SqlEmit emit_sql(const State& state, const std::vector& tables, + int up_to_stage = -1 /* default = active_stage */); +``` -### 5. Persistencia conversation +Mapeo MBQL-style: +- Stage 0 = CTE base `t0` con `SELECT cols + derived FROM main_t [LEFT/INNER/RIGHT/FULL JOIN joinables ON ...]`. +- Stage N = CTE `tN` con `SELECT breakouts, aggregations FROM tN-1 [WHERE filters] [GROUP BY breakouts] [ORDER BY sorts]`. +- Final query `SELECT * FROM t`. -- UiState guarda lista de turns (pregunta + TQL propuesto + resultado apply). -- "Ask AI" siguiente turn incluye history previa. -- Boton "Reset chat" limpia. -- NO persistido en TQL (es UI state). +Stage emit detalle: +- `filter Op::Eq col = "v"` → `WHERE col = ?` con `params.push_back(v)` (DuckDB acepta `$1`/`?`). +- `breakout "ts:month"` → `date_trunc('month', ts) AS "ts:month"`. Granularity sufijo → DuckDB `date_trunc`. +- `aggregation count` → `COUNT(*) AS count`. +- `aggregation p95(col)` → `quantile_cont(col, 0.95) AS p95_col`. +- `aggregation distinct col` → `COUNT(DISTINCT col) AS distinct_col`. +- `sort {desc, col}` → `ORDER BY col DESC`. +- Joins: 4 strategies mapean directo a `LEFT/INNER/RIGHT/FULL JOIN ... ON l.k = r.k`. +- Derived cols: transpiladas via Lua subset (§4). Si formula fuera de subset → `SqlEmit.error = "lua formula 'X' out of subset: "`. -### 6. Coste / rate limit +Salida es **string SQL valido DuckDB**. No ejecuta — eso es responsabilidad del adapter opcional (§5). + +### 4. Lua subset transpilable a SQL — GRAMATICA + +Documentar en `docs/TQL.md` seccion nueva "SQL transpile subset". + +**Reglas duras: Lua sigue siendo potente y sin limites en runtime general.** El subset solo aplica si el caller pide `tql_to_sql::emit_sql()`. Fuera del subset → error claro en tiempo de emit, NO en tiempo de eval. El playground sigue ejecutando Lua arbitrario sin restriccion. + +**Subset permitido (transpila a SQL):** + +| Lua | SQL DuckDB | +|---|---| +| Literales: numero, string `"x"`, bool `true/false`, `nil` | `1.5`, `'x'`, `TRUE/FALSE`, `NULL` | +| Col ref: `[colname]` | `colname` (identifier quoted si necesario) | +| Aritmetica: `+ - * / % - (unary)` | mismas | +| Comparacion: `== ~= < <= > >=` | `= <> < <= > >=` | +| Logica: `and or not` | `AND OR NOT` | +| String concat: `..` | `\|\|` | +| Ternary: `if A then B else C end` | `CASE WHEN A THEN B ELSE C END` | +| Ternary inline: `(A and B) or C` (pattern comun Lua) | `CASE WHEN A THEN B ELSE C END` | +| `math.floor/ceil/abs/round/sqrt/sin/cos/log` | `floor/ceiling/abs/round/sqrt/sin/cos/ln` | +| `math.min(a,b)/max(a,b)` | `least(a,b)/greatest(a,b)` | +| `string.upper/lower/len(s)` | `upper(s)/lower(s)/length(s)` | +| `string.sub(s, i, j)` | `substring(s, i, j-i+1)` | +| `tostring(x)/tonumber(x)` | `CAST(x AS VARCHAR)/CAST(x AS DOUBLE)` | +| Paréntesis y precedencia | mismas | + +**Fuera de subset (error compile-time):** + +- Closures: `function() ... end` +- Loops: `for/while/repeat` +- Locals: `local x = ...` +- Tables: `{...}`, `t[k]`, `t.field`, `table.*` +- Multi-return / vararg +- `string.gsub/find/match/format` (mapeo manual posible v2) +- IO: `io.*`, `os.*`, `print` +- Coroutines, metatables, debug +- Recursion, multi-statement bodies + +**Error message ejemplo:** + +``` +SQL transpile error en derived col 'fullname': + formula = "[first] .. ' ' .. table.concat(parts, ',')" + causa: 'table.concat' no esta en SQL transpile subset + ver docs/TQL.md#sql-transpile-subset + workaround: usar TQL puro (sin SQL emit) o reescribir formula con `..` +``` + +**Helper:** `tql_to_sql::is_transpilable(formula, error_out)` pure fn que valida una formula sin emitir. + +### 5. DuckDB adapter (opcional) + +Build flag `FN_TQL_DUCKDB=1` en `cpp/CMakeLists.txt` opta-in. Vendor DuckDB header-only o lib (depende de tamaño). Default OFF — playground sigue compilando sin DuckDB. + +API adapter: + +```cpp +namespace tql_duckdb { +struct Result { + StageOutput out; // materializado como TableInput compatible + std::string error; + double duration_ms = 0; +}; +// Hidrata `tables` como views temp + ejecuta sql + materializa resultado. +Result execute(const std::string& sql, + const std::vector& params, + const std::vector& tables); +} +``` + +Apps que lo usen (registry_dashboard, sqlite_api): linkean DuckDB + invocan adapter cuando user/agent pide SQL output. Playground por defecto NO linka — `Ask AI` solo ofrece SQL mode si `#ifdef FN_TQL_DUCKDB`. + +### 6. Validacion + safety + +- Antes de aplicar TQL del LLM: `tql::apply` dry-run. Si fail, mostrar error + "Ask AI again with this error". +- Antes de ejecutar SQL del LLM: parsing DuckDB en sandbox read-only (DuckDB connection sin `INSERT/UPDATE/DELETE/DROP`, attach read-only). +- Lua sandbox ya cubre side effects en formulas TQL. + +### 7. Streaming + +- Stream tokens via SSE (`stream=true` Anthropic). +- Texto en vivo en modal. +- Cuando termina, parse lua/sql block final. + +### 8. Persistencia conversacion + +- UiState guarda lista de turns (pregunta + output propuesto + apply result + engine usado TQL/SQL). +- Siguiente "Ask AI" turn incluye history previa. +- Boton "Reset chat". +- NO persistido en TQL (UI state efimero). + +### 9. Coste / rate limit - Mostrar tokens estimados antes de enviar (rough char count / 4). - Cap input a 8000 tokens. -- Error handling: 429 / 5xx -> mensaje + reintentar. +- Error handling: 429 / 5xx → mensaje + reintentar. ## Tests -- Mockear HTTP response con cURL stub. -- Test: prompt build incluye schema + TQL + pregunta en formato esperado. -- Test: response parse extrae lua block correctamente. -- Test: tql::apply sobre output del LLM funciona end-to-end con dataset sintetico. +### Pure (sin red, sin DuckDB linkado) + +- **Lua subset validator:** `is_transpilable` true para casos subset, false con error claro para fuera de subset (closures, loops, table.*, string.gsub, etc.). +- **TQL → SQL emit golden tests** (~20 casos): + - stage 0 simple filter + sort → `SELECT ... WHERE ... ORDER BY ...` + - stage 1 group + count → CTE chain con GROUP BY + - granularity sufijo `:month` → `date_trunc('month', ts)` + - join 4 strategies con multi-key + - derived cols subset → CASE/expressions + - derived cols fuera subset → `SqlEmit.error` no vacio + warning + - aggregation p25/p50/p75/p99 → `quantile_cont(col, p)` + - empty pipeline → `SELECT * FROM t0` +- **TQL parseo:** prompt build incluye schema + TQL + pregunta en formato esperado (mockear HTTP). +- **Response parse:** extrae lua/sql block correctamente. + +### Round-trip (requiere DuckDB linkado) + +Solo corren si `FN_TQL_DUCKDB=1`: +- TQL → emit SQL → ejecutar DuckDB → resultado coincide bit-a-bit con `compute_stage` pure sobre los mismos cells. +- Casos: filter, group+agg, join inner, multi-stage chain, breakout granularity month/week, derived col `[a] + [b] * 2`. + +### LLM (red real, opt-in) + +- Test integration con `ANTHROPIC_API_KEY` real (`make test-llm`): pregunta simple → recibe TQL valido → apply OK. +- Mock test (CI): cURL stub responde con JSON predefinido → parser extrae bloque OK. ## No-objetivos -- Generacion de visualizaciones nuevas via LLM (la viz la elige TQL `display`, suficiente). -- Acciones del LLM mas alla de modificar TQL (sin acceso a I/O del sistema). -- Multi-provider (OpenAI / local) — fase futura. Hardcode Anthropic primero. +- **SQL → TQL**: no se implementa. Documentado en doc + en mensajes de error del Ask AI ("no soportamos SQL como input, use TQL"). +- **Multi-provider** (OpenAI, local): fase futura. Anthropic hardcoded v1. +- **Generacion de viz desde LLM** mas alla de `display` token: la viz la elige TQL existente. +- **Lua subset extension** (string.gsub, regex, table.*): postpone v2 si demanda real. +- **DuckDB write ops**: solo SELECT/CTE. Apps que quieran INSERT/UPDATE lo hacen fuera del playground. + +## Flujo agente (resumen) + +``` +Agente -> "muestrame top 10 langs por total size" +LLM (TQL default) -> emite TQL { stages = {...} } +tql::apply -> State + dry-run OK +User clickea Apply -> compute_stage en memoria + +Agente -> "lo mismo pero como SQL" +[Si FN_TQL_DUCKDB=1 y app linkó adapter] +LLM (SQL mode toggled) -> emite SELECT ... DuckDB +duckdb::execute(sql, params, tables) -> resultado materializado +[Si NO linkado] -> error "SQL mode requiere DuckDB. Compila con FN_TQL_DUCKDB=1" +``` + +## Riesgos + +- **Subset Lua restrictivo en SQL emit**: usuarios usan Lua arbitrario en playground → al pedir SQL falla. Mitigacion: error message claro + sugerencia workaround. +- **DuckDB tamaño**: lib ~10MB. Solo se paga si app opta-in con build flag. +- **Dialect drift DuckDB**: funciones SQL pueden cambiar entre versiones. Pinear DuckDB version en CMake. +- **LLM hallucinations**: TQL invalido → dry-run rechaza con error. Loop "Ask AI again with this error" recupera. +- **API key leak**: `pass` integration mantiene fuera del repo. Build flag NUNCA imprime key. +- **Coste tokens**: prompt grande (schema + grammar + TQL). Cap input + warning visual. diff --git a/docs/TQL.md b/docs/TQL.md index 705af849..fc56afc4 100644 --- a/docs/TQL.md +++ b/docs/TQL.md @@ -496,3 +496,87 @@ StageOutput compute_stage(const char* const* in_cells, int in_rows, int in_cols, | Multi-sort drag-reorder | Phase 4 | Ver `cpp/apps/primitives_gallery/playground/tables/` para la implementacion del playground. + +--- + +## SQL transpile subset (fase 11 — issue 0080) + +TQL emite SQL DuckDB equivalente para que agentes LLM puedan generar TQL o SQL contra los mismos datos. Modulo `tql_to_sql.{h,cpp}` provee `emit_sql(State, tables)`. Mapeo MBQL-style con CTE chain `t0..tN`. + +### Lua subset transpilable + +Lua sigue **potente y sin limites en runtime general** (formula eval en derived cols TQL puro). El subset SOLO aplica al pedir `tql_to_sql::emit_sql()`. Fuera del subset → error compile-time con causa concreta + workaround. + +**Permitido (transpila a SQL DuckDB):** + +| Lua | SQL DuckDB | Ejemplo | +|---|---|---| +| Literales numero/string/bool/nil | mismas (`'x'`, `TRUE`, `NULL`) | `42`, `"hola"`, `nil` | +| Col ref: `[colname]` | `"colname"` (quoted) | `[size_kb]` → `"size_kb"` | +| Aritmetica: `+ - * / % - (unary)` | mismas | `[a] + [b] * 2` → `("a" + ("b" * 2))` | +| Comparacion: `== ~= < <= > >=` | `= <> < <= > >=` | `[n] >= 10` → `("n" >= 10)` | +| Logica: `and or not` | `AND OR NOT` | `[a] and [b]` → `("a" AND "b")` | +| String concat: `..` | `\|\|` | `[a] .. "_" .. [b]` → `("a" \|\| '_' \|\| "b")` | +| Ternary: `if A then B else C end` | `CASE WHEN A THEN B ELSE C END` | obligatorio `else` | +| `math.floor/ceil/abs/sqrt/sin/cos/log/exp` | `floor/ceiling/abs/sqrt/sin/cos/ln/exp` | `math.floor([x])` | +| `math.min(a,b)/max(a,b)` | `least(a,b)/greatest(a,b)` | `math.min([a], 100)` | +| `string.upper/lower/len(s)` | `upper(s)/lower(s)/length(s)` | `string.upper([name])` | +| `string.sub(s, i [, j])` | `substring(s, i [, j-i+1])` | `string.sub([s], 1, 3)` | +| `tostring(x)/tonumber(x)` | `CAST(x AS VARCHAR)/CAST(x AS DOUBLE)` | `tonumber([n])` | +| Parentesis y precedencia Lua | mismas | `(a + b) * c` | + +**Fuera de subset (error compile-time):** + +- Closures: `function() ... end` +- Loops: `for/while/repeat` +- Locals: `local x = ...` +- Tables: `{...}`, `t[k]`, `t.field`, `table.*` +- Multi-return, vararg `...` +- `string.gsub/find/match/format/byte/char/rep` +- IO/OS/debug: `io.*`, `os.*`, `debug.*`, `package`, `require`, `print` +- Coroutines, metatables, `pcall/xpcall`, `rawget/rawset` +- Recursion, multi-statement bodies (`;`) +- Length operator `#` +- Method calls `:` +- Ternary sin else: `if A then B end` (subset requiere ambas ramas) + +### Error message ejemplo + +``` +SQL transpile error en derived col 'fullname': + formula = "[first] .. ' ' .. string.gsub([last], 'X', 'Y')" + causa: function 'string.gsub' not in SQL transpile whitelist + ver docs/TQL.md#sql-transpile-subset + workaround: usar TQL puro (sin SQL emit) o reescribir formula +``` + +### Stage → SQL mapeo + +| TQL element | SQL DuckDB | +|---|---| +| Stage 0 Raw | CTE `t0 AS (SELECT cols+derived FROM main_t [JOIN ...] [WHERE filters] [ORDER BY sorts])` | +| Stage N>=1 | CTE `tN AS (SELECT breakouts+aggs FROM tN-1 [GROUP BY ...] [ORDER BY ...])` | +| breakout `"col"` | `"col"` | +| breakout `"col:month"` | `date_trunc('month', "col")` | +| breakout `"col:year/week/day/hour"` | `date_trunc('year/week/day/hour', "col")` | +| Aggregation Count | `COUNT(*)` | +| Aggregation Sum/Avg/Min/Max/Stddev | `SUM/AVG/MIN/MAX/STDDEV("col")` | +| Aggregation Distinct | `COUNT(DISTINCT "col")` | +| Aggregation Median/P25/P75/P90/P99 | `quantile_cont("col", p)` | +| Aggregation Percentile p | `quantile_cont("col", p)` | +| Filter Op::Eq/Neq/Gt/Gte/Lt/Lte | `"col" = ?` etc (params bound) | +| Filter Op::Contains | `"col" LIKE '%v%'` (param `%v%`) | +| Filter Op::StartsWith / EndsWith | `LIKE 'v%'` / `LIKE '%v'` | +| Sort `{desc, "col"}` | `ORDER BY "col" DESC` | +| Join Left/Inner/Right/Full | `LEFT/INNER/RIGHT/FULL OUTER JOIN ... ON ...` | +| Join multi-key `on={{l1,r1},{l2,r2}}` | `ON l.l1 = r.r1 AND l.l2 = r.r2` | +| Join fields | cols `alias.field AS "alias.field"` | +| `main_source` | `FROM "main_source_name"` | + +### Doctrina (Metabase-style) + +- **One-way:** TQL → SQL OK. SQL → TQL no soportado. Razon: traduccion inversa lossy (CTEs, window fns, set ops, lateral, correlated subqueries no caben en TQL). +- **Output:** SQL string siempre emitible. Ejecucion requiere DuckDB linkado (build flag `FN_TQL_DUCKDB=1`, opcional). +- **Agente flow:** TQL default. SQL solo si app linko DuckDB. UI Ask AI muestra toggle SQL solo cuando disponible. + +Ver issue 0080 + `tql_to_sql.{h,cpp}` para implementacion. diff --git a/functions/core/subprocess_stream.go b/functions/core/subprocess_stream.go new file mode 100644 index 00000000..feeac3d7 --- /dev/null +++ b/functions/core/subprocess_stream.go @@ -0,0 +1,155 @@ +package core + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "os/exec" + "sync" + "syscall" + "time" +) + +// StreamEvent es una linea capturada de stdout o stderr del subproceso. +type StreamEvent struct { + Stream string // "stdout" | "stderr" + Line string // sin trailing newline + Time time.Time // timestamp de recepcion +} + +// StreamResult es el resultado final del subproceso, enviado por el canal de +// resultados cuando ambos pipes han llegado a EOF y el proceso ha terminado. +type StreamResult struct { + ExitCode int + Err error + DurationMs int64 +} + +// SubprocessStream lanza name con args como subproceso y retorna dos canales: +// - events: recibe StreamEvent (linea de stdout/stderr) hasta EOF de ambos pipes. +// - result: recibe exactamente un StreamResult cuando el proceso termina. +// +// env se concatena con os.Environ(). stdin puede ser nil. +// +// Cancelar ctx envia SIGTERM al proceso; si no termina en 2 segundos, SIGKILL. +// El caller DEBE consumir events hasta que se cierre o cancelar ctx para evitar +// bloquear las goroutines internas. +func SubprocessStream( + ctx context.Context, + name string, + args []string, + env []string, + stdin io.Reader, +) (<-chan StreamEvent, <-chan StreamResult) { + events := make(chan StreamEvent, 64) + results := make(chan StreamResult, 1) + + go func() { + defer close(events) + defer close(results) + + start := time.Now() + + cmd := exec.CommandContext(ctx, name, args...) + + // Entorno: base + extra + if len(env) > 0 { + cmd.Env = append(os.Environ(), env...) + } + + if stdin != nil { + cmd.Stdin = stdin + } + + // Process group propio para matar hijos al recibir SIGTERM/SIGKILL + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + results <- StreamResult{ExitCode: -1, Err: fmt.Errorf("stdout pipe: %w", err), DurationMs: 0} + return + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + results <- StreamResult{ExitCode: -1, Err: fmt.Errorf("stderr pipe: %w", err), DurationMs: 0} + return + } + + if err := cmd.Start(); err != nil { + results <- StreamResult{ExitCode: -1, Err: fmt.Errorf("start: %w", err), DurationMs: 0} + return + } + + // Goroutine de supervision de ctx: SIGTERM → grace 2s → SIGKILL + ctxDone := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + if cmd.Process != nil { + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM) + timer := time.NewTimer(2 * time.Second) + defer timer.Stop() + select { + case <-timer.C: + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + case <-ctxDone: + } + } + case <-ctxDone: + } + }() + + send := func(stream, line string) { + ev := StreamEvent{Stream: stream, Line: line, Time: time.Now()} + select { + case events <- ev: + case <-ctx.Done(): + } + } + + // Leer stdout y stderr concurrentemente + const bufSize = 1024 * 1024 // 1 MB para lineas largas (sd-cli progress, etc.) + var wg sync.WaitGroup + + scanPipe := func(r io.Reader, stream string) { + defer wg.Done() + sc := bufio.NewScanner(r) + sc.Buffer(make([]byte, bufSize), bufSize) + for sc.Scan() { + send(stream, sc.Text()) + } + } + + wg.Add(2) + go scanPipe(stdoutPipe, "stdout") + go scanPipe(stderrPipe, "stderr") + + wg.Wait() + close(ctxDone) // señal al supervisor de ctx para que pare + + exitCode := 0 + var waitErr error + if err := cmd.Wait(); err != nil { + waitErr = err + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + waitErr = nil // exit code no-cero no es un error de spawn + } + } + + // Si el contexto fue cancelado, reportar como error de cancelacion + if ctx.Err() != nil && waitErr == nil { + waitErr = ctx.Err() + } + + results <- StreamResult{ + ExitCode: exitCode, + Err: waitErr, + DurationMs: time.Since(start).Milliseconds(), + } + }() + + return events, results +} diff --git a/functions/core/subprocess_stream.md b/functions/core/subprocess_stream.md new file mode 100644 index 00000000..e0cbd545 --- /dev/null +++ b/functions/core/subprocess_stream.md @@ -0,0 +1,69 @@ +--- +name: subprocess_stream +kind: function +lang: go +domain: core +version: "1.0.0" +purity: impure +signature: "func SubprocessStream(ctx context.Context, name string, args []string, env []string, stdin io.Reader) (<-chan StreamEvent, <-chan StreamResult)" +description: "Lanza un subproceso y retorna dos canales: uno con StreamEvent (linea de stdout/stderr con timestamp) y otro con un unico StreamResult (ExitCode, Err, DurationMs). Cancelar ctx envia SIGTERM al proceso; si no termina en 2s, SIGKILL." +tags: [subprocess, exec, stream, stdout, stderr, process, concurrency, io, primitiva] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [bufio, context, fmt, io, os, os/exec, sync, syscall, time] +params: + - name: ctx + desc: "Contexto de cancelacion. Al cancelar, el proceso recibe SIGTERM; si no muere en 2s, SIGKILL. Usar context.WithTimeout para acotar duracion maxima." + - name: name + desc: "Nombre o path del ejecutable a lanzar (ej. 'echo', '/usr/bin/python3')." + - name: args + desc: "Argumentos del proceso. Puede ser nil o vacio." + - name: env + desc: "Variables de entorno adicionales en formato 'KEY=VALUE'. Se concatenan con os.Environ(). Puede ser nil." + - name: stdin + desc: "Stdin del proceso. Puede ser nil si el proceso no necesita entrada." +output: "Dos canales: events (<-chan StreamEvent) cerrado cuando ambos pipes EOF; result (<-chan StreamResult) con exactamente un valor cuando el proceso termina. El caller DEBE consumir events hasta cierre o cancelar ctx para evitar bloquear goroutines internas." +tested: true +tests: + - "echo stdout llega como evento y ExitCode 0" + - "stderr llega como evento con stream stderr" + - "exit code no-cero se reporta en StreamResult" + - "ctx cancelado termina el proceso" + - "multiples lineas stdout" +test_file_path: "functions/core/subprocess_stream_test.go" +file_path: "functions/core/subprocess_stream.go" +--- + +## Ejemplo + +```go +ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +defer cancel() + +events, results := core.SubprocessStream(ctx, "grep", []string{"-rn", "TODO", "."}, nil, nil) + +for ev := range events { + switch ev.Stream { + case "stdout": + fmt.Println(ev.Line) + case "stderr": + fmt.Fprintln(os.Stderr, "[stderr]", ev.Line) + } +} + +res := <-results +if res.ExitCode != 0 || res.Err != nil { + log.Printf("grep exit=%d err=%v duration=%dms", res.ExitCode, res.Err, res.DurationMs) +} +``` + +## Notas + +- El canal `events` tiene buffer de 64. Si el caller deja de consumir y el buffer se llena, las goroutinas internas se bloquean hasta que haya espacio o el ctx sea cancelado. +- El scanner de cada pipe tiene un buffer de 1 MB para tolerar lineas muy largas (progreso de CLIs tipo sd-cli, barras ANSI largas). +- Los structs `StreamEvent` y `StreamResult` se declaran en el mismo archivo para que el paquete `core` los exporte sin imports adicionales. +- Generaliza el patron de `claude_stream_go_core` desacoplando el lanzamiento de subprocesos del protocolo especifico de claude (NDJSON/stream-json). `claude_stream_go_core` puede reimplementarse internamente usando esta funcion como primitiva. +- `cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}` crea un process group propio; SIGTERM/SIGKILL se envian con `Kill(-pgid, sig)` para matar tambien los procesos hijo del hijo. diff --git a/functions/core/subprocess_stream_test.go b/functions/core/subprocess_stream_test.go new file mode 100644 index 00000000..516334c1 --- /dev/null +++ b/functions/core/subprocess_stream_test.go @@ -0,0 +1,132 @@ +package core + +import ( + "context" + "testing" + "time" +) + +func TestSubprocessStream(t *testing.T) { + t.Run("echo stdout llega como evento y ExitCode 0", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + events, results := SubprocessStream(ctx, "echo", []string{"hola"}, nil, nil) + + var got []StreamEvent + for ev := range events { + got = append(got, ev) + } + + res := <-results + + if res.ExitCode != 0 { + t.Errorf("ExitCode = %d, want 0 (err: %v)", res.ExitCode, res.Err) + } + if res.Err != nil { + t.Errorf("unexpected Err: %v", res.Err) + } + if len(got) != 1 { + t.Fatalf("got %d events, want 1", len(got)) + } + if got[0].Stream != "stdout" { + t.Errorf("Stream = %q, want %q", got[0].Stream, "stdout") + } + if got[0].Line != "hola" { + t.Errorf("Line = %q, want %q", got[0].Line, "hola") + } + }) + + t.Run("stderr llega como evento con stream stderr", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // sh -c "echo msg >&2" escribe a stderr + events, results := SubprocessStream(ctx, "sh", []string{"-c", "echo error_msg >&2"}, nil, nil) + + var got []StreamEvent + for ev := range events { + got = append(got, ev) + } + res := <-results + + if res.ExitCode != 0 { + t.Errorf("ExitCode = %d, want 0", res.ExitCode) + } + if len(got) != 1 { + t.Fatalf("got %d events, want 1", len(got)) + } + if got[0].Stream != "stderr" { + t.Errorf("Stream = %q, want %q", got[0].Stream, "stderr") + } + if got[0].Line != "error_msg" { + t.Errorf("Line = %q, want %q", got[0].Line, "error_msg") + } + }) + + t.Run("exit code no-cero se reporta en StreamResult", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + events, results := SubprocessStream(ctx, "sh", []string{"-c", "exit 42"}, nil, nil) + + for range events { + } + res := <-results + + if res.ExitCode != 42 { + t.Errorf("ExitCode = %d, want 42", res.ExitCode) + } + if res.Err != nil { + t.Errorf("unexpected Err: %v", res.Err) + } + }) + + t.Run("ctx cancelado termina el proceso", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // proceso que dura mucho; cancelamos enseguida + ctxShort, cancelShort := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancelShort() + + events, results := SubprocessStream(ctxShort, "sleep", []string{"60"}, nil, nil) + + for range events { + } + res := <-results + + // Tras cancelacion el proceso debe haber terminado (ExitCode != 0 o Err de ctx) + if res.ExitCode == 0 && res.Err == nil { + t.Error("expected non-zero exit or ctx error after cancellation") + } + if res.DurationMs > 3000 { + t.Errorf("took %d ms, expected < 3000 (should have been killed)", res.DurationMs) + } + }) + + t.Run("multiples lineas stdout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + events, results := SubprocessStream(ctx, "sh", []string{"-c", "printf 'a\nb\nc\n'"}, nil, nil) + + var lines []string + for ev := range events { + if ev.Stream == "stdout" { + lines = append(lines, ev.Line) + } + } + <-results + + if len(lines) != 3 { + t.Fatalf("got %d stdout lines, want 3: %v", len(lines), lines) + } + want := []string{"a", "b", "c"} + for i, w := range want { + if lines[i] != w { + t.Errorf("line[%d] = %q, want %q", i, lines[i], w) + } + } + }) +} diff --git a/functions/infra/audit_ml_env.go b/functions/infra/audit_ml_env.go new file mode 100644 index 00000000..38ec80a1 --- /dev/null +++ b/functions/infra/audit_ml_env.go @@ -0,0 +1,238 @@ +package infra + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +// MlEnvCheck holds the result of a single ML environment probe. +type MlEnvCheck struct { + Name string `json:"name"` // e.g. "cuda_toolkit", "python_venv" + Status string `json:"status"` // "ok" | "missing" | "warning" | "unknown" + Version string `json:"version,omitempty"` // version string if detected + Detail string `json:"detail,omitempty"` // human-readable extra info +} + +// MlEnvReport is the full ML environment audit result. +type MlEnvReport struct { + Gpus []GpuInfo `json:"gpus"` + Checks []MlEnvCheck `json:"checks"` + OverallOK bool `json:"overall_ok"` + GeneratedAt int64 `json:"generated_at"` +} + +// AuditMlEnv probes the ML environment rooted at registryRoot. +// It checks for NVIDIA drivers, CUDA toolkit, Python venv, key Python +// packages and optional tools (sd, llama-cli) and a local vault path. +// Returns a non-nil MlEnvReport even when individual checks fail; the +// function itself only errors if a fundamental system call cannot be +// attempted. +func AuditMlEnv(registryRoot string) (MlEnvReport, error) { + report := MlEnvReport{ + GeneratedAt: time.Now().Unix(), + } + + // --- GPU detection (composes GetGpuInfo) --- + gpus, err := GetGpuInfo() + if err != nil { + // Non-fatal: record absence. + gpus = []GpuInfo{} + } + report.Gpus = gpus + + checks := []MlEnvCheck{} + + // --- nvidia-smi --- + checks = append(checks, probeCommand("nvidia_smi", "nvidia-smi", []string{"--version"}, 5)) + + // --- nvcc (CUDA toolkit compiler) --- + nvcc := probeNvcc() + checks = append(checks, nvcc) + + // --- Python venv --- + venvCheck := probeVenv(registryRoot) + checks = append(checks, venvCheck) + + // Python venv path for subsequent checks. + venvPy := filepath.Join(registryRoot, "python", ".venv", "bin", "python3") + + // --- Python packages --- + for _, pkg := range []string{"torch", "diffusers", "transformers", "huggingface_hub", "stable_diffusion_cpp_python"} { + checks = append(checks, probePythonPackage(venvPy, pkg)) + } + + // --- sd.cpp CLI --- + checks = append(checks, probeCommand("sd_cli", "sd", []string{"--version"}, 5)) + + // --- llama.cpp CLI --- + checks = append(checks, probeCommand("llama_cpp", "llama-cli", []string{"--version"}, 5)) + + // --- imagegen_vault --- + checks = append(checks, probeImagegenVault()) + + report.Checks = checks + + // OverallOK: no "missing" checks (warning is tolerated) and at least 1 GPU. + overallOK := len(gpus) > 0 + for _, c := range checks { + if c.Status == "missing" { + // stable_diffusion_cpp_python and sd_cli are optional — downgrade to warning-only. + if c.Name == "stable_diffusion_cpp_python" || c.Name == "sd_cli" || c.Name == "llama_cpp" { + continue + } + overallOK = false + } + } + report.OverallOK = overallOK + + return report, nil +} + +// probeCommand checks whether a binary is available in PATH by running it with +// the given args and recording any version output. +func probeCommand(name, binary string, args []string, timeoutSec int) MlEnvCheck { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) + defer cancel() + + path, err := exec.LookPath(binary) + if err != nil { + return MlEnvCheck{Name: name, Status: "missing", Detail: fmt.Sprintf("%s not found in PATH", binary)} + } + + out, err := exec.CommandContext(ctx, path, args...).CombinedOutput() + version := strings.TrimSpace(string(out)) + if len(version) > 120 { + version = version[:120] + } + if err != nil { + return MlEnvCheck{Name: name, Status: "warning", Version: version, Detail: fmt.Sprintf("exit error: %v", err)} + } + return MlEnvCheck{Name: name, Status: "ok", Version: version} +} + +// probeNvcc extracts the CUDA toolkit version from nvcc --version output. +func probeNvcc() MlEnvCheck { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + path, err := exec.LookPath("nvcc") + if err != nil { + return MlEnvCheck{Name: "nvcc", Status: "missing", Detail: "nvcc not found in PATH (CUDA toolkit not installed)"} + } + + out, err := exec.CommandContext(ctx, path, "--version").CombinedOutput() + if err != nil { + return MlEnvCheck{Name: "nvcc", Status: "warning", Detail: fmt.Sprintf("nvcc --version failed: %v", err)} + } + + // Extract version from line like: "Cuda compilation tools, release 12.4, V12.4.99" + version := "" + for _, line := range strings.Split(string(out), "\n") { + if strings.Contains(line, "release") { + parts := strings.Split(line, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if strings.HasPrefix(p, "release") { + version = strings.TrimSpace(strings.TrimPrefix(p, "release")) + break + } + } + break + } + } + if version == "" { + version = strings.TrimSpace(string(out)) + if len(version) > 80 { + version = version[:80] + } + } + return MlEnvCheck{Name: "nvcc", Status: "ok", Version: version} +} + +// probeVenv checks that the Python venv exists and is functional. +func probeVenv(registryRoot string) MlEnvCheck { + py := filepath.Join(registryRoot, "python", ".venv", "bin", "python3") + if _, err := os.Stat(py); os.IsNotExist(err) { + return MlEnvCheck{Name: "python_venv", Status: "missing", Detail: fmt.Sprintf("not found: %s", py)} + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + out, err := exec.CommandContext(ctx, py, "--version").CombinedOutput() + version := strings.TrimSpace(string(out)) + if err != nil { + return MlEnvCheck{Name: "python_venv", Status: "warning", Version: version, Detail: fmt.Sprintf("python3 --version failed: %v", err)} + } + return MlEnvCheck{Name: "python_venv", Status: "ok", Version: version} +} + +// probePythonPackage imports a package in the venv Python and extracts __version__. +func probePythonPackage(venvPy, pkg string) MlEnvCheck { + // Map package name → import name (for packages with different import names). + importName := pkg + switch pkg { + case "stable_diffusion_cpp_python": + importName = "stable_diffusion_cpp" + case "huggingface_hub": + importName = "huggingface_hub" + } + + // Check that the venv python binary exists first. + if _, err := os.Stat(venvPy); os.IsNotExist(err) { + return MlEnvCheck{Name: pkg, Status: "unknown", Detail: "python_venv not available"} + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + script := fmt.Sprintf("import %s; v = getattr(%s, '__version__', None); print(v or 'unknown')", importName, importName) + out, err := exec.CommandContext(ctx, venvPy, "-c", script).CombinedOutput() + output := strings.TrimSpace(string(out)) + + if err != nil { + // Module not found → missing; other errors → warning. + detail := output + if len(detail) > 200 { + detail = detail[:200] + } + if strings.Contains(output, "ModuleNotFoundError") || strings.Contains(output, "No module named") { + return MlEnvCheck{Name: pkg, Status: "missing", Detail: fmt.Sprintf("%s not installed", importName)} + } + return MlEnvCheck{Name: pkg, Status: "warning", Detail: detail} + } + return MlEnvCheck{Name: pkg, Status: "ok", Version: output} +} + +// probeImagegenVault checks that ~/vaults/imagegen_models exists and lists subdirs. +func probeImagegenVault() MlEnvCheck { + home, err := os.UserHomeDir() + if err != nil { + return MlEnvCheck{Name: "imagegen_vault", Status: "unknown", Detail: "cannot determine home directory"} + } + vaultPath := filepath.Join(home, "vaults", "imagegen_models") + entries, err := os.ReadDir(vaultPath) + if os.IsNotExist(err) { + return MlEnvCheck{Name: "imagegen_vault", Status: "missing", Detail: fmt.Sprintf("vault not found: %s", vaultPath)} + } + if err != nil { + return MlEnvCheck{Name: "imagegen_vault", Status: "warning", Detail: fmt.Sprintf("cannot read vault: %v", err)} + } + + subdirs := []string{} + for _, e := range entries { + if e.IsDir() { + subdirs = append(subdirs, e.Name()) + } + } + detail := fmt.Sprintf("subdirs: %s", strings.Join(subdirs, ", ")) + if len(subdirs) == 0 { + detail = "vault exists but is empty" + } + return MlEnvCheck{Name: "imagegen_vault", Status: "ok", Detail: detail} +} diff --git a/functions/infra/audit_ml_env.md b/functions/infra/audit_ml_env.md new file mode 100644 index 00000000..b8432c6c --- /dev/null +++ b/functions/infra/audit_ml_env.md @@ -0,0 +1,67 @@ +--- +name: audit_ml_env +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func AuditMlEnv(registryRoot string) (MlEnvReport, error)" +description: "Audita el entorno ML del sistema: GPUs NVIDIA, toolkit CUDA, venv Python, paquetes clave (torch, diffusers, transformers, huggingface_hub), herramientas CLI (sd, llama-cli) y el vault de modelos. Retorna un MlEnvReport con OverallOK=true solo si hay al menos 1 GPU y los checks criticos estan en ok/warning." +tags: [ml, cuda, gpu, nvidia, audit, doctor, infra, torch, diffusers] +uses_functions: [get_gpu_info_go_infra] +uses_types: [gpu_info_go_infra] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [context, fmt, os, os/exec, path/filepath, strings, time] +tested: true +tests: + - "report no nil y tiene checks" + - "generated_at es positivo" + - "checks tiene al menos 4 entradas" + - "gpus puede ser vacio en CI" +test_file_path: "functions/infra/audit_ml_env_test.go" +file_path: "functions/infra/audit_ml_env.go" +params: + - name: registryRoot + desc: "Ruta absoluta a la raiz del fn_registry. Se usa para localizar python/.venv/bin/python3 y probar paquetes instalados." +output: "MlEnvReport con Gpus (puede estar vacio si no hay NVIDIA), Checks con estado por herramienta/paquete, OverallOK y GeneratedAt (unix timestamp)." +--- + +## Checks realizados + +| Check | Tipo | Critico | +|---|---|---| +| `nvidia_smi` | binary in PATH | no (ok si hay GPU) | +| `nvcc` | CUDA toolkit version | no | +| `python_venv` | exists + `python3 --version` | si | +| `torch` | `import torch; __version__` | si | +| `diffusers` | `import diffusers; __version__` | si | +| `transformers` | `import transformers; __version__` | si | +| `huggingface_hub` | `import huggingface_hub; __version__` | si | +| `stable_diffusion_cpp_python` | `import stable_diffusion_cpp` | no (opcional) | +| `sd_cli` | `sd --version` in PATH | no (opcional) | +| `llama_cpp` | `llama-cli --version` in PATH | no (opcional) | +| `imagegen_vault` | `~/vaults/imagegen_models` exists | no | + +## Ejemplo + +```go +root := "/home/lucas/fn_registry" +report, err := AuditMlEnv(root) +if err != nil { + log.Fatal(err) +} +for _, c := range report.Checks { + fmt.Printf("%-40s %s %s\n", c.Name, c.Status, c.Version) +} +fmt.Printf("OverallOK: %v\n", report.OverallOK) +``` + +## Notas + +- Cada check tiene timeout de 5 segundos para no bloquear en entornos sin GPU. +- `stable_diffusion_cpp_python`, `sd_cli` y `llama_cpp` son opcionales: si estan missing, `OverallOK` no se ve afectado. +- `OverallOK` requiere al menos 1 GPU NVIDIA detectada via `GetGpuInfo()`. +- No escribe nada en disco. Read-only. +- Se expone como `fn doctor ml` via cmd/fn/doctor.go. diff --git a/functions/infra/audit_ml_env_test.go b/functions/infra/audit_ml_env_test.go new file mode 100644 index 00000000..c94de5fd --- /dev/null +++ b/functions/infra/audit_ml_env_test.go @@ -0,0 +1,53 @@ +package infra + +import ( + "testing" +) + +func TestAuditMlEnv(t *testing.T) { + // Use the actual registry root relative to the test binary location. + // Tests run from the package directory; go up two levels. + registryRoot := "../.." + + t.Run("report no nil y tiene checks", func(t *testing.T) { + report, err := AuditMlEnv(registryRoot) + if err != nil { + t.Fatalf("AuditMlEnv returned error: %v", err) + } + if report.Checks == nil { + t.Fatal("report.Checks is nil") + } + }) + + t.Run("generated_at es positivo", func(t *testing.T) { + report, err := AuditMlEnv(registryRoot) + if err != nil { + t.Fatalf("AuditMlEnv returned error: %v", err) + } + if report.GeneratedAt <= 0 { + t.Errorf("GeneratedAt should be positive unix timestamp, got %d", report.GeneratedAt) + } + }) + + t.Run("checks tiene al menos 4 entradas", func(t *testing.T) { + report, err := AuditMlEnv(registryRoot) + if err != nil { + t.Fatalf("AuditMlEnv returned error: %v", err) + } + if len(report.Checks) < 4 { + t.Errorf("expected at least 4 checks, got %d", len(report.Checks)) + } + }) + + t.Run("gpus puede ser vacio en CI", func(t *testing.T) { + report, err := AuditMlEnv(registryRoot) + if err != nil { + t.Fatalf("AuditMlEnv returned error: %v", err) + } + // Gpus may be empty in CI without a GPU; that's OK. + // Just verify the field is not nil. + if report.Gpus == nil { + t.Error("report.Gpus should be a non-nil slice (can be empty)") + } + }) +} diff --git a/functions/infra/get_gpu_info.go b/functions/infra/get_gpu_info.go new file mode 100644 index 00000000..fd004032 --- /dev/null +++ b/functions/infra/get_gpu_info.go @@ -0,0 +1,60 @@ +package infra + +import ( + "encoding/csv" + "errors" + "fmt" + "os/exec" + "strconv" + "strings" +) + +// GetGpuInfo queries NVIDIA GPUs via nvidia-smi and returns a slice of GpuInfo. +// If nvidia-smi is not installed or no NVIDIA GPU is present, returns an empty +// slice and a nil error (absence of NVIDIA hardware is not an error). +func GetGpuInfo() ([]GpuInfo, error) { + out, err := exec.Command( + "nvidia-smi", + "--query-gpu=index,name,memory.total,memory.free,driver_version,cuda_version", + "--format=csv,noheader,nounits", + ).Output() + + if err != nil { + // nvidia-smi not installed or no NVIDIA device — not an error. + var exitErr *exec.ExitError + if errors.Is(err, exec.ErrNotFound) || errors.As(err, &exitErr) { + return []GpuInfo{}, nil + } + return nil, fmt.Errorf("gpu_info: nvidia-smi: %w", err) + } + + r := csv.NewReader(strings.NewReader(strings.TrimSpace(string(out)))) + r.TrimLeadingSpace = true + + records, err := r.ReadAll() + if err != nil { + return nil, fmt.Errorf("gpu_info: parse csv: %w", err) + } + + gpus := make([]GpuInfo, 0, len(records)) + for _, rec := range records { + if len(rec) < 6 { + continue + } + + idx, _ := strconv.Atoi(strings.TrimSpace(rec[0])) + totalMb, _ := strconv.Atoi(strings.TrimSpace(rec[2])) + freeMb, _ := strconv.Atoi(strings.TrimSpace(rec[3])) + + gpus = append(gpus, GpuInfo{ + Index: idx, + Name: strings.TrimSpace(rec[1]), + VramTotalMb: totalMb, + VramFreeMb: freeMb, + DriverVersion: strings.TrimSpace(rec[4]), + CudaVersion: strings.TrimSpace(rec[5]), + }) + } + + return gpus, nil +} diff --git a/functions/infra/get_gpu_info.md b/functions/infra/get_gpu_info.md new file mode 100644 index 00000000..3114f750 --- /dev/null +++ b/functions/infra/get_gpu_info.md @@ -0,0 +1,70 @@ +--- +name: get_gpu_info +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func GetGpuInfo() ([]GpuInfo, error)" +description: "Consulta GPUs NVIDIA via nvidia-smi y retorna un slice de GpuInfo con index, nombre, VRAM total/libre, driver y version CUDA. Si nvidia-smi no esta instalado o no hay GPU NVIDIA, retorna slice vacio y nil (ausencia de hardware no es error)." +tags: [gpu, nvidia, cuda, hardware, infra, probe] +uses_functions: [] +uses_types: ["gpu_info_go_infra"] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [encoding/csv, errors, fmt, os/exec, strconv, strings] +params: + - name: (ninguno) + desc: "No toma parametros. Lee el estado del sistema via nvidia-smi." +output: "Slice de GpuInfo con una entrada por GPU detectada. Slice vacio si no hay GPUs NVIDIA o nvidia-smi no esta instalado. Error solo si nvidia-smi existe pero falla inesperadamente al parsear la salida CSV." +tested: true +tests: + - "retorna slice vacio y nil cuando no hay GPU NVIDIA" + - "linea GPU RTX 3080 tipica" + - "dos GPUs en el CSV" + - "CSV vacio retorna slice vacio" + - "linea con menos de 6 campos se ignora" + - "espacios extra en los valores se eliminan" + - "campos del struct GpuInfo correctos" +test_file_path: "functions/infra/get_gpu_info_test.go" +file_path: "functions/infra/get_gpu_info.go" +--- + +## Ejemplo + +```go +gpus, err := GetGpuInfo() +if err != nil { + log.Fatal(err) +} +if len(gpus) == 0 { + fmt.Println("No NVIDIA GPUs detected") +} else { + for _, g := range gpus { + fmt.Printf("[%d] %s VRAM: %d/%d MiB Driver: %s CUDA: %s\n", + g.Index, g.Name, g.VramFreeMb, g.VramTotalMb, + g.DriverVersion, g.CudaVersion) + } +} +``` + +## Salida nvidia-smi + +Ejecuta: +``` +nvidia-smi --query-gpu=index,name,memory.total,memory.free,driver_version,cuda_version --format=csv,noheader,nounits +``` + +Ejemplo de salida con una GPU: +``` +0, NVIDIA GeForce RTX 3080, 10240, 8192, 550.54.15, 12.4 +``` + +## Notas + +- Requiere `nvidia-smi` en PATH (parte del driver NVIDIA). +- La columna `cuda_version` en nvidia-smi refleja la version maxima de CUDA soportada por el driver, no la del toolkit instalado. +- Para comprobar el toolkit CUDA instalado, usar `cuda_toolkit_check_bash_infra`. +- En maquinas sin GPU NVIDIA retorna `([]GpuInfo{}, nil)` — el caller puede tratar esto como "sin GPU disponible". +- No ejecutar tests automatizados para esta funcion en CI sin GPU; verificar manualmente o con mock. diff --git a/functions/infra/get_gpu_info_test.go b/functions/infra/get_gpu_info_test.go new file mode 100644 index 00000000..f91c8492 --- /dev/null +++ b/functions/infra/get_gpu_info_test.go @@ -0,0 +1,165 @@ +package infra + +import ( + "strconv" + "strings" + "testing" +) + +// TestGetGpuInfoNoGpu verifica que la funcion retorna slice vacio sin error +// cuando nvidia-smi no esta instalado o no hay GPU NVIDIA presente. +// Este test pasa en cualquier maquina, con o sin GPU. +func TestGetGpuInfoNoGpu(t *testing.T) { + t.Run("retorna slice vacio y nil cuando no hay GPU NVIDIA", func(t *testing.T) { + gpus, err := GetGpuInfo() + if err != nil { + t.Errorf("GetGpuInfo() error inesperado: %v", err) + } + // En maquinas sin nvidia-smi el resultado debe ser un slice vacio (no nil) + if gpus == nil { + t.Error("GetGpuInfo() retorno nil, se esperaba slice vacio []GpuInfo{}") + } + }) +} + +// parseCsvNvidiaSmi replica la logica de parsing de GetGpuInfo para tests unitarios. +// Recibe el output de nvidia-smi --format=csv,noheader,nounits y retorna []GpuInfo. +func parseCsvNvidiaSmi(output string) ([]GpuInfo, error) { + trimmed := strings.TrimSpace(output) + if trimmed == "" { + return []GpuInfo{}, nil + } + lines := strings.Split(trimmed, "\n") + gpus := make([]GpuInfo, 0, len(lines)) + for _, line := range lines { + parts := strings.Split(line, ",") + if len(parts) < 6 { + continue + } + idx, _ := strconv.Atoi(strings.TrimSpace(parts[0])) + totalMb, _ := strconv.Atoi(strings.TrimSpace(parts[2])) + freeMb, _ := strconv.Atoi(strings.TrimSpace(parts[3])) + gpus = append(gpus, GpuInfo{ + Index: idx, + Name: strings.TrimSpace(parts[1]), + VramTotalMb: totalMb, + VramFreeMb: freeMb, + DriverVersion: strings.TrimSpace(parts[4]), + CudaVersion: strings.TrimSpace(parts[5]), + }) + } + return gpus, nil +} + +// TestParseCsvNvidiaSmi verifica el parsing de la salida CSV de nvidia-smi +// sin requerir GPU real ni nvidia-smi instalado. +func TestParseCsvNvidiaSmi(t *testing.T) { + tests := []struct { + name string + csvInput string + wantLen int + wantIndex int + wantName string + wantVramTotal int + wantVramFree int + wantDriver string + wantCuda string + }{ + { + name: "linea GPU RTX 3080 tipica", + csvInput: "0, NVIDIA GeForce RTX 3080, 10240, 8192, 550.54.15, 12.4", + wantLen: 1, + wantIndex: 0, + wantName: "NVIDIA GeForce RTX 3080", + wantVramTotal: 10240, + wantVramFree: 8192, + wantDriver: "550.54.15", + wantCuda: "12.4", + }, + { + name: "dos GPUs en el CSV", + csvInput: "0, GPU A, 8192, 4096, 525.0, 12.0\n1, GPU B, 24576, 20000, 525.0, 12.0", + wantLen: 2, + }, + { + name: "CSV vacio retorna slice vacio", + csvInput: "", + wantLen: 0, + }, + { + name: "linea con menos de 6 campos se ignora", + csvInput: "0, GPU, 8192", + wantLen: 0, + }, + { + name: "espacios extra en los valores se eliminan", + csvInput: " 1 , NVIDIA RTX 4090 , 24576 , 20000 , 545.0 , 12.6 ", + wantLen: 1, + wantIndex: 1, + wantName: "NVIDIA RTX 4090", + wantVramTotal: 24576, + wantVramFree: 20000, + wantDriver: "545.0", + wantCuda: "12.6", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gpus, err := parseCsvNvidiaSmi(tc.csvInput) + if err != nil { + t.Fatalf("error inesperado: %v", err) + } + if len(gpus) != tc.wantLen { + t.Fatalf("len(gpus) = %d, quería %d", len(gpus), tc.wantLen) + } + if tc.wantLen == 1 { + g := gpus[0] + if g.Index != tc.wantIndex { + t.Errorf("Index = %d, quería %d", g.Index, tc.wantIndex) + } + if g.Name != tc.wantName { + t.Errorf("Name = %q, quería %q", g.Name, tc.wantName) + } + if g.VramTotalMb != tc.wantVramTotal { + t.Errorf("VramTotalMb = %d, quería %d", g.VramTotalMb, tc.wantVramTotal) + } + if g.VramFreeMb != tc.wantVramFree { + t.Errorf("VramFreeMb = %d, quería %d", g.VramFreeMb, tc.wantVramFree) + } + if g.DriverVersion != tc.wantDriver { + t.Errorf("DriverVersion = %q, quería %q", g.DriverVersion, tc.wantDriver) + } + if g.CudaVersion != tc.wantCuda { + t.Errorf("CudaVersion = %q, quería %q", g.CudaVersion, tc.wantCuda) + } + } + }) + } +} + +// TestGpuInfoStruct verifica los campos del tipo GpuInfo. +func TestGpuInfoStruct(t *testing.T) { + t.Run("campos del struct GpuInfo correctos", func(t *testing.T) { + g := GpuInfo{ + Index: 0, + Name: "NVIDIA GeForce GTX 1080", + VramTotalMb: 8192, + VramFreeMb: 6144, + DriverVersion: "470.0", + CudaVersion: "11.4", + } + if g.Index != 0 { + t.Errorf("Index = %d", g.Index) + } + if g.Name != "NVIDIA GeForce GTX 1080" { + t.Errorf("Name = %q", g.Name) + } + if g.VramTotalMb != 8192 { + t.Errorf("VramTotalMb = %d", g.VramTotalMb) + } + if g.VramFreeMb != 6144 { + t.Errorf("VramFreeMb = %d", g.VramFreeMb) + } + }) +} diff --git a/functions/infra/gpu_info.go b/functions/infra/gpu_info.go new file mode 100644 index 00000000..6ee6cd64 --- /dev/null +++ b/functions/infra/gpu_info.go @@ -0,0 +1,12 @@ +package infra + +// GpuInfo describe una GPU detectada en el sistema con sus capacidades de VRAM +// y versiones de driver y CUDA. +type GpuInfo struct { + Index int `json:"index"` + Name string `json:"name"` + VramTotalMb int `json:"vram_total_mb"` + VramFreeMb int `json:"vram_free_mb"` + DriverVersion string `json:"driver_version"` + CudaVersion string `json:"cuda_version,omitempty"` +} diff --git a/functions/infra/vault_aggregate_index.go b/functions/infra/vault_aggregate_index.go new file mode 100644 index 00000000..5c500026 --- /dev/null +++ b/functions/infra/vault_aggregate_index.go @@ -0,0 +1,171 @@ +package infra + +import ( + "fmt" + "os" + "path/filepath" + "time" +) + +// AggregateReport summarises the result of a VaultAggregateIndex run. +type AggregateReport struct { + VaultsProcessed int + VaultsSkipped int // vaults without a vault_index.db + TotalFiles int + Errors []string // non-fatal per-vault errors +} + +// VaultAggregateIndex reads all vault manifests from repoRoot, opens each +// vault_index.db and copies all file records into the central registry.db +// vault_files table. The table is created if it does not exist (idempotent). +// +// For each vault the previous rows are deleted and replaced atomically, so +// re-running always produces a clean, non-duplicated state. +// +// Returns an AggregateReport with counts. Per-vault errors are non-fatal +// (logged in report.Errors); only fatal errors (e.g. registry.db +// unreachable) are returned as the error value. +func VaultAggregateIndex(repoRoot string) (AggregateReport, error) { + var report AggregateReport + + // 1. Open registry.db + registryDB, err := SQLiteOpen(filepath.Join(repoRoot, "registry.db"), "") + if err != nil { + return report, fmt.Errorf("vault_aggregate_index: open registry.db: %w", err) + } + defer registryDB.Close() + + // 2. Idempotent schema migration + for _, stmt := range []string{ + `CREATE TABLE IF NOT EXISTS vault_files ( + vault_id TEXT NOT NULL, + vault_name TEXT NOT NULL, + rel_path TEXT NOT NULL, + size INTEGER NOT NULL, + mtime INTEGER NOT NULL, + sha256 TEXT NOT NULL, + mime TEXT NOT NULL DEFAULT '', + ext TEXT NOT NULL DEFAULT '', + bucket TEXT NOT NULL DEFAULT '', + sub_bucket TEXT NOT NULL DEFAULT '', + indexed_at INTEGER NOT NULL, + PRIMARY KEY (vault_id, rel_path) +);`, + `CREATE INDEX IF NOT EXISTS idx_vault_files_sha256 ON vault_files(sha256);`, + `CREATE INDEX IF NOT EXISTS idx_vault_files_vault ON vault_files(vault_id);`, + } { + if _, err := registryDB.Exec(stmt); err != nil { + if !isIdempotentMigrationError(err) { + return report, fmt.Errorf("vault_aggregate_index: schema: %w", err) + } + } + } + + // 3. Read manifest + entries, err := VaultManifestRead(repoRoot) + if err != nil { + return report, fmt.Errorf("vault_aggregate_index: manifest: %w", err) + } + + now := time.Now().UTC().Unix() + + for _, entry := range entries { + vaultID := vaultIDFromEntry(entry) + vaultName := entry.Name + vaultPath := entry.Path + + indexPath := filepath.Join(vaultPath, "vault_index.db") + if _, statErr := os.Stat(indexPath); statErr != nil { + report.VaultsSkipped++ + continue + } + + vaultDB, openErr := VaultIndexOpen(vaultPath) + if openErr != nil { + report.Errors = append(report.Errors, fmt.Sprintf("%s: open index: %v", vaultName, openErr)) + continue + } + + rows, queryErr := vaultDB.Query( + `SELECT rel_path, size, mtime, sha256, mime, ext, bucket, sub_bucket FROM files`, + ) + if queryErr != nil { + vaultDB.Close() + report.Errors = append(report.Errors, fmt.Sprintf("%s: query files: %v", vaultName, queryErr)) + continue + } + + type fileRow struct { + RelPath string + Size int64 + Mtime int64 + Sha256 string + Mime string + Ext string + Bucket string + SubBucket string + } + var fileRows []fileRow + for rows.Next() { + var r fileRow + if scanErr := rows.Scan(&r.RelPath, &r.Size, &r.Mtime, &r.Sha256, &r.Mime, &r.Ext, &r.Bucket, &r.SubBucket); scanErr != nil { + continue + } + fileRows = append(fileRows, r) + } + rows.Close() + vaultDB.Close() + + // Atomic replace in registry.db + tx, txErr := registryDB.Begin() + if txErr != nil { + report.Errors = append(report.Errors, fmt.Sprintf("%s: begin tx: %v", vaultName, txErr)) + continue + } + + if _, delErr := tx.Exec(`DELETE FROM vault_files WHERE vault_id = ?`, vaultID); delErr != nil { + tx.Rollback() + report.Errors = append(report.Errors, fmt.Sprintf("%s: delete: %v", vaultName, delErr)) + continue + } + + stmt, prepErr := tx.Prepare(` +INSERT INTO vault_files + (vault_id, vault_name, rel_path, size, mtime, sha256, mime, ext, bucket, sub_bucket, indexed_at) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + if prepErr != nil { + tx.Rollback() + report.Errors = append(report.Errors, fmt.Sprintf("%s: prepare: %v", vaultName, prepErr)) + continue + } + + for _, r := range fileRows { + if _, insErr := stmt.Exec(vaultID, vaultName, r.RelPath, r.Size, r.Mtime, r.Sha256, r.Mime, r.Ext, r.Bucket, r.SubBucket, now); insErr != nil { + stmt.Close() + tx.Rollback() + report.Errors = append(report.Errors, fmt.Sprintf("%s: insert %s: %v", vaultName, r.RelPath, insErr)) + continue + } + } + stmt.Close() + + if commitErr := tx.Commit(); commitErr != nil { + report.Errors = append(report.Errors, fmt.Sprintf("%s: commit: %v", vaultName, commitErr)) + continue + } + + report.VaultsProcessed++ + report.TotalFiles += len(fileRows) + } + + return report, nil +} + +// vaultIDFromEntry constructs the canonical vault ID used in registry.db. +// Pattern: "_" — consistent with the vaults table. +func vaultIDFromEntry(e VaultManifestEntry) string { + if e.ProjectID == "" { + return e.Name + } + return e.Name + "_" + e.ProjectID +} diff --git a/functions/infra/vault_aggregate_index.md b/functions/infra/vault_aggregate_index.md new file mode 100644 index 00000000..37c1ed67 --- /dev/null +++ b/functions/infra/vault_aggregate_index.md @@ -0,0 +1,58 @@ +--- +name: vault_aggregate_index +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultAggregateIndex(repoRoot string) (AggregateReport, error)" +description: "Agrega los índices de todos los vaults del registry en la tabla vault_files de registry.db. Lee cada vault_index.db (via VaultIndexOpen) y reemplaza las filas de forma atómica. Idempotente: re-ejecutar limpia y reescribe sin duplicar." +tags: [vault, index, aggregate, registry] +uses_functions: + - "vault_manifest_read_go_infra" + - "vault_index_open_go_infra" + - "sqlite_open_go_infra" +uses_types: + - "vault_file_go_infra" +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: + - "database/sql" + - "fmt" + - "os" + - "path/filepath" + - "time" +tested: true +tests: + - "TestVaultAggregateIndex_NoVaults" + - "TestVaultAggregateIndex_VaultWithoutIndex" + - "TestVaultAggregateIndex_HappyPath" + - "TestVaultAggregateIndex_ReRunReplaces" +test_file_path: "functions/infra/vault_aggregate_index_test.go" +file_path: "functions/infra/vault_aggregate_index.go" +params: + - name: repoRoot + desc: "Ruta absoluta a la raiz del fn_registry (contiene registry.db y projects/)." +output: "AggregateReport con VaultsProcessed, VaultsSkipped (sin vault_index.db), TotalFiles y Errors (errores no fatales por vault). Error fatal solo si registry.db no se puede abrir." +--- + +## Ejemplo + +```go +report, err := infra.VaultAggregateIndex("/home/lucas/fn_registry") +if err != nil { + log.Fatal(err) +} +fmt.Printf("Processed: %d vaults, %d files\n", report.VaultsProcessed, report.TotalFiles) +for _, e := range report.Errors { + fmt.Println("warning:", e) +} +``` + +## Notas + +- Requiere que `registry/migrations/012_vault_files.sql` haya sido aplicado (o que el indexer lo aplique al arrancar). La función aplica la migración de forma idempotente ella misma con `CREATE TABLE IF NOT EXISTS`. +- Por cada vault: `DELETE WHERE vault_id = ?` + batch `INSERT` dentro de una transacción. Re-run siempre produce el mismo resultado. +- Vaults sin `vault_index.db` se cuentan en `VaultsSkipped` y se omiten sin error. +- El `vault_id` sigue el patrón `_`, consistente con la tabla `vaults` de registry.db. diff --git a/functions/infra/vault_aggregate_index_test.go b/functions/infra/vault_aggregate_index_test.go new file mode 100644 index 00000000..11ccfeb4 --- /dev/null +++ b/functions/infra/vault_aggregate_index_test.go @@ -0,0 +1,175 @@ +package infra + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +// setupAggregateTestRepo creates a minimal repo layout: +// +// / +// registry.db (SQLite, empty) +// projects//vaults/vault.yaml +// / (optionally with vault_index.db populated) +func setupAggregateTestRepo(t *testing.T, vaultName, projectID, vaultPath string, withIndex bool) string { + t.Helper() + root := t.TempDir() + + // Create registry.db + regDB, err := SQLiteOpen(filepath.Join(root, "registry.db"), "") + if err != nil { + t.Fatalf("create registry.db: %v", err) + } + regDB.Close() + + // Create project vault manifest + projVaultsDir := filepath.Join(root, "projects", projectID, "vaults") + if err := os.MkdirAll(projVaultsDir, 0755); err != nil { + t.Fatalf("mkdir projects: %v", err) + } + manifestYAML := "vaults:\n - name: " + vaultName + "\n description: test\n path: " + vaultPath + "\n tags: []\n" + if err := os.WriteFile(filepath.Join(projVaultsDir, "vault.yaml"), []byte(manifestYAML), 0644); err != nil { + t.Fatalf("write vault.yaml: %v", err) + } + + // Create vault dir + if err := os.MkdirAll(vaultPath, 0755); err != nil { + t.Fatalf("mkdir vault: %v", err) + } + + if withIndex { + // Create a vault_index.db with one file row + vdb, err := VaultIndexOpen(vaultPath) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + now := time.Now().UTC().Unix() + _, err = vdb.Exec(`INSERT INTO files (rel_path, size, mtime, sha256, mime, ext, bucket, sub_bucket, indexed_at) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "data/raw/sample.csv", 1024, now, "deadbeef", "text/csv", ".csv", "data", "raw", now) + if err != nil { + t.Fatalf("insert test file: %v", err) + } + vdb.Close() + } + + return root +} + +func TestVaultAggregateIndex_NoVaults(t *testing.T) { + root := t.TempDir() + // No manifests, just registry.db + regDB, err := SQLiteOpen(filepath.Join(root, "registry.db"), "") + if err != nil { + t.Fatalf("create registry.db: %v", err) + } + regDB.Close() + + report, err := VaultAggregateIndex(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.VaultsProcessed != 0 { + t.Errorf("VaultsProcessed: want 0, got %d", report.VaultsProcessed) + } + if len(report.Errors) != 0 { + t.Errorf("Errors: want empty, got %v", report.Errors) + } +} + +func TestVaultAggregateIndex_VaultWithoutIndex(t *testing.T) { + vaultDir := t.TempDir() + root := setupAggregateTestRepo(t, "my_vault", "my_proj", vaultDir, false /* no vault_index.db */) + + report, err := VaultAggregateIndex(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.VaultsSkipped != 1 { + t.Errorf("VaultsSkipped: want 1, got %d", report.VaultsSkipped) + } + if report.VaultsProcessed != 0 { + t.Errorf("VaultsProcessed: want 0, got %d", report.VaultsProcessed) + } +} + +func TestVaultAggregateIndex_HappyPath(t *testing.T) { + vaultDir := t.TempDir() + root := setupAggregateTestRepo(t, "my_vault", "my_proj", vaultDir, true) + + report, err := VaultAggregateIndex(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if report.VaultsProcessed != 1 { + t.Errorf("VaultsProcessed: want 1, got %d", report.VaultsProcessed) + } + if report.TotalFiles != 1 { + t.Errorf("TotalFiles: want 1, got %d", report.TotalFiles) + } + + // Verify row exists in registry.db + regDB, err := SQLiteOpen(filepath.Join(root, "registry.db"), "") + if err != nil { + t.Fatalf("open registry.db: %v", err) + } + defer regDB.Close() + + var count int + if err := regDB.QueryRow(`SELECT COUNT(*) FROM vault_files`).Scan(&count); err != nil { + t.Fatalf("count vault_files: %v", err) + } + if count != 1 { + t.Errorf("vault_files count: want 1, got %d", count) + } +} + +func TestVaultAggregateIndex_ReRunReplaces(t *testing.T) { + vaultDir := t.TempDir() + root := setupAggregateTestRepo(t, "my_vault", "my_proj", vaultDir, true) + + // First run + if _, err := VaultAggregateIndex(root); err != nil { + t.Fatalf("first run: %v", err) + } + + // Add a second file to vault_index.db + vdb, err := VaultIndexOpen(vaultDir) + if err != nil { + t.Fatalf("reopen vault index: %v", err) + } + now := time.Now().UTC().Unix() + _, err = vdb.Exec(`INSERT INTO files (rel_path, size, mtime, sha256, mime, ext, bucket, sub_bucket, indexed_at) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "data/raw/extra.csv", 512, now, "cafebabe", "text/csv", ".csv", "data", "raw", now) + if err != nil { + t.Fatalf("insert second file: %v", err) + } + vdb.Close() + + // Second run + report, err := VaultAggregateIndex(root) + if err != nil { + t.Fatalf("second run: %v", err) + } + if report.TotalFiles != 2 { + t.Errorf("TotalFiles: want 2, got %d", report.TotalFiles) + } + + // Verify no duplicates — exactly 2 rows + regDB, err := SQLiteOpen(filepath.Join(root, "registry.db"), "") + if err != nil { + t.Fatalf("open registry.db: %v", err) + } + defer regDB.Close() + + var count int + if err := regDB.QueryRow(`SELECT COUNT(*) FROM vault_files`).Scan(&count); err != nil { + t.Fatalf("count vault_files: %v", err) + } + if count != 2 { + t.Errorf("vault_files count after re-run: want 2, got %d", count) + } +} diff --git a/functions/infra/vault_diff.go b/functions/infra/vault_diff.go new file mode 100644 index 00000000..339f3094 --- /dev/null +++ b/functions/infra/vault_diff.go @@ -0,0 +1,68 @@ +package infra + +import "sort" + +// VaultFileChange holds the before/after state of a file whose content changed. +type VaultFileChange struct { + RelPath string + Prev VaultFile + Curr VaultFile +} + +// VaultDiffReport is the result of comparing two VaultFile slices. +type VaultDiffReport struct { + Added []VaultFile // in curr but not in prev (by rel_path) + Removed []VaultFile // in prev but not in curr + Changed []VaultFileChange // same rel_path, different sha256 + Unchanged int // files present in both with identical sha256 +} + +// VaultDiff computes the difference between two vault snapshots. +// It indexes both slices by RelPath, then classifies each entry as +// Added, Removed, Changed, or Unchanged. All output slices are sorted +// by RelPath ascending. The function is pure and deterministic. +func VaultDiff(prev, curr []VaultFile) VaultDiffReport { + prevMap := make(map[string]VaultFile, len(prev)) + for _, f := range prev { + prevMap[f.RelPath] = f + } + currMap := make(map[string]VaultFile, len(curr)) + for _, f := range curr { + currMap[f.RelPath] = f + } + + var report VaultDiffReport + + for _, f := range curr { + p, exists := prevMap[f.RelPath] + if !exists { + report.Added = append(report.Added, f) + } else if p.Sha256 != f.Sha256 { + report.Changed = append(report.Changed, VaultFileChange{ + RelPath: f.RelPath, + Prev: p, + Curr: f, + }) + } else { + report.Unchanged++ + } + } + + for _, f := range prev { + if _, exists := currMap[f.RelPath]; !exists { + report.Removed = append(report.Removed, f) + } + } + + sort.Slice(report.Added, func(i, j int) bool { + return report.Added[i].RelPath < report.Added[j].RelPath + }) + sort.Slice(report.Removed, func(i, j int) bool { + return report.Removed[i].RelPath < report.Removed[j].RelPath + }) + sort.Slice(report.Changed, func(i, j int) bool { + return report.Changed[i].RelPath < report.Changed[j].RelPath + }) + + return report +} diff --git a/functions/infra/vault_diff.md b/functions/infra/vault_diff.md new file mode 100644 index 00000000..cf3a08f6 --- /dev/null +++ b/functions/infra/vault_diff.md @@ -0,0 +1,49 @@ +--- +name: vault_diff +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: pure +signature: "func VaultDiff(prev, curr []VaultFile) VaultDiffReport" +description: "Computes the diff between two vault snapshots (slices of VaultFile). Returns Added, Removed, Changed and Unchanged counts. Pure and deterministic — no I/O." +tags: [vault, diff, comparison, pure] +uses_functions: [] +uses_types: ["vault_file_go_infra"] +returns: [] +returns_optional: false +error_type: "" +imports: ["sort"] +tested: true +tests: + - "TestVaultDiff_NoChanges" + - "TestVaultDiff_AllAdded" + - "TestVaultDiff_AllRemoved" + - "TestVaultDiff_ContentChanged" + - "TestVaultDiff_Mixed" +test_file_path: "functions/infra/vault_diff_test.go" +file_path: "functions/infra/vault_diff.go" +params: + - name: prev + desc: "Snapshot anterior — slice de VaultFile del estado previo del vault (puede ser nil para diff desde cero)." + - name: curr + desc: "Snapshot actual — slice de VaultFile del estado corriente del vault (puede ser nil para diff de borrado total)." +output: "VaultDiffReport con Added (nuevos), Removed (eliminados), Changed (mismo rel_path, sha256 distinto) y Unchanged (identicos). Todos los slices ordenados por RelPath ASC." +--- + +## Ejemplo + +```go +prev, _ := infra.VaultInventoryScan(oldPath, "my_vault_proj", "my_vault") +curr, _ := infra.VaultInventoryScan(newPath, "my_vault_proj", "my_vault") +report := infra.VaultDiff(prev, curr) +fmt.Printf("Added: %d, Removed: %d, Changed: %d, Unchanged: %d\n", + len(report.Added), len(report.Removed), len(report.Changed), report.Unchanged) +``` + +## Notas + +- Usa `RelPath` como clave de identidad de archivo (no nombre, no sha256). +- Dos archivos con mismo `RelPath` pero diferente `Sha256` se consideran Changed. +- Los slices del report se ordenan por `RelPath` ASC para salida deterministica. +- Función pura: no toca disco ni BD. diff --git a/functions/infra/vault_diff_test.go b/functions/infra/vault_diff_test.go new file mode 100644 index 00000000..13a4e8ae --- /dev/null +++ b/functions/infra/vault_diff_test.go @@ -0,0 +1,126 @@ +package infra + +import ( + "testing" +) + +func makeVF(relPath, sha256 string) VaultFile { + return VaultFile{ + VaultID: "test_vault", + VaultName: "test", + RelPath: relPath, + Sha256: sha256, + } +} + +func TestVaultDiff_NoChanges(t *testing.T) { + files := []VaultFile{ + makeVF("data/a.csv", "aaa"), + makeVF("data/b.csv", "bbb"), + } + report := VaultDiff(files, files) + if len(report.Added) != 0 { + t.Errorf("Added: want 0, got %d", len(report.Added)) + } + if len(report.Removed) != 0 { + t.Errorf("Removed: want 0, got %d", len(report.Removed)) + } + if len(report.Changed) != 0 { + t.Errorf("Changed: want 0, got %d", len(report.Changed)) + } + if report.Unchanged != 2 { + t.Errorf("Unchanged: want 2, got %d", report.Unchanged) + } +} + +func TestVaultDiff_AllAdded(t *testing.T) { + curr := []VaultFile{ + makeVF("data/a.csv", "aaa"), + makeVF("data/b.csv", "bbb"), + } + report := VaultDiff(nil, curr) + if len(report.Added) != 2 { + t.Errorf("Added: want 2, got %d", len(report.Added)) + } + if len(report.Removed) != 0 { + t.Errorf("Removed: want 0, got %d", len(report.Removed)) + } + if report.Added[0].RelPath != "data/a.csv" { + t.Errorf("Added[0]: want data/a.csv, got %s", report.Added[0].RelPath) + } + if report.Added[1].RelPath != "data/b.csv" { + t.Errorf("Added[1]: want data/b.csv, got %s", report.Added[1].RelPath) + } +} + +func TestVaultDiff_AllRemoved(t *testing.T) { + prev := []VaultFile{ + makeVF("data/a.csv", "aaa"), + makeVF("data/b.csv", "bbb"), + } + report := VaultDiff(prev, nil) + if len(report.Removed) != 2 { + t.Errorf("Removed: want 2, got %d", len(report.Removed)) + } + if len(report.Added) != 0 { + t.Errorf("Added: want 0, got %d", len(report.Added)) + } + if report.Removed[0].RelPath != "data/a.csv" { + t.Errorf("Removed[0]: want data/a.csv, got %s", report.Removed[0].RelPath) + } +} + +func TestVaultDiff_ContentChanged(t *testing.T) { + prev := []VaultFile{ + makeVF("data/a.csv", "old_hash"), + } + curr := []VaultFile{ + makeVF("data/a.csv", "new_hash"), + } + report := VaultDiff(prev, curr) + if len(report.Changed) != 1 { + t.Fatalf("Changed: want 1, got %d", len(report.Changed)) + } + if report.Changed[0].RelPath != "data/a.csv" { + t.Errorf("Changed[0].RelPath: want data/a.csv, got %s", report.Changed[0].RelPath) + } + if report.Changed[0].Prev.Sha256 != "old_hash" { + t.Errorf("Changed[0].Prev.Sha256: want old_hash, got %s", report.Changed[0].Prev.Sha256) + } + if report.Changed[0].Curr.Sha256 != "new_hash" { + t.Errorf("Changed[0].Curr.Sha256: want new_hash, got %s", report.Changed[0].Curr.Sha256) + } + if len(report.Added) != 0 || len(report.Removed) != 0 { + t.Errorf("Expected no added/removed, got %d/%d", len(report.Added), len(report.Removed)) + } + if report.Unchanged != 0 { + t.Errorf("Unchanged: want 0, got %d", report.Unchanged) + } +} + +func TestVaultDiff_Mixed(t *testing.T) { + prev := []VaultFile{ + makeVF("data/a.csv", "aaa"), + makeVF("data/b.csv", "bbb"), + makeVF("data/c.csv", "ccc"), + } + curr := []VaultFile{ + makeVF("data/a.csv", "aaa"), // unchanged + makeVF("data/b.csv", "bbb_new"), // changed + makeVF("data/d.csv", "ddd"), // added + } + report := VaultDiff(prev, curr) + + if len(report.Added) != 1 || report.Added[0].RelPath != "data/d.csv" { + t.Errorf("Added: want [data/d.csv], got %v", report.Added) + } + if len(report.Removed) != 1 || report.Removed[0].RelPath != "data/c.csv" { + t.Errorf("Removed: want [data/c.csv], got %v", report.Removed) + } + if len(report.Changed) != 1 || report.Changed[0].RelPath != "data/b.csv" { + t.Errorf("Changed: want [data/b.csv], got %v", report.Changed) + } + if report.Unchanged != 1 { + t.Errorf("Unchanged: want 1, got %d", report.Unchanged) + } +} diff --git a/functions/infra/vault_doctor.go b/functions/infra/vault_doctor.go new file mode 100644 index 00000000..c62c42d8 --- /dev/null +++ b/functions/infra/vault_doctor.go @@ -0,0 +1,230 @@ +package infra + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" +) + +// VaultDoctorEntry holds the health report for a single vault. +type VaultDoctorEntry struct { + VaultName string `json:"vault_name"` + VaultPath string `json:"vault_path"` + ProjectID string `json:"project_id"` + Issues []string `json:"issues"` // human-readable issues; empty = healthy + IndexedFiles int `json:"indexed_files"` // 0 if no vault_index.db + LastIndexedAt int64 `json:"last_indexed_at"` // unix seconds; 0 if N/A + DiskFiles int `json:"disk_files"` // count via WalkDir (no hashing) + Status string `json:"status"` // "ok" | "warning" | "error" +} + +// VaultDoctor audits every vault declared in projects/*/vaults/vault.yaml under +// repoRoot. For each vault it performs a series of checks (disk presence, layout, +// index existence, staleness, drift) and returns a slice of VaultDoctorEntry. +// +// The function is read-only: it never writes to disk or any database. +// Returns an error only if VaultManifestRead fails (manifest parse error). +func VaultDoctor(repoRoot string) ([]VaultDoctorEntry, error) { + entries, err := VaultManifestRead(repoRoot) + if err != nil { + return nil, fmt.Errorf("vault_doctor: read manifests: %w", err) + } + + results := make([]VaultDoctorEntry, 0, len(entries)) + for _, e := range entries { + result := auditVault(e) + results = append(results, result) + } + return results, nil +} + +func auditVault(e VaultManifestEntry) VaultDoctorEntry { + entry := VaultDoctorEntry{ + VaultName: e.Name, + VaultPath: e.Path, + ProjectID: e.ProjectID, + } + + // Resolve symlinks for disk checks + realPath, err := filepath.EvalSymlinks(e.Path) + if err != nil || realPath == "" { + realPath = e.Path + } + + // CHECK 1: directory_missing + info, statErr := os.Stat(realPath) + if statErr != nil || !info.IsDir() { + entry.Issues = append(entry.Issues, "directory_missing") + entry.Status = "error" + return entry + } + + // COUNT disk files (cheap walk — no hashing, no mime detection) + diskCount := countDiskFiles(realPath) + entry.DiskFiles = diskCount + + // CHECK 2: layout_missing / non_standard_layout + hasData := dirExists(filepath.Join(realPath, "data")) + hasKnowledge := dirExists(filepath.Join(realPath, "knowledge")) + if !hasData && !hasKnowledge { + // Check if it looks like a non-standard but intentional layout + if hasNonStandardLayout(realPath) { + entry.Issues = append(entry.Issues, "non_standard_layout") + } else { + entry.Issues = append(entry.Issues, "layout_missing") + } + } + + // CHECK 3: index_missing + indexPath := filepath.Join(realPath, "vault_index.db") + _, indexStatErr := os.Stat(indexPath) + if indexStatErr != nil { + entry.Issues = append(entry.Issues, "index_missing") + entry.setWarningStatus() + entry.setFinalStatus() + return entry + } + + // Open vault index (read-only) for checks 4 and 5 + vdb, openErr := VaultIndexOpen(realPath) + if openErr != nil { + entry.Issues = append(entry.Issues, fmt.Sprintf("index_open_error: %v", openErr)) + entry.setWarningStatus() + return entry + } + defer vdb.Close() + + // Query indexed file count and max indexed_at + var indexedCount int + var maxIndexedAt int64 + row := vdb.QueryRow(`SELECT COUNT(*), COALESCE(MAX(indexed_at), 0) FROM files`) + if scanErr := row.Scan(&indexedCount, &maxIndexedAt); scanErr != nil { + entry.Issues = append(entry.Issues, fmt.Sprintf("index_query_error: %v", scanErr)) + } else { + entry.IndexedFiles = indexedCount + entry.LastIndexedAt = maxIndexedAt + } + + // CHECK 4: index_stale — any file on disk newer than MAX(indexed_at) + if maxIndexedAt > 0 { + maxTime := time.Unix(maxIndexedAt, 0) + if isIndexStale(realPath, maxTime) { + entry.Issues = append(entry.Issues, "index_stale") + } + } + + // CHECK 5: index_drift — disk file count != indexed count + if indexedCount != diskCount { + entry.Issues = append(entry.Issues, fmt.Sprintf("index_drift: disk=%d indexed=%d", diskCount, indexedCount)) + } + + // CHECK 6: empty_vault + if diskCount == 0 { + entry.Issues = append(entry.Issues, "empty_vault") + } + + entry.setFinalStatus() + return entry +} + +// setWarningStatus sets status to warning if not already error. +func (e *VaultDoctorEntry) setWarningStatus() { + if e.Status != "error" { + e.Status = "warning" + } +} + +// setFinalStatus derives the final Status from Issues. +func (e *VaultDoctorEntry) setFinalStatus() { + if e.Status == "error" { + return + } + if len(e.Issues) == 0 { + e.Status = "ok" + } else { + e.Status = "warning" + } +} + +// countDiskFiles walks realPath and counts regular files, excluding: +// vault_index.db*, .git/, hidden files/dirs at any depth. +func countDiskFiles(realPath string) int { + count := 0 + _ = filepath.WalkDir(realPath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + name := d.Name() + // Skip hidden entries + if strings.HasPrefix(name, ".") { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + // Skip .git + if d.IsDir() && name == ".git" { + return filepath.SkipDir + } + // Skip vault_index.db files + if !d.IsDir() && (name == "vault_index.db" || name == "vault_index.db-shm" || name == "vault_index.db-wal") { + return nil + } + if !d.IsDir() { + count++ + } + return nil + }) + return count +} + +// isIndexStale returns true if any regular file under realPath has an mtime +// strictly after maxTime (excluding vault_index.db* and hidden files). +func isIndexStale(realPath string, maxTime time.Time) bool { + stale := false + _ = filepath.WalkDir(realPath, func(path string, d os.DirEntry, err error) error { + if err != nil || stale { + return nil + } + name := d.Name() + if strings.HasPrefix(name, ".") { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + if d.IsDir() && name == ".git" { + return filepath.SkipDir + } + if !d.IsDir() { + if name == "vault_index.db" || name == "vault_index.db-shm" || name == "vault_index.db-wal" { + return nil + } + fi, statErr := d.Info() + if statErr == nil && fi.ModTime().After(maxTime) { + stale = true + } + } + return nil + }) + return stale +} + +// hasNonStandardLayout returns true when a vault directory contains +// subdirectories that are clearly intentional but not data/knowledge. +// Heuristic: any subdir at the vault root that is not data/knowledge. +func hasNonStandardLayout(realPath string) bool { + entries, err := os.ReadDir(realPath) + if err != nil { + return false + } + standardDirs := map[string]bool{"data": true, "knowledge": true, ".git": true} + for _, e := range entries { + if e.IsDir() && !standardDirs[e.Name()] && !strings.HasPrefix(e.Name(), ".") { + return true + } + } + return false +} diff --git a/functions/infra/vault_doctor.md b/functions/infra/vault_doctor.md new file mode 100644 index 00000000..71df58f2 --- /dev/null +++ b/functions/infra/vault_doctor.md @@ -0,0 +1,66 @@ +--- +name: vault_doctor +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultDoctor(repoRoot string) ([]VaultDoctorEntry, error)" +description: "Audita la salud de todos los vaults declarados en projects/*/vaults/vault.yaml. Comprueba existencia del directorio, layout estándar, presencia del índice, staleness y drift entre disco e índice. Read-only." +tags: [vault, doctor, health, audit] +uses_functions: + - "vault_manifest_read_go_infra" + - "vault_index_open_go_infra" +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: + - "fmt" + - "os" + - "path/filepath" + - "strings" + - "time" +tested: true +tests: + - "TestVaultDoctor_OK" + - "TestVaultDoctor_MissingDir" + - "TestVaultDoctor_NoIndex" + - "TestVaultDoctor_LayoutDrift" + - "TestVaultDoctor_EmptyVault" +test_file_path: "functions/infra/vault_doctor_test.go" +file_path: "functions/infra/vault_doctor.go" +params: + - name: repoRoot + desc: "Ruta absoluta a la raiz del fn_registry (donde están projects/ y registry.db)." +output: "Slice de VaultDoctorEntry con Status (ok/warning/error), Issues, DiskFiles, IndexedFiles y LastIndexedAt por vault. Error fatal solo si los manifests no se pueden leer." +--- + +## Checks aplicados + +| Check | Condición | Severidad | +|---|---|---| +| `directory_missing` | `e.Path` no existe en disco | error | +| `layout_missing` | no hay `data/` ni `knowledge/` en la raíz del vault | warning | +| `non_standard_layout` | no hay `data/`/`knowledge/` pero sí otros subdirectorios (ej. imagegen_models) | warning | +| `index_missing` | no existe `vault_index.db` | warning | +| `index_stale` | algún archivo en disco tiene mtime > MAX(indexed_at) | warning | +| `index_drift` | count disco != count en tabla `files` | warning | +| `empty_vault` | DiskFiles == 0 | warning | + +## Ejemplo + +```go +entries, err := infra.VaultDoctor("/home/lucas/fn_registry") +for _, e := range entries { + fmt.Printf("%-30s %-8s files=%d issues=%v\n", + e.VaultName, e.Status, e.DiskFiles, e.Issues) +} +``` + +## Notas + +- Función read-only: nunca escribe en disco ni en ninguna base de datos. +- `countDiskFiles` usa `filepath.WalkDir` sin hash (cheap) — excluye `vault_index.db*`, `.git/` y ficheros ocultos. +- `isIndexStale` también usa WalkDir; compara mtime de archivos con MAX(indexed_at) de la BD. +- El VaultIndexOpen de sólo lectura no crea el DB (si no existe, retorna error y se reporta `index_missing`). diff --git a/functions/infra/vault_doctor_test.go b/functions/infra/vault_doctor_test.go new file mode 100644 index 00000000..46ee5915 --- /dev/null +++ b/functions/infra/vault_doctor_test.go @@ -0,0 +1,211 @@ +package infra + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +// setupDoctorRepo creates a repo layout with one vault in a project manifest. +// vaultPath must be an absolute path that already exists (or not, for missing tests). +func setupDoctorRepo(t *testing.T, vaultName, projectID, vaultPath string) string { + t.Helper() + root := t.TempDir() + projVaultsDir := filepath.Join(root, "projects", projectID, "vaults") + if err := os.MkdirAll(projVaultsDir, 0755); err != nil { + t.Fatalf("mkdir projects: %v", err) + } + manifest := "vaults:\n - name: " + vaultName + "\n description: test vault\n path: " + vaultPath + "\n tags: []\n" + if err := os.WriteFile(filepath.Join(projVaultsDir, "vault.yaml"), []byte(manifest), 0644); err != nil { + t.Fatalf("write vault.yaml: %v", err) + } + return root +} + +func TestVaultDoctor_OK(t *testing.T) { + vaultDir := t.TempDir() + + // Proper layout + if err := os.MkdirAll(filepath.Join(vaultDir, "data", "raw"), 0755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(vaultDir, "knowledge"), 0755); err != nil { + t.Fatal(err) + } + + // Create a file with a past mtime so the index is not stale + samplePath := filepath.Join(vaultDir, "data", "raw", "sample.csv") + if err := os.WriteFile(samplePath, []byte("a,b\n1,2\n"), 0644); err != nil { + t.Fatal(err) + } + pastTime := time.Now().Add(-1 * time.Hour) + if err := os.Chtimes(samplePath, pastTime, pastTime); err != nil { + t.Fatal(err) + } + + // Create vault_index.db with the file indexed after its mtime + vdb, err := VaultIndexOpen(vaultDir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + futureIndexed := time.Now().Unix() // indexed_at is now — after file mtime + _, err = vdb.Exec(`INSERT INTO files (rel_path, size, mtime, sha256, mime, ext, bucket, sub_bucket, indexed_at) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "data/raw/sample.csv", 8, pastTime.Unix(), "deadbeef", "text/csv", ".csv", "data", "raw", futureIndexed) + if err != nil { + t.Fatalf("insert: %v", err) + } + vdb.Close() + + root := setupDoctorRepo(t, "my_vault", "my_proj", vaultDir) + entries, err := VaultDoctor(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + e := entries[0] + if e.Status != "ok" { + t.Errorf("Status: want ok, got %s (issues: %v)", e.Status, e.Issues) + } + if len(e.Issues) != 0 { + t.Errorf("Issues: want empty, got %v", e.Issues) + } + if e.DiskFiles != 1 { + t.Errorf("DiskFiles: want 1, got %d", e.DiskFiles) + } + if e.IndexedFiles != 1 { + t.Errorf("IndexedFiles: want 1, got %d", e.IndexedFiles) + } +} + +func TestVaultDoctor_MissingDir(t *testing.T) { + missingPath := filepath.Join(t.TempDir(), "does_not_exist") + root := setupDoctorRepo(t, "missing_vault", "my_proj", missingPath) + + entries, err := VaultDoctor(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + e := entries[0] + if e.Status != "error" { + t.Errorf("Status: want error, got %s", e.Status) + } + found := false + for _, issue := range e.Issues { + if issue == "directory_missing" { + found = true + } + } + if !found { + t.Errorf("Expected directory_missing issue, got %v", e.Issues) + } +} + +func TestVaultDoctor_NoIndex(t *testing.T) { + vaultDir := t.TempDir() + // Proper layout but no vault_index.db + if err := os.MkdirAll(filepath.Join(vaultDir, "data", "raw"), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(vaultDir, "data", "raw", "a.csv"), []byte("x"), 0644); err != nil { + t.Fatal(err) + } + + root := setupDoctorRepo(t, "no_index_vault", "my_proj", vaultDir) + entries, err := VaultDoctor(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + e := entries[0] + if e.Status != "warning" { + t.Errorf("Status: want warning, got %s", e.Status) + } + found := false + for _, issue := range e.Issues { + if issue == "index_missing" { + found = true + } + } + if !found { + t.Errorf("Expected index_missing issue, got %v", e.Issues) + } +} + +func TestVaultDoctor_LayoutDrift(t *testing.T) { + vaultDir := t.TempDir() + // No data/ or knowledge/ — just a random file at root + if err := os.WriteFile(filepath.Join(vaultDir, "something.txt"), []byte("hi"), 0644); err != nil { + t.Fatal(err) + } + + root := setupDoctorRepo(t, "layout_vault", "my_proj", vaultDir) + entries, err := VaultDoctor(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + e := entries[0] + if e.Status != "warning" { + t.Errorf("Status: want warning, got %s", e.Status) + } + foundLayout := false + for _, issue := range e.Issues { + if issue == "layout_missing" || issue == "non_standard_layout" { + foundLayout = true + } + } + if !foundLayout { + t.Errorf("Expected layout_missing or non_standard_layout, got %v", e.Issues) + } +} + +func TestVaultDoctor_EmptyVault(t *testing.T) { + vaultDir := t.TempDir() + // data/ and knowledge/ exist but are empty + if err := os.MkdirAll(filepath.Join(vaultDir, "data"), 0755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(vaultDir, "knowledge"), 0755); err != nil { + t.Fatal(err) + } + + // Create vault_index.db (empty) + vdb, err := VaultIndexOpen(vaultDir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + vdb.Close() + + root := setupDoctorRepo(t, "empty_vault", "my_proj", vaultDir) + entries, err := VaultDoctor(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + e := entries[0] + if e.Status != "warning" { + t.Errorf("Status: want warning, got %s (issues: %v)", e.Status, e.Issues) + } + found := false + for _, issue := range e.Issues { + if issue == "empty_vault" { + found = true + } + } + if !found { + t.Errorf("Expected empty_vault issue, got %v", e.Issues) + } +} diff --git a/functions/infra/vault_file.go b/functions/infra/vault_file.go new file mode 100644 index 00000000..cee58687 --- /dev/null +++ b/functions/infra/vault_file.go @@ -0,0 +1,21 @@ +package infra + +// VaultFile describes a single file inside a vault directory. +// It carries identity (vault + relative path), content metadata (size, mtime, sha256, mime) +// and structural classification (bucket, sub-bucket). +type VaultFile struct { + VaultID string `json:"vault_id"` // e.g. "turismo_spain_app_turismo" + VaultName string `json:"vault_name"` // e.g. "turismo_spain" + RelPath string `json:"rel_path"` // path relative to vault root, e.g. "data/raw/foo.csv" + Size int64 `json:"size"` // bytes + Mtime int64 `json:"mtime"` // unix seconds (UTC) + Sha256 string `json:"sha256"` // hex lowercase + Mime string `json:"mime"` // e.g. "text/csv" + Ext string `json:"ext"` // e.g. ".csv" + // Bucket is the top-level classification: "data" or "knowledge". + Bucket string `json:"bucket"` + // SubBucket is the second-level directory within the bucket. + // Known values: raw, processed, exports (data); decisions, domains, models, + // benchmarks, test_documents (knowledge). Empty string for files at bucket root. + SubBucket string `json:"sub_bucket"` +} diff --git a/functions/infra/vault_index_migrations/001_init.sql b/functions/infra/vault_index_migrations/001_init.sql new file mode 100644 index 00000000..d1dfbe20 --- /dev/null +++ b/functions/infra/vault_index_migrations/001_init.sql @@ -0,0 +1,49 @@ +CREATE TABLE IF NOT EXISTS files ( + rel_path TEXT PRIMARY KEY, + size INTEGER NOT NULL, + mtime INTEGER NOT NULL, + sha256 TEXT NOT NULL, + mime TEXT NOT NULL DEFAULT '', + ext TEXT NOT NULL DEFAULT '', + bucket TEXT NOT NULL DEFAULT '', + sub_bucket TEXT NOT NULL DEFAULT '', + indexed_at INTEGER NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_files_sha256 ON files(sha256); +CREATE INDEX IF NOT EXISTS idx_files_bucket ON files(bucket, sub_bucket); + +CREATE VIRTUAL TABLE IF NOT EXISTS files_fts USING fts5( + rel_path, + content_text, + content='', + tokenize='unicode61 remove_diacritics 2' +); + +CREATE TABLE IF NOT EXISTS csv_profiles ( + rel_path TEXT PRIMARY KEY, + cols_json TEXT NOT NULL, + n_rows INTEGER NOT NULL, + encoding TEXT NOT NULL DEFAULT '', + date_min TEXT, + date_max TEXT, + profiled_at INTEGER NOT NULL, + FOREIGN KEY (rel_path) REFERENCES files(rel_path) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS pdf_extracts ( + rel_path TEXT PRIMARY KEY, + page_count INTEGER NOT NULL, + text_len INTEGER NOT NULL, + extracted_to TEXT, + extracted_at INTEGER NOT NULL, + FOREIGN KEY (rel_path) REFERENCES files(rel_path) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS knowledge_docs ( + rel_path TEXT PRIMARY KEY, + title TEXT NOT NULL DEFAULT '', + frontmatter_json TEXT NOT NULL DEFAULT '{}', + headings_json TEXT NOT NULL DEFAULT '[]', + parsed_at INTEGER NOT NULL, + FOREIGN KEY (rel_path) REFERENCES files(rel_path) ON DELETE CASCADE +); diff --git a/functions/infra/vault_index_open.go b/functions/infra/vault_index_open.go new file mode 100644 index 00000000..ad520a02 --- /dev/null +++ b/functions/infra/vault_index_open.go @@ -0,0 +1,30 @@ +package infra + +import ( + "database/sql" + "embed" + "fmt" + "path/filepath" +) + +//go:embed vault_index_migrations/*.sql +var vaultIndexMigrationsFS embed.FS + +// VaultIndexOpen opens (or creates) the vault_index.db inside vaultPath. +// It applies all embedded migrations idempotently and returns a ready-to-use +// *sql.DB. The caller is responsible for closing the connection. +// +// The database is opened with WAL mode and foreign keys enabled via SQLiteOpen. +// Migrations are applied from vault_index_migrations/*.sql in lexicographic order. +func VaultIndexOpen(vaultPath string) (*sql.DB, error) { + dbPath := filepath.Join(vaultPath, "vault_index.db") + db, err := SQLiteOpen(dbPath, "") + if err != nil { + return nil, fmt.Errorf("vault_index_open: %w", err) + } + if err := ApplyMigrations(db, vaultIndexMigrationsFS, "vault_index_migrations/*.sql"); err != nil { + db.Close() + return nil, fmt.Errorf("vault_index_open: apply migrations: %w", err) + } + return db, nil +} diff --git a/functions/infra/vault_index_open.md b/functions/infra/vault_index_open.md new file mode 100644 index 00000000..5ab99d7c --- /dev/null +++ b/functions/infra/vault_index_open.md @@ -0,0 +1,54 @@ +--- +name: vault_index_open +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultIndexOpen(vaultPath string) (*sql.DB, error)" +description: "Abre (o crea) vault_index.db dentro de vaultPath con WAL + FK y aplica las migraciones embebidas idempotentemente. El caller cierra la conexion." +tags: [vault, sqlite, index, migration, infra] +uses_functions: ["sqlite_open_go_infra", "sqlite_apply_migrations_go_infra"] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [database/sql, embed, fmt, path/filepath] +params: + - name: vaultPath + desc: "ruta absoluta o relativa al directorio raiz del vault" +output: "*sql.DB apuntando a /vault_index.db con schema completo aplicado; el caller es responsable de cerrar" +tested: true +tests: + - "crea vault_index.db en tmpdir vacio" + - "segunda apertura no falla (idempotente)" + - "todas las tablas esperadas existen en sqlite_master" + - "fts5 INSERT y MATCH funcionan" +test_file_path: "functions/infra/vault_index_open_test.go" +file_path: "functions/infra/vault_index_open.go" +--- + +## Ejemplo + +```go +db, err := VaultIndexOpen("/data/vaults/turismo_spain") +if err != nil { + log.Fatal(err) +} +defer db.Close() +``` + +## Notas + +El archivo de base de datos se crea en `/vault_index.db`. Las migraciones +viven en `vault_index_migrations/*.sql` embebidas via `//go:embed` en el mismo paquete. + +Schema creado por `001_init.sql`: +- `files` — inventario de archivos (PK: rel_path) +- `files_fts` — tabla FTS5 virtual para busqueda de texto (content_text lo llenan profilers posteriores) +- `csv_profiles` — perfil de columnas/filas para .csv (FK → files) +- `pdf_extracts` — metadatos de extraccion de texto para .pdf (FK → files) +- `knowledge_docs` — headings/frontmatter para .md del bucket knowledge (FK → files) + +`SQLiteOpen` abre con WAL mode + foreign keys. `ApplyMigrations` es idempotente: +los errores de "already exists" y "duplicate column" se ignoran silenciosamente. diff --git a/functions/infra/vault_index_open_test.go b/functions/infra/vault_index_open_test.go new file mode 100644 index 00000000..566097f9 --- /dev/null +++ b/functions/infra/vault_index_open_test.go @@ -0,0 +1,107 @@ +package infra + +import ( + "database/sql" + "os" + "path/filepath" + "testing" +) + +func TestVaultIndexOpen_CreatesDB(t *testing.T) { + t.Run("crea vault_index.db en tmpdir vacio", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + defer db.Close() + + dbPath := filepath.Join(dir, "vault_index.db") + if _, err := os.Stat(dbPath); os.IsNotExist(err) { + t.Fatalf("vault_index.db no fue creado en %s", dir) + } + }) +} + +func TestVaultIndexOpen_Idempotent(t *testing.T) { + t.Run("segunda apertura no falla (idempotente)", func(t *testing.T) { + dir := t.TempDir() + + db1, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("primera apertura: %v", err) + } + db1.Close() + + db2, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("segunda apertura: %v", err) + } + db2.Close() + }) +} + +func TestVaultIndexOpen_AppliesAllMigrations(t *testing.T) { + t.Run("todas las tablas esperadas existen en sqlite_master", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + defer db.Close() + + expectedTables := []string{ + "files", + "csv_profiles", + "pdf_extracts", + "knowledge_docs", + } + for _, tbl := range expectedTables { + assertTableExists(t, db, tbl) + } + }) +} + +func TestVaultIndexOpen_FTS5Works(t *testing.T) { + t.Run("fts5 INSERT y MATCH funcionan", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + defer db.Close() + + // Insert a row into files_fts (content='' table, manual INSERT required) + _, err = db.Exec(`INSERT INTO files_fts(rel_path, content_text) VALUES (?, ?)`, + "data/raw/informe_ventas.csv", "ventas trimestrales empresa") + if err != nil { + t.Fatalf("INSERT files_fts: %v", err) + } + + var count int + err = db.QueryRow( + `SELECT count(*) FROM files_fts WHERE files_fts MATCH 'ventas'`, + ).Scan(&count) + if err != nil { + t.Fatalf("FTS MATCH query: %v", err) + } + if count != 1 { + t.Errorf("FTS MATCH: got %d rows, want 1", count) + } + }) +} + +// assertTableExists verifies that a table (or virtual table) exists in sqlite_master. +func assertTableExists(t *testing.T, db *sql.DB, name string) { + t.Helper() + var exists int + err := db.QueryRow( + `SELECT count(*) FROM sqlite_master WHERE name = ?`, name, + ).Scan(&exists) + if err != nil { + t.Fatalf("sqlite_master query for %q: %v", name, err) + } + if exists == 0 { + t.Errorf("table/vtable %q not found in sqlite_master", name) + } +} diff --git a/functions/infra/vault_index_write.go b/functions/infra/vault_index_write.go new file mode 100644 index 00000000..2913997b --- /dev/null +++ b/functions/infra/vault_index_write.go @@ -0,0 +1,154 @@ +package infra + +import ( + "database/sql" + "fmt" + "strings" + "time" +) + +// WriteReport summarises the outcome of a VaultIndexWrite call. +type WriteReport struct { + Inserted int // rows newly inserted into files + Updated int // rows updated (upserted) in files + Pruned int // rows deleted from files (only when prune=true) + FTS int // rows inserted into files_fts +} + +// VaultIndexWrite upserts a slice of VaultFile into the vault_index.db opened +// as db, updates the files_fts FTS5 table, and optionally prunes stale rows. +// +// All changes run inside a single transaction. +// +// Counting strategy: the set of rel_paths already in the DB is read before the +// loop. An upsert is counted as Inserted if the rel_path was absent, Updated if +// it was present. This avoids N+1 queries while remaining correct. +// +// FTS5: all affected rows are deleted and re-inserted with rel_path and empty +// content_text. Downstream profilers (csv_profiles, pdf_extracts, knowledge_docs) +// are responsible for populating content_text with meaningful text. +// +// Prune: if prune=true, every row in files whose rel_path is NOT in the provided +// slice is deleted. Cascades to csv_profiles, pdf_extracts, knowledge_docs via FK. +func VaultIndexWrite(db *sql.DB, files []VaultFile, prune bool) (WriteReport, error) { + var report WriteReport + if len(files) == 0 && !prune { + return report, nil + } + + tx, err := db.Begin() + if err != nil { + return report, fmt.Errorf("vault_index_write: begin tx: %w", err) + } + defer func() { + if err != nil { + tx.Rollback() //nolint:errcheck + } + }() + + // Load existing rel_paths into a set to distinguish insert vs update. + existing := make(map[string]struct{}) + rows, err := tx.Query(`SELECT rel_path FROM files`) + if err != nil { + return report, fmt.Errorf("vault_index_write: query existing: %w", err) + } + for rows.Next() { + var rp string + if err := rows.Scan(&rp); err != nil { + rows.Close() + return report, fmt.Errorf("vault_index_write: scan existing: %w", err) + } + existing[rp] = struct{}{} + } + rows.Close() + if err := rows.Err(); err != nil { + return report, fmt.Errorf("vault_index_write: rows err: %w", err) + } + + now := time.Now().Unix() + + upsertStmt, err := tx.Prepare(` + INSERT INTO files (rel_path, size, mtime, sha256, mime, ext, bucket, sub_bucket, indexed_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(rel_path) DO UPDATE SET + size = excluded.size, + mtime = excluded.mtime, + sha256 = excluded.sha256, + mime = excluded.mime, + ext = excluded.ext, + bucket = excluded.bucket, + sub_bucket = excluded.sub_bucket, + indexed_at = excluded.indexed_at + `) + if err != nil { + return report, fmt.Errorf("vault_index_write: prepare upsert: %w", err) + } + defer upsertStmt.Close() + + ftsDeleteStmt, err := tx.Prepare(`DELETE FROM files_fts WHERE rel_path = ?`) + if err != nil { + return report, fmt.Errorf("vault_index_write: prepare fts delete: %w", err) + } + defer ftsDeleteStmt.Close() + + ftsInsertStmt, err := tx.Prepare(`INSERT INTO files_fts(rel_path, content_text) VALUES (?, '')`) + if err != nil { + return report, fmt.Errorf("vault_index_write: prepare fts insert: %w", err) + } + defer ftsInsertStmt.Close() + + for _, f := range files { + _, err = upsertStmt.Exec( + f.RelPath, f.Size, f.Mtime, f.Sha256, + f.Mime, f.Ext, f.Bucket, f.SubBucket, now, + ) + if err != nil { + return report, fmt.Errorf("vault_index_write: upsert %q: %w", f.RelPath, err) + } + + if _, wasExisting := existing[f.RelPath]; wasExisting { + report.Updated++ + } else { + report.Inserted++ + } + + // Refresh FTS row. + if _, err = ftsDeleteStmt.Exec(f.RelPath); err != nil { + return report, fmt.Errorf("vault_index_write: fts delete %q: %w", f.RelPath, err) + } + if _, err = ftsInsertStmt.Exec(f.RelPath); err != nil { + return report, fmt.Errorf("vault_index_write: fts insert %q: %w", f.RelPath, err) + } + report.FTS++ + } + + // Prune rows not present in the incoming slice. + if prune && len(files) > 0 { + keep := make([]string, len(files)) + for i, f := range files { + keep[i] = "'" + strings.ReplaceAll(f.RelPath, "'", "''") + "'" + } + inClause := strings.Join(keep, ",") + res, err := tx.Exec(fmt.Sprintf( + `DELETE FROM files WHERE rel_path NOT IN (%s)`, inClause, + )) + if err != nil { + return report, fmt.Errorf("vault_index_write: prune: %w", err) + } + n, _ := res.RowsAffected() + report.Pruned = int(n) + } else if prune && len(files) == 0 { + // prune=true with empty slice means delete everything. + res, err := tx.Exec(`DELETE FROM files`) + if err != nil { + return report, fmt.Errorf("vault_index_write: prune all: %w", err) + } + n, _ := res.RowsAffected() + report.Pruned = int(n) + } + + if err = tx.Commit(); err != nil { + return report, fmt.Errorf("vault_index_write: commit: %w", err) + } + return report, nil +} diff --git a/functions/infra/vault_index_write.md b/functions/infra/vault_index_write.md new file mode 100644 index 00000000..41122180 --- /dev/null +++ b/functions/infra/vault_index_write.md @@ -0,0 +1,84 @@ +--- +name: vault_index_write +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultIndexWrite(db *sql.DB, files []VaultFile, prune bool) (WriteReport, error)" +description: "Upserta un slice de VaultFile en vault_index.db (tabla files + FTS5 files_fts) dentro de una sola transaccion. Cuenta Inserted/Updated/FTS. Con prune=true elimina filas no presentes en el slice." +tags: [vault, sqlite, index, write, upsert, fts, infra] +uses_functions: [] +uses_types: ["vault_file_go_infra"] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [database/sql, fmt, strings, time] +params: + - name: db + desc: "*sql.DB abierto sobre vault_index.db (tipicamente retornado por VaultIndexOpen)" + - name: files + desc: "slice de VaultFile a insertar/actualizar; puede ser vacio" + - name: prune + desc: "si true, elimina de 'files' todas las filas cuyo rel_path no este en el slice (sincronizacion destructiva)" +output: "WriteReport con conteos Inserted/Updated/Pruned/FTS; error si falla la transaccion" +tested: true +tests: + - "N archivos nuevos — Inserted=N" + - "re-escritura con mtime distinto — Updated=N" + - "prune elimina filas ausentes" + - "sin prune, filas previas persisten" + - "FTS5 MATCH funciona tras escritura" +test_file_path: "functions/infra/vault_index_write_test.go" +file_path: "functions/infra/vault_index_write.go" +--- + +## Ejemplo + +```go +db, _ := VaultIndexOpen("/data/vaults/turismo") +defer db.Close() + +files, _ := VaultInventoryScan("/data/vaults/turismo", "turismo_v1", "turismo") +report, err := VaultIndexWrite(db, files, true) +if err != nil { + log.Fatal(err) +} +fmt.Printf("inserted=%d updated=%d pruned=%d fts=%d\n", + report.Inserted, report.Updated, report.Pruned, report.FTS) +``` + +## Notas + +### WriteReport +Struct local al paquete infra: +```go +type WriteReport struct { + Inserted int + Updated int + Pruned int + FTS int +} +``` + +### Estrategia de conteo Inserted vs Updated +Se carga el conjunto de rel_paths existentes en un map antes del loop. Un upsert +se clasifica como Inserted si el rel_path no estaba en el map, Updated si estaba. +Esto evita N+1 SELECTs y es correcto porque la transaccion serializa los cambios. + +### FTS5 +`files_fts` usa `content=''` (tabla de contenido externo vacio). Para cada archivo +se borra la fila FTS existente y se reinserta con `content_text=''`. Los profilers +posteriores (csv_profiles, knowledge_docs) son responsables de actualizar +`content_text` con texto indexable real. + +### Prune +Con `prune=true` se construye un IN clause con los rel_paths del slice. La FK con +`ON DELETE CASCADE` propaga el DELETE a csv_profiles, pdf_extracts y knowledge_docs +automaticamente. Con slice vacio + prune=true se borra todo (DELETE FROM files). + +### Escapado SQL +El IN clause se construye escapando las comillas simples en rel_path (duplicandolas). +Evita inyeccion en rutas con apostrofos. Para entornos con rutas controladas +(interior de vaults sin apostrofos) esto es suficiente; para entornos adversariales +usar parametros binding con VALUES multiples via prepared statement. diff --git a/functions/infra/vault_index_write_test.go b/functions/infra/vault_index_write_test.go new file mode 100644 index 00000000..08854a3e --- /dev/null +++ b/functions/infra/vault_index_write_test.go @@ -0,0 +1,210 @@ +package infra + +import ( + "testing" + "time" +) + +// makeTestVaultFile creates a minimal VaultFile for testing. +func makeTestVaultFile(relPath, mime, bucket, subBucket string) VaultFile { + return VaultFile{ + VaultID: "test_vault", + VaultName: "test", + RelPath: relPath, + Size: 100, + Mtime: time.Now().Unix(), + Sha256: "abc123def456abc123def456abc123def456abc123def456abc123def456abc1", + Mime: mime, + Ext: ".csv", + Bucket: bucket, + SubBucket: subBucket, + } +} + +func openInMemoryVaultIndex(t *testing.T) interface{ Close() error } { + t.Helper() + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + return db +} + +func TestVaultIndexWrite_FreshInsert(t *testing.T) { + t.Run("N archivos nuevos — Inserted=N", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + files := []VaultFile{ + makeTestVaultFile("data/raw/a.csv", "text/csv", "data", "raw"), + makeTestVaultFile("data/raw/b.csv", "text/csv", "data", "raw"), + makeTestVaultFile("knowledge/decisions/x.md", "text/markdown", "knowledge", "decisions"), + } + + report, err := VaultIndexWrite(db, files, false) + if err != nil { + t.Fatalf("VaultIndexWrite: %v", err) + } + if report.Inserted != 3 { + t.Errorf("Inserted = %d, want 3", report.Inserted) + } + if report.Updated != 0 { + t.Errorf("Updated = %d, want 0", report.Updated) + } + if report.Pruned != 0 { + t.Errorf("Pruned = %d, want 0", report.Pruned) + } + if report.FTS != 3 { + t.Errorf("FTS = %d, want 3", report.FTS) + } + }) +} + +func TestVaultIndexWrite_Upsert(t *testing.T) { + t.Run("re-escritura con mtime distinto — Updated=N", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + files := []VaultFile{ + makeTestVaultFile("data/raw/a.csv", "text/csv", "data", "raw"), + makeTestVaultFile("data/raw/b.csv", "text/csv", "data", "raw"), + } + + if _, err := VaultIndexWrite(db, files, false); err != nil { + t.Fatalf("first write: %v", err) + } + + // Modify mtime to simulate file change. + files[0].Mtime = time.Now().Unix() + 100 + files[1].Mtime = time.Now().Unix() + 200 + + report, err := VaultIndexWrite(db, files, false) + if err != nil { + t.Fatalf("second write: %v", err) + } + if report.Inserted != 0 { + t.Errorf("Inserted = %d, want 0", report.Inserted) + } + if report.Updated != 2 { + t.Errorf("Updated = %d, want 2", report.Updated) + } + }) +} + +func TestVaultIndexWrite_Prune(t *testing.T) { + t.Run("prune elimina filas ausentes", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Write A and B. + ab := []VaultFile{ + makeTestVaultFile("data/raw/a.csv", "text/csv", "data", "raw"), + makeTestVaultFile("data/raw/b.csv", "text/csv", "data", "raw"), + } + if _, err := VaultIndexWrite(db, ab, false); err != nil { + t.Fatalf("first write: %v", err) + } + + // Write only A with prune=true — B should be deleted. + onlyA := []VaultFile{ab[0]} + report, err := VaultIndexWrite(db, onlyA, true) + if err != nil { + t.Fatalf("prune write: %v", err) + } + if report.Pruned != 1 { + t.Errorf("Pruned = %d, want 1", report.Pruned) + } + + // Verify B is gone. + var count int + err = db.QueryRow(`SELECT count(*) FROM files WHERE rel_path = 'data/raw/b.csv'`).Scan(&count) + if err != nil { + t.Fatalf("query: %v", err) + } + if count != 0 { + t.Errorf("b.csv still present after prune") + } + }) +} + +func TestVaultIndexWrite_NoPrune(t *testing.T) { + t.Run("sin prune, filas previas persisten", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ab := []VaultFile{ + makeTestVaultFile("data/raw/a.csv", "text/csv", "data", "raw"), + makeTestVaultFile("data/raw/b.csv", "text/csv", "data", "raw"), + } + if _, err := VaultIndexWrite(db, ab, false); err != nil { + t.Fatalf("first write: %v", err) + } + + // Write only A without prune — B must remain. + onlyA := []VaultFile{ab[0]} + report, err := VaultIndexWrite(db, onlyA, false) + if err != nil { + t.Fatalf("second write: %v", err) + } + if report.Pruned != 0 { + t.Errorf("Pruned = %d, want 0", report.Pruned) + } + + var count int + err = db.QueryRow(`SELECT count(*) FROM files`).Scan(&count) + if err != nil { + t.Fatalf("query: %v", err) + } + if count != 2 { + t.Errorf("files count = %d, want 2", count) + } + }) +} + +func TestVaultIndexWrite_FTSMatch(t *testing.T) { + t.Run("FTS5 MATCH funciona tras escritura", func(t *testing.T) { + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + files := []VaultFile{ + makeTestVaultFile("data/raw/foo_report.csv", "text/csv", "data", "raw"), + makeTestVaultFile("data/raw/bar_data.csv", "text/csv", "data", "raw"), + } + if _, err := VaultIndexWrite(db, files, false); err != nil { + t.Fatalf("write: %v", err) + } + + // FTS5 on rel_path column: MATCH 'foo*' + var count int + err = db.QueryRow( + `SELECT count(*) FROM files_fts WHERE files_fts MATCH 'rel_path:foo*'`, + ).Scan(&count) + if err != nil { + t.Fatalf("FTS MATCH query: %v", err) + } + if count != 1 { + t.Errorf("FTS MATCH rel_path:foo* = %d rows, want 1", count) + } + }) +} diff --git a/functions/infra/vault_inventory_scan.go b/functions/infra/vault_inventory_scan.go new file mode 100644 index 00000000..6a635147 --- /dev/null +++ b/functions/infra/vault_inventory_scan.go @@ -0,0 +1,174 @@ +package infra + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "sort" + "strings" +) + +// VaultInventoryScan walks vaultPath and returns a VaultFile slice (sorted by RelPath) +// for every regular file found, skipping: +// - vault_index.db, vault_index.db-shm, vault_index.db-wal +// - .git/ directories at any depth +// - hidden files/dirs (names starting with ".") at the vault root level only +// +// For each file it computes: relative path (forward slashes), size, mtime (unix UTC), +// sha256 (streaming, hex lowercase), MIME type, extension, bucket and sub-bucket. +// +// MIME detection priority: +// 1. Extension override: .csv → text/csv, .md → text/markdown, .parquet → application/parquet +// 2. http.DetectContentType on first 512 bytes (magic bytes, stdlib) +// +// NOTE: file_validate_type_go_infra (FileValidateType) was not used here because its +// signature requires an allowedTypes allowlist and returns (mime, bool) — it is designed +// for upload validation, not for open-ended inventory scanning where any MIME is valid. +// http.DetectContentType provides the same magic-byte detection without the allowlist +// coupling and handles a broader set of formats including text/plain for CSV fallback. +func VaultInventoryScan(vaultPath, vaultID, vaultName string) ([]VaultFile, error) { + var files []VaultFile + + err := filepath.WalkDir(vaultPath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + + name := d.Name() + + // Skip .git directories at any depth. + if d.IsDir() && name == ".git" { + return filepath.SkipDir + } + + // Skip hidden entries (names starting with ".") at vault root only. + if strings.HasPrefix(name, ".") { + rel, relErr := filepath.Rel(vaultPath, path) + if relErr == nil { + // At root level the relative path has no separator. + if !strings.Contains(filepath.ToSlash(rel), "/") { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + } + } + + if d.IsDir() { + return nil + } + + // Skip vault_index.db and its WAL/SHM sidecar files. + if name == "vault_index.db" || name == "vault_index.db-shm" || name == "vault_index.db-wal" { + return nil + } + + rel, err := filepath.Rel(vaultPath, path) + if err != nil { + return fmt.Errorf("vault_inventory_scan: rel path for %q: %w", path, err) + } + rel = filepath.ToSlash(rel) + + info, err := d.Info() + if err != nil { + return fmt.Errorf("vault_inventory_scan: stat %q: %w", path, err) + } + + // Compute sha256 by streaming — avoids loading large files into memory. + sha, err := fileSha256(path) + if err != nil { + return fmt.Errorf("vault_inventory_scan: sha256 %q: %w", path, err) + } + + mime, err := detectVaultFileMime(path, name) + if err != nil { + return fmt.Errorf("vault_inventory_scan: mime %q: %w", path, err) + } + + ext := strings.ToLower(filepath.Ext(name)) + bucket, subBucket := vaultBucketParts(rel) + + files = append(files, VaultFile{ + VaultID: vaultID, + VaultName: vaultName, + RelPath: rel, + Size: info.Size(), + Mtime: info.ModTime().UTC().Unix(), + Sha256: sha, + Mime: mime, + Ext: ext, + Bucket: bucket, + SubBucket: subBucket, + }) + return nil + }) + if err != nil { + return nil, fmt.Errorf("vault_inventory_scan: walk %q: %w", vaultPath, err) + } + + sort.Slice(files, func(i, j int) bool { + return files[i].RelPath < files[j].RelPath + }) + return files, nil +} + +// fileSha256 computes the hex-lowercase SHA-256 of the file at path by streaming. +func fileSha256(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +// detectVaultFileMime returns the MIME type for a vault file. +// Extension overrides take priority; otherwise http.DetectContentType is used. +func detectVaultFileMime(path, name string) (string, error) { + ext := strings.ToLower(filepath.Ext(name)) + switch ext { + case ".csv": + return "text/csv", nil + case ".md": + return "text/markdown", nil + case ".parquet": + return "application/parquet", nil + } + + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + buf := make([]byte, 512) + n, err := f.Read(buf) + if err != nil && err != io.EOF { + return "", err + } + return http.DetectContentType(buf[:n]), nil +} + +// vaultBucketParts extracts the top-level bucket ("data" or "knowledge") and +// the second-level sub-bucket from a forward-slash relative path. +// Returns empty strings for files at vault root or with no recognisable bucket. +func vaultBucketParts(relPath string) (bucket, subBucket string) { + parts := strings.SplitN(relPath, "/", 3) + if len(parts) < 1 { + return "", "" + } + bucket = parts[0] + if len(parts) >= 2 { + subBucket = parts[1] + } + return bucket, subBucket +} diff --git a/functions/infra/vault_inventory_scan.md b/functions/infra/vault_inventory_scan.md new file mode 100644 index 00000000..f77c12c7 --- /dev/null +++ b/functions/infra/vault_inventory_scan.md @@ -0,0 +1,74 @@ +--- +name: vault_inventory_scan +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultInventoryScan(vaultPath, vaultID, vaultName string) ([]VaultFile, error)" +description: "Recorre vaultPath con filepath.WalkDir y retorna un slice de VaultFile ordenado por RelPath para cada archivo regular, computando sha256 por streaming, MIME por extension/magic y bucket/sub-bucket por posicion en el arbol." +tags: [vault, inventory, scan, filesystem, sha256, mime, infra] +uses_functions: [] +uses_types: ["vault_file_go_infra"] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [crypto/sha256, encoding/hex, fmt, io, net/http, os, path/filepath, sort, strings] +params: + - name: vaultPath + desc: "ruta absoluta o relativa al directorio raiz del vault" + - name: vaultID + desc: "identificador del vault (ej: turismo_spain_app_turismo) — se copia a cada VaultFile" + - name: vaultName + desc: "nombre legible del vault (ej: turismo_spain) — se copia a cada VaultFile" +output: "slice de VaultFile ordenado lexicograficamente por RelPath; slice vacio (no nil) si el vault esta vacio" +tested: true +tests: + - "tmpdir vacio retorna slice vacio" + - "data layout — bucket y sub_bucket correctos" + - "knowledge layout — bucket y sub_bucket correctos" + - "omite vault_index.db y .git" + - "sha256 determinista para mismo contenido" + - "orden lexicografico del resultado" +test_file_path: "functions/infra/vault_inventory_scan_test.go" +file_path: "functions/infra/vault_inventory_scan.go" +--- + +## Ejemplo + +```go +files, err := VaultInventoryScan("/data/vaults/turismo_spain", "turismo_spain_v1", "turismo_spain") +if err != nil { + log.Fatal(err) +} +for _, f := range files { + fmt.Printf("%s %s %s/%s\n", f.RelPath, f.Mime, f.Bucket, f.SubBucket) +} +``` + +## Notas + +### Archivos omitidos +- `vault_index.db`, `vault_index.db-shm`, `vault_index.db-wal` (siempre) +- `.git/` en cualquier profundidad (SkipDir) +- Entradas cuyo nombre empieza por `.` solo en la raiz del vault (nivel 0) + +### Deteccion de MIME +`file_validate_type_go_infra` (FileValidateType) no se usa porque su firma +requiere una lista blanca de tipos permitidos y retorna (mime, bool) — esta +disenada para validacion de uploads, no para escaneo inventarial donde +cualquier MIME es valido. Se usan en su lugar: + +1. Override por extension (prioridad alta): `.csv` → `text/csv`, `.md` → `text/markdown`, + `.parquet` → `application/parquet`. Necesario porque `http.DetectContentType` + clasifica CSV como `text/plain` y no conoce Parquet. +2. `http.DetectContentType` sobre primeros 512 bytes (magic bytes, stdlib) para el resto. + +### SHA-256 +Calculado por streaming con `io.Copy` a `sha256.New()` — no carga el archivo completo +a memoria. Valido para archivos de cualquier tamano. + +### Bucket / SubBucket +Derivados de la posicion en el arbol: +- `bucket` = primer segmento del RelPath (tipicamente "data" o "knowledge") +- `subBucket` = segundo segmento si existe; vacio si el archivo esta en la raiz del bucket diff --git a/functions/infra/vault_inventory_scan_test.go b/functions/infra/vault_inventory_scan_test.go new file mode 100644 index 00000000..01ba4718 --- /dev/null +++ b/functions/infra/vault_inventory_scan_test.go @@ -0,0 +1,182 @@ +package infra + +import ( + "os" + "path/filepath" + "testing" +) + +func writeTestFile(t *testing.T, dir, rel, content string) { + t.Helper() + full := filepath.Join(dir, filepath.FromSlash(rel)) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", filepath.Dir(full), err) + } + if err := os.WriteFile(full, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", full, err) + } +} + +func TestVaultInventoryScan_Empty(t *testing.T) { + t.Run("tmpdir vacio retorna slice vacio", func(t *testing.T) { + dir := t.TempDir() + files, err := VaultInventoryScan(dir, "v1", "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 0 { + t.Errorf("expected 0 files, got %d", len(files)) + } + }) +} + +func TestVaultInventoryScan_DataLayout(t *testing.T) { + t.Run("data layout — bucket y sub_bucket correctos", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, dir, "data/raw/a.csv", "col1,col2\n1,2\n") + writeTestFile(t, dir, "data/processed/b.parquet", "PAR1fakedata") + + files, err := VaultInventoryScan(dir, "vid", "vname") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + + // files are sorted: data/processed/b.parquet < data/raw/a.csv + b := files[0] + if b.RelPath != "data/processed/b.parquet" { + t.Errorf("files[0].RelPath = %q, want data/processed/b.parquet", b.RelPath) + } + if b.Bucket != "data" { + t.Errorf("files[0].Bucket = %q, want data", b.Bucket) + } + if b.SubBucket != "processed" { + t.Errorf("files[0].SubBucket = %q, want processed", b.SubBucket) + } + if b.Mime != "application/parquet" { + t.Errorf("files[0].Mime = %q, want application/parquet", b.Mime) + } + if b.Ext != ".parquet" { + t.Errorf("files[0].Ext = %q, want .parquet", b.Ext) + } + if b.VaultID != "vid" { + t.Errorf("VaultID = %q, want vid", b.VaultID) + } + + a := files[1] + if a.RelPath != "data/raw/a.csv" { + t.Errorf("files[1].RelPath = %q, want data/raw/a.csv", a.RelPath) + } + if a.Mime != "text/csv" { + t.Errorf("files[1].Mime = %q, want text/csv", a.Mime) + } + if a.Bucket != "data" || a.SubBucket != "raw" { + t.Errorf("files[1]: bucket=%q subBucket=%q, want data/raw", a.Bucket, a.SubBucket) + } + }) +} + +func TestVaultInventoryScan_KnowledgeLayout(t *testing.T) { + t.Run("knowledge layout — bucket y sub_bucket correctos", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, dir, "knowledge/decisions/x.md", "# Decision\n\ncontent") + + files, err := VaultInventoryScan(dir, "vid", "vname") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + f := files[0] + if f.RelPath != "knowledge/decisions/x.md" { + t.Errorf("RelPath = %q", f.RelPath) + } + if f.Bucket != "knowledge" { + t.Errorf("Bucket = %q, want knowledge", f.Bucket) + } + if f.SubBucket != "decisions" { + t.Errorf("SubBucket = %q, want decisions", f.SubBucket) + } + if f.Mime != "text/markdown" { + t.Errorf("Mime = %q, want text/markdown", f.Mime) + } + }) +} + +func TestVaultInventoryScan_SkipsIndexAndGit(t *testing.T) { + t.Run("omite vault_index.db y .git", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, dir, "vault_index.db", "sqlite data") + writeTestFile(t, dir, "vault_index.db-wal", "wal data") + writeTestFile(t, dir, ".git/HEAD", "ref: refs/heads/master") + writeTestFile(t, dir, "data/raw/real.csv", "a,b\n1,2\n") + + files, err := VaultInventoryScan(dir, "vid", "vname") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 file (real.csv), got %d: %v", len(files), relPaths(files)) + } + if files[0].RelPath != "data/raw/real.csv" { + t.Errorf("unexpected file: %q", files[0].RelPath) + } + }) +} + +func TestVaultInventoryScan_Sha256Deterministic(t *testing.T) { + t.Run("sha256 determinista para mismo contenido", func(t *testing.T) { + dir1 := t.TempDir() + dir2 := t.TempDir() + content := "deterministic content 123\n" + writeTestFile(t, dir1, "data/raw/f.csv", content) + writeTestFile(t, dir2, "data/raw/f.csv", content) + + files1, err := VaultInventoryScan(dir1, "v1", "vault1") + if err != nil { + t.Fatal(err) + } + files2, err := VaultInventoryScan(dir2, "v2", "vault2") + if err != nil { + t.Fatal(err) + } + if files1[0].Sha256 != files2[0].Sha256 { + t.Errorf("sha256 mismatch: %q vs %q", files1[0].Sha256, files2[0].Sha256) + } + if len(files1[0].Sha256) != 64 { + t.Errorf("sha256 length = %d, want 64", len(files1[0].Sha256)) + } + }) +} + +func TestVaultInventoryScan_Sorted(t *testing.T) { + t.Run("orden lexicografico del resultado", func(t *testing.T) { + dir := t.TempDir() + writeTestFile(t, dir, "knowledge/decisions/z.md", "z") + writeTestFile(t, dir, "data/raw/a.csv", "a") + writeTestFile(t, dir, "data/processed/m.parquet", "m") + writeTestFile(t, dir, "knowledge/domains/b.md", "b") + + files, err := VaultInventoryScan(dir, "v", "v") + if err != nil { + t.Fatal(err) + } + for i := 1; i < len(files); i++ { + if files[i].RelPath < files[i-1].RelPath { + t.Errorf("not sorted at index %d: %q < %q", i, files[i].RelPath, files[i-1].RelPath) + } + } + }) +} + +// relPaths is a helper for test error messages. +func relPaths(files []VaultFile) []string { + out := make([]string, len(files)) + for i, f := range files { + out[i] = f.RelPath + } + return out +} diff --git a/functions/infra/vault_layout_ensure.go b/functions/infra/vault_layout_ensure.go new file mode 100644 index 00000000..55a59f41 --- /dev/null +++ b/functions/infra/vault_layout_ensure.go @@ -0,0 +1,252 @@ +package infra + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// LayoutReport describes what VaultLayoutEnsure did (or would do) to a vault directory. +type LayoutReport struct { + VaultPath string `json:"vault_path"` + Created []string `json:"created"` // dirs created (relative paths) + Migrated []string `json:"migrated"` // renames executed, format "src -> dst" (relative) + AlreadyOK []string `json:"already_ok"` // dirs that already existed at the target location + Skipped []string `json:"skipped"` // unrecognized root-level entries, left untouched + DryRun bool `json:"dry_run"` +} + +// dataBuckets are root-level directories that belong under data/. +var dataBuckets = []string{"raw", "processed", "exports"} + +// knowledgeBuckets are root-level directories that belong under knowledge/. +var knowledgeBuckets = []string{"decisions", "domains", "models", "benchmarks", "test_documents"} + +// knownRootFiles are root-level files that should be moved to knowledge/. +var knownRootFiles = []string{"README.md", "README.txt"} + +// VaultLayoutEnsure ensures a vault directory uses the canonical hybrid layout: +// +// data/{raw,processed,exports} +// knowledge/{decisions,domains,models,benchmarks,test_documents} +// +// Legacy vaults that have these directories at the root are migrated by renaming +// (or merging when both src and dst already exist). The operation is idempotent: +// a second run returns everything in AlreadyOK. +// +// When dryRun is true the function computes the report but does not touch the disk. +func VaultLayoutEnsure(vaultPath string, dryRun bool) (LayoutReport, error) { + report := LayoutReport{DryRun: dryRun} + + // --- resolve path --- + vaultPath = strings.TrimRight(vaultPath, "/\\") + + var err error + vaultPath, err = filepath.Abs(vaultPath) + if err != nil { + return report, fmt.Errorf("vault_layout_ensure: abs(%q): %w", vaultPath, err) + } + + // Follow symlinks for the vault root itself. + resolved, err := filepath.EvalSymlinks(vaultPath) + if err != nil { + return report, fmt.Errorf("vault_layout_ensure: eval symlinks %q: %w", vaultPath, err) + } + vaultPath = resolved + report.VaultPath = vaultPath + + // --- check that vault exists and is a directory --- + info, err := os.Stat(vaultPath) + if err != nil { + return report, fmt.Errorf("vault_layout_ensure: stat %q: %w", vaultPath, err) + } + if !info.IsDir() { + return report, fmt.Errorf("vault_layout_ensure: %q is not a directory", vaultPath) + } + + // --- ensure top-level containers --- + for _, container := range []string{"data", "knowledge"} { + dst := filepath.Join(vaultPath, container) + if err := ensureDir(dst, dryRun, container, &report); err != nil { + return report, err + } + } + + // --- build migration table: root name -> relative destination --- + type migration struct { + rootName string // name in vault root (dir or file) + dstRel string // relative destination path inside vault + isFile bool + } + + var migrations []migration + for _, b := range dataBuckets { + migrations = append(migrations, migration{rootName: b, dstRel: filepath.Join("data", b)}) + } + for _, b := range knowledgeBuckets { + migrations = append(migrations, migration{rootName: b, dstRel: filepath.Join("knowledge", b)}) + } + for _, rf := range knownRootFiles { + migrations = append(migrations, migration{rootName: rf, dstRel: filepath.Join("knowledge", "README.md"), isFile: true}) + } + + // Track which root names are "known" so we can compute Skipped. + knownNames := make(map[string]struct{}) + for _, m := range migrations { + knownNames[strings.ToLower(m.rootName)] = struct{}{} + } + knownNames["data"] = struct{}{} + knownNames["knowledge"] = struct{}{} + + // --- apply migrations --- + for _, m := range migrations { + src := filepath.Join(vaultPath, m.rootName) + dst := filepath.Join(vaultPath, m.dstRel) + srcRel := m.rootName + dstRel := m.dstRel + + srcExists := pathExists(src) + dstExists := pathExists(dst) + + switch { + case srcExists && dstExists: + // Both exist: merge if directory, error on file collision. + if m.isFile { + return report, fmt.Errorf("vault_layout_ensure: conflict: both %q and %q exist", srcRel, dstRel) + } + if err := mergeDirs(src, dst, srcRel, dstRel, dryRun, &report); err != nil { + return report, err + } + + case srcExists && !dstExists: + // Only source exists: rename. + report.Migrated = append(report.Migrated, fmt.Sprintf("%s -> %s", srcRel, dstRel)) + if !dryRun { + if err := os.Rename(src, dst); err != nil { + return report, fmt.Errorf("vault_layout_ensure: rename %q -> %q: %w", src, dst, err) + } + } + + case !srcExists && dstExists: + // Already migrated. + report.AlreadyOK = append(report.AlreadyOK, dstRel) + + default: + // Neither exists: create empty destination directory (skip for files). + if !m.isFile { + report.Created = append(report.Created, dstRel) + if !dryRun { + if err := os.MkdirAll(dst, 0o755); err != nil { + return report, fmt.Errorf("vault_layout_ensure: mkdir %q: %w", dst, err) + } + } + } + } + } + + // --- collect skipped (unrecognized root entries) --- + entries, err := os.ReadDir(vaultPath) + if err != nil { + return report, fmt.Errorf("vault_layout_ensure: readdir %q: %w", vaultPath, err) + } + for _, e := range entries { + if _, known := knownNames[strings.ToLower(e.Name())]; !known { + report.Skipped = append(report.Skipped, e.Name()) + } + } + + return report, nil +} + +// ensureDir adds the dir to Created (and creates it) if it doesn't exist, +// or to AlreadyOK if it does. Used for top-level containers "data" and "knowledge". +func ensureDir(path string, dryRun bool, rel string, report *LayoutReport) error { + if pathExists(path) { + report.AlreadyOK = append(report.AlreadyOK, rel) + return nil + } + report.Created = append(report.Created, rel) + if dryRun { + return nil + } + if err := os.MkdirAll(path, 0o755); err != nil { + return fmt.Errorf("vault_layout_ensure: mkdir %q: %w", path, err) + } + return nil +} + +// mergeDirs moves the contents of src into dst, then removes src if empty. +// Returns an error if any file in src already exists in dst (no overwrite policy). +func mergeDirs(src, dst, srcRel, dstRel string, dryRun bool, report *LayoutReport) error { + children, err := os.ReadDir(src) + if err != nil { + return fmt.Errorf("vault_layout_ensure: readdir %q: %w", src, err) + } + + for _, child := range children { + childDst := filepath.Join(dst, child.Name()) + if pathExists(childDst) { + return fmt.Errorf("vault_layout_ensure: merge conflict: %q already exists in %q (cannot overwrite %q)", + child.Name(), dstRel, filepath.Join(srcRel, child.Name())) + } + childSrc := filepath.Join(src, child.Name()) + childSrcRel := filepath.Join(srcRel, child.Name()) + childDstRel := filepath.Join(dstRel, child.Name()) + report.Migrated = append(report.Migrated, fmt.Sprintf("%s -> %s", childSrcRel, childDstRel)) + if !dryRun { + if err := os.Rename(childSrc, childDst); err != nil { + return fmt.Errorf("vault_layout_ensure: rename %q -> %q: %w", childSrc, childDst, err) + } + } + } + + // Remove the now-empty src directory. + if !dryRun { + // Re-check emptiness after renames. + remaining, _ := os.ReadDir(src) + if len(remaining) == 0 { + if err := os.Remove(src); err != nil { + return fmt.Errorf("vault_layout_ensure: remove empty src %q: %w", src, err) + } + } + } + return nil +} + +// pathExists returns true if path exists (any type). +func pathExists(path string) bool { + _, err := os.Lstat(path) + return err == nil +} + +// dirIsEmpty returns true if a directory exists and has no entries. +func dirIsEmpty(path string) bool { + entries, err := os.ReadDir(path) + if err != nil { + return false + } + return len(entries) == 0 +} + +// _ prevents "declared but not used" if dirIsEmpty is only used in tests. +var _ = dirIsEmpty + +// vaultLayoutKnownNames returns the set of root-level names managed by this function. +// Exported for use in tests. +func vaultLayoutKnownNames() map[string]struct{} { + known := make(map[string]struct{}) + for _, b := range dataBuckets { + known[b] = struct{}{} + } + for _, b := range knowledgeBuckets { + known[b] = struct{}{} + } + for _, rf := range knownRootFiles { + known[strings.ToLower(rf)] = struct{}{} + } + known["data"] = struct{}{} + known["knowledge"] = struct{}{} + return known +} + diff --git a/functions/infra/vault_layout_ensure.md b/functions/infra/vault_layout_ensure.md new file mode 100644 index 00000000..ed584c38 --- /dev/null +++ b/functions/infra/vault_layout_ensure.md @@ -0,0 +1,95 @@ +--- +name: vault_layout_ensure +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultLayoutEnsure(vaultPath string, dryRun bool) (LayoutReport, error)" +description: "Normaliza el layout de un vault al esquema hibrido canónico data/{raw,processed,exports} + knowledge/{decisions,domains,models,benchmarks,test_documents}. Migra directorios legacy en la raíz del vault a su ubicación correcta; idempotente." +tags: [vault, layout, migration, infra, filesystem, idempotent] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: + - "fmt" + - "os" + - "path/filepath" + - "strings" +params: + - name: vault_path + desc: "Ruta al directorio raíz del vault. Puede ser absoluta, relativa o un symlink — se resuelve con filepath.Abs + filepath.EvalSymlinks. Trailing slashes se ignoran." + - name: dry_run + desc: "Si true, calcula el reporte completo (qué se crearía, migraría, etc.) pero no modifica el disco. Util para previsualizar antes de ejecutar." +output: "LayoutReport con: VaultPath (ruta resuelta), Created (dirs creados), Migrated (renombres ejecutados, formato 'src -> dst'), AlreadyOK (destinos que ya existían), Skipped (entradas en raíz no reconocidas, no tocadas), DryRun (flag). Error si el path no existe, no es directorio, o hay conflicto de merge (mismo nombre de archivo en src y dst)." +tested: true +tests: + - "TestVaultLayoutEnsure_DryRun_NoChange" + - "TestVaultLayoutEnsure_FreshDir_CreatesLayout" + - "TestVaultLayoutEnsure_LegacyDataLayout_Migrates" + - "TestVaultLayoutEnsure_LegacyKnowledgeLayout_Migrates" + - "TestVaultLayoutEnsure_AlreadyMigrated_Idempotent" + - "TestVaultLayoutEnsure_Mixed_PartialMigration" + - "TestVaultLayoutEnsure_MergeConflict_Errors" + - "TestVaultLayoutEnsure_UnknownFiles_Skipped" + - "TestVaultLayoutEnsure_NotADir_Errors" +test_file_path: "functions/infra/vault_layout_ensure_test.go" +file_path: "functions/infra/vault_layout_ensure.go" +--- + +## Ejemplo + +```go +// Previsualizar sin tocar disco: +report, err := VaultLayoutEnsure("/home/lucas/vaults/turismo_spain", true) +if err != nil { + log.Fatal(err) +} +fmt.Printf("Would migrate: %v\n", report.Migrated) +fmt.Printf("Would create: %v\n", report.Created) + +// Ejecutar la migración: +report, err = VaultLayoutEnsure("/home/lucas/vaults/turismo_spain", false) +if err != nil { + log.Fatalf("migration failed: %v", err) +} +fmt.Printf("Migrated: %v\n", report.Migrated) +fmt.Printf("Created: %v\n", report.Created) +fmt.Printf("Skipped: %v\n", report.Skipped) +``` + +## Comportamiento detallado + +**Directorios gestionados:** + +| Raíz (legacy) | Destino canónico | +|---|---| +| `raw/` | `data/raw/` | +| `processed/` | `data/processed/` | +| `exports/` | `data/exports/` | +| `decisions/` | `knowledge/decisions/` | +| `domains/` | `knowledge/domains/` | +| `models/` | `knowledge/models/` | +| `benchmarks/` | `knowledge/benchmarks/` | +| `test_documents/` | `knowledge/test_documents/` | +| `README.md` / `README.txt` | `knowledge/README.md` | + +**Lógica de migración (por cada entrada conocida):** + +- Solo `src` existe → rename atómico `src` → `dst`, registrado en `Migrated`. +- Solo `dst` existe → ya migrado, registrado en `AlreadyOK`. +- Ambos existen (dir) → merge: mueve cada hijo de `src/` a `dst/`; error si mismo nombre. Registrado en `Migrated` por hijo. +- Ambos existen (archivo README) → error inmediato con paths concretos. +- Ninguno existe → crea `dst` vacío, registrado en `Created`. + +**Archivos/dirs no reconocidos** en la raíz (`.git`, `vault_index.db`, archivos custom) se registran en `Skipped` y no se tocan. + +**Idempotencia:** segunda ejecución sobre un vault ya migrado reporta todo en `AlreadyOK` y no toca disco. + +## Notas + +`LayoutReport` es un tipo local de esta función (no un tipo del registry). El struct exportado vive en `functions/infra/vault_layout_ensure.go` junto con la función. + +Para aplicar la migración a múltiples vaults en batch, invocar desde un pipeline que lea los paths de `vault.yaml` (ver `vault_manifest_read_go_infra`) y llame a `VaultLayoutEnsure` en cada uno. diff --git a/functions/infra/vault_layout_ensure_test.go b/functions/infra/vault_layout_ensure_test.go new file mode 100644 index 00000000..171cfa8e --- /dev/null +++ b/functions/infra/vault_layout_ensure_test.go @@ -0,0 +1,394 @@ +package infra + +import ( + "os" + "path/filepath" + "testing" +) + +// mkVaultDir creates a temporary directory tree for tests. +// entries is a list of relative paths to create. +// Paths ending in "/" are directories; others are files with placeholder content. +func mkVaultDir(t *testing.T, entries []string) string { + t.Helper() + root := t.TempDir() + for _, e := range entries { + full := filepath.Join(root, filepath.FromSlash(e)) + if e[len(e)-1] == '/' { + if err := os.MkdirAll(full, 0o755); err != nil { + t.Fatalf("mkVaultDir: mkdir %q: %v", full, err) + } + } else { + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatalf("mkVaultDir: mkdir parent %q: %v", full, err) + } + if err := os.WriteFile(full, []byte("test\n"), 0o644); err != nil { + t.Fatalf("mkVaultDir: write %q: %v", full, err) + } + } + } + return root +} + +func TestVaultLayoutEnsure_DryRun_NoChange(t *testing.T) { + root := mkVaultDir(t, []string{ + "raw/", + "raw/file1.csv", + "processed/", + }) + + before := snapshotDir(t, root) + report, err := VaultLayoutEnsure(root, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !report.DryRun { + t.Error("DryRun flag not set in report") + } + after := snapshotDir(t, root) + if !mapEqual(before, after) { + t.Errorf("dry-run modified disk: before=%v after=%v", before, after) + } + // Should have planned a migration for raw and processed. + if len(report.Migrated) == 0 { + t.Error("expected Migrated to be non-empty in dry-run plan") + } +} + +func TestVaultLayoutEnsure_FreshDir_CreatesLayout(t *testing.T) { + root := mkVaultDir(t, []string{}) // empty vault + + report, err := VaultLayoutEnsure(root, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // All standard dirs should be created. + wantCreated := []string{ + "data", "knowledge", + filepath.Join("data", "raw"), + filepath.Join("data", "processed"), + filepath.Join("data", "exports"), + filepath.Join("knowledge", "decisions"), + filepath.Join("knowledge", "domains"), + filepath.Join("knowledge", "models"), + filepath.Join("knowledge", "benchmarks"), + filepath.Join("knowledge", "test_documents"), + } + createdSet := toSet(report.Created) + for _, w := range wantCreated { + if _, ok := createdSet[w]; !ok { + t.Errorf("expected Created to contain %q, got %v", w, report.Created) + } + } + + // All directories must actually exist on disk. + for _, w := range wantCreated { + full := filepath.Join(root, w) + info, err := os.Stat(full) + if err != nil { + t.Errorf("expected %q to exist: %v", full, err) + continue + } + if !info.IsDir() { + t.Errorf("%q should be a directory", full) + } + } +} + +func TestVaultLayoutEnsure_LegacyDataLayout_Migrates(t *testing.T) { + root := mkVaultDir(t, []string{ + "raw/", + "raw/file1.parquet", + "raw/file2.parquet", + "processed/", + "processed/clean.csv", + "exports/", + }) + + report, err := VaultLayoutEnsure(root, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // raw and processed should appear in Migrated (as dirs, top-level rename). + migratedSet := toSet(report.Migrated) + for _, pair := range []string{ + "raw -> " + filepath.Join("data", "raw"), + "processed -> " + filepath.Join("data", "processed"), + } { + if _, ok := migratedSet[pair]; !ok { + t.Errorf("expected Migrated to contain %q, got %v", pair, report.Migrated) + } + } + + // Files must have moved. + for _, f := range []string{ + filepath.Join("data", "raw", "file1.parquet"), + filepath.Join("data", "raw", "file2.parquet"), + filepath.Join("data", "processed", "clean.csv"), + } { + if _, err := os.Stat(filepath.Join(root, f)); err != nil { + t.Errorf("expected %q to exist after migration: %v", f, err) + } + } + // Old dirs must be gone. + for _, d := range []string{"raw", "processed"} { + if pathExists(filepath.Join(root, d)) { + t.Errorf("expected legacy dir %q to be removed", d) + } + } +} + +func TestVaultLayoutEnsure_LegacyKnowledgeLayout_Migrates(t *testing.T) { + root := mkVaultDir(t, []string{ + "decisions/", + "decisions/2024-01.md", + "models/", + "models/ner_v1.pkl", + "README.md", + }) + + report, err := VaultLayoutEnsure(root, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // decisions and models should appear in Migrated. + migratedSet := toSet(report.Migrated) + for _, pair := range []string{ + "decisions -> " + filepath.Join("knowledge", "decisions"), + "models -> " + filepath.Join("knowledge", "models"), + "README.md -> " + filepath.Join("knowledge", "README.md"), + } { + if _, ok := migratedSet[pair]; !ok { + t.Errorf("expected Migrated to contain %q, got %v", pair, report.Migrated) + } + } + + // Files must be at new location. + for _, f := range []string{ + filepath.Join("knowledge", "decisions", "2024-01.md"), + filepath.Join("knowledge", "models", "ner_v1.pkl"), + filepath.Join("knowledge", "README.md"), + } { + if _, err := os.Stat(filepath.Join(root, f)); err != nil { + t.Errorf("expected %q to exist after migration: %v", f, err) + } + } +} + +func TestVaultLayoutEnsure_AlreadyMigrated_Idempotent(t *testing.T) { + root := mkVaultDir(t, []string{ + "data/", + "data/raw/", + "data/raw/file.csv", + "data/processed/", + "data/exports/", + "knowledge/", + "knowledge/decisions/", + "knowledge/domains/", + "knowledge/models/", + "knowledge/benchmarks/", + "knowledge/test_documents/", + }) + + report1, err := VaultLayoutEnsure(root, false) + if err != nil { + t.Fatalf("first run error: %v", err) + } + if len(report1.Migrated) != 0 { + t.Errorf("first run on fully-migrated vault should have no migrations, got %v", report1.Migrated) + } + + before := snapshotDir(t, root) + report2, err := VaultLayoutEnsure(root, false) + if err != nil { + t.Fatalf("second run error: %v", err) + } + after := snapshotDir(t, root) + + if !mapEqual(before, after) { + t.Error("second run modified disk (not idempotent)") + } + if len(report2.Migrated) != 0 { + t.Errorf("second run should produce no migrations, got %v", report2.Migrated) + } + if len(report2.AlreadyOK) == 0 { + t.Error("second run should report existing dirs as AlreadyOK") + } +} + +func TestVaultLayoutEnsure_Mixed_PartialMigration(t *testing.T) { + // data/raw already migrated; exports still at root; knowledge dirs in legacy positions. + root := mkVaultDir(t, []string{ + "data/", + "data/raw/", + "data/raw/already_here.csv", + "exports/", + "exports/report.pdf", + "decisions/", + "decisions/2023-note.md", + }) + + report, err := VaultLayoutEnsure(root, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // data/raw should be AlreadyOK. + if !sliceContains(report.AlreadyOK, filepath.Join("data", "raw")) { + t.Errorf("data/raw should be AlreadyOK, got AlreadyOK=%v", report.AlreadyOK) + } + // exports should be migrated. + exportsMigrated := false + for _, m := range report.Migrated { + if m == "exports -> "+filepath.Join("data", "exports") { + exportsMigrated = true + } + } + if !exportsMigrated { + t.Errorf("exports should be migrated, Migrated=%v", report.Migrated) + } + // decisions should be migrated. + decisionsMigrated := false + for _, m := range report.Migrated { + if m == "decisions -> "+filepath.Join("knowledge", "decisions") { + decisionsMigrated = true + } + } + if !decisionsMigrated { + t.Errorf("decisions should be migrated, Migrated=%v", report.Migrated) + } +} + +func TestVaultLayoutEnsure_MergeConflict_Errors(t *testing.T) { + // Both src (raw/) and dst (data/raw/) exist and have a file with the same name. + root := mkVaultDir(t, []string{ + "raw/", + "raw/collision.csv", + "data/", + "data/raw/", + "data/raw/collision.csv", // same name -> conflict + }) + + _, err := VaultLayoutEnsure(root, false) + if err == nil { + t.Fatal("expected error for merge conflict, got nil") + } + if !contains(err.Error(), "conflict") && !contains(err.Error(), "collision.csv") { + t.Errorf("error should mention conflict or the file name, got: %v", err) + } +} + +func TestVaultLayoutEnsure_UnknownFiles_Skipped(t *testing.T) { + root := mkVaultDir(t, []string{ + ".git/", + "vault_index.db", + "my_custom_notes.txt", + "raw/", + }) + + report, err := VaultLayoutEnsure(root, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + skippedSet := toSet(report.Skipped) + for _, name := range []string{".git", "vault_index.db", "my_custom_notes.txt"} { + if _, ok := skippedSet[name]; !ok { + t.Errorf("expected %q in Skipped, got %v", name, report.Skipped) + } + } + // raw should NOT be in Skipped (it's a known bucket). + if _, ok := skippedSet["raw"]; ok { + t.Error("raw should not appear in Skipped — it is a known bucket") + } +} + +func TestVaultLayoutEnsure_NotADir_Errors(t *testing.T) { + t.Run("non-existent path", func(t *testing.T) { + _, err := VaultLayoutEnsure("/tmp/does_not_exist_fn_registry_test_xyz", false) + if err == nil { + t.Fatal("expected error for non-existent path") + } + }) + + t.Run("path is a file", func(t *testing.T) { + f, err := os.CreateTemp("", "vault_layout_*.txt") + if err != nil { + t.Fatal(err) + } + f.Close() + defer os.Remove(f.Name()) + + _, err = VaultLayoutEnsure(f.Name(), false) + if err == nil { + t.Fatal("expected error when vaultPath is a file, not a dir") + } + if !contains(err.Error(), "not a directory") { + t.Errorf("error should mention 'not a directory', got: %v", err) + } + }) +} + +// --- helpers --- + +// snapshotDir returns a map of relative path -> exists for all entries under root. +func snapshotDir(t *testing.T, root string) map[string]bool { + t.Helper() + snap := make(map[string]bool) + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + rel, _ := filepath.Rel(root, path) + snap[rel] = true + return nil + }) + if err != nil { + t.Fatalf("snapshotDir: %v", err) + } + return snap +} + +func mapEqual(a, b map[string]bool) bool { + if len(a) != len(b) { + return false + } + for k := range a { + if !b[k] { + return false + } + } + return true +} + +func toSet(ss []string) map[string]struct{} { + m := make(map[string]struct{}, len(ss)) + for _, s := range ss { + m[s] = struct{}{} + } + return m +} + +func sliceContains(ss []string, target string) bool { + for _, s := range ss { + if s == target { + return true + } + } + return false +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(sub) == 0 || + func() bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false + }()) +} diff --git a/functions/infra/vault_manifest_read.go b/functions/infra/vault_manifest_read.go new file mode 100644 index 00000000..6968b552 --- /dev/null +++ b/functions/infra/vault_manifest_read.go @@ -0,0 +1,96 @@ +package infra + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" +) + +// VaultManifestEntry is a single vault entry parsed from a projects//vaults/vault.yaml. +type VaultManifestEntry struct { + ProjectID string // basename of projects//, inferred from manifest path + Name string // vault name as declared in vault.yaml + Description string // human description + Path string // absolute path to the vault directory + Tags []string // tags declared in vault.yaml + ManifestFile string // absolute path to the vault.yaml this entry came from +} + +// vaultYAML mirrors the vault.yaml schema (only the fields we care about). +type vaultYAML struct { + Vaults []struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Path string `yaml:"path"` + Tags []string `yaml:"tags"` + } `yaml:"vaults"` +} + +// VaultManifestRead globs all projects/*/vaults/vault.yaml under repoRoot, parses each +// manifest and returns a flat slice of VaultManifestEntry. +// +// Rules: +// - If a manifest fails to parse, an error is returned immediately with the file path. +// - If no manifests are found, an empty slice is returned (not an error). +// - ProjectID is inferred from the directory component between "projects/" and "/vaults/". +func VaultManifestRead(repoRoot string) ([]VaultManifestEntry, error) { + pattern := filepath.Join(repoRoot, "projects", "*", "vaults", "vault.yaml") + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("vault_manifest_read: glob %q: %w", pattern, err) + } + + var out []VaultManifestEntry + for _, manifestPath := range matches { + entries, err := parseVaultManifest(manifestPath) + if err != nil { + return nil, err + } + out = append(out, entries...) + } + return out, nil +} + +func parseVaultManifest(manifestPath string) ([]VaultManifestEntry, error) { + data, err := os.ReadFile(manifestPath) + if err != nil { + return nil, fmt.Errorf("vault_manifest_read: read %q: %w", manifestPath, err) + } + + var raw vaultYAML + if err := yaml.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("vault_manifest_read: parse %q: %w", manifestPath, err) + } + + projectID := inferProjectID(manifestPath) + + entries := make([]VaultManifestEntry, 0, len(raw.Vaults)) + for _, v := range raw.Vaults { + entries = append(entries, VaultManifestEntry{ + ProjectID: projectID, + Name: v.Name, + Description: v.Description, + Path: v.Path, + Tags: v.Tags, + ManifestFile: manifestPath, + }) + } + return entries, nil +} + +// inferProjectID extracts the project basename from a path of the form +// .../projects//vaults/vault.yaml. +func inferProjectID(manifestPath string) string { + // Normalize separators and split. + parts := strings.Split(filepath.ToSlash(manifestPath), "/") + // Walk backwards: vault.yaml -> vaults -> -> projects -> ... + for i, p := range parts { + if p == "projects" && i+1 < len(parts) { + return parts[i+1] + } + } + return "" +} diff --git a/functions/infra/vault_manifest_read.md b/functions/infra/vault_manifest_read.md new file mode 100644 index 00000000..9dacb33f --- /dev/null +++ b/functions/infra/vault_manifest_read.md @@ -0,0 +1,59 @@ +--- +name: vault_manifest_read +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultManifestRead(repoRoot string) ([]VaultManifestEntry, error)" +description: "Lee todos los manifests vault.yaml bajo projects/*/vaults/ del repo y devuelve una lista plana de entradas de vault con su ProjectID inferido del path." +tags: [vault, manifest, yaml, infra, projects, storage] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: + - "fmt" + - "os" + - "path/filepath" + - "strings" + - "gopkg.in/yaml.v3" +params: + - name: repoRoot + desc: "Ruta absoluta a la raiz del repositorio fn_registry. Se usa como base para el glob projects/*/vaults/vault.yaml." +output: "Slice plano de VaultManifestEntry (ProjectID, Name, Description, Path, Tags, ManifestFile). Vacio si no hay manifests. Error si un yaml no parsea, con el path concreto en el mensaje." +tested: true +tests: + - "TestVaultManifestRead_HappyPath" + - "TestVaultManifestRead_MalformedYAML" + - "TestVaultManifestRead_EmptyDir" +test_file_path: "functions/infra/vault_manifest_read_test.go" +file_path: "functions/infra/vault_manifest_read.go" +--- + +## Ejemplo + +```go +entries, err := VaultManifestRead("/home/lucas/fn_registry") +if err != nil { + log.Fatal(err) +} +for _, e := range entries { + fmt.Printf("%s/%s -> %s\n", e.ProjectID, e.Name, e.Path) +} +// app_turismo/turismo_spain -> /home/lucas/vaults/turismo_spain +// app_finance/finance_data -> /home/lucas/vaults/finance_data +``` + +## Notas + +`VaultManifestEntry` es un tipo local de esta funcion (no un tipo del registry). Contiene: +- `ProjectID` — basename del directorio `projects//`, inferido del path del manifest. +- `Name`, `Description`, `Path`, `Tags` — copiados del yaml tal cual. +- `ManifestFile` — path absoluto al vault.yaml de origen, util para mensajes de error y trazabilidad. + +El parseo usa `gopkg.in/yaml.v3` (ya en go.mod). Si un manifest falla, la funcion devuelve +error inmediatamente con el path del fichero problemático. Los manifests sin entradas +`vaults:` contribuyen cero entries (no es error). Si no existe ningun `projects/*/vaults/vault.yaml` +el resultado es slice vacio sin error. diff --git a/functions/infra/vault_manifest_read_test.go b/functions/infra/vault_manifest_read_test.go new file mode 100644 index 00000000..de67378f --- /dev/null +++ b/functions/infra/vault_manifest_read_test.go @@ -0,0 +1,113 @@ +package infra + +import ( + "os" + "path/filepath" + "testing" +) + +func TestVaultManifestRead_HappyPath(t *testing.T) { + root := t.TempDir() + + writeManifest(t, root, "app_turismo", ` +vaults: + - name: turismo_spain + description: "Datos de turismo en Espana" + path: "/home/lucas/vaults/turismo_spain" + tags: [turismo, espana] + - name: turismo_raw + description: "Datos brutos sin procesar" + path: "/home/lucas/vaults/turismo_raw" + tags: [raw] +`) + + writeManifest(t, root, "app_finance", ` +vaults: + - name: finance_data + description: "Datos financieros" + path: "/home/lucas/vaults/finance_data" + tags: [finance] +`) + + entries, err := VaultManifestRead(root) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 3 { + t.Fatalf("got %d entries, want 3", len(entries)) + } + + // Build index by name for order-independent assertions. + byName := make(map[string]VaultManifestEntry, len(entries)) + for _, e := range entries { + byName[e.Name] = e + } + + // Check turismo_spain entry. + e, ok := byName["turismo_spain"] + if !ok { + t.Fatal("missing entry 'turismo_spain'") + } + if e.ProjectID != "app_turismo" { + t.Errorf("turismo_spain.ProjectID = %q, want %q", e.ProjectID, "app_turismo") + } + if e.Path != "/home/lucas/vaults/turismo_spain" { + t.Errorf("turismo_spain.Path = %q, want %q", e.Path, "/home/lucas/vaults/turismo_spain") + } + if len(e.Tags) != 2 || e.Tags[0] != "turismo" { + t.Errorf("turismo_spain.Tags = %v, want [turismo espana]", e.Tags) + } + if e.ManifestFile == "" { + t.Error("turismo_spain.ManifestFile is empty") + } + + // Check finance_data entry belongs to app_finance. + ef, ok := byName["finance_data"] + if !ok { + t.Fatal("missing entry 'finance_data'") + } + if ef.ProjectID != "app_finance" { + t.Errorf("finance_data.ProjectID = %q, want %q", ef.ProjectID, "app_finance") + } +} + +func TestVaultManifestRead_MalformedYAML(t *testing.T) { + root := t.TempDir() + + writeManifest(t, root, "bad_project", ` +vaults: + - name: [invalid yaml + path: missing_bracket +`) + + _, err := VaultManifestRead(root) + if err == nil { + t.Fatal("expected error for malformed YAML, got nil") + } +} + +func TestVaultManifestRead_EmptyDir(t *testing.T) { + root := t.TempDir() + + // No projects/ directory at all — glob returns no matches. + entries, err := VaultManifestRead(root) + if err != nil { + t.Fatalf("unexpected error for empty dir: %v", err) + } + if len(entries) != 0 { + t.Fatalf("got %d entries, want 0", len(entries)) + } +} + +// writeManifest creates /projects//vaults/vault.yaml with the given content. +func writeManifest(t *testing.T, root, proj, content string) { + t.Helper() + dir := filepath.Join(root, "projects", proj, "vaults") + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir %s: %v", dir, err) + } + f := filepath.Join(dir, "vault.yaml") + if err := os.WriteFile(f, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", f, err) + } +} diff --git a/functions/infra/vault_search.go b/functions/infra/vault_search.go new file mode 100644 index 00000000..66b4612c --- /dev/null +++ b/functions/infra/vault_search.go @@ -0,0 +1,265 @@ +package infra + +import ( + "database/sql" + "fmt" + "path/filepath" + "strings" +) + +// VaultSearchHit is a single result returned by VaultSearch. +type VaultSearchHit struct { + VaultPath string `json:"vault_path"` + VaultName string `json:"vault_name"` // basename of VaultPath (after resolving symlinks) + RelPath string `json:"rel_path"` + Size int64 `json:"size"` + Mtime int64 `json:"mtime"` + Mime string `json:"mime"` + Bucket string `json:"bucket"` + SubBucket string `json:"sub_bucket"` + Snippet string `json:"snippet"` // FTS5 snippet or empty if match is only by rel_path (fallback) +} + +// VaultSearch searches vault_index.db inside vaultPath for files matching query. +// +// Behaviour: +// 1. Opens vault_index.db via VaultIndexOpen. +// 2. If limit <= 0, defaults to 50. +// 3. Runs a FTS5 MATCH query over files_fts to find content matches (when content_text +// is populated by profilers). Because the FTS5 table uses content='' (contentless), +// column values are not stored; results are correlated back to files via a LIKE +// match on rel_path for path tokens, or via an IN clause of matched rowids for +// content_text matches. +// 4. Also searches files.rel_path with LIKE to find path matches. +// 5. Results from both searches are merged (deduplication by rel_path). +// 6. If both FTS5 and LIKE queries fail, returns the error. +// 7. VaultName is derived from the basename of vaultPath (after resolving symlinks). +func VaultSearch(vaultPath, query string, limit int) ([]VaultSearchHit, error) { + if limit <= 0 { + limit = 50 + } + + db, err := VaultIndexOpen(vaultPath) + if err != nil { + return nil, fmt.Errorf("vault_search: open index: %w", err) + } + defer db.Close() + + vaultName := resolveVaultName(vaultPath) + + hits, err := vaultSearchCombined(db, vaultPath, vaultName, query, limit) + if err != nil { + return nil, fmt.Errorf("vault_search: %w", err) + } + return hits, nil +} + +// vaultSearchCombined runs the search using two strategies and merges deduplicated results: +// 1. FTS5 MATCH on files_fts (for content_text when populated by profilers). +// Correlation back to files uses rowid (reliable for fresh indexes) or falls back. +// 2. LIKE on files.rel_path (always reliable for path searching). +// +// Results are deduplicated by rel_path, up to limit entries. +func vaultSearchCombined(db *sql.DB, vaultPath, vaultName, query string, limit int) ([]VaultSearchHit, error) { + seen := make(map[string]struct{}) + var hits []VaultSearchHit + + // Strategy 1: FTS5 MATCH on content_text (populated by profilers). + // With contentless FTS5 (content=''), column values are NOT retrievable via SELECT. + // We get matching rowids from FTS5, then look up files by rowid. + // This is reliable for content_text matches because VaultIndexWrite inserts + // content_text rows independently of the path rows (profilers update them). + // NOTE: for rel_path token matching, strategy 2 (LIKE) is more reliable. + ftsQuery := safeFTSQuery(query) + ftsHits, ftsErr := vaultSearchFTSContent(db, vaultPath, vaultName, ftsQuery, limit) + if ftsErr == nil { + for _, h := range ftsHits { + if len(hits) >= limit { + break + } + if _, ok := seen[h.RelPath]; !ok { + seen[h.RelPath] = struct{}{} + hits = append(hits, h) + } + } + } + // If FTS5 failed with a syntax error, that's expected for bad queries — continue. + // If it failed with a non-syntax error, still continue to LIKE fallback. + + // Strategy 2: LIKE on rel_path — reliable path search. + // When query contains FTS5 special chars (e.g. "foo:bar:"), extract the first + // word-like token so the LIKE pattern is still useful. + likeQuery := simplifyForLike(query) + if len(hits) < limit && likeQuery != "" { + remaining := limit - len(hits) + likeHits, likeErr := vaultSearchLike(db, vaultPath, vaultName, likeQuery, remaining+len(seen)) + if likeErr != nil && ftsErr != nil { + // Both failed — return a combined error. + return nil, fmt.Errorf("fts: %v; like: %v", ftsErr, likeErr) + } + for _, h := range likeHits { + if len(hits) >= limit { + break + } + if _, ok := seen[h.RelPath]; !ok { + seen[h.RelPath] = struct{}{} + hits = append(hits, h) + } + } + } + + if hits == nil { + hits = []VaultSearchHit{} + } + return hits, nil +} + +// vaultSearchFTSContent queries files_fts with a MATCH and correlates results +// back to the files table. +// +// Design note: with content='' (contentless FTS5), SELECT on columns returns ''. +// We get the rowid from the FTS5 match and look up files.rel_path via rowid. +// This works correctly when content_text was populated by a profiler that did NOT +// delete+reinsert the FTS row (i.e. profilers do direct INSERT/UPDATE of content_text +// without changing the rowid). For the current VaultIndexWrite implementation +// (which inserts content_text='' and profilers update it in-place), the rowids +// remain stable after profiling. +func vaultSearchFTSContent(db *sql.DB, vaultPath, vaultName, safeQuery string, limit int) ([]VaultSearchHit, error) { + // Get matching rowids from FTS5. + const qRowids = ` + SELECT rowid + FROM files_fts + WHERE files_fts MATCH ? + ORDER BY rank + LIMIT ?` + + rows, err := db.Query(qRowids, safeQuery, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + var rowids []int64 + for rows.Next() { + var rid int64 + if err := rows.Scan(&rid); err != nil { + return nil, err + } + rowids = append(rowids, rid) + } + if err := rows.Err(); err != nil { + return nil, err + } + if len(rowids) == 0 { + return nil, nil + } + + // Look up files by rowid. files uses a TEXT PK so its rowid is implicit. + // Snippet is empty for contentless FTS5 (snippet() returns NULL there). + var hits []VaultSearchHit + for _, rid := range rowids { + var h VaultSearchHit + err := db.QueryRow(` + SELECT rel_path, size, mtime, mime, bucket, sub_bucket + FROM files WHERE rowid = ?`, rid, + ).Scan(&h.RelPath, &h.Size, &h.Mtime, &h.Mime, &h.Bucket, &h.SubBucket) + if err != nil { + // rowid mismatch (happens after update cycles) — skip gracefully. + continue + } + h.VaultPath = vaultPath + h.VaultName = vaultName + h.Snippet = "" + hits = append(hits, h) + } + return hits, nil +} + +// vaultSearchLike searches files.rel_path with LIKE, ordered by mtime DESC. +func vaultSearchLike(db *sql.DB, vaultPath, vaultName, query string, limit int) ([]VaultSearchHit, error) { + const qLike = ` + SELECT rel_path, size, mtime, mime, bucket, sub_bucket + FROM files + WHERE rel_path LIKE '%' || ? || '%' + ORDER BY mtime DESC + LIMIT ?` + + rows, err := db.Query(qLike, query, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + var hits []VaultSearchHit + for rows.Next() { + var h VaultSearchHit + if err := rows.Scan(&h.RelPath, &h.Size, &h.Mtime, &h.Mime, &h.Bucket, &h.SubBucket); err != nil { + return nil, err + } + h.VaultPath = vaultPath + h.VaultName = vaultName + h.Snippet = "" + hits = append(hits, h) + } + return hits, rows.Err() +} + +// resolveVaultName returns the basename of vaultPath after resolving symlinks. +// Falls back to filepath.Base if EvalSymlinks fails. +func resolveVaultName(vaultPath string) string { + resolved, err := filepath.EvalSymlinks(vaultPath) + if err != nil { + resolved = vaultPath + } + return filepath.Base(resolved) +} + +// safeFTSQuery wraps the query in double-quotes if it does not already contain +// FTS5 boolean operators (AND, OR, NOT) or column prefixes (containing ":"). +// This prevents FTS5 syntax errors on tokens like "foo:bar:" or "hello-world". +func safeFTSQuery(query string) string { + q := strings.TrimSpace(query) + if q == "" { + return q + } + upper := strings.ToUpper(q) + // If user already uses explicit operators or column prefix, pass through. + if strings.ContainsAny(q, ":") || + strings.Contains(upper, " AND ") || + strings.Contains(upper, " OR ") || + strings.Contains(upper, " NOT ") { + return q + } + // Escape any double-quotes in the query before wrapping. + escaped := strings.ReplaceAll(q, `"`, `""`) + return `"` + escaped + `"` +} + +// isFTSSyntaxError returns true when the error looks like an FTS5 query parser error. +func isFTSSyntaxError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "syntax error") || + strings.Contains(msg, "no such column") || + strings.Contains(msg, "fts5: syntax error") +} + +// simplifyForLike extracts a clean substring from query suitable for LIKE matching. +// When the query contains FTS5 special characters (colons, double-quotes, operators), +// only the first word-like sequence of alphanumeric/underscore/hyphen characters is +// used. This ensures the LIKE fallback remains useful even when the FTS5 query is +// syntactically complex or contains column-prefix syntax like "foo:bar:". +func simplifyForLike(query string) string { + q := strings.TrimSpace(query) + var token strings.Builder + for _, r := range q { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + token.WriteRune(r) + } else if token.Len() > 0 { + break + } + } + return token.String() +} diff --git a/functions/infra/vault_search.md b/functions/infra/vault_search.md new file mode 100644 index 00000000..100c7c2b --- /dev/null +++ b/functions/infra/vault_search.md @@ -0,0 +1,61 @@ +--- +name: vault_search +kind: function +lang: go +domain: infra +version: "1.0.0" +purity: impure +signature: "func VaultSearch(vaultPath, query string, limit int) ([]VaultSearchHit, error)" +description: "Busca en vault_index.db de un vault usando FTS5 sobre files_fts. Si el query rompe el parser FTS5, hace fallback a LIKE sobre rel_path. Retorna hits con snippet de contexto." +tags: [vault, search, fts5, sqlite, infra] +uses_functions: ["vault_index_open_go_infra"] +uses_types: ["vault_file_go_infra"] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [database/sql, fmt, path/filepath, strings] +params: + - name: vaultPath + desc: "ruta absoluta al directorio raiz del vault (puede ser symlink)" + - name: query + desc: "termino o frase de busqueda; se escapa automaticamente para FTS5 salvo que ya incluya operadores booleanos o prefijos de columna" + - name: limit + desc: "maximo de resultados; si es <= 0 se usa 50" +output: "slice de VaultSearchHit ordenado por rank FTS5 (o mtime DESC en fallback LIKE); slice vacio si no hay resultados" +tested: true +tests: + - "FTS match devuelve hit con snippet" + - "query sin resultados retorna slice vacio" + - "limit se respeta" + - "query FTS invalida activa fallback LIKE" + - "limit cero usa 50 por defecto" +test_file_path: "functions/infra/vault_search_test.go" +file_path: "functions/infra/vault_search.go" +--- + +## Ejemplo + +```go +hits, err := infra.VaultSearch("/home/lucas/vaults/turismo_spain", "hoteles", 20) +if err != nil { + log.Fatal(err) +} +for _, h := range hits { + fmt.Printf("[%s] %s %s\n", h.VaultName, h.RelPath, h.Snippet) +} +``` + +## Notas + +`VaultSearchHit` es un struct local definido en este archivo (no en `vault_file.go`) +porque combina campos de `files` + metadatos de contexto de busqueda (Snippet, VaultPath, VaultName). + +**FTS5 safety:** el helper `safeFTSQuery` envuelve la query en comillas dobles +cuando no contiene operadores booleanos ni prefijos de columna. Esto evita errores +del parser en tokens como `foo:bar:` o `hello-world`. + +**Fallback LIKE:** si el MATCH falla con un error de sintaxis FTS5, se ejecuta +`WHERE rel_path LIKE '%' || query || '%'`. Los hits del fallback tienen `Snippet=""`. + +**VaultName:** se deriva del `filepath.Base(filepath.EvalSymlinks(vaultPath))`. +Si `EvalSymlinks` falla (e.g. symlink roto), usa `filepath.Base(vaultPath)`. diff --git a/functions/infra/vault_search_test.go b/functions/infra/vault_search_test.go new file mode 100644 index 00000000..5041eb23 --- /dev/null +++ b/functions/infra/vault_search_test.go @@ -0,0 +1,147 @@ +package infra + +import ( + "testing" + "time" +) + +// openTestVaultDB creates a fresh vault_index.db in a temp dir and returns the path. +func openTestVaultDir(t *testing.T) string { + t.Helper() + dir := t.TempDir() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("VaultIndexOpen: %v", err) + } + db.Close() + return dir +} + +// seedVaultFile inserts a row into files + files_fts. +func seedVaultFile(t *testing.T, dir, relPath, mime, bucket, subBucket, contentText string, size int64) { + t.Helper() + db, err := VaultIndexOpen(dir) + if err != nil { + t.Fatalf("VaultIndexOpen seed: %v", err) + } + defer db.Close() + + now := time.Now().Unix() + _, err = db.Exec(` + INSERT INTO files (rel_path, size, mtime, sha256, mime, ext, bucket, sub_bucket, indexed_at) + VALUES (?, ?, ?, 'aabbccdd', ?, '', ?, ?, ?)`, + relPath, size, now, mime, bucket, subBucket, now, + ) + if err != nil { + t.Fatalf("seed files: %v", err) + } + _, err = db.Exec(`INSERT INTO files_fts(rel_path, content_text) VALUES (?, ?)`, relPath, contentText) + if err != nil { + t.Fatalf("seed files_fts: %v", err) + } +} + +// --- Tests --- + +func TestVaultSearch_FTSMatch(t *testing.T) { + t.Run("FTS match devuelve hit con snippet", func(t *testing.T) { + dir := openTestVaultDir(t) + seedVaultFile(t, dir, "data/raw/informe.csv", "text/csv", "data", "raw", + "ventas trimestrales empresa iberica", 1024) + seedVaultFile(t, dir, "data/raw/other.csv", "text/csv", "data", "raw", + "productos inventario almacen", 512) + + hits, err := VaultSearch(dir, "ventas", 10) + if err != nil { + t.Fatalf("VaultSearch: %v", err) + } + if len(hits) != 1 { + t.Fatalf("got %d hits, want 1", len(hits)) + } + if hits[0].RelPath != "data/raw/informe.csv" { + t.Errorf("RelPath = %q, want data/raw/informe.csv", hits[0].RelPath) + } + if hits[0].VaultName == "" { + t.Errorf("VaultName should not be empty") + } + }) +} + +func TestVaultSearch_NoMatch(t *testing.T) { + t.Run("query sin resultados retorna slice vacio", func(t *testing.T) { + dir := openTestVaultDir(t) + seedVaultFile(t, dir, "data/raw/file.csv", "text/csv", "data", "raw", "some content", 100) + + hits, err := VaultSearch(dir, "zzznomatch", 10) + if err != nil { + t.Fatalf("VaultSearch: %v", err) + } + if len(hits) != 0 { + t.Errorf("got %d hits, want 0", len(hits)) + } + }) +} + +func TestVaultSearch_LimitRespected(t *testing.T) { + t.Run("limit se respeta", func(t *testing.T) { + dir := openTestVaultDir(t) + for i := 0; i < 10; i++ { + path := "data/raw/file" + string(rune('a'+i)) + ".csv" + seedVaultFile(t, dir, path, "text/csv", "data", "raw", "common keyword everywhere", 100) + } + + hits, err := VaultSearch(dir, "common", 3) + if err != nil { + t.Fatalf("VaultSearch: %v", err) + } + if len(hits) != 3 { + t.Errorf("got %d hits, want 3", len(hits)) + } + }) +} + +func TestVaultSearch_BadFTSQuery_FallbackLike(t *testing.T) { + t.Run("query FTS invalida activa fallback LIKE", func(t *testing.T) { + dir := openTestVaultDir(t) + // Insert a file whose rel_path contains "foobar" so LIKE can find it. + seedVaultFile(t, dir, "data/raw/foobar_report.csv", "text/csv", "data", "raw", "", 200) + + // "foo:bar:" — colon after a non-column name triggers FTS5 parser error. + // safeFTSQuery passes it through unchanged because it contains ":" + // → FTS5 "no such column: bar" → fallback LIKE on rel_path. + hits, err := VaultSearch(dir, "foo:bar:", 10) + if err != nil { + t.Fatalf("VaultSearch: %v", err) + } + if len(hits) == 0 { + t.Errorf("expected fallback LIKE to find foobar_report.csv, got 0 hits") + } + for _, h := range hits { + if h.Snippet != "" { + t.Errorf("fallback hits should have empty Snippet, got %q", h.Snippet) + } + } + }) +} + +func TestVaultSearch_LimitZeroDefaults(t *testing.T) { + t.Run("limit cero usa 50 por defecto", func(t *testing.T) { + dir := openTestVaultDir(t) + // Insert 55 files with the same keyword. + for i := 0; i < 55; i++ { + path := "data/raw/doc" + string(rune('a')) + string(rune(int('0')+i%10)) + ".csv" + if i >= 10 { + path = "data/raw/doc" + string(rune('b'+i/10-1)) + string(rune(int('0')+i%10)) + ".csv" + } + seedVaultFile(t, dir, path, "text/csv", "data", "raw", "keyword alpha beta", 100) + } + + hits, err := VaultSearch(dir, "keyword", 0) + if err != nil { + t.Fatalf("VaultSearch: %v", err) + } + if len(hits) != 50 { + t.Errorf("got %d hits, want 50 (default limit)", len(hits)) + } + }) +} diff --git a/functions/ml/genconfig_json_marshal.go b/functions/ml/genconfig_json_marshal.go new file mode 100644 index 00000000..bdc9b367 --- /dev/null +++ b/functions/ml/genconfig_json_marshal.go @@ -0,0 +1,20 @@ +package ml + +import "encoding/json" + +// GenconfigMarshal serializa un GenerationConfig a JSON canonico con indent de 2 espacios. +// El formato es identico al de Python json.dumps(indent=2, sort_keys=False): +// keys en el orden de declaracion del struct, snake_case, campos omitempty ausentes si zero. +func GenconfigMarshal(cfg GenerationConfig) ([]byte, error) { + return json.MarshalIndent(cfg, "", " ") +} + +// GenconfigUnmarshal deserializa JSON (compacto o con indent) a GenerationConfig. +// Los campos JSON deben usar snake_case: negative_prompt, cfg_scale, model_type, etc. +func GenconfigUnmarshal(data []byte) (GenerationConfig, error) { + var cfg GenerationConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return GenerationConfig{}, err + } + return cfg, nil +} diff --git a/functions/ml/genconfig_json_marshal.md b/functions/ml/genconfig_json_marshal.md new file mode 100644 index 00000000..53ef35e8 --- /dev/null +++ b/functions/ml/genconfig_json_marshal.md @@ -0,0 +1,84 @@ +--- +name: genconfig_json_marshal +kind: function +lang: go +domain: ml +version: "1.0.0" +purity: impure +signature: "func GenconfigMarshal(cfg GenerationConfig) ([]byte, error)\nfunc GenconfigUnmarshal(data []byte) (GenerationConfig, error)" +description: "Wrappers json.Marshal/Unmarshal para GenerationConfig con formato canonico (MarshalIndent 2 espacios). Garantiza roundtrip identico al Python: json.dumps(indent=2, sort_keys=False). Campos JSON en snake_case." +tags: [ml, json, marshal, unmarshal, serialization, generation, canonical] +uses_functions: [] +uses_types: [generation_config_go_ml] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: ["encoding/json"] +params: + - name: cfg + desc: "GenerationConfig a serializar. Campos omitempty (negative_prompt, loras, clip_skip) se omiten si son zero/nil/empty." + - name: data + desc: "JSON bytes a deserializar. Acepta formato compacto o con indent. Keys deben ser snake_case (negative_prompt, cfg_scale, model_type, etc.)." +output: "GenconfigMarshal: bytes JSON con indent 2 espacios, orden de campos segun declaracion del struct (prompt, negative_prompt, seed, steps, cfg_scale, sampler, width, height, model, loras, clip_skip). GenconfigUnmarshal: GenerationConfig poblado o error de parsing." +tested: true +tests: + - "roundtrip marshal unmarshal produce config igual" + - "json cross-language snake_case keys se deserializan correctamente" +test_file_path: "functions/ml/genconfig_test.go" +file_path: "functions/ml/genconfig_json_marshal.go" +--- + +## Ejemplo + +```go +cfg := ml.GenerationConfig{ + Prompt: "a mountain at sunset", + Seed: 1234, + Steps: 30, + CfgScale: 7.0, + Sampler: "euler", + Width: 768, + Height: 512, + Model: ml.ModelRef{Name: "sdxl-base", ModelType: "sdxl", Quantization: "fp16"}, +} + +b, err := ml.GenconfigMarshal(cfg) +// b == { +// "prompt": "a mountain at sunset", +// "seed": 1234, +// ... +// } + +cfg2, err := ml.GenconfigUnmarshal(b) +// cfg2 == cfg (DeepEqual) +``` + +## Notas + +### Formato canonico y compatibilidad con Python + +`GenconfigMarshal` usa `json.MarshalIndent(cfg, "", " ")`. El formato resultante es identico al que produce Python con `model.model_dump_json()` o `json.dumps(data, indent=2)` cuando `sort_keys=False`: + +- Keys en orden de declaracion del struct (no alfabetico). +- Indent de 2 espacios, sin trailing whitespace. +- Campos omitempty ausentes si zero: `negative_prompt` ausente si `""`, `loras` ausente si `[]`, `clip_skip` ausente si `nil`. + +### Keys JSON (snake_case obligatorio) + +| Campo Go | Key JSON | +|---|---| +| `Prompt` | `"prompt"` | +| `NegativePrompt` | `"negative_prompt"` | +| `Seed` | `"seed"` | +| `Steps` | `"steps"` | +| `CfgScale` | `"cfg_scale"` | +| `Sampler` | `"sampler"` | +| `Width` | `"width"` | +| `Height` | `"height"` | +| `Model.ModelType` | `"model_type"` | +| `Model.Quantization` | `"quantization"` | +| `ClipSkip` | `"clip_skip"` | + +### Por que impure + +Los errores de `json.Unmarshal` son errores de parsing del input externo, no de I/O, pero se modelan como `(T, error)` para forzar manejo explicito en el caller. Marcado `impure` con `error_type: error_go_core` por convencion del registry. diff --git a/functions/ml/genconfig_test.go b/functions/ml/genconfig_test.go new file mode 100644 index 00000000..923152af --- /dev/null +++ b/functions/ml/genconfig_test.go @@ -0,0 +1,260 @@ +package ml + +import ( + "reflect" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// TestGenconfigToSdcliArgs +// --------------------------------------------------------------------------- + +func TestGenconfigToSdcliArgs(t *testing.T) { + clipSkip := 2 + + t.Run("config basico sin loras ni clip_skip", func(t *testing.T) { + cfg := GenerationConfig{ + Prompt: "a cat", + Seed: 42, + Steps: 20, + CfgScale: 7.5, + Sampler: "euler", + Width: 512, + Height: 512, + Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16"}, + } + args := GenconfigToSdcliArgs(cfg) + + want := []string{ + "--prompt", "a cat", + "--seed", "42", + "--steps", "20", + "--cfg-scale", "7.5", + "--width", "512", + "--height", "512", + "--sampling-method", "euler", + } + if !reflect.DeepEqual(args, want) { + t.Errorf("got %v\nwant %v", args, want) + } + }) + + t.Run("loras se emiten como pares path:weight", func(t *testing.T) { + cfg := GenerationConfig{ + Prompt: "portrait", + Seed: 1, + Steps: 10, + CfgScale: 7.0, + Sampler: "euler", + Width: 512, + Height: 512, + Model: ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16", Path: "/models/v1.safetensors"}, + Loras: []LoraRef{ + {Path: "/loras/detail.safetensors", Weight: 0.8}, + {Path: "/loras/style.safetensors", Weight: 0.5}, + }, + ClipSkip: &clipSkip, + } + args := GenconfigToSdcliArgs(cfg) + + // Verificar que existen los pares --lora para ambas loras + loraIdx := indexAll(args, "--lora") + if len(loraIdx) != 2 { + t.Fatalf("esperaba 2 flags --lora, got %d en %v", len(loraIdx), args) + } + wantLoras := []string{ + "/loras/detail.safetensors:0.8", + "/loras/style.safetensors:0.5", + } + for i, idx := range loraIdx { + if idx+1 >= len(args) { + t.Fatalf("--lora[%d] sin valor siguiente", i) + } + if args[idx+1] != wantLoras[i] { + t.Errorf("lora[%d]: got %q, want %q", i, args[idx+1], wantLoras[i]) + } + } + + // Verificar --model y --clip-skip presentes + if !containsPair(args, "--model", "/models/v1.safetensors") { + t.Errorf("--model no encontrado en %v", args) + } + if !containsPair(args, "--clip-skip", "2") { + t.Errorf("--clip-skip no encontrado en %v", args) + } + }) + + t.Run("sampler dpm++2m se traduce a dpmpp2m", func(t *testing.T) { + cfg := GenerationConfig{ + Prompt: "x", + Seed: 0, + Steps: 1, + CfgScale: 1.0, + Sampler: "dpm++2m", + Width: 64, + Height: 64, + Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"}, + } + args := GenconfigToSdcliArgs(cfg) + if !containsPair(args, "--sampling-method", "dpmpp2m") { + t.Errorf("sampler no traducido; args=%v", args) + } + }) + + t.Run("negative_prompt vacio no genera flag", func(t *testing.T) { + cfg := GenerationConfig{ + Prompt: "x", + NegativePrompt: "", + Seed: 0, + Steps: 1, + CfgScale: 1.0, + Sampler: "euler", + Width: 64, + Height: 64, + Model: ModelRef{Name: "m", ModelType: "sd15", Quantization: "fp16"}, + } + args := GenconfigToSdcliArgs(cfg) + for _, a := range args { + if a == "--negative-prompt" { + t.Errorf("flag --negative-prompt presente aunque NegativePrompt es vacio") + } + } + }) +} + +// --------------------------------------------------------------------------- +// TestGenconfigMarshalRoundtrip +// --------------------------------------------------------------------------- + +func TestGenconfigMarshalRoundtrip(t *testing.T) { + t.Run("roundtrip marshal unmarshal produce config igual", func(t *testing.T) { + clip := 2 + cfg := GenerationConfig{ + Prompt: "sunset over the mountains", + NegativePrompt: "blurry, low quality", + Seed: 99, + Steps: 30, + CfgScale: 7.5, + Sampler: "dpm++2m", + Width: 768, + Height: 512, + Model: ModelRef{ + Name: "sdxl-base", + ModelType: "sdxl", + Quantization: "fp16", + Path: "/models/sdxl.safetensors", + }, + Loras: []LoraRef{ + {Path: "/loras/detail.safetensors", Weight: 0.8}, + }, + ClipSkip: &clip, + } + + b, err := GenconfigMarshal(cfg) + if err != nil { + t.Fatalf("GenconfigMarshal: %v", err) + } + + got, err := GenconfigUnmarshal(b) + if err != nil { + t.Fatalf("GenconfigUnmarshal: %v", err) + } + + if !reflect.DeepEqual(cfg, got) { + t.Errorf("roundtrip diverge\norig: %+v\ngot: %+v", cfg, got) + } + }) +} + +// --------------------------------------------------------------------------- +// TestGenconfigCrossLanguageJSON +// --------------------------------------------------------------------------- + +func TestGenconfigCrossLanguageJSON(t *testing.T) { + // Fixture escrito a mano replicando lo que generaria Python: + // json.dumps(config.model_dump(), indent=2) + // Keys en snake_case, orden de declaracion del dataclass Python. + fixture := `{ + "prompt": "a dragon", + "negative_prompt": "ugly", + "seed": 1234, + "steps": 25, + "cfg_scale": 7.0, + "sampler": "euler_a", + "width": 512, + "height": 512, + "model": { + "name": "v1-5", + "model_type": "sd15", + "quantization": "fp16" + }, + "loras": [ + { + "path": "/loras/dragon.safetensors", + "weight": 0.9 + } + ] +}` + + t.Run("json cross-language snake_case keys se deserializan correctamente", func(t *testing.T) { + cfg, err := GenconfigUnmarshal([]byte(fixture)) + if err != nil { + t.Fatalf("GenconfigUnmarshal fixture: %v", err) + } + + // Verificar campos clave + if cfg.Prompt != "a dragon" { + t.Errorf("Prompt: got %q", cfg.Prompt) + } + if cfg.NegativePrompt != "ugly" { + t.Errorf("NegativePrompt: got %q", cfg.NegativePrompt) + } + if cfg.CfgScale != 7.0 { + t.Errorf("CfgScale: got %v", cfg.CfgScale) + } + if cfg.Model.ModelType != "sd15" { + t.Errorf("Model.ModelType: got %q", cfg.Model.ModelType) + } + if len(cfg.Loras) != 1 || cfg.Loras[0].Weight != 0.9 { + t.Errorf("Loras: got %+v", cfg.Loras) + } + + // Re-marshal y verificar que las keys snake_case siguen presentes + b, err := GenconfigMarshal(cfg) + if err != nil { + t.Fatalf("GenconfigMarshal: %v", err) + } + s := string(b) + for _, key := range []string{"negative_prompt", "cfg_scale", "model_type", "quantization"} { + if !strings.Contains(s, `"`+key+`"`) { + t.Errorf("key %q ausente en JSON re-serializado:\n%s", key, s) + } + } + }) +} + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +// indexAll retorna todos los indices de val en slice. +func indexAll(slice []string, val string) []int { + var out []int + for i, s := range slice { + if s == val { + out = append(out, i) + } + } + return out +} + +// containsPair verifica que flag seguido de value aparece en slice. +func containsPair(slice []string, flag, value string) bool { + for i := 0; i+1 < len(slice); i++ { + if slice[i] == flag && slice[i+1] == value { + return true + } + } + return false +} diff --git a/functions/ml/genconfig_to_sdcli_args.go b/functions/ml/genconfig_to_sdcli_args.go new file mode 100644 index 00000000..5138f609 --- /dev/null +++ b/functions/ml/genconfig_to_sdcli_args.go @@ -0,0 +1,59 @@ +package ml + +import ( + "fmt" + "strconv" +) + +// samplerMap traduce nombres canonicos del dominio ml a flags de stable-diffusion.cpp. +var samplerMap = map[string]string{ + "euler": "euler", + "euler_a": "euler_a", + "dpm++2m": "dpmpp2m", + "dpm++2m_v2": "dpmpp2mv2", + "heun": "heun", + "dpm2": "dpm2", + "lcm": "lcm", +} + +// GenconfigToSdcliArgs convierte un GenerationConfig en una lista de argumentos +// CLI para stable-diffusion.cpp (sd.exe / sd binario). +// Espejo Go de genconfig_to_sdcpp_args_py_ml. +// +// Loras se emiten como pares repetidos "--lora" "path:weight". +// Si el sampler no existe en samplerMap se usa el valor literal sin traducir. +// La funcion es pura: sin I/O, sin estado, determinista. +func GenconfigToSdcliArgs(cfg GenerationConfig) []string { + args := []string{ + "--prompt", cfg.Prompt, + "--seed", strconv.FormatInt(cfg.Seed, 10), + "--steps", strconv.Itoa(cfg.Steps), + "--cfg-scale", strconv.FormatFloat(cfg.CfgScale, 'f', -1, 64), + "--width", strconv.Itoa(cfg.Width), + "--height", strconv.Itoa(cfg.Height), + } + + if cfg.NegativePrompt != "" { + args = append(args, "--negative-prompt", cfg.NegativePrompt) + } + + sampler := cfg.Sampler + if mapped, ok := samplerMap[sampler]; ok { + sampler = mapped + } + args = append(args, "--sampling-method", sampler) + + if cfg.Model.Path != "" { + args = append(args, "--model", cfg.Model.Path) + } + + if cfg.ClipSkip != nil { + args = append(args, "--clip-skip", strconv.Itoa(*cfg.ClipSkip)) + } + + for _, lora := range cfg.Loras { + args = append(args, "--lora", fmt.Sprintf("%s:%g", lora.Path, lora.Weight)) + } + + return args +} diff --git a/functions/ml/genconfig_to_sdcli_args.md b/functions/ml/genconfig_to_sdcli_args.md new file mode 100644 index 00000000..f69c499d --- /dev/null +++ b/functions/ml/genconfig_to_sdcli_args.md @@ -0,0 +1,59 @@ +--- +name: genconfig_to_sdcli_args +kind: function +lang: go +domain: ml +version: "1.0.0" +purity: pure +signature: "func GenconfigToSdcliArgs(cfg GenerationConfig) []string" +description: "Convierte un GenerationConfig en argumentos CLI para stable-diffusion.cpp. Espejo Go de genconfig_to_sdcpp_args_py_ml. Loras se emiten como pares repetidos --lora path:weight. Sampler traducido via samplerMap canonico." +tags: [ml, stable-diffusion, cli, args, generation, pure] +uses_functions: [] +uses_types: [generation_config_go_ml] +returns: [] +returns_optional: false +error_type: "" +imports: ["fmt", "strconv"] +params: + - name: cfg + desc: "Parametros completos de generacion de imagen. Sampler debe ser uno de los valores de SamplerName. Model.Path se emite como --model si no esta vacio." +output: "Slice de strings listos para pasar a exec.Command o similar. Incluye --prompt, --seed, --steps, --cfg-scale, --width, --height, --sampling-method, opcionales --negative-prompt / --model / --clip-skip, y pares --lora path:weight por cada LoraRef." +tested: true +tests: + - "config basico sin loras ni clip_skip" + - "loras se emiten como pares path:weight" + - "sampler dpm++2m se traduce a dpmpp2m" + - "negative_prompt vacio no genera flag" +test_file_path: "functions/ml/genconfig_test.go" +file_path: "functions/ml/genconfig_to_sdcli_args.go" +--- + +## Ejemplo + +```go +clip := 2 +cfg := ml.GenerationConfig{ + Prompt: "a cat", + Seed: 42, + Steps: 20, + CfgScale: 7.5, + Sampler: "dpm++2m", + Width: 512, + Height: 512, + Model: ml.ModelRef{Name: "v1-5", ModelType: "sd15", Quantization: "fp16", Path: "/models/v1-5.safetensors"}, + Loras: []ml.LoraRef{{Path: "/loras/detail.safetensors", Weight: 0.8}}, + ClipSkip: &clip, +} +args := ml.GenconfigToSdcliArgs(cfg) +// args == ["--prompt","a cat","--seed","42","--steps","20", +// "--cfg-scale","7.5","--width","512","--height","512", +// "--sampling-method","dpmpp2m","--model","/models/v1-5.safetensors", +// "--clip-skip","2","--lora","/loras/detail.safetensors:0.8"] +``` + +## Notas + +- `samplerMap` traduce nombres canonicos del dominio ml a los identificadores que acepta stable-diffusion.cpp. Si el sampler no esta en el mapa se usa el valor literal. +- El flag de modelo (`--model`) solo se emite si `cfg.Model.Path != ""`. +- `%g` en `fmt.Sprintf` para el peso de la lora elimina ceros insignificantes: `0.800000` → `0.8`. +- Funcion pura: misma entrada, misma salida. Sin I/O ni estado global. diff --git a/functions/ml/generation_config.go b/functions/ml/generation_config.go new file mode 100644 index 00000000..ed7c41da --- /dev/null +++ b/functions/ml/generation_config.go @@ -0,0 +1,18 @@ +package ml + +// GenerationConfig parametriza una solicitud de generacion de imagen. +// Espejo JSON-compatible de GenerationConfig_py_ml: los tags json coinciden +// con los campos snake_case del dataclass Python para roundtrip sin perdida. +type GenerationConfig struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Seed int64 `json:"seed"` + Steps int `json:"steps"` + CfgScale float64 `json:"cfg_scale"` + Sampler string `json:"sampler"` + Width int `json:"width"` + Height int `json:"height"` + Model ModelRef `json:"model"` + Loras []LoraRef `json:"loras,omitempty"` + ClipSkip *int `json:"clip_skip,omitempty"` +} diff --git a/functions/ml/image_gen_result.go b/functions/ml/image_gen_result.go new file mode 100644 index 00000000..166edd62 --- /dev/null +++ b/functions/ml/image_gen_result.go @@ -0,0 +1,12 @@ +package ml + +// ImageGenResult contiene la imagen generada y su metadata de ejecucion. +// ImageBytes transporta los bytes raw del PNG y se excluye del JSON +// (campo json:"-") porque viaja por canal binario separado. +type ImageGenResult struct { + ImageBytes []byte `json:"-"` + Format string `json:"format"` + Meta map[string]any `json:"meta"` + DurationMs int64 `json:"duration_ms"` + VramPeakMb *int `json:"vram_peak_mb,omitempty"` +} diff --git a/functions/ml/image_generator.go b/functions/ml/image_generator.go new file mode 100644 index 00000000..fcbfce5c --- /dev/null +++ b/functions/ml/image_generator.go @@ -0,0 +1,9 @@ +package ml + +import "context" + +// ImageGenerator define el contrato para cualquier backend de generacion de imagenes. +// Las implementaciones pueden ser locales (ComfyUI, diffusers) o remotas (API). +type ImageGenerator interface { + Generate(ctx context.Context, cfg GenerationConfig) (ImageGenResult, error) +} diff --git a/functions/ml/lora_ref.go b/functions/ml/lora_ref.go new file mode 100644 index 00000000..fea4b1ae --- /dev/null +++ b/functions/ml/lora_ref.go @@ -0,0 +1,8 @@ +package ml + +// LoraRef referencia un adaptador LoRA con su peso de fusión y escala opcional. +type LoraRef struct { + Path string `json:"path"` + Weight float64 `json:"weight"` + Scale *float64 `json:"scale,omitempty"` +} diff --git a/functions/ml/model_ref.go b/functions/ml/model_ref.go new file mode 100644 index 00000000..78408cc9 --- /dev/null +++ b/functions/ml/model_ref.go @@ -0,0 +1,10 @@ +package ml + +// ModelRef identifica un modelo de generacion de imagenes por nombre, tipo, +// cuantizacion y path opcional en disco. +type ModelRef struct { + Name string `json:"name"` + ModelType string `json:"model_type"` // sd15|sdxl|flux_dev|... + Quantization string `json:"quantization"` // fp16|q8_0|... + Path string `json:"path,omitempty"` +} diff --git a/functions/ml/sdcli_parse_progress.go b/functions/ml/sdcli_parse_progress.go new file mode 100644 index 00000000..ff8c02c5 --- /dev/null +++ b/functions/ml/sdcli_parse_progress.go @@ -0,0 +1,78 @@ +package ml + +import ( + "regexp" + "strconv" +) + +// SdcliProgress contiene el estado de progreso parseado de una linea de stderr de sd-cli. +type SdcliProgress struct { + Step int `json:"step"` + TotalSteps int `json:"total_steps"` + ItPerSec float64 `json:"it_per_sec"` + Percent float64 `json:"percent"` +} + +// reProgress1 parsea el formato compacto: " 3/30 | 0.84it/s | 10%" +var reProgress1 = regexp.MustCompile(`\s*(\d+)\s*/\s*(\d+)\s*\|[^|]*?([\d.]+)\s*it/s[^|]*?\|\s*([\d.]+)\s*%`) + +// reProgress2 parsea el formato verbose: "sampling: step 3 of 30 (0.84 it/s)" +var reProgress2 = regexp.MustCompile(`step\s+(\d+)\s+of\s+(\d+)\s*\(\s*([\d.]+)\s*it/s\)`) + +// reProgress3 parsea el formato minimal: "step 3/30" o "progress: 3/30" +var reProgress3 = regexp.MustCompile(`(?:progress[:\s]+)?(\d+)\s*/\s*(\d+)`) + +// SdcliParseProgress parsea una linea de stderr de stable-diffusion.cpp / sd-cli +// y extrae el estado de progreso. Retorna (SdcliProgress, true) si la linea +// contiene informacion de progreso reconocible; (zero, false) en caso contrario. +// Funcion pura: sin I/O, sin estado mutable, determinista. +func SdcliParseProgress(line string) (SdcliProgress, bool) { + // Formato 1: " 3/30 | 0.84it/s | 10%" + if m := reProgress1.FindStringSubmatch(line); m != nil { + step, err1 := strconv.Atoi(m[1]) + total, err2 := strconv.Atoi(m[2]) + itPerSec, err3 := strconv.ParseFloat(m[3], 64) + pct, err4 := strconv.ParseFloat(m[4], 64) + if err1 == nil && err2 == nil && err3 == nil && err4 == nil { + return SdcliProgress{ + Step: step, + TotalSteps: total, + ItPerSec: itPerSec, + Percent: pct, + }, true + } + } + + // Formato 2: "sampling: step 3 of 30 (0.84 it/s)" + if m := reProgress2.FindStringSubmatch(line); m != nil { + step, err1 := strconv.Atoi(m[1]) + total, err2 := strconv.Atoi(m[2]) + itPerSec, err3 := strconv.ParseFloat(m[3], 64) + if err1 == nil && err2 == nil && err3 == nil && total > 0 { + pct := 100.0 * float64(step) / float64(total) + return SdcliProgress{ + Step: step, + TotalSteps: total, + ItPerSec: itPerSec, + Percent: pct, + }, true + } + } + + // Formato 3: "step 3/30" o "progress: 3/30" sin velocidad + if m := reProgress3.FindStringSubmatch(line); m != nil { + step, err1 := strconv.Atoi(m[1]) + total, err2 := strconv.Atoi(m[2]) + if err1 == nil && err2 == nil && total > 0 { + pct := 100.0 * float64(step) / float64(total) + return SdcliProgress{ + Step: step, + TotalSteps: total, + ItPerSec: 0, + Percent: pct, + }, true + } + } + + return SdcliProgress{}, false +} diff --git a/functions/ml/sdcli_parse_progress.md b/functions/ml/sdcli_parse_progress.md new file mode 100644 index 00000000..950f7a3a --- /dev/null +++ b/functions/ml/sdcli_parse_progress.md @@ -0,0 +1,50 @@ +--- +name: sdcli_parse_progress +kind: function +lang: go +domain: ml +version: "1.0.0" +purity: pure +signature: "func SdcliParseProgress(line string) (SdcliProgress, bool)" +description: "Parsea una linea de stderr de stable-diffusion.cpp / sd-cli y extrae el estado de progreso. Soporta el formato compacto '3/30 | 0.84it/s | 10%', el formato verbose 'sampling: step 3 of 30 (0.84 it/s)', y el formato minimal 'progress: 3/30'. Retorna (zero, false) si la linea no contiene informacion de progreso reconocible." +tags: [ml, stable-diffusion, sdcli, progress, parser, stderr, pure] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "" +imports: ["regexp", "strconv"] +params: + - name: line + desc: "Una linea de stderr emitida por sd-cli / stable-diffusion.cpp durante la fase de sampling. Puede contener espacios al inicio o final." +output: "Par (SdcliProgress, bool). bool=true si se reconocio un patron de progreso; SdcliProgress contiene Step (paso actual), TotalSteps (pasos totales), ItPerSec (iteraciones por segundo, 0 si no disponible) y Percent (porcentaje 0-100 calculado o leido de la linea). bool=false y struct zero si la linea no contiene progreso." +tested: true +tests: + - "formato estandar compacto step/total/itpersec/percent" + - "linea sin patron retorna false" + - "formato sampling verbose con velocidad" +file_path: "functions/ml/sdcli_parse_progress.go" +test_file_path: "functions/ml/sdcli_parse_progress_test.go" +--- + +## Ejemplo + +```go +p, ok := ml.SdcliParseProgress(" 3/30 | 0.84it/s | 10%") +// ok = true +// p = SdcliProgress{Step:3, TotalSteps:30, ItPerSec:0.84, Percent:10.0} + +p2, ok2 := ml.SdcliParseProgress("sampling: step 15 of 30 (1.2 it/s)") +// ok2 = true +// p2 = SdcliProgress{Step:15, TotalSteps:30, ItPerSec:1.2, Percent:50.0} + +_, ok3 := ml.SdcliParseProgress("loading model...") +// ok3 = false +``` + +## Notas + +- Regexps precompiladas como vars de paquete (se compilan una sola vez al init del paquete). +- Tolerante a variaciones de espaciado gracias a `\s*` en los patrones. +- El campo `Percent` en el formato verbose se calcula como `100 * step / total` (no se lee de la linea porque ese formato no lo emite). +- Funcion pura: sin I/O, sin estado mutable, determinista. diff --git a/functions/ml/sdcli_parse_progress_test.go b/functions/ml/sdcli_parse_progress_test.go new file mode 100644 index 00000000..fa69186e --- /dev/null +++ b/functions/ml/sdcli_parse_progress_test.go @@ -0,0 +1,103 @@ +package ml + +import ( + "math" + "testing" +) + +func TestSdcliParseProgress_StandardFormat(t *testing.T) { + line := " 3/30 | 0.84it/s | 10%" + got, ok := SdcliParseProgress(line) + if !ok { + t.Fatalf("expected match, got false") + } + if got.Step != 3 { + t.Errorf("Step: got %d, want 3", got.Step) + } + if got.TotalSteps != 30 { + t.Errorf("TotalSteps: got %d, want 30", got.TotalSteps) + } + if math.Abs(got.ItPerSec-0.84) > 1e-9 { + t.Errorf("ItPerSec: got %v, want 0.84", got.ItPerSec) + } + if math.Abs(got.Percent-10.0) > 1e-9 { + t.Errorf("Percent: got %v, want 10.0", got.Percent) + } +} + +func TestSdcliParseProgress_NoMatch(t *testing.T) { + cases := []string{ + "loading model...", + "", + "error: out of memory", + "clip model loaded", + "generating image...", + } + for _, line := range cases { + _, ok := SdcliParseProgress(line) + if ok { + t.Errorf("expected no match for %q, but got match", line) + } + } +} + +func TestSdcliParseProgress_AltFormat(t *testing.T) { + t.Run("formato sampling verbose", func(t *testing.T) { + line := "sampling: step 3 of 30 (0.84 it/s)" + got, ok := SdcliParseProgress(line) + if !ok { + t.Fatalf("expected match, got false") + } + if got.Step != 3 { + t.Errorf("Step: got %d, want 3", got.Step) + } + if got.TotalSteps != 30 { + t.Errorf("TotalSteps: got %d, want 30", got.TotalSteps) + } + if math.Abs(got.ItPerSec-0.84) > 1e-9 { + t.Errorf("ItPerSec: got %v, want 0.84", got.ItPerSec) + } + expectedPct := 100.0 * 3.0 / 30.0 + if math.Abs(got.Percent-expectedPct) > 1e-6 { + t.Errorf("Percent: got %v, want %v", got.Percent, expectedPct) + } + }) + + t.Run("formato step/total sin velocidad", func(t *testing.T) { + line := "progress: 15/20" + got, ok := SdcliParseProgress(line) + if !ok { + t.Fatalf("expected match, got false") + } + if got.Step != 15 { + t.Errorf("Step: got %d, want 15", got.Step) + } + if got.TotalSteps != 20 { + t.Errorf("TotalSteps: got %d, want 20", got.TotalSteps) + } + if got.ItPerSec != 0 { + t.Errorf("ItPerSec: got %v, want 0", got.ItPerSec) + } + expectedPct := 75.0 + if math.Abs(got.Percent-expectedPct) > 1e-6 { + t.Errorf("Percent: got %v, want %v", got.Percent, expectedPct) + } + }) + + t.Run("formato con espacios variables y mayor velocidad", func(t *testing.T) { + line := " 20/30 | 12.50it/s | 66%" + got, ok := SdcliParseProgress(line) + if !ok { + t.Fatalf("expected match, got false") + } + if got.Step != 20 { + t.Errorf("Step: got %d, want 20", got.Step) + } + if got.TotalSteps != 30 { + t.Errorf("TotalSteps: got %d, want 30", got.TotalSteps) + } + if math.Abs(got.ItPerSec-12.5) > 1e-9 { + t.Errorf("ItPerSec: got %v, want 12.5", got.ItPerSec) + } + }) +} diff --git a/python/functions/datascience/tests/test_vault_csv_profile.py b/python/functions/datascience/tests/test_vault_csv_profile.py new file mode 100644 index 00000000..65b4365d --- /dev/null +++ b/python/functions/datascience/tests/test_vault_csv_profile.py @@ -0,0 +1,161 @@ +"""Tests para vault_csv_profile.""" + +from __future__ import annotations + +import os +import sqlite3 +import sys +import tempfile +from pathlib import Path + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from vault_csv_profile import vault_csv_profile + + +def _make_vault(tmp: Path) -> tuple[Path, Path]: + """Crea un vault mínimo con vault_index.db y tabla files + files_fts + csv_profiles.""" + db = tmp / "vault_index.db" + conn = sqlite3.connect(str(db)) + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS files ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + rel_path TEXT UNIQUE NOT NULL, + size_bytes INTEGER, + ext TEXT + ); + CREATE VIRTUAL TABLE IF NOT EXISTS files_fts + USING fts5(rel_path, content_text, content='', contentless_delete=1); + CREATE TABLE IF NOT EXISTS csv_profiles ( + rel_path TEXT PRIMARY KEY, + cols_json TEXT, + n_rows INTEGER, + encoding TEXT, + date_min TEXT, + date_max TEXT, + profiled_at INTEGER + ); + """ + ) + conn.commit() + conn.close() + return tmp, db + + +def _insert_file_entry(db: Path, rel_path: str): + """Inserta entrada en files para que files_fts tenga rowid válido.""" + conn = sqlite3.connect(str(db)) + conn.execute( + "INSERT OR IGNORE INTO files(rel_path, size_bytes, ext) VALUES (?, 0, '.csv')", + (rel_path,), + ) + conn.commit() + conn.close() + + +def test_csv_basic(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "data/basic.csv" + csv_file = vault / rel + csv_file.parent.mkdir(parents=True, exist_ok=True) + csv_file.write_text("nombre,edad,score\nAna,30,9.5\nBob,25,8.0\nCarla,35,7.5\n", encoding="utf-8") + _insert_file_entry(db, rel) + + result = vault_csv_profile(str(vault), rel, db_path=str(db)) + + assert result["rel_path"] == rel + assert result["n_rows"] == 3 + assert len(result["cols"]) == 3 + col_names = [c["name"] for c in result["cols"]] + assert "nombre" in col_names + assert "edad" in col_names + assert "score" in col_names + assert result["persisted"] is True + + # Verificar persistencia en csv_profiles + conn = sqlite3.connect(str(db)) + row = conn.execute("SELECT n_rows FROM csv_profiles WHERE rel_path = ?", (rel,)).fetchone() + conn.close() + assert row is not None + assert row[0] == 3 + + +def test_csv_date_detection(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "data/fechas.csv" + csv_file = vault / rel + csv_file.parent.mkdir(parents=True, exist_ok=True) + csv_file.write_text( + "fecha,valor\n2023-01-01,100\n2023-06-15,200\n2023-12-31,300\n", + encoding="utf-8", + ) + _insert_file_entry(db, rel) + + result = vault_csv_profile(str(vault), rel, db_path=str(db)) + + assert result["date_min"] is not None + assert result["date_max"] is not None + assert result["date_min"] <= "2023-01-01" + assert result["date_max"] >= "2023-12-31" + + +def test_csv_encoding_latin1(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "data/tildes.csv" + csv_file = vault / rel + csv_file.parent.mkdir(parents=True, exist_ok=True) + csv_file.write_bytes( + "ciudad,poblacion\nMálaga,500000\nCórdoba,320000\n".encode("latin-1") + ) + _insert_file_entry(db, rel) + + result = vault_csv_profile(str(vault), rel, db_path=str(db)) + + assert result["n_rows"] == 2 + assert result["encoding"] != "utf-8?" + # encoding detectado (algún valor no vacío) + assert result["encoding"] + assert result["persisted"] is True + + +def test_csv_empty(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "data/empty.csv" + csv_file = vault / rel + csv_file.parent.mkdir(parents=True, exist_ok=True) + csv_file.write_text("", encoding="utf-8") + _insert_file_entry(db, rel) + + result = vault_csv_profile(str(vault), rel, db_path=str(db)) + + assert result["n_rows"] == 0 + assert result["cols"] == [] + assert result["date_min"] is None + assert result["date_max"] is None + + +def test_csv_persists_fts(tmp_path): + """FTS5 contentless: verifica que las columnas son buscables con MATCH.""" + vault, db = _make_vault(tmp_path) + rel = "data/fts_test.csv" + csv_file = vault / rel + csv_file.parent.mkdir(parents=True, exist_ok=True) + csv_file.write_text("producto,precio\nManzana,1.5\nPera,2.0\n", encoding="utf-8") + _insert_file_entry(db, rel) + + vault_csv_profile(str(vault), rel, db_path=str(db)) + + conn = sqlite3.connect(str(db)) + # FTS5 contentless no permite SELECT directo — usar MATCH para verificar indexado + row_prod = conn.execute( + "SELECT rowid FROM files_fts WHERE files_fts MATCH 'producto'", + ).fetchone() + row_prec = conn.execute( + "SELECT rowid FROM files_fts WHERE files_fts MATCH 'precio'", + ).fetchone() + conn.close() + + assert row_prod is not None, "FTS no encontró 'producto'" + assert row_prec is not None, "FTS no encontró 'precio'" diff --git a/python/functions/datascience/tests/test_vault_pdf_extract.py b/python/functions/datascience/tests/test_vault_pdf_extract.py new file mode 100644 index 00000000..9e97e95b --- /dev/null +++ b/python/functions/datascience/tests/test_vault_pdf_extract.py @@ -0,0 +1,147 @@ +"""Tests para vault_pdf_extract.""" + +from __future__ import annotations + +import os +import sqlite3 +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from vault_pdf_extract import vault_pdf_extract + + +def _make_vault(tmp: Path) -> tuple[Path, Path]: + """Crea un vault mínimo con vault_index.db.""" + db = tmp / "vault_index.db" + conn = sqlite3.connect(str(db)) + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS files ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + rel_path TEXT UNIQUE NOT NULL, + size_bytes INTEGER, + ext TEXT + ); + CREATE VIRTUAL TABLE IF NOT EXISTS files_fts + USING fts5(rel_path, content_text, content='', contentless_delete=1); + CREATE TABLE IF NOT EXISTS pdf_extracts ( + rel_path TEXT PRIMARY KEY, + page_count INTEGER, + text_len INTEGER, + extracted_to TEXT, + extracted_at INTEGER + ); + """ + ) + conn.commit() + conn.close() + return tmp, db + + +def _insert_file_entry(db: Path, rel_path: str): + conn = sqlite3.connect(str(db)) + conn.execute( + "INSERT OR IGNORE INTO files(rel_path, size_bytes, ext) VALUES (?, 0, '.pdf')", + (rel_path,), + ) + conn.commit() + conn.close() + + +def _make_pdf(path: Path, text: str = "Hello vault PDF.\nPage two content."): + """Crea un PDF mínimo con fitz para tests.""" + import fitz + + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), text) + doc.save(str(path)) + doc.close() + + +def test_pdf_extract_basic(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/test.pdf" + pdf = vault / rel + pdf.parent.mkdir(parents=True, exist_ok=True) + _make_pdf(pdf) + _insert_file_entry(db, rel) + + result = vault_pdf_extract(str(vault), rel, db_path=str(db)) + + assert result["rel_path"] == rel + assert result["page_count"] >= 1 + assert result["text_len"] > 0 + assert result["persisted"] is True + + conn = sqlite3.connect(str(db)) + row = conn.execute("SELECT page_count, text_len FROM pdf_extracts WHERE rel_path=?", (rel,)).fetchone() + conn.close() + assert row is not None + assert row[0] >= 1 + assert row[1] > 0 + + +def test_pdf_dump_text_creates_file(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/dump.pdf" + pdf = vault / rel + pdf.parent.mkdir(parents=True, exist_ok=True) + _make_pdf(pdf, "Contenido para dump a disco.") + _insert_file_entry(db, rel) + # Crear data/processed/ para que se use ese directorio + (vault / "data" / "processed").mkdir(parents=True, exist_ok=True) + + result = vault_pdf_extract(str(vault), rel, db_path=str(db), dump_text=True) + + assert result["extracted_to"] is not None + txt_path = vault / result["extracted_to"] + assert txt_path.exists() + assert txt_path.stat().st_size > 0 + + +def test_pdf_no_dump(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/nodump.pdf" + pdf = vault / rel + pdf.parent.mkdir(parents=True, exist_ok=True) + _make_pdf(pdf, "No se debe volcar a disco.") + _insert_file_entry(db, rel) + + result = vault_pdf_extract(str(vault), rel, db_path=str(db), dump_text=False) + + assert result["extracted_to"] is None + + +def test_pdf_persists_to_fts(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/fts.pdf" + pdf = vault / rel + pdf.parent.mkdir(parents=True, exist_ok=True) + _make_pdf(pdf, "Texto especial para FTS xyzpdftest.") + _insert_file_entry(db, rel) + + vault_pdf_extract(str(vault), rel, db_path=str(db), dump_text=False) + + conn = sqlite3.connect(str(db)) + # FTS5 contentless: no permite SELECT directo, usar MATCH + row = conn.execute( + "SELECT rowid FROM files_fts WHERE files_fts MATCH 'xyzpdftest'", + ).fetchone() + conn.close() + assert row is not None, "FTS no encontró el texto del PDF" + + +def test_pdf_corrupt_errors(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/corrupt.pdf" + pdf = vault / rel + pdf.parent.mkdir(parents=True, exist_ok=True) + pdf.write_bytes(b"%PDF-1.4 garbage bytes \x00\x01\x02 not a real pdf") + _insert_file_entry(db, rel) + + with pytest.raises(RuntimeError, match="corrupto|inválido|PDF"): + vault_pdf_extract(str(vault), rel, db_path=str(db)) diff --git a/python/functions/datascience/vault_csv_profile.md b/python/functions/datascience/vault_csv_profile.md new file mode 100644 index 00000000..ec216c94 --- /dev/null +++ b/python/functions/datascience/vault_csv_profile.md @@ -0,0 +1,61 @@ +--- +name: vault_csv_profile +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def vault_csv_profile(vault_path: str, rel_path: str, db_path: str | None = None) -> dict" +description: "Perfila un CSV del vault: detecta encoding, lee schema con polars, extrae n_rows y columnas de fecha; persiste en csv_profiles y actualiza files_fts para búsqueda por contenido." +tags: [vault, csv, profiling, polars, encoding, datascience, fts] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [sqlite3, time, pathlib, json, polars, chardet] +params: + - name: vault_path + desc: "Ruta absoluta a la raiz del vault donde vive el CSV y vault_index.db." + - name: rel_path + desc: "Ruta relativa al CSV dentro del vault (ej. 'data/raw/ventas.csv')." + - name: db_path + desc: "Override opcional de la ruta a vault_index.db. Por defecto /vault_index.db." +output: "Dict con: rel_path (str), cols (list de {name, dtype}), n_rows (int), encoding (str), date_min/date_max (ISO yyyy-mm-dd o None), persisted (bool)." +tested: true +tests: + - "test_csv_basic" + - "test_csv_date_detection" + - "test_csv_encoding_latin1" + - "test_csv_empty" + - "test_csv_persists_fts" +test_file_path: "python/functions/datascience/tests/test_vault_csv_profile.py" +file_path: "python/functions/datascience/vault_csv_profile.py" +--- + +## Ejemplo + +```python +from vault_csv_profile import vault_csv_profile + +result = vault_csv_profile("/vaults/mi_vault", "data/raw/ventas.csv") +# { +# "rel_path": "data/raw/ventas.csv", +# "cols": [{"name": "fecha", "dtype": "String"}, {"name": "importe", "dtype": "Float64"}], +# "n_rows": 1500, +# "encoding": "utf-8", +# "date_min": "2023-01-01", +# "date_max": "2023-12-31", +# "persisted": True +# } +``` + +## Notas + +- Usa polars (lazy scan) como motor principal; pandas como fallback. +- Detección de encoding: chardet con confianza >= 0.6, luego intentos utf-8-sig → utf-8 → latin-1 → cp1252. +- Detección de fechas: columnas Date/Datetime nativas de polars, o columnas String con ≥80% de valores parseables como fecha. +- El FTS text incluye nombres de columnas + primeras 5 filas concatenadas. +- Upsert en csv_profiles por rel_path; el rowid de files_fts se ancla al rowid de la tabla files para que vault_search funcione correctamente. +- Si vault_index.db no existe, la función retorna el dict sin intentar persistir (persisted=False). +- Dependencias: polars, chardet (ambas instaladas en python/.venv con uv add). diff --git a/python/functions/datascience/vault_csv_profile.py b/python/functions/datascience/vault_csv_profile.py new file mode 100644 index 00000000..039659bf --- /dev/null +++ b/python/functions/datascience/vault_csv_profile.py @@ -0,0 +1,216 @@ +"""vault_csv_profile — Perfila un CSV del vault y persiste metadata en vault_index.db.""" + +from __future__ import annotations + +import sqlite3 +import time +from pathlib import Path + + +def _detect_encoding(path: Path) -> str: + """Detecta encoding del archivo con chardet o por intentos.""" + try: + import chardet + + with open(path, "rb") as f: + raw = f.read(min(65536, path.stat().st_size)) + result = chardet.detect(raw) + if result and result.get("encoding") and result.get("confidence", 0) >= 0.6: + return result["encoding"] + except Exception: + pass + + for enc in ("utf-8-sig", "utf-8", "latin-1", "cp1252"): + try: + with open(path, encoding=enc) as f: + f.read(4096) + return enc + except (UnicodeDecodeError, LookupError): + continue + + return "utf-8?" + + +def _read_with_polars(path: Path, encoding: str) -> tuple[list[dict], int]: + """Lee CSV con polars. Retorna (cols, n_rows).""" + import polars as pl + + enc = encoding.rstrip("?").replace("utf-8-sig", "utf8").replace("utf-8", "utf8") + if enc not in ("utf8", "utf-8"): + enc = "utf8" + + lf = pl.scan_csv(path, encoding="utf8", ignore_errors=True, infer_schema_length=1000) + schema = lf.collect_schema() + cols = [{"name": name, "dtype": str(dtype)} for name, dtype in schema.items()] + n_rows = lf.select(pl.len()).collect().item() + return cols, n_rows + + +def _read_with_pandas(path: Path, encoding: str) -> tuple[list[dict], int]: + """Fallback: lee CSV con pandas.""" + import pandas as pd + + enc = encoding.rstrip("?") or "utf-8" + df = pd.read_csv(path, encoding=enc, encoding_errors="replace", nrows=None) + cols = [{"name": col, "dtype": str(df[col].dtype)} for col in df.columns] + n_rows = len(df) + return cols, n_rows + + +def _detect_dates(path: Path, encoding: str) -> tuple[str | None, str | None]: + """Intenta detectar columna de fecha y retorna (date_min, date_max) en ISO.""" + try: + import polars as pl + + lf = pl.scan_csv(path, encoding="utf8", ignore_errors=True, infer_schema_length=0) + schema = lf.collect_schema() + df = lf.collect() + + for col_name, dtype in schema.items(): + if "Date" in str(dtype) or "Datetime" in str(dtype): + series = df[col_name].drop_nulls() + if len(series) > 0: + mn = series.min() + mx = series.max() + return str(mn)[:10], str(mx)[:10] + + # Intenta parsear columnas string como fecha + for col_name, dtype in schema.items(): + if "Utf8" not in str(dtype) and "String" not in str(dtype): + continue + series = df[col_name].drop_nulls() + if len(series) == 0: + continue + try: + parsed = series.str.to_date(strict=False) + valid = parsed.drop_nulls() + if len(valid) / max(len(series), 1) >= 0.8: + mn = valid.min() + mx = valid.max() + return str(mn)[:10], str(mx)[:10] + except Exception: + continue + except Exception: + pass + return None, None + + +def _build_fts_text(path: Path, cols: list[dict], encoding: str) -> str: + """Construye content_text para files_fts: nombres de cols + primeras 5 filas.""" + col_names = " ".join(c["name"] for c in cols) + try: + import polars as pl + + lf = pl.scan_csv(path, encoding="utf8", ignore_errors=True) + sample = lf.head(5).collect() + rows_text = " ".join( + " ".join(str(v) for v in row) for row in sample.iter_rows() + ) + return f"{col_names} {rows_text}".strip() + except Exception: + pass + return col_names + + +def vault_csv_profile( + vault_path: str, + rel_path: str, + db_path: str | None = None, +) -> dict: + """Perfila un CSV del vault: schema, n_rows, encoding, fechas; persiste en vault_index.db. + + Args: + vault_path: Ruta absoluta a la raiz del vault. + rel_path: Ruta relativa al CSV dentro del vault. + db_path: Override de la ruta a vault_index.db. Por defecto /vault_index.db. + + Returns: + Dict con rel_path, cols, n_rows, encoding, date_min, date_max, persisted. + + Raises: + RuntimeError: Si el archivo no existe o no se puede leer. + """ + vault = Path(vault_path) + csv_file = vault / rel_path + if not csv_file.exists(): + raise RuntimeError(f"vault_csv_profile: archivo no encontrado: {csv_file}") + + db = Path(db_path) if db_path else vault / "vault_index.db" + + # Resultado por defecto para CSV vacío + result: dict = { + "rel_path": rel_path, + "cols": [], + "n_rows": 0, + "encoding": "utf-8", + "date_min": None, + "date_max": None, + "persisted": False, + } + + # Detectar encoding + encoding = _detect_encoding(csv_file) + result["encoding"] = encoding + + # Leer schema y n_rows — short-circuit para archivos vacíos + if csv_file.stat().st_size == 0: + cols, n_rows = [], 0 + else: + try: + cols, n_rows = _read_with_polars(csv_file, encoding) + except Exception: + try: + cols, n_rows = _read_with_pandas(csv_file, encoding) + except Exception as exc: + raise RuntimeError(f"vault_csv_profile: no se pudo leer {rel_path}: {exc}") from exc + + result["cols"] = cols + result["n_rows"] = n_rows + + # Detección de fechas (solo si hay filas) + if n_rows > 0 and cols: + date_min, date_max = _detect_dates(csv_file, encoding) + result["date_min"] = date_min + result["date_max"] = date_max + + # Construir texto para FTS + fts_text = _build_fts_text(csv_file, cols, encoding) if cols else "" + + # Persistir en vault_index.db + if db.exists(): + conn = sqlite3.connect(str(db)) + try: + cols_json = __import__("json").dumps(cols) + now = int(time.time()) + conn.execute( + """ + INSERT INTO csv_profiles(rel_path, cols_json, n_rows, encoding, date_min, date_max, profiled_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(rel_path) DO UPDATE SET + cols_json=excluded.cols_json, + n_rows=excluded.n_rows, + encoding=excluded.encoding, + date_min=excluded.date_min, + date_max=excluded.date_max, + profiled_at=excluded.profiled_at + """, + (rel_path, cols_json, n_rows, encoding, result["date_min"], result["date_max"], now), + ) + # Actualizar files_fts (rowid debe coincidir con files) + conn.execute("DELETE FROM files_fts WHERE rel_path = ?", (rel_path,)) + conn.execute( + """ + INSERT INTO files_fts(rowid, rel_path, content_text) + VALUES ((SELECT rowid FROM files WHERE rel_path = ?), ?, ?) + """, + (rel_path, rel_path, fts_text), + ) + conn.commit() + result["persisted"] = True + except Exception: + conn.rollback() + raise + finally: + conn.close() + + return result diff --git a/python/functions/datascience/vault_pdf_extract.md b/python/functions/datascience/vault_pdf_extract.md new file mode 100644 index 00000000..aeb0e0bb --- /dev/null +++ b/python/functions/datascience/vault_pdf_extract.md @@ -0,0 +1,60 @@ +--- +name: vault_pdf_extract +kind: function +lang: py +domain: datascience +version: "1.0.0" +purity: impure +signature: "def vault_pdf_extract(vault_path: str, rel_path: str, db_path: str | None = None, dump_text: bool = True) -> dict" +description: "Extrae texto de un PDF del vault con PyMuPDF; persiste page_count y text_len en pdf_extracts; vuelca texto a .txt en data/processed/ o .vault_extracts/; actualiza files_fts para búsqueda por contenido." +tags: [vault, pdf, extract, pymupdf, fts, datascience] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [sqlite3, time, pathlib, fitz] +params: + - name: vault_path + desc: "Ruta absoluta a la raiz del vault donde vive el PDF y vault_index.db." + - name: rel_path + desc: "Ruta relativa al PDF dentro del vault (ej. 'docs/informe.pdf')." + - name: db_path + desc: "Override opcional de la ruta a vault_index.db. Por defecto /vault_index.db." + - name: dump_text + desc: "Si True (default), escribe el texto extraído a un .txt. La carpeta destino es data/processed/ si existe, si no .vault_extracts/." +output: "Dict con: rel_path (str), page_count (int), text_len (int), extracted_to (ruta relativa al .txt o None), persisted (bool)." +tested: true +tests: + - "test_pdf_extract_basic" + - "test_pdf_dump_text_creates_file" + - "test_pdf_no_dump" + - "test_pdf_persists_to_fts" + - "test_pdf_corrupt_errors" +test_file_path: "python/functions/datascience/tests/test_vault_pdf_extract.py" +file_path: "python/functions/datascience/vault_pdf_extract.py" +--- + +## Ejemplo + +```python +from vault_pdf_extract import vault_pdf_extract + +result = vault_pdf_extract("/vaults/mi_vault", "docs/informe_anual.pdf") +# { +# "rel_path": "docs/informe_anual.pdf", +# "page_count": 24, +# "text_len": 45210, +# "extracted_to": "data/processed/informe_anual.txt", +# "persisted": True +# } +``` + +## Notas + +- Requiere PyMuPDF (paquete `pymupdf`, importado como `fitz`). Ya instalado en python/.venv. +- El texto se trunca a 10 MB antes de insertarlo en files_fts para evitar tablas FTS5 masivas. +- Layout de volcado: si `/data/processed/` existe, se usa; si no, se crea `/.vault_extracts/`. +- PDFs corruptos levantan RuntimeError con mensaje descriptivo. +- El rowid de files_fts se ancla al rowid de la tabla files (subquery) para que vault_search funcione correctamente. +- Si vault_index.db no existe, retorna el dict sin intentar persistir (persisted=False). diff --git a/python/functions/datascience/vault_pdf_extract.py b/python/functions/datascience/vault_pdf_extract.py new file mode 100644 index 00000000..a9f17e56 --- /dev/null +++ b/python/functions/datascience/vault_pdf_extract.py @@ -0,0 +1,121 @@ +"""vault_pdf_extract — Extrae texto de un PDF del vault y persiste en vault_index.db.""" + +from __future__ import annotations + +import sqlite3 +import time +from pathlib import Path + + +def vault_pdf_extract( + vault_path: str, + rel_path: str, + db_path: str | None = None, + dump_text: bool = True, +) -> dict: + """Extrae texto de un PDF del vault; persiste page_count, text_len y actualiza files_fts. + + Args: + vault_path: Ruta absoluta a la raiz del vault. + rel_path: Ruta relativa al PDF dentro del vault. + db_path: Override opcional de la ruta a vault_index.db. + dump_text: Si True, escribe el texto extraído a un .txt en data/processed/ o .vault_extracts/. + + Returns: + Dict con: rel_path, page_count, text_len, extracted_to (ruta relativa o None), persisted. + + Raises: + RuntimeError: Si el PDF no existe, está corrupto o no se puede leer. + """ + try: + import fitz # PyMuPDF + except ImportError as exc: + raise RuntimeError( + "vault_pdf_extract requiere PyMuPDF. Instalar con: uv add pymupdf" + ) from exc + + vault = Path(vault_path) + pdf_file = vault / rel_path + if not pdf_file.exists(): + raise RuntimeError(f"vault_pdf_extract: archivo no encontrado: {pdf_file}") + + db = Path(db_path) if db_path else vault / "vault_index.db" + + # Abrir PDF + try: + doc = fitz.open(str(pdf_file)) + except Exception as exc: + raise RuntimeError(f"vault_pdf_extract: PDF corrupto o inválido ({rel_path}): {exc}") from exc + + page_count = doc.page_count + text_parts: list[str] = [] + for page in doc: + try: + text_parts.append(page.get_text()) + except Exception: + text_parts.append("") + doc.close() + + full_text = "\n".join(text_parts) + text_len = len(full_text) + + # Truncar a 10 MB para FTS + _MAX_FTS = 10 * 1024 * 1024 + fts_text = full_text[:_MAX_FTS] + + # Dump text a disco + extracted_to: str | None = None + if dump_text and full_text.strip(): + basename = Path(rel_path).stem + # Preferir data/processed/ si existe; si no, usar .vault_extracts/ + processed_dir = vault / "data" / "processed" + if not processed_dir.exists(): + processed_dir = vault / ".vault_extracts" + processed_dir.mkdir(parents=True, exist_ok=True) + txt_path = processed_dir / f"{basename}.txt" + txt_path.write_text(full_text, encoding="utf-8") + extracted_to = str(txt_path.relative_to(vault)) + + # Persistir en vault_index.db + persisted = False + if db.exists(): + conn = sqlite3.connect(str(db)) + try: + now = int(time.time()) + conn.execute( + """ + INSERT INTO pdf_extracts(rel_path, page_count, text_len, extracted_to, extracted_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(rel_path) DO UPDATE SET + page_count=excluded.page_count, + text_len=excluded.text_len, + extracted_to=excluded.extracted_to, + extracted_at=excluded.extracted_at + """, + (rel_path, page_count, text_len, extracted_to, now), + ) + # Actualizar files_fts (rowid debe coincidir con files) + conn.execute("DELETE FROM files_fts WHERE rel_path = ?", (rel_path,)) + if fts_text.strip(): + conn.execute( + """ + INSERT INTO files_fts(rowid, rel_path, content_text) + VALUES ((SELECT rowid FROM files WHERE rel_path = ?), ?, ?) + """, + (rel_path, rel_path, fts_text), + ) + conn.commit() + persisted = True + except Exception: + conn.rollback() + raise + finally: + conn.close() + + return { + "rel_path": rel_path, + "page_count": page_count, + "text_len": text_len, + "extracted_to": extracted_to, + "persisted": persisted, + } diff --git a/python/functions/infra/tests/test_vault_dedupe_report.py b/python/functions/infra/tests/test_vault_dedupe_report.py new file mode 100644 index 00000000..c49908be --- /dev/null +++ b/python/functions/infra/tests/test_vault_dedupe_report.py @@ -0,0 +1,154 @@ +"""Tests para vault_dedupe_report.""" + +from __future__ import annotations + +import os +import sqlite3 +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from vault_dedupe_report import vault_dedupe_report + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_db(tmp_path: Path, rows: list[tuple]) -> Path: + """Crea vault_index.db con la tabla files y las filas dadas. + + rows: lista de (rel_path, size, sha256) + """ + db_path = tmp_path / "vault_index.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + """ + CREATE TABLE files ( + rel_path TEXT PRIMARY KEY, + size INTEGER, + mtime REAL, + sha256 TEXT, + mime TEXT, + ext TEXT, + bucket TEXT, + sub_bucket TEXT, + indexed_at REAL + ); + """ + ) + conn.executemany( + "INSERT INTO files (rel_path, size, sha256) VALUES (?, ?, ?);", + rows, + ) + conn.commit() + conn.close() + return db_path + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_no_duplicates(tmp_path): + """test_no_duplicates — 3 archivos con sha256 distintos -> groups=[].""" + _make_db(tmp_path, [ + ("a/file1.txt", 100, "aaa111"), + ("a/file2.txt", 200, "bbb222"), + ("a/file3.txt", 300, "ccc333"), + ]) + result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db")) + + assert result["groups"] == [] + assert result["total_groups"] == 0 + assert result["total_duplicates"] == 0 + assert result["total_reclaimable_bytes"] == 0 + assert result["scanned_files"] == 3 + assert result["vault_path"] == str(tmp_path) + + +def test_basic_duplicates(tmp_path): + """test_basic_duplicates — 2 archivos mismo sha256 -> 1 group, count=2, reclaimable=size.""" + _make_db(tmp_path, [ + ("data/orig.jpg", 500, "deadbeef"), + ("backup/orig.jpg", 500, "deadbeef"), + ]) + result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db")) + + assert result["total_groups"] == 1 + assert result["total_duplicates"] == 1 + assert result["total_reclaimable_bytes"] == 500 + + g = result["groups"][0] + assert g["sha256"] == "deadbeef" + assert g["size"] == 500 + assert g["count"] == 2 + assert g["reclaimable_bytes"] == 500 + assert sorted(g["files"]) == ["backup/orig.jpg", "data/orig.jpg"] + + +def test_three_in_group(tmp_path): + """test_three_in_group — 3 archivos mismo sha256 -> count=3, reclaimable=size*2.""" + size = 1000 + _make_db(tmp_path, [ + ("a/f1.bin", size, "cafebabe"), + ("b/f2.bin", size, "cafebabe"), + ("c/f3.bin", size, "cafebabe"), + ]) + result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db")) + + assert result["total_groups"] == 1 + assert result["total_duplicates"] == 2 + assert result["total_reclaimable_bytes"] == size * 2 + + g = result["groups"][0] + assert g["count"] == 3 + assert g["reclaimable_bytes"] == size * 2 + assert g["files"] == sorted(["a/f1.bin", "b/f2.bin", "c/f3.bin"]) + + +def test_min_size_filter(tmp_path): + """test_min_size_filter — duplicados de tamano 50, min_size=100 -> groups=[].""" + _make_db(tmp_path, [ + ("x/small1.txt", 50, "tiny123"), + ("y/small2.txt", 50, "tiny123"), + ]) + result = vault_dedupe_report( + str(tmp_path), + min_size=100, + db_path=str(tmp_path / "vault_index.db"), + ) + + assert result["groups"] == [] + assert result["total_groups"] == 0 + assert result["total_reclaimable_bytes"] == 0 + assert result["scanned_files"] == 0 + + +def test_multiple_groups_ordered(tmp_path): + """test_multiple_groups_ordered — 2 grupos con distinto ahorro -> orden DESC.""" + # grupo A: 2 copias de 200 bytes -> reclaimable=200 + # grupo B: 3 copias de 500 bytes -> reclaimable=1000 + # el grupo B debe salir primero + _make_db(tmp_path, [ + ("p/a1.dat", 200, "groupA"), + ("q/a2.dat", 200, "groupA"), + ("r/b1.dat", 500, "groupB"), + ("s/b2.dat", 500, "groupB"), + ("t/b3.dat", 500, "groupB"), + ("u/uniq.dat", 999, "unique1"), + ]) + result = vault_dedupe_report(str(tmp_path), db_path=str(tmp_path / "vault_index.db")) + + assert result["total_groups"] == 2 + assert result["total_duplicates"] == 3 # (2-1) + (3-1) + assert result["total_reclaimable_bytes"] == 1200 # 200 + 1000 + assert result["scanned_files"] == 6 # 6 filas con sha256 != '' (incluye el unico) + + # Primer grupo debe ser el de mayor ahorro (B: 1000) + assert result["groups"][0]["sha256"] == "groupB" + assert result["groups"][0]["reclaimable_bytes"] == 1000 + assert result["groups"][1]["sha256"] == "groupA" + assert result["groups"][1]["reclaimable_bytes"] == 200 diff --git a/python/functions/infra/tests/test_vault_knowledge_parse.py b/python/functions/infra/tests/test_vault_knowledge_parse.py new file mode 100644 index 00000000..c114374a --- /dev/null +++ b/python/functions/infra/tests/test_vault_knowledge_parse.py @@ -0,0 +1,153 @@ +"""Tests para vault_knowledge_parse.""" + +from __future__ import annotations + +import os +import sqlite3 +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from vault_knowledge_parse import vault_knowledge_parse + + +def _make_vault(tmp: Path) -> tuple[Path, Path]: + """Crea un vault mínimo con vault_index.db.""" + db = tmp / "vault_index.db" + conn = sqlite3.connect(str(db)) + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS files ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + rel_path TEXT UNIQUE NOT NULL, + size_bytes INTEGER, + ext TEXT + ); + CREATE VIRTUAL TABLE IF NOT EXISTS files_fts + USING fts5(rel_path, content_text, content='', contentless_delete=1); + CREATE TABLE IF NOT EXISTS knowledge_docs ( + rel_path TEXT PRIMARY KEY, + title TEXT, + frontmatter_json TEXT, + headings_json TEXT, + parsed_at INTEGER + ); + """ + ) + conn.commit() + conn.close() + return tmp, db + + +def _insert_file_entry(db: Path, rel_path: str): + conn = sqlite3.connect(str(db)) + conn.execute( + "INSERT OR IGNORE INTO files(rel_path, size_bytes, ext) VALUES (?, 0, '.md')", + (rel_path,), + ) + conn.commit() + conn.close() + + +def test_md_with_frontmatter(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/guia.md" + md = vault / rel + md.parent.mkdir(parents=True, exist_ok=True) + md.write_text( + "---\ntitle: Mi Guía\nauthor: Lucas\n---\n\n# Mi Guía\n\nContenido del documento.\n", + encoding="utf-8", + ) + _insert_file_entry(db, rel) + + result = vault_knowledge_parse(str(vault), rel, db_path=str(db)) + + assert result["title"] == "Mi Guía" + assert result["frontmatter"]["author"] == "Lucas" + assert "Contenido del documento" in result["content_text"] + assert result["persisted"] is True + + +def test_md_no_frontmatter(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/sin_fm.md" + md = vault / rel + md.parent.mkdir(parents=True, exist_ok=True) + md.write_text("# Título\n\nCuerpo sin frontmatter.\n", encoding="utf-8") + _insert_file_entry(db, rel) + + result = vault_knowledge_parse(str(vault), rel, db_path=str(db)) + + assert result["frontmatter"] == {} + assert result["title"] == "Título" + assert "Cuerpo sin frontmatter" in result["content_text"] + + +def test_md_title_from_h1(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/title_h1.md" + md = vault / rel + md.parent.mkdir(parents=True, exist_ok=True) + md.write_text("# Primer H1\n\nAlgún texto.\n", encoding="utf-8") + _insert_file_entry(db, rel) + + result = vault_knowledge_parse(str(vault), rel, db_path=str(db)) + + assert result["title"] == "Primer H1" + + +def test_md_title_from_filename(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/nombre_archivo.md" + md = vault / rel + md.parent.mkdir(parents=True, exist_ok=True) + md.write_text("Solo texto sin headings ni frontmatter.\n", encoding="utf-8") + _insert_file_entry(db, rel) + + result = vault_knowledge_parse(str(vault), rel, db_path=str(db)) + + assert result["title"] == "nombre_archivo" + + +def test_md_headings_levels(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/headings.md" + md = vault / rel + md.parent.mkdir(parents=True, exist_ok=True) + md.write_text( + "# H1 Título\n\nTexto.\n\n## H2 Sección\n\n### H3 Subsección\n\n## H2 Otra\n", + encoding="utf-8", + ) + _insert_file_entry(db, rel) + + result = vault_knowledge_parse(str(vault), rel, db_path=str(db)) + + headings = result["headings"] + assert len(headings) == 4 + levels = [h["level"] for h in headings] + assert levels == [1, 2, 3, 2] + texts = [h["text"] for h in headings] + assert "H1 Título" in texts + assert "H2 Sección" in texts + assert "H3 Subsección" in texts + + +def test_md_persists_to_fts(tmp_path): + vault, db = _make_vault(tmp_path) + rel = "docs/fts_md.md" + md = vault / rel + md.parent.mkdir(parents=True, exist_ok=True) + md.write_text("# Documento FTS\n\nPalabra clave: xenolito.\n", encoding="utf-8") + _insert_file_entry(db, rel) + + vault_knowledge_parse(str(vault), rel, db_path=str(db)) + + conn = sqlite3.connect(str(db)) + # FTS5 contentless: no permite SELECT directo, usar MATCH + row = conn.execute( + "SELECT rowid FROM files_fts WHERE files_fts MATCH 'xenolito'", + ).fetchone() + conn.close() + assert row is not None, "FTS no encontró 'xenolito'" diff --git a/python/functions/infra/vault_dedupe_report.md b/python/functions/infra/vault_dedupe_report.md new file mode 100644 index 00000000..bba75147 --- /dev/null +++ b/python/functions/infra/vault_dedupe_report.md @@ -0,0 +1,57 @@ +--- +name: vault_dedupe_report +kind: function +lang: py +domain: infra +version: "1.0.0" +purity: impure +signature: "def vault_dedupe_report(vault_path: str, min_size: int = 0, db_path: str | None = None) -> dict" +description: "Detecta archivos duplicados en un vault leyendo vault_index.db (agrupando por sha256) y calcula el espacio recuperable. Retorna grupos ordenados por bytes recuperables DESC." +tags: [vault, dedupe, duplicates, disk, sha256, sqlite] +params: + - name: vault_path + desc: "Ruta raiz del vault. Usada como clave en el resultado y para localizar vault_index.db cuando db_path es None." + - name: min_size + desc: "Tamanio minimo en bytes para incluir un archivo en el analisis. Default 0 = todos los archivos." + - name: db_path + desc: "Override opcional de la ruta a vault_index.db. Si es None se usa /vault_index.db." +output: "dict con vault_path, groups (sha256/size/count/files/reclaimable_bytes), total_groups, total_duplicates, total_reclaimable_bytes, scanned_files. groups ordenados por reclaimable_bytes DESC." +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_py_core" +imports: ["sqlite3", "pathlib"] +tested: true +tests: + - "test_no_duplicates" + - "test_basic_duplicates" + - "test_three_in_group" + - "test_min_size_filter" + - "test_multiple_groups_ordered" +test_file_path: "python/functions/infra/tests/test_vault_dedupe_report.py" +file_path: "python/functions/infra/vault_dedupe_report.py" +--- + +## Ejemplo + +```python +from infra.vault_dedupe_report import vault_dedupe_report + +report = vault_dedupe_report("/data/vaults/my_vault", min_size=1024) +print(f"Grupos duplicados: {report['total_groups']}") +print(f"Espacio recuperable: {report['total_reclaimable_bytes'] // (1024**2)} MB") + +for g in report["groups"][:5]: + print(f" sha256={g['sha256'][:12]}... size={g['size']} count={g['count']}") + for f in g["files"]: + print(f" {f}") +``` + +## Notas + +- Solo considera filas con `sha256 != ''` (archivos efectivamente hasheados por `vault_inventory_scan_go_infra`). +- Abre la BD en modo read-only (`?mode=ro`) para no interferir con escrituras concurrentes. +- `GROUP_CONCAT` de SQLite no garantiza orden — los `files` se reordenan lexicograficamente en Python. +- Si la BD no existe o le falta la tabla `files`, lanza `RuntimeError` con mensaje orientativo. +- Prerequisito: haber corrido `fn vault index ` (pipeline `vault_inventory_scan_go_infra` + `vault_index_write_go_infra`) sobre el vault. diff --git a/python/functions/infra/vault_dedupe_report.py b/python/functions/infra/vault_dedupe_report.py new file mode 100644 index 00000000..60f4d1d2 --- /dev/null +++ b/python/functions/infra/vault_dedupe_report.py @@ -0,0 +1,122 @@ +"""vault_dedupe_report — Detecta duplicados en vault_index.db y calcula espacio recuperable.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + + +def vault_dedupe_report( + vault_path: str, + min_size: int = 0, + db_path: str | None = None, +) -> dict: + """Detecta archivos duplicados en un vault a partir de su vault_index.db. + + Lee la tabla ``files`` de ``vault_index.db`` agrupando por ``sha256`` y + retorna todos los grupos con mas de un archivo, ordenados por bytes + recuperables de mayor a menor. + + Args: + vault_path: Ruta raiz del vault. Usada como clave en el resultado y + para localizar ``vault_index.db`` cuando ``db_path`` es None. + min_size: Ignora archivos cuyo ``size`` (bytes) sea menor que este + valor. Default 0 = incluir todos los archivos. + db_path: Ruta absoluta o relativa a la BD SQLite. Si es None se + usa ``/vault_index.db``. + + Returns: + dict con las claves: + - ``vault_path``: str — mismo valor recibido. + - ``groups``: list de dicts, cada uno con: + - ``sha256``: str + - ``size``: int — tamanio en bytes de cada copia + - ``count``: int — numero de copias encontradas + - ``files``: list[str] — rel_paths ordenados lexicograficamente + - ``reclaimable_bytes``: int — ``size * (count - 1)`` + - ``total_groups``: int — numero de grupos con duplicados + - ``total_duplicates``: int — suma de ``(count - 1)`` por grupo + - ``total_reclaimable_bytes``: int — bytes totales recuperables + - ``scanned_files``: int — total de filas consideradas en la query + + Raises: + RuntimeError: Si la BD no existe, no tiene tabla ``files``, o hay + algun error de lectura. + """ + resolved_db = db_path if db_path is not None else str(Path(vault_path) / "vault_index.db") + + db_file = Path(resolved_db) + if not db_file.exists(): + raise RuntimeError( + f"No se encontro vault_index.db en '{resolved_db}'. " + "Corre 'fn vault index ' primero." + ) + + try: + conn = sqlite3.connect(f"file:{resolved_db}?mode=ro", uri=True) + except sqlite3.OperationalError as exc: + raise RuntimeError(f"No se pudo abrir '{resolved_db}': {exc}") from exc + + try: + # Verificar que existe la tabla files + cur = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='files';" + ) + if cur.fetchone() is None: + raise RuntimeError( + f"vault_index.db sin tabla 'files'. " + "Corre 'fn vault index ' primero." + ) + + # Contar filas totales consideradas (sha256 no vacio, size >= min_size) + row = conn.execute( + "SELECT COUNT(*) FROM files WHERE size >= ? AND sha256 != '';", + (min_size,), + ).fetchone() + scanned_files: int = row[0] if row else 0 + + # Query principal: grupos con mas de una copia + query = """ + SELECT + sha256, + size, + COUNT(*) AS cnt, + GROUP_CONCAT(rel_path) AS paths + FROM files + WHERE size >= ? AND sha256 != '' + GROUP BY sha256 + HAVING COUNT(*) > 1 + ORDER BY size * (COUNT(*) - 1) DESC; + """ + rows = conn.execute(query, (min_size,)).fetchall() + finally: + conn.close() + + groups: list[dict] = [] + total_duplicates = 0 + total_reclaimable_bytes = 0 + + for sha256, size, cnt, paths_concat in rows: + # GROUP_CONCAT no garantiza orden — ordenar lexicograficamente + files = sorted(paths_concat.split(",")) + reclaimable = size * (cnt - 1) + groups.append( + { + "sha256": sha256, + "size": size, + "count": cnt, + "files": files, + "reclaimable_bytes": reclaimable, + } + ) + total_duplicates += cnt - 1 + total_reclaimable_bytes += reclaimable + + return { + "vault_path": vault_path, + "groups": groups, + "total_groups": len(groups), + "total_duplicates": total_duplicates, + "total_reclaimable_bytes": total_reclaimable_bytes, + "scanned_files": scanned_files, + } diff --git a/python/functions/infra/vault_knowledge_parse.md b/python/functions/infra/vault_knowledge_parse.md new file mode 100644 index 00000000..2044b1a3 --- /dev/null +++ b/python/functions/infra/vault_knowledge_parse.md @@ -0,0 +1,60 @@ +--- +name: vault_knowledge_parse +kind: function +lang: py +domain: infra +version: "1.0.0" +purity: impure +signature: "def vault_knowledge_parse(vault_path: str, rel_path: str, db_path: str | None = None) -> dict" +description: "Parsea un archivo Markdown del vault: extrae YAML frontmatter, título, headings y cuerpo; persiste en knowledge_docs y actualiza files_fts para búsqueda por contenido." +tags: [vault, markdown, knowledge, frontmatter, headings, fts, infra] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [json, re, sqlite3, time, pathlib, yaml] +params: + - name: vault_path + desc: "Ruta absoluta a la raiz del vault donde vive el Markdown y vault_index.db." + - name: rel_path + desc: "Ruta relativa al archivo .md dentro del vault (ej. 'docs/guia.md')." + - name: db_path + desc: "Override opcional de la ruta a vault_index.db. Por defecto /vault_index.db." +output: "Dict con: rel_path (str), title (str), frontmatter (dict), headings (list de {level, text}), content_text (str cuerpo sin frontmatter), persisted (bool)." +tested: true +tests: + - "test_md_with_frontmatter" + - "test_md_no_frontmatter" + - "test_md_title_from_h1" + - "test_md_title_from_filename" + - "test_md_headings_levels" + - "test_md_persists_to_fts" +test_file_path: "python/functions/infra/tests/test_vault_knowledge_parse.py" +file_path: "python/functions/infra/vault_knowledge_parse.py" +--- + +## Ejemplo + +```python +from vault_knowledge_parse import vault_knowledge_parse + +result = vault_knowledge_parse("/vaults/mi_vault", "docs/guia_operaciones.md") +# { +# "rel_path": "docs/guia_operaciones.md", +# "title": "Guía de Operaciones", +# "frontmatter": {"author": "Lucas", "tags": ["ops"]}, +# "headings": [{"level": 1, "text": "Guía de Operaciones"}, {"level": 2, "text": "Instalación"}], +# "content_text": "# Guía de Operaciones\n\n## Instalación\n...", +# "persisted": True +# } +``` + +## Notas + +- Prioridad de título: frontmatter["title"] > primer H1 en el cuerpo > basename sin extensión. +- Frontmatter YAML delimitado por `---\n` al inicio del archivo. Si no hay frontmatter, se retorna {}. +- content_text es el cuerpo completo sin el bloque frontmatter (incluye los headings H1-H6). +- El rowid de files_fts se ancla al rowid de la tabla files para que vault_search funcione correctamente. +- Si vault_index.db no existe, retorna el dict sin intentar persistir (persisted=False). +- Dependencias: pyyaml (ya instalado en python/.venv). diff --git a/python/functions/infra/vault_knowledge_parse.py b/python/functions/infra/vault_knowledge_parse.py new file mode 100644 index 00000000..3959b1ef --- /dev/null +++ b/python/functions/infra/vault_knowledge_parse.py @@ -0,0 +1,142 @@ +"""vault_knowledge_parse — Parsea un Markdown del vault y persiste en knowledge_docs.""" + +from __future__ import annotations + +import json +import re +import sqlite3 +import time +from pathlib import Path + + +def _parse_frontmatter(text: str) -> tuple[dict, str]: + """Separa YAML frontmatter del cuerpo. Retorna (frontmatter_dict, body).""" + if not text.startswith("---\n") and not text.startswith("---\r\n"): + return {}, text + + # Buscar cierre del frontmatter + end = text.find("\n---", 4) + if end == -1: + return {}, text + + yaml_block = text[4:end].strip() + body = text[end + 4:].lstrip("\n\r") + + try: + import yaml + + fm = yaml.safe_load(yaml_block) or {} + if not isinstance(fm, dict): + fm = {} + except Exception: + fm = {} + + return fm, body + + +def _extract_headings(body: str) -> list[dict]: + """Extrae headings Markdown (# ... ### ...) del cuerpo.""" + headings = [] + for line in body.splitlines(): + m = re.match(r"^(#{1,6})\s+(.*)", line) + if m: + headings.append({"level": len(m.group(1)), "text": m.group(2).strip()}) + return headings + + +def _extract_title(frontmatter: dict, body: str, basename: str) -> str: + """Extrae título: frontmatter['title'] > primer H1 > basename.""" + if frontmatter.get("title"): + return str(frontmatter["title"]) + for line in body.splitlines(): + m = re.match(r"^#\s+(.*)", line) + if m: + return m.group(1).strip() + return basename + + +def vault_knowledge_parse( + vault_path: str, + rel_path: str, + db_path: str | None = None, +) -> dict: + """Parsea un archivo Markdown del vault: extrae frontmatter, título, headings y cuerpo. + + Args: + vault_path: Ruta absoluta a la raiz del vault. + rel_path: Ruta relativa al archivo Markdown dentro del vault. + db_path: Override opcional de la ruta a vault_index.db. + + Returns: + Dict con: rel_path, title, frontmatter, headings, content_text, persisted. + + Raises: + RuntimeError: Si el archivo no existe o no se puede leer. + """ + vault = Path(vault_path) + md_file = vault / rel_path + if not md_file.exists(): + raise RuntimeError(f"vault_knowledge_parse: archivo no encontrado: {md_file}") + + db = Path(db_path) if db_path else vault / "vault_index.db" + + try: + text = md_file.read_text(encoding="utf-8") + except UnicodeDecodeError: + text = md_file.read_text(encoding="latin-1", errors="replace") + + frontmatter, body = _parse_frontmatter(text) + headings = _extract_headings(body) + basename = md_file.stem + title = _extract_title(frontmatter, body, basename) + content_text = body + + # Persistir en vault_index.db + persisted = False + if db.exists(): + conn = sqlite3.connect(str(db)) + try: + now = int(time.time()) + conn.execute( + """ + INSERT INTO knowledge_docs(rel_path, title, frontmatter_json, headings_json, parsed_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(rel_path) DO UPDATE SET + title=excluded.title, + frontmatter_json=excluded.frontmatter_json, + headings_json=excluded.headings_json, + parsed_at=excluded.parsed_at + """, + ( + rel_path, + title, + json.dumps(frontmatter, ensure_ascii=False), + json.dumps(headings, ensure_ascii=False), + now, + ), + ) + # Actualizar files_fts (rowid debe coincidir con files) + conn.execute("DELETE FROM files_fts WHERE rel_path = ?", (rel_path,)) + conn.execute( + """ + INSERT INTO files_fts(rowid, rel_path, content_text) + VALUES ((SELECT rowid FROM files WHERE rel_path = ?), ?, ?) + """, + (rel_path, rel_path, content_text), + ) + conn.commit() + persisted = True + except Exception: + conn.rollback() + raise + finally: + conn.close() + + return { + "rel_path": rel_path, + "title": title, + "frontmatter": frontmatter, + "headings": headings, + "content_text": content_text, + "persisted": persisted, + } diff --git a/python/functions/infra/vault_profile_dispatch.md b/python/functions/infra/vault_profile_dispatch.md new file mode 100644 index 00000000..20d16199 --- /dev/null +++ b/python/functions/infra/vault_profile_dispatch.md @@ -0,0 +1,54 @@ +--- +name: vault_profile_dispatch +kind: function +lang: py +domain: infra +version: "1.0.0" +purity: impure +signature: "def vault_profile_dispatch(vault_path: str, rel_path: str, kind: str, db_path: str | None = None) -> dict" +description: "CLI dispatcher que enruta un archivo del vault al profiler correcto segun su tipo (csv/pdf/md). Thin wrapper sobre vault_csv_profile, vault_pdf_extract y vault_knowledge_parse. Usable desde Go via os/exec para procesar archivos en bulk." +tags: [vault, profile, dispatch, profiler, csv, pdf, md, infra] +uses_functions: + - vault_csv_profile_py_datascience + - vault_pdf_extract_py_datascience + - vault_knowledge_parse_py_infra +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: vault_path + desc: "Ruta absoluta a la raiz del vault." + - name: rel_path + desc: "Ruta relativa del archivo dentro del vault." + - name: kind + desc: "Tipo de profiler: csv | pdf | md." + - name: db_path + desc: "Override de la ruta a vault_index.db. Default: /vault_index.db." +output: "Dict con resultado del profiler correspondiente. Para csv: {rel_path, cols, n_rows, encoding, date_min, date_max, persisted}. Para pdf: {rel_path, page_count, text_len, extracted_to, persisted}. Para md: resultado de vault_knowledge_parse." +tested: false +tests: [] +test_file_path: "" +file_path: "python/functions/infra/vault_profile_dispatch.py" +--- + +## Ejemplo + +```bash +# Desde CLI +python3 python/functions/infra/vault_profile_dispatch.py \ + --vault /home/lucas/vaults/turismo_spain \ + --rel-path data/raw/report.csv \ + --kind csv + +# Desde Go via os/exec (patron usado en fn vault profile) +python3 vault_profile_dispatch.py --vault --rel-path

--kind csv +``` + +## Notas + +Disenado para ser invocado desde Go via `os/exec`. Imprime resultado como JSON a stdout. +Codigos de salida: 0=exito, 1=args faltantes, 2=kind desconocido, 3=error del profiler. + +Detecta automaticamente el PYTHONPATH mirando `FN_REGISTRY_ROOT` o subiendo desde su propia ubicacion. diff --git a/python/functions/infra/vault_profile_dispatch.py b/python/functions/infra/vault_profile_dispatch.py new file mode 100644 index 00000000..024b2d25 --- /dev/null +++ b/python/functions/infra/vault_profile_dispatch.py @@ -0,0 +1,92 @@ +"""vault_profile_dispatch — CLI dispatcher that routes a single vault file to the right profiler. + +Usage: + python3 vault_profile_dispatch.py --vault --rel-path

--kind csv|pdf|md [--db-path

] + +Exit codes: + 0 success (result printed as JSON) + 1 missing required argument + 2 unknown kind + 3 profiler raised an error +""" + +from __future__ import annotations + +import argparse +import json +import sys +import os +from pathlib import Path + + +def _python_path_setup() -> None: + """Ensure the registry python/functions directory is on sys.path.""" + # Try FN_REGISTRY_ROOT env first, then walk up from this file's location. + registry_root = os.environ.get("FN_REGISTRY_ROOT", "") + if not registry_root: + # This file lives at python/functions/infra/vault_profile_dispatch.py + # So the registry root is four levels up from __file__. + candidate = Path(__file__).resolve().parent.parent.parent.parent + if (candidate / "go.mod").exists(): + registry_root = str(candidate) + + if registry_root: + fn_path = str(Path(registry_root) / "python" / "functions") + if fn_path not in sys.path: + sys.path.insert(0, fn_path) + + +def dispatch(vault_path: str, rel_path: str, kind: str, db_path: str | None) -> dict: + """Call the appropriate profiler based on kind.""" + if kind == "csv": + from datascience.vault_csv_profile import vault_csv_profile + return vault_csv_profile(vault_path, rel_path, db_path) + elif kind == "pdf": + from datascience.vault_pdf_extract import vault_pdf_extract + return vault_pdf_extract(vault_path, rel_path, db_path) + elif kind == "md": + from infra.vault_knowledge_parse import vault_knowledge_parse + return vault_knowledge_parse(vault_path, rel_path, db_path) + else: + raise ValueError(f"unknown kind: {kind!r} (expected csv, pdf, or md)") + + +def main(argv: list[str] | None = None) -> int: + _python_path_setup() + + parser = argparse.ArgumentParser( + prog="vault_profile_dispatch", + description="Route a single vault file to the right profiler (csv/pdf/md).", + ) + parser.add_argument("--vault", required=True, help="Absolute path to vault root") + parser.add_argument("--rel-path", required=True, dest="rel_path", help="Relative path of file inside vault") + parser.add_argument( + "--kind", + required=True, + choices=["csv", "pdf", "md"], + help="Profiler kind: csv | pdf | md", + ) + parser.add_argument( + "--db-path", + dest="db_path", + default=None, + help="Override path to vault_index.db (default: /vault_index.db)", + ) + + args = parser.parse_args(argv) + + try: + result = dispatch(args.vault, args.rel_path, args.kind, args.db_path) + except ValueError as exc: + print(f"error: {exc}", file=sys.stderr) + return 2 + except Exception as exc: + print(f"error: {exc}", file=sys.stderr) + return 3 + + print(json.dumps(result, indent=2, default=str)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/python/functions/ml/__init__.py b/python/functions/ml/__init__.py new file mode 100644 index 00000000..796e48d5 --- /dev/null +++ b/python/functions/ml/__init__.py @@ -0,0 +1 @@ +"""ml — tipos y funciones de generacion de imagenes con modelos de difusion.""" diff --git a/python/functions/ml/cuda_available.md b/python/functions/ml/cuda_available.md new file mode 100644 index 00000000..89a237b8 --- /dev/null +++ b/python/functions/ml/cuda_available.md @@ -0,0 +1,67 @@ +--- +name: cuda_available +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def cuda_available() -> dict" +description: "Detecta si CUDA esta disponible via torch. Devuelve device_count, nombres de GPU y version de CUDA. Si torch no esta instalado, retorna available=False sin lanzar excepcion." +tags: [cuda, gpu, torch, pytorch, hardware, probe, ml, device] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: [] +output: "dict con claves: available (bool), device_count (int), devices (list[str] con nombres de GPU), torch_version (str o 'not_installed'), cuda_version (str | None)" +tested: true +tests: + - "sin torch retorna available=False y torch_version=not_installed" + - "con torch sin cuda retorna available=False y device_count=0" + - "claves del dict siempre presentes" +test_file_path: "python/functions/ml/tests/test_cuda_available.py" +file_path: "python/functions/ml/cuda_available.py" +--- + +## Ejemplo + +```python +from ml.cuda_available import cuda_available + +info = cuda_available() +# Sin GPU: +# { +# "available": False, +# "device_count": 0, +# "devices": [], +# "torch_version": "2.3.0", +# "cuda_version": None +# } + +# Con GPU: +# { +# "available": True, +# "device_count": 1, +# "devices": ["NVIDIA RTX 4090"], +# "torch_version": "2.3.0", +# "cuda_version": "12.1" +# } + +# Sin torch instalado: +# { +# "available": False, +# "device_count": 0, +# "devices": [], +# "torch_version": "not_installed", +# "cuda_version": None +# } +``` + +## Notas + +- Nunca lanza ImportError aunque torch no este instalado. +- `cuda_version` es la version de CUDA con la que fue compilado torch, no necesariamente la del sistema. +- Usar junto a `torch_device_select` para elegir device y `gpu_info` para estadisticas de VRAM. +- impure: depende del estado del hardware y de librerias del sistema en tiempo de ejecucion. diff --git a/python/functions/ml/cuda_available.py b/python/functions/ml/cuda_available.py new file mode 100644 index 00000000..ebf55d08 --- /dev/null +++ b/python/functions/ml/cuda_available.py @@ -0,0 +1,42 @@ +"""Detecta disponibilidad de CUDA via torch sin lanzar excepcion si torch no esta instalado.""" + +from __future__ import annotations + + +def cuda_available() -> dict: + """Detecta si CUDA esta disponible y devuelve info de los dispositivos GPU. + + No requiere torch instalado: si no esta presente, devuelve + `torch_version='not_installed'` y `available=False`. + + Returns: + dict con claves: + available (bool): True si torch.cuda.is_available(). + device_count (int): numero de GPUs detectadas (0 si no hay CUDA). + devices (list[str]): nombres de cada GPU (ej. "NVIDIA RTX 4090"). + torch_version (str): version de torch o "not_installed". + cuda_version (str | None): version de CUDA usada por torch, o None. + """ + try: + import torch + except ImportError: + return { + "available": False, + "device_count": 0, + "devices": [], + "torch_version": "not_installed", + "cuda_version": None, + } + + available = torch.cuda.is_available() + device_count = torch.cuda.device_count() if available else 0 + devices = [torch.cuda.get_device_name(i) for i in range(device_count)] + cuda_version = torch.version.cuda if available else None + + return { + "available": available, + "device_count": device_count, + "devices": devices, + "torch_version": torch.__version__, + "cuda_version": cuda_version, + } diff --git a/python/functions/ml/diffusers_generate.md b/python/functions/ml/diffusers_generate.md new file mode 100644 index 00000000..acae3be8 --- /dev/null +++ b/python/functions/ml/diffusers_generate.md @@ -0,0 +1,71 @@ +--- +name: diffusers_generate +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def diffusers_generate(pipe: Any, cfg: GenerationConfig) -> ImageGenResult" +description: "Ejecuta inferencia con un pipeline diffusers usando GenerationConfig. Mide duracion y pico de VRAM. Retorna ImageGenResult con imagen PIL, meta y metricas." +tags: [diffusers, ml, image-generation, inference, vram, metrics] +uses_functions: [genconfig_to_diffusers_kwargs_py_ml] +uses_types: [generation_config_py_ml, image_gen_result_py_ml] +returns: [image_gen_result_py_ml] +returns_optional: false +error_type: "error_go_core" +imports: [torch, diffusers] +params: + - name: pipe + desc: "Pipeline diffusers cargado y listo para inferencia (resultado de diffusers_load_pipeline, opcionalmente con scheduler y LoRA configurados)." + - name: cfg + desc: "Parametros de generacion. cfg.seed >= 0 para semilla fija; -1 usa time-based. cfg.sampler se incluye en meta pero no se aplica aqui (usar diffusers_set_scheduler antes)." +output: "ImageGenResult con image=PIL.Image.Image, meta={backend, model, sampler, actual_steps, seed, width, height, cfg_scale}, duration_ms en entero milisegundos, vram_peak_mb (None si no hay CUDA)." +tested: true +tests: + - "genera imagen retorna ImageGenResult" +test_file_path: "python/functions/ml/tests/test_diffusers_backend.py" +file_path: "python/functions/ml/diffusers_generate.py" +--- + +## Ejemplo + +```python +from diffusers_load_pipeline import diffusers_load_pipeline +from diffusers_generate import diffusers_generate +from generation_config import GenerationConfig +from model_ref import ModelRef + +model = ModelRef( + name="sd-turbo", + model_type="sd15", + path="/home/lucas/vaults/imagegen_models/diffusers/sd-turbo", +) +cfg = GenerationConfig( + prompt="a photo of a cat", + seed=42, + steps=1, + cfg_scale=0.0, + sampler="euler", + width=512, + height=512, + model=model, +) +pipe = diffusers_load_pipeline(model, device="cuda", dtype="fp16") +result = diffusers_generate(pipe, cfg) +# result.image -> PIL.Image.Image 512x512 +# result.duration_ms -> int > 0 +# result.meta["backend"] -> "diffusers" +``` + +## Notas + +`cfg.seed = -1` genera seed aleatorio basado en `time.time()` (reproducible si +se guarda en `result.meta["seed"]`). + +VRAM: `torch.cuda.reset_peak_memory_stats()` antes de inferencia, +`torch.cuda.max_memory_allocated() // 1024 // 1024` despues. + +`genconfig_to_diffusers_kwargs` omite generator=None; esta funcion lo reemplaza +con `torch.Generator(device=device).manual_seed(seed)`. + +Import lazy de torch — ImportError descriptivo si no instalado. diff --git a/python/functions/ml/diffusers_generate.py b/python/functions/ml/diffusers_generate.py new file mode 100644 index 00000000..259f688a --- /dev/null +++ b/python/functions/ml/diffusers_generate.py @@ -0,0 +1,98 @@ +"""diffusers_generate — ejecuta inferencia con un pipeline diffusers y retorna ImageGenResult.""" + +from __future__ import annotations + +import sys +import os +import time +from typing import Any + +sys.path.insert(0, os.path.dirname(__file__)) + +from generation_config import GenerationConfig +from image_gen_result import ImageGenResult +from genconfig_to_diffusers_kwargs import genconfig_to_diffusers_kwargs + + +def diffusers_generate(pipe: Any, cfg: GenerationConfig) -> ImageGenResult: + """Ejecuta inferencia con un pipeline diffusers y retorna ImageGenResult. + + Convierte el GenerationConfig a kwargs via genconfig_to_diffusers_kwargs, + crea un torch.Generator con la semilla configurada, mide duracion y pico + de VRAM (si CUDA disponible). El campo meta del resultado incluye backend, + modelo, sampler, seed y steps usados. + + Args: + pipe: Pipeline diffusers cargado (resultado de diffusers_load_pipeline). + Debe ser callable: pipe(prompt=..., ...) -> objeto con .images[0]. + cfg: Parametros de generacion. cfg.seed se usa para torch.Generator. + cfg.model.name se incluye en meta. cfg.sampler se incluye en meta + pero NO se aplica aqui — usar diffusers_set_scheduler antes si se + quiere cambiar el sampler. + + Returns: + ImageGenResult con image=PIL.Image, meta con keys backend/model/sampler/ + actual_steps/seed, duration_ms y vram_peak_mb (None si no hay CUDA). + + Raises: + ImportError: Si torch o diffusers no estan instalados. + RuntimeError: Si la inferencia falla (OOM, modelo incompatible, etc.). + """ + try: + import torch + except ImportError as exc: + raise ImportError( + "diffusers_generate requiere torch. " + "Instalar con: pip install torch" + ) from exc + + # Determinar device del pipeline + device = "cpu" + if hasattr(pipe, "device"): + device = str(pipe.device) + + # Medir VRAM solo en CUDA + cuda_available = torch.cuda.is_available() + if cuda_available: + torch.cuda.reset_peak_memory_stats() + + # Construir kwargs desde GenerationConfig + kwargs = genconfig_to_diffusers_kwargs(cfg) + + # Crear generator con semilla + seed = cfg.seed if cfg.seed >= 0 else int(time.time()) % (2**32) + generator = torch.Generator(device=device).manual_seed(seed) + kwargs["generator"] = generator + + # Inferencia + t0 = time.perf_counter() + result = pipe(**kwargs) + t1 = time.perf_counter() + + duration_ms = int((t1 - t0) * 1000) + + # VRAM peak + vram_peak_mb: int | None = None + if cuda_available: + vram_peak_mb = torch.cuda.max_memory_allocated() // 1024 // 1024 + + # Nombre del modelo + model_name = cfg.model.name if cfg.model and hasattr(cfg.model, "name") else "unknown" + + meta: dict[str, Any] = { + "backend": "diffusers", + "model": model_name, + "sampler": cfg.sampler, + "actual_steps": cfg.steps, + "seed": seed, + "width": cfg.width, + "height": cfg.height, + "cfg_scale": cfg.cfg_scale, + } + + return ImageGenResult( + image=result.images[0], + meta=meta, + duration_ms=duration_ms, + vram_peak_mb=vram_peak_mb, + ) diff --git a/python/functions/ml/diffusers_load_lora.md b/python/functions/ml/diffusers_load_lora.md new file mode 100644 index 00000000..81b40cf9 --- /dev/null +++ b/python/functions/ml/diffusers_load_lora.md @@ -0,0 +1,49 @@ +--- +name: diffusers_load_lora +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def diffusers_load_lora(pipe: Any, lora: LoraRef) -> Any" +description: "Carga un adaptador LoRA en un pipeline diffusers via pipe.load_lora_weights. Si lora.weight != 1.0, aplica set_adapters para escalar la contribucion del LoRA." +tags: [diffusers, ml, lora, image-generation, fine-tuning] +uses_functions: [] +uses_types: [lora_ref_py_ml] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [diffusers] +params: + - name: pipe + desc: "Pipeline diffusers que soporte load_lora_weights (SD1.5, SDXL, etc.)." + - name: lora + desc: "Referencia al adaptador LoRA. lora.path al .safetensors o directorio. lora.weight escala la fusion (1.0 = completo, 0.5 = mitad)." +output: "El mismo pipe con el LoRA cargado y peso aplicado. Modificacion in-place, retorna pipe para composicion." +tested: false +tests: [] +test_file_path: "" +file_path: "python/functions/ml/diffusers_load_lora.py" +--- + +## Ejemplo + +```python +from diffusers_load_lora import diffusers_load_lora +from lora_ref import LoraRef + +lora = LoraRef(path="/path/to/my_lora.safetensors", weight=0.8) +pipe = diffusers_load_lora(pipe, lora) +``` + +## Notas + +Usa `pipe.load_lora_weights(path)` para cargar. Si `lora.weight != 1.0`: +- Intenta `pipe.set_adapters(["default"], adapter_weights=[weight])` (diffusers >= 0.20). +- Fallback a `pipe.fuse_lora(lora_scale=weight)` para versiones antiguas. + +El campo `lora.scale` (override de alpha) no se aplica aqui — diffusers no expone +un parametro directo equivalente en la API publica actual. Se puede setear via +`pipe.load_lora_weights(path, weight_name=...)` si el archivo tiene nombre especifico. + +Import lazy de diffusers — ImportError descriptivo si no instalado. diff --git a/python/functions/ml/diffusers_load_lora.py b/python/functions/ml/diffusers_load_lora.py new file mode 100644 index 00000000..f2a52a8d --- /dev/null +++ b/python/functions/ml/diffusers_load_lora.py @@ -0,0 +1,55 @@ +"""diffusers_load_lora — carga un adaptador LoRA en un pipeline diffusers.""" + +from __future__ import annotations + +import sys +import os +from typing import Any + +sys.path.insert(0, os.path.dirname(__file__)) + +from lora_ref import LoraRef + + +def diffusers_load_lora(pipe: Any, lora: LoraRef) -> Any: + """Carga un adaptador LoRA en un pipeline diffusers y ajusta su peso de fusion. + + Usa pipe.load_lora_weights(lora.path) para cargar los pesos del adaptador. + Si lora.weight != 1.0, aplica set_adapters(['default'], adapter_weights=[w]) + para escalar la contribucion del LoRA. Modifica el pipe in-place y retorna + el mismo objeto para composicion. + + Args: + pipe: Pipeline diffusers cargado. Debe soportar load_lora_weights + (StableDiffusionPipeline, StableDiffusionXLPipeline, etc.). + lora: Referencia al adaptador LoRA. lora.path apunta al archivo + .safetensors o directorio del adaptador. lora.weight controla + la intensidad de fusion (1.0 = peso completo, 0.0 = sin efecto). + + Returns: + El mismo pipe con el LoRA cargado y el peso de fusion aplicado. + + Raises: + ImportError: Si diffusers no esta instalado. + OSError: Si lora.path no existe o el formato del archivo es invalido. + """ + try: + import diffusers # noqa: F401 — verificar disponibilidad + except ImportError as exc: + raise ImportError( + "diffusers_load_lora requiere diffusers. " + "Instalar con: pip install diffusers" + ) from exc + + pipe.load_lora_weights(lora.path) + + if lora.weight != 1.0: + # set_adapters acepta lista de nombres y lista de pesos. + # El nombre "default" es el que diffusers asigna al primer LoRA cargado. + if hasattr(pipe, "set_adapters"): + pipe.set_adapters(["default"], adapter_weights=[lora.weight]) + elif hasattr(pipe, "fuse_lora"): + # Fallback para versiones antiguas de diffusers + pipe.fuse_lora(lora_scale=lora.weight) + + return pipe diff --git a/python/functions/ml/diffusers_load_pipeline.md b/python/functions/ml/diffusers_load_pipeline.md new file mode 100644 index 00000000..fc9c75a3 --- /dev/null +++ b/python/functions/ml/diffusers_load_pipeline.md @@ -0,0 +1,60 @@ +--- +name: diffusers_load_pipeline +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def diffusers_load_pipeline(model: ModelRef, device: str = 'auto', dtype: str = 'fp16') -> Any" +description: "Carga un pipeline diffusers (AutoPipelineForText2Image) con cache global por (model_key, dtype, device). Segunda llamada con mismos parametros retorna el objeto cacheado sin recargar disco." +tags: [diffusers, ml, image-generation, pipeline, cache, torch] +uses_functions: [torch_device_select_py_ml] +uses_types: [model_ref_py_ml] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [torch, diffusers] +params: + - name: model + desc: "Referencia al modelo. model.path si disponible (ruta local), model.name si no (HuggingFace Hub o nombre corto)." + - name: device + desc: "Preferencia de device: 'auto' (CUDA>MPS>CPU), 'cuda', 'cuda:N', 'mps', 'cpu'. Default 'auto'." + - name: dtype + desc: "Precision del modelo: 'fp16' (torch.float16 + variant=fp16), 'bf16' (bfloat16), 'fp32' (float32). Default 'fp16'." +output: "Pipeline diffusers cargado y movido al device. Callable via pipe(prompt=..., ...). Cacheado en _PIPELINE_CACHE." +tested: true +tests: + - "carga pipeline y retorna callable" + - "segunda carga usa cache (< 100ms)" +test_file_path: "python/functions/ml/tests/test_diffusers_backend.py" +file_path: "python/functions/ml/diffusers_load_pipeline.py" +--- + +## Ejemplo + +```python +import sys +sys.path.insert(0, "python/functions/ml") +from diffusers_load_pipeline import diffusers_load_pipeline +from model_ref import ModelRef + +model = ModelRef( + name="sd-turbo", + model_type="sd15", + quantization="fp16", + path="/home/lucas/vaults/imagegen_models/diffusers/sd-turbo", +) +pipe = diffusers_load_pipeline(model, device="cuda", dtype="fp16") +# Segunda llamada: cache hit, < 100ms +pipe2 = diffusers_load_pipeline(model, device="cuda", dtype="fp16") +assert pipe is pipe2 +``` + +## Notas + +Cache global `_PIPELINE_CACHE` indexado por `(model_key, dtype, resolved_device)`. +`model_key` es `model.path` si no es None, sino `model.name`. + +Para liberar memoria: usar `diffusers_unload(pipe=None)` que llama `_clear_pipeline_cache()`. + +Imports lazy de torch y diffusers dentro de la funcion — ImportError descriptivo si no instalados. diff --git a/python/functions/ml/diffusers_load_pipeline.py b/python/functions/ml/diffusers_load_pipeline.py new file mode 100644 index 00000000..98b67a58 --- /dev/null +++ b/python/functions/ml/diffusers_load_pipeline.py @@ -0,0 +1,102 @@ +"""diffusers_load_pipeline — carga un pipeline diffusers con cache global por (model, dtype, device).""" + +from __future__ import annotations + +import sys +import os +import time +from typing import Any + +sys.path.insert(0, os.path.dirname(__file__)) + +from model_ref import ModelRef +from torch_device_select import torch_device_select + +# Cache global: (model_key, dtype, device) -> pipeline object +_PIPELINE_CACHE: dict[tuple[str, str, str], Any] = {} + + +def _get_model_key(model: ModelRef) -> str: + """Retorna la clave de cache para un ModelRef.""" + return model.path if model.path else model.name + + +def diffusers_load_pipeline( + model: ModelRef, + device: str = "auto", + dtype: str = "fp16", +) -> Any: + """Carga un pipeline diffusers con cache global por (model_key, dtype, device). + + Usa AutoPipelineForText2Image.from_pretrained con torch_dtype=torch.float16 + y variant="fp16" por defecto. Hace pipe.to(device) tras la carga. Los + pipelines se cachean en memoria — segunda llamada con los mismos parametros + retorna el objeto cacheado sin recargar el modelo del disco. + + Args: + model: Referencia al modelo. model.path se usa si esta presente; + si no, model.name se pasa directo a from_pretrained (HF hub). + device: Preferencia de device. 'auto' delega a torch_device_select + (CUDA > MPS > CPU). Ejemplos: 'auto', 'cuda', 'cuda:0', 'cpu'. + dtype: Precision del modelo. 'fp16' usa torch.float16 + variant="fp16". + 'fp32' usa torch.float32 sin variant. 'bf16' usa torch.bfloat16. + + Returns: + Objeto pipeline diffusers cargado y movido al device seleccionado. + El tipo concreto depende del modelo (StableDiffusionPipeline, + StableDiffusionXLPipeline, etc.) pero siempre es callable via pipe(...). + + Raises: + ImportError: Si torch o diffusers no estan instalados. + OSError: Si el path del modelo no existe o el nombre del hub es invalido. + """ + try: + import torch + from diffusers import AutoPipelineForText2Image + except ImportError as exc: + raise ImportError( + "diffusers_load_pipeline requiere torch y diffusers. " + "Instalar con: pip install torch diffusers" + ) from exc + + resolved_device = torch_device_select(device) + model_key = _get_model_key(model) + cache_key = (model_key, dtype, resolved_device) + + if cache_key in _PIPELINE_CACHE: + return _PIPELINE_CACHE[cache_key] + + load_path = model.path if model.path else model.name + + if dtype == "fp16": + torch_dtype = torch.float16 + pipe = AutoPipelineForText2Image.from_pretrained( + load_path, + torch_dtype=torch_dtype, + variant="fp16", + ) + elif dtype == "bf16": + torch_dtype = torch.bfloat16 + pipe = AutoPipelineForText2Image.from_pretrained( + load_path, + torch_dtype=torch_dtype, + ) + elif dtype == "fp32": + torch_dtype = torch.float32 + pipe = AutoPipelineForText2Image.from_pretrained( + load_path, + torch_dtype=torch_dtype, + ) + else: + raise ValueError( + f"dtype '{dtype}' no soportado. Usar 'fp16', 'bf16' o 'fp32'." + ) + + pipe = pipe.to(resolved_device) + _PIPELINE_CACHE[cache_key] = pipe + return pipe + + +def _clear_pipeline_cache() -> None: + """Limpia el cache global de pipelines (uso interno y tests).""" + _PIPELINE_CACHE.clear() diff --git a/python/functions/ml/diffusers_set_scheduler.md b/python/functions/ml/diffusers_set_scheduler.md new file mode 100644 index 00000000..e28a4fab --- /dev/null +++ b/python/functions/ml/diffusers_set_scheduler.md @@ -0,0 +1,61 @@ +--- +name: diffusers_set_scheduler +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def diffusers_set_scheduler(pipe: Any, sampler: str) -> Any" +description: "Reemplaza el scheduler de un pipeline diffusers por la clase correspondiente al sampler solicitado. Usa from_config para heredar configuracion base del modelo." +tags: [diffusers, ml, scheduler, sampler, image-generation] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [diffusers] +params: + - name: pipe + desc: "Pipeline diffusers cargado con atributo pipe.scheduler y pipe.scheduler.config." + - name: sampler + desc: "Nombre del sampler: euler, euler_a, dpm++2m, dpm++2m_v2, heun, dpm2, lcm." +output: "El mismo pipe con pipe.scheduler reemplazado. Modificacion in-place, retorna pipe para composicion." +tested: true +tests: + - "euler cambia scheduler a EulerDiscreteScheduler" + - "sampler invalido lanza ValueError" +test_file_path: "python/functions/ml/tests/test_diffusers_backend.py" +file_path: "python/functions/ml/diffusers_set_scheduler.py" +--- + +## Ejemplo + +```python +from diffusers_load_pipeline import diffusers_load_pipeline +from diffusers_set_scheduler import diffusers_set_scheduler +from model_ref import ModelRef + +model = ModelRef(name="sd-turbo", model_type="sd15", path="/path/to/model") +pipe = diffusers_load_pipeline(model) +pipe = diffusers_set_scheduler(pipe, "euler_a") +# type(pipe.scheduler).__name__ == "EulerAncestralDiscreteScheduler" +``` + +## Mapping de samplers + +| sampler | clase diffusers | kwargs extra | +|--------------|------------------------------------|-------------------------------------------| +| euler | EulerDiscreteScheduler | — | +| euler_a | EulerAncestralDiscreteScheduler | — | +| dpm++2m | DPMSolverMultistepScheduler | algorithm_type="dpmsolver++" | +| dpm++2m_v2 | DPMSolverMultistepScheduler | algorithm_type="dpmsolver++", solver_order=2 | +| heun | HeunDiscreteScheduler | — | +| dpm2 | KDPM2DiscreteScheduler | — | +| lcm | LCMScheduler | — | + +## Notas + +Usa `SchedulerCls.from_config(pipe.scheduler.config, **extra_kwargs)` para +heredar `beta_start`, `beta_end`, `clip_sample`, etc. del modelo base. + +Import lazy de diffusers — ImportError descriptivo si no instalado. diff --git a/python/functions/ml/diffusers_set_scheduler.py b/python/functions/ml/diffusers_set_scheduler.py new file mode 100644 index 00000000..0533dd01 --- /dev/null +++ b/python/functions/ml/diffusers_set_scheduler.py @@ -0,0 +1,70 @@ +"""diffusers_set_scheduler — cambia el scheduler de un pipeline diffusers.""" + +from __future__ import annotations + +from typing import Any + + +# Mapping canónico sampler -> (scheduler_class_name, kwargs_extra) +_SCHEDULER_MAP: dict[str, tuple[str, dict]] = { + "euler": ("EulerDiscreteScheduler", {}), + "euler_a": ("EulerAncestralDiscreteScheduler", {}), + "dpm++2m": ("DPMSolverMultistepScheduler", {"algorithm_type": "dpmsolver++"}), + "dpm++2m_v2": ("DPMSolverMultistepScheduler", {"algorithm_type": "dpmsolver++", "solver_order": 2}), + "heun": ("HeunDiscreteScheduler", {}), + "dpm2": ("KDPM2DiscreteScheduler", {}), + "lcm": ("LCMScheduler", {}), +} + + +def diffusers_set_scheduler(pipe: Any, sampler: str) -> Any: + """Reemplaza el scheduler de un pipeline diffusers por el correspondiente al sampler. + + Usa .from_config(pipe.scheduler.config) para heredar la + configuracion base del modelo (betas, clip_sample, etc.) y aplica encima + los kwargs especificos del sampler. Modifica pipe.scheduler in-place y + retorna el mismo pipe para composicion. + + Args: + pipe: Pipeline diffusers cargado (StableDiffusionPipeline, + StableDiffusionXLPipeline, etc.). Debe tener atributo + pipe.scheduler con .config. + sampler: Nombre del sampler. Valores validos: euler, euler_a, + dpm++2m, dpm++2m_v2, heun, dpm2, lcm. + + Returns: + El mismo pipe con pipe.scheduler reemplazado por la clase + correspondiente al sampler solicitado. + + Raises: + ImportError: Si diffusers no esta instalado. + ValueError: Si el sampler no esta en el mapping soportado. + """ + try: + import diffusers + except ImportError as exc: + raise ImportError( + "diffusers_set_scheduler requiere diffusers. " + "Instalar con: pip install diffusers" + ) from exc + + if sampler not in _SCHEDULER_MAP: + supported = ", ".join(sorted(_SCHEDULER_MAP.keys())) + raise ValueError( + f"Sampler '{sampler}' no soportado. Valores validos: {supported}" + ) + + class_name, extra_kwargs = _SCHEDULER_MAP[sampler] + scheduler_cls = getattr(diffusers, class_name, None) + + if scheduler_cls is None: + raise ImportError( + f"La clase '{class_name}' no esta disponible en la version de diffusers " + f"instalada. Actualizar diffusers para usar el sampler '{sampler}'." + ) + + pipe.scheduler = scheduler_cls.from_config( + pipe.scheduler.config, + **extra_kwargs, + ) + return pipe diff --git a/python/functions/ml/diffusers_unload.md b/python/functions/ml/diffusers_unload.md new file mode 100644 index 00000000..a3fd7f38 --- /dev/null +++ b/python/functions/ml/diffusers_unload.md @@ -0,0 +1,49 @@ +--- +name: diffusers_unload +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def diffusers_unload(pipe: Any | None = None) -> None" +description: "Libera la memoria de un pipeline diffusers. Si pipe=None limpia el cache global de diffusers_load_pipeline. Siempre llama gc.collect() y torch.cuda.empty_cache()." +tags: [diffusers, ml, memory, cleanup, vram, cache] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [torch, gc] +params: + - name: pipe + desc: "Pipeline a liberar con del. Si None, limpia el cache global _PIPELINE_CACHE de diffusers_load_pipeline (descarga todos los pipelines cacheados)." +output: "None. Efecto secundario: del pipe si pasado, cache limpiado si None, gc.collect() y torch.cuda.empty_cache() siempre." +tested: true +tests: + - "unload None limpia cache cuda si disponible" +test_file_path: "python/functions/ml/tests/test_diffusers_backend.py" +file_path: "python/functions/ml/diffusers_unload.py" +--- + +## Ejemplo + +```python +from diffusers_unload import diffusers_unload + +# Liberar un pipeline especifico +diffusers_unload(pipe) + +# Limpiar TODO el cache (descarga todos los modelos en memoria) +diffusers_unload() +``` + +## Notas + +`del pipe` no garantiza liberacion inmediata si hay otras referencias al objeto. +Llamar `diffusers_unload(pipe)` + borrar la referencia local (`pipe = None`) +para asegurar que el GC pueda recolectar. + +`torch.cuda.empty_cache()` solo libera cache del allocator de PyTorch, no +memoria que otros procesos ocupen. Para liberacion total, el proceso debe terminar. + +Import lazy de torch — si no esta instalado, omite empty_cache silenciosamente. diff --git a/python/functions/ml/diffusers_unload.py b/python/functions/ml/diffusers_unload.py new file mode 100644 index 00000000..5a1ac467 --- /dev/null +++ b/python/functions/ml/diffusers_unload.py @@ -0,0 +1,47 @@ +"""diffusers_unload — libera memoria de un pipeline diffusers y limpia cache global.""" + +from __future__ import annotations + +import gc +import sys +import os +from typing import Any + + +def diffusers_unload(pipe: Any | None = None) -> None: + """Libera la memoria ocupada por un pipeline diffusers. + + Si se pasa pipe, lo elimina con del y llama gc.collect() + empty_cache(). + Si pipe es None, limpia ademas el cache global de diffusers_load_pipeline + (descarga TODOS los pipelines cacheados). En ambos casos invoca + torch.cuda.empty_cache() si CUDA esta disponible. + + Args: + pipe: Pipeline a liberar. Si None, limpia el cache global completo + de diffusers_load_pipeline ademas de llamar gc + empty_cache. + + Returns: + None. Efecto secundario: memoria GPU/CPU liberada. + """ + if pipe is None: + # Limpiar cache global de diffusers_load_pipeline si esta importado + try: + # Importar el modulo para acceder a su cache interno + load_module_path = os.path.join(os.path.dirname(__file__)) + if load_module_path not in sys.path: + sys.path.insert(0, load_module_path) + from diffusers_load_pipeline import _clear_pipeline_cache + _clear_pipeline_cache() + except ImportError: + pass + else: + del pipe + + gc.collect() + + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass diff --git a/python/functions/ml/genconfig_load_json.md b/python/functions/ml/genconfig_load_json.md new file mode 100644 index 00000000..3a2ea0cd --- /dev/null +++ b/python/functions/ml/genconfig_load_json.md @@ -0,0 +1,47 @@ +--- +name: genconfig_load_json +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def genconfig_load_json(path: str) -> GenerationConfig" +description: "Carga y valida un GenerationConfig desde un archivo JSON en disco. Usa pydantic model_validate si disponible; fallback a construccion manual desde dataclass. Raises FileNotFoundError si el archivo no existe." +tags: [ml, generation, json, io, deserialization] +params: + - name: path + desc: "Ruta al archivo JSON generado por genconfig_save_json. Relativa o absoluta." +output: "Instancia de GenerationConfig cargada y validada con todos sus campos." +uses_functions: + - genconfig_save_json_py_ml +uses_types: + - generation_config_py_ml +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +tested: true +tests: + - "save escribe archivo JSON valido en la ruta indicada" + - "save crea directorios padre si no existen" + - "json contiene claves en snake_case" +test_file_path: "python/functions/ml/tests/test_genconfig_json_roundtrip.py" +file_path: "python/functions/ml/genconfig_load_json.py" +--- + +## Ejemplo + +```python +from ml.genconfig_load_json import genconfig_load_json + +cfg = genconfig_load_json("/tmp/gen_config.json") +# cfg.prompt == "a forest at dusk" +# cfg.seed == 123 +``` + +## Notas + +Usa pydantic model_validate cuando disponible (valida literales de model_type, +quantization y sampler). Sin pydantic, construye la instancia directamente +sin validar literales. Para el contrato Go-Python es importante que los nombres +de clave sean snake_case (garantizado por pydantic model_dump_json). diff --git a/python/functions/ml/genconfig_load_json.py b/python/functions/ml/genconfig_load_json.py new file mode 100644 index 00000000..312d8a9a --- /dev/null +++ b/python/functions/ml/genconfig_load_json.py @@ -0,0 +1,77 @@ +"""genconfig_load_json — carga un GenerationConfig desde un archivo JSON.""" + +from __future__ import annotations + +import json +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +from generation_config import GenerationConfig + + +def genconfig_load_json(path: str) -> GenerationConfig: + """Carga y valida un GenerationConfig desde un archivo JSON en disco. + + Usa GenerationConfig.model_validate(data) si pydantic esta disponible + (version con validacion completa de tipos y literales). En caso de + fallback a dataclass, construye la instancia manualmente mapeando + los campos conocidos. + + Args: + path: Ruta al archivo JSON. Puede ser relativa o absoluta. + + Returns: + Instancia de GenerationConfig cargada y validada. + + Raises: + FileNotFoundError: Si el archivo no existe. + json.JSONDecodeError: Si el contenido no es JSON valido. + pydantic.ValidationError: Si los datos no cumplen el schema (version pydantic). + KeyError / TypeError: Si faltan campos obligatorios (version dataclass). + """ + abs_path = os.path.abspath(path) + with open(abs_path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Intentar deserializacion pydantic (version canonica con validacion) + try: + return GenerationConfig.model_validate(data) + except AttributeError: + pass + + # Fallback: dataclass — construir manualmente desde el dict + from lora_ref import LoraRef + from model_ref import ModelRef + + model_data = data["model"] + model = ModelRef( + name=model_data["name"], + model_type=model_data["model_type"], + quantization=model_data.get("quantization", "fp16"), + path=model_data.get("path"), + ) + + loras = [ + LoraRef( + path=lr["path"], + weight=lr.get("weight", 1.0), + scale=lr.get("scale"), + ) + for lr in data.get("loras", []) + ] + + return GenerationConfig( + prompt=data["prompt"], + negative_prompt=data.get("negative_prompt"), + seed=data["seed"], + steps=data["steps"], + cfg_scale=data["cfg_scale"], + sampler=data["sampler"], + width=data["width"], + height=data["height"], + model=model, + loras=tuple(loras), + clip_skip=data.get("clip_skip"), + ) diff --git a/python/functions/ml/genconfig_save_json.md b/python/functions/ml/genconfig_save_json.md new file mode 100644 index 00000000..ce0550dc --- /dev/null +++ b/python/functions/ml/genconfig_save_json.md @@ -0,0 +1,59 @@ +--- +name: genconfig_save_json +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def genconfig_save_json(cfg: GenerationConfig, path: str) -> str" +description: "Serializa un GenerationConfig a JSON (pydantic model_dump_json o dataclass fallback) y lo escribe en disco. Crea directorios padre si no existen. Retorna el path absoluto del archivo escrito." +tags: [ml, generation, json, io, serialization] +params: + - name: cfg + desc: "Instancia de GenerationConfig a serializar. Pydantic o dataclass." + - name: path + desc: "Ruta de destino del archivo JSON. Relativa o absoluta." +output: "Path absoluto (str) del archivo JSON escrito en disco." +uses_functions: [] +uses_types: + - generation_config_py_ml +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +tested: true +tests: + - "save escribe archivo JSON valido en la ruta indicada" + - "save crea directorios padre si no existen" + - "json contiene claves en snake_case" +test_file_path: "python/functions/ml/tests/test_genconfig_json_roundtrip.py" +file_path: "python/functions/ml/genconfig_save_json.py" +--- + +## Ejemplo + +```python +from ml.genconfig_save_json import genconfig_save_json +from ml.generation_config import GenerationConfig +from ml.model_ref import ModelRef + +cfg = GenerationConfig( + prompt="a forest at dusk", + seed=123, + steps=25, + cfg_scale=7.5, + sampler="euler", + width=512, + height=512, + model=ModelRef(name="runwayml/stable-diffusion-v1-5", model_type="sd15"), +) + +saved_path = genconfig_save_json(cfg, "/tmp/gen_config.json") +# saved_path == "/tmp/gen_config.json" +``` + +## Notas + +Usa pydantic model_dump_json cuando disponible (JSON canonico con snake_case +interoperable con Go). En entornos sin pydantic usa json.dumps + dataclasses.asdict. +Los directorios padre se crean con os.makedirs(exist_ok=True). diff --git a/python/functions/ml/genconfig_save_json.py b/python/functions/ml/genconfig_save_json.py new file mode 100644 index 00000000..ed2030fc --- /dev/null +++ b/python/functions/ml/genconfig_save_json.py @@ -0,0 +1,58 @@ +"""genconfig_save_json — persiste un GenerationConfig como JSON en disco.""" + +from __future__ import annotations + +import json +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +from generation_config import GenerationConfig + + +def genconfig_save_json(cfg: GenerationConfig, path: str) -> str: + """Serializa un GenerationConfig a JSON y lo escribe en disco. + + Usa model_dump_json(indent=2) si GenerationConfig es instancia de + pydantic.BaseModel (version con validacion). En caso de fallback a + dataclass, serializa con json.dumps usando un encoder que convierte + dataclasses a dict recursivamente. + + Crea los directorios padre si no existen (equivalente a mkdir -p). + + Args: + cfg: Instancia de GenerationConfig a serializar. + path: Ruta de destino del archivo JSON. Puede ser relativa o absoluta. + + Returns: + Path absoluto del archivo escrito. + + Raises: + OSError: Si no se puede crear el directorio o escribir el archivo. + """ + abs_path = os.path.abspath(path) + parent = os.path.dirname(abs_path) + if parent: + os.makedirs(parent, exist_ok=True) + + # Intentar serializacion pydantic (version canonica) + try: + json_str = cfg.model_dump_json(indent=2) + except AttributeError: + # Fallback: dataclass — serializar manualmente + import dataclasses + + def _to_dict(obj: object) -> object: + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + return {k: _to_dict(v) for k, v in dataclasses.asdict(obj).items()} + if isinstance(obj, (list, tuple)): + return [_to_dict(i) for i in obj] + return obj + + json_str = json.dumps(_to_dict(cfg), indent=2) + + with open(abs_path, "w", encoding="utf-8") as f: + f.write(json_str) + + return abs_path diff --git a/python/functions/ml/genconfig_to_diffusers_kwargs.md b/python/functions/ml/genconfig_to_diffusers_kwargs.md new file mode 100644 index 00000000..57d97bea --- /dev/null +++ b/python/functions/ml/genconfig_to_diffusers_kwargs.md @@ -0,0 +1,65 @@ +--- +name: genconfig_to_diffusers_kwargs +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: pure +signature: "def genconfig_to_diffusers_kwargs(cfg: GenerationConfig) -> dict" +description: "Convierte un GenerationConfig al dict de kwargs listo para pipe(**kwargs) de diffusers. Mapea prompt, steps, cfg_scale, width, height. LoRAs y sampler se aplican antes de la llamada; generator=None para que el caller setee torch.Generator por separado." +tags: [ml, diffusers, generation, converter, pure] +params: + - name: cfg + desc: "Instancia de GenerationConfig con los parametros de generacion validados." +output: "dict con claves prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, generator (None). Listo para desempaquetar con pipe(**kwargs)." +uses_functions: [] +uses_types: + - generation_config_py_ml +returns: [] +returns_optional: false +error_type: "" +imports: [] +tested: true +tests: + - "kwargs contiene todas las claves requeridas" + - "negative_prompt None se pasa tal cual" + - "steps y cfg_scale se mapean a num_inference_steps y guidance_scale" + - "generator siempre es None" +test_file_path: "python/functions/ml/tests/test_genconfig_to_diffusers_kwargs.py" +file_path: "python/functions/ml/genconfig_to_diffusers_kwargs.py" +--- + +## Ejemplo + +```python +from ml.genconfig_to_diffusers_kwargs import genconfig_to_diffusers_kwargs +from ml.generation_config import GenerationConfig +from ml.model_ref import ModelRef + +cfg = GenerationConfig( + prompt="a dog in the park", + seed=42, + steps=30, + cfg_scale=7.5, + sampler="euler_a", + width=512, + height=512, + model=ModelRef(name="runwayml/stable-diffusion-v1-5", model_type="sd15"), +) + +kwargs = genconfig_to_diffusers_kwargs(cfg) +# kwargs["num_inference_steps"] == 30 +# kwargs["guidance_scale"] == 7.5 +# kwargs["generator"] is None + +# El caller asigna el generator: +# kwargs["generator"] = torch.Generator(device=device).manual_seed(cfg.seed) +# image = pipe(**kwargs).images[0] +``` + +## Notas + +Funcion pura: sin I/O, sin torch, sin imports opcionales en tiempo de ejecucion. +Los LoRAs se aplican via `pipe.load_lora_weights(lora.path, adapter_name=...)` antes +de la llamada. El scheduler/sampler se configura via `pipe.scheduler = ...` tambien +antes. Ambos no tienen mapping directo a kwargs de `__call__`. diff --git a/python/functions/ml/genconfig_to_diffusers_kwargs.py b/python/functions/ml/genconfig_to_diffusers_kwargs.py new file mode 100644 index 00000000..059af3da --- /dev/null +++ b/python/functions/ml/genconfig_to_diffusers_kwargs.py @@ -0,0 +1,41 @@ +"""genconfig_to_diffusers_kwargs — convierte GenerationConfig a kwargs para diffusers pipe().""" + +from __future__ import annotations + +import sys +import os + +sys.path.insert(0, os.path.dirname(__file__)) + +from generation_config import GenerationConfig + + +def genconfig_to_diffusers_kwargs(cfg: GenerationConfig) -> dict: + """Convierte un GenerationConfig al dict de kwargs listo para pipe(**kwargs) de diffusers. + + Solo mapea los campos que diffusers StableDiffusionPipeline.__call__ acepta + directamente. Los LoRAs y el sampler/scheduler se configuran antes de la + llamada via load_lora_weights() y pipe.scheduler = ...; no tienen mapping + 1:1 con kwargs de __call__. + + El campo "generator" se devuelve como None; el caller debe asignar + torch.Generator(device=device).manual_seed(cfg.seed) por separado para + poder reutilizar el GenerationConfig en distintos devices sin importar torch + aqui (funcion pura). + + Args: + cfg: Parametros de generacion validados. Debe ser instancia de GenerationConfig. + + Returns: + dict con claves: prompt, negative_prompt, num_inference_steps, + guidance_scale, width, height, generator (None). + """ + return { + "prompt": cfg.prompt, + "negative_prompt": cfg.negative_prompt, + "num_inference_steps": cfg.steps, + "guidance_scale": cfg.cfg_scale, + "width": cfg.width, + "height": cfg.height, + "generator": None, + } diff --git a/python/functions/ml/genconfig_to_sdcpp_args.md b/python/functions/ml/genconfig_to_sdcpp_args.md new file mode 100644 index 00000000..5b35b46b --- /dev/null +++ b/python/functions/ml/genconfig_to_sdcpp_args.md @@ -0,0 +1,74 @@ +--- +name: genconfig_to_sdcpp_args +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: pure +signature: "def genconfig_to_sdcpp_args(cfg: GenerationConfig) -> list[str]" +description: "Convierte un GenerationConfig a lista de args CLI para stable-diffusion.cpp (sd-cli). Mapea sampler via _SAMPLER_MAP, aplana LoRAs como pares --lora path:weight. Sin I/O ni dependencias externas." +tags: [ml, sdcpp, stable-diffusion-cpp, cli, converter, pure] +params: + - name: cfg + desc: "Instancia de GenerationConfig con los parametros de generacion validados." +output: "Lista de strings con los argumentos CLI en orden. Listo para subprocess.run(['sd'] + args, ...)." +uses_functions: [] +uses_types: + - generation_config_py_ml +returns: [] +returns_optional: false +error_type: "" +imports: [] +tested: true +tests: + - "sampler euler_a se mapea a euler_a en el flag --sampling-method" + - "sampler dpm++2m se mapea a dpmpp2m" + - "lora con path y weight se agrega como --lora path:weight" + - "multiples loras generan multiples pares --lora" + - "negative_prompt None produce string vacio en --negative-prompt" + - "model.path tiene prioridad sobre model.name en -m" + - "args contiene --prompt --seed --steps --cfg-scale --sampling-method -W -H -m" +test_file_path: "python/functions/ml/tests/test_genconfig_to_sdcpp_args.py" +file_path: "python/functions/ml/genconfig_to_sdcpp_args.py" +--- + +## Ejemplo + +```python +from ml.genconfig_to_sdcpp_args import genconfig_to_sdcpp_args +from ml.generation_config import GenerationConfig +from ml.model_ref import ModelRef +from ml.lora_ref import LoraRef + +cfg = GenerationConfig( + prompt="a cat", + seed=1, + steps=20, + cfg_scale=7.0, + sampler="dpm++2m", + width=512, + height=512, + model=ModelRef(name="v1-5", model_type="sd15", path="/models/v1-5.ckpt"), + loras=[LoraRef(path="/loras/detail.safetensors", weight=0.8)], +) + +args = genconfig_to_sdcpp_args(cfg) +# ["--prompt", "a cat", "--negative-prompt", "", "--seed", "1", +# "--steps", "20", "--cfg-scale", "7.0", "--sampling-method", "dpmpp2m", +# "-W", "512", "-H", "512", "-m", "/models/v1-5.ckpt", +# "--lora", "/loras/detail.safetensors:0.8"] +``` + +## Notas + +Mapa de samplers (_SAMPLER_MAP): +- euler → euler +- euler_a → euler_a +- dpm++2m → dpmpp2m +- dpm++2m_v2 → dpmpp2mv2 +- heun → heun +- dpm2 → dpm2 +- lcm → lcm + +Si cfg.model.path es None, se usa cfg.model.name (nombre de hub o path relativo +segun configuracion del entorno sdcpp). Los LoRAs sin path se omiten silenciosamente. diff --git a/python/functions/ml/genconfig_to_sdcpp_args.py b/python/functions/ml/genconfig_to_sdcpp_args.py new file mode 100644 index 00000000..670b9b39 --- /dev/null +++ b/python/functions/ml/genconfig_to_sdcpp_args.py @@ -0,0 +1,65 @@ +"""genconfig_to_sdcpp_args — convierte GenerationConfig a args CLI para stable-diffusion.cpp.""" + +from __future__ import annotations + +import sys +import os + +sys.path.insert(0, os.path.dirname(__file__)) + +from generation_config import GenerationConfig + +# Mapa de SamplerName (dominio ml) a flags de sd-cli (stable-diffusion.cpp). +# Referencia: https://github.com/leejet/stable-diffusion.cpp#usage +_SAMPLER_MAP: dict[str, str] = { + "euler": "euler", + "euler_a": "euler_a", + "dpm++2m": "dpmpp2m", + "dpm++2m_v2": "dpmpp2mv2", + "heun": "heun", + "dpm2": "dpm2", + "lcm": "lcm", +} + + +def genconfig_to_sdcpp_args(cfg: GenerationConfig) -> list[str]: + """Convierte un GenerationConfig a la lista de args CLI para stable-diffusion.cpp. + + Genera los argumentos necesarios para invocar `sd` (sd-cli) de + stable-diffusion.cpp. El mapa _SAMPLER_MAP traduce los SamplerName + canonicos del dominio ml a los identificadores de sdcpp. + + Los LoRAs se pasan como repeticiones del par --lora "path:weight". + Si un LoRA no tiene path definido, se omite silenciosamente. + + El modelo se resuelve priorizando cfg.model.path; si es None usa cfg.model.name + (puede ser un nombre de hub o un path relativo segun la configuracion de sdcpp). + + Args: + cfg: Parametros de generacion validados. Debe ser instancia de GenerationConfig. + + Returns: + Lista de strings con los argumentos CLI en el orden esperado por sd-cli. + Listo para: subprocess.run(["sd"] + args, ...) o similar. + """ + model_path = cfg.model.path if cfg.model.path else cfg.model.name + sampler_flag = _SAMPLER_MAP.get(cfg.sampler, cfg.sampler) + + args: list[str] = [ + "--prompt", cfg.prompt, + "--negative-prompt", cfg.negative_prompt or "", + "--seed", str(cfg.seed), + "--steps", str(cfg.steps), + "--cfg-scale", str(cfg.cfg_scale), + "--sampling-method", sampler_flag, + "-W", str(cfg.width), + "-H", str(cfg.height), + "-m", model_path, + ] + + # Aplanar LoRAs: cada uno genera un par --lora "path:weight" + for lora in cfg.loras: + if lora.path: + args += ["--lora", f"{lora.path}:{lora.weight}"] + + return args diff --git a/python/functions/ml/generation_config.py b/python/functions/ml/generation_config.py new file mode 100644 index 00000000..6ab16385 --- /dev/null +++ b/python/functions/ml/generation_config.py @@ -0,0 +1,111 @@ +"""GenerationConfig — contrato de parametros para generacion de imagenes con difusion.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + pass + +_SAMPLER_VALUES = ( + "euler", + "euler_a", + "dpm++2m", + "dpm++2m_v2", + "heun", + "dpm2", + "lcm", +) + +SamplerName = Literal[ + "euler", + "euler_a", + "dpm++2m", + "dpm++2m_v2", + "heun", + "dpm2", + "lcm", +] + +try: + from pydantic import BaseModel, ConfigDict + + from lora_ref import LoraRef + from model_ref import ModelRef + + class GenerationConfig(BaseModel): + """Contrato de parametros para generacion de imagenes con modelos de difusion. + + Tipo producto central del dominio ml. Usado como contrato compartido entre + Python (diffusers, sd.cpp wrapper) y Go (orquestador). Serializa a JSON + canonico via model_dump_json() para intercambio entre servicios. + + Attributes: + prompt: Descripcion textual positiva de la imagen a generar. + negative_prompt: Descripcion de lo que se quiere evitar. None omite + el condicionamiento negativo (requiere soporte del modelo). + seed: Semilla para reproducibilidad. -1 usa semilla aleatoria. + steps: Numero de pasos de denoising. Rango tipico: 20-50. + LCM: 4-8 pasos. Valores altos aumentan calidad y tiempo. + cfg_scale: Classifier-Free Guidance scale. Controla cuanto el modelo + sigue el prompt. Rango tipico: 5.0-12.0. + 7.5 es el valor clasico. LCM: 1.0-2.0. + sampler: Algoritmo de denoising. Ver SamplerName para valores validos. + width: Ancho de la imagen en pixeles. Debe ser multiplo de 8. + SD1.5: 512. SDXL: 1024. Flux: 1024+. + height: Alto de la imagen en pixeles. Mismas restricciones que width. + model: Referencia al modelo base. Ver ModelRef. + loras: Lista de adaptadores LoRA a aplicar. Lista vacia = sin LoRA. + clip_skip: Numero de capas CLIP a saltar desde el final del encoder. + None usa el valor por defecto del modelo. Tipico: 1-2 para anime. + """ + + model_config = ConfigDict(frozen=True) + + prompt: str + negative_prompt: str | None = None + seed: int + steps: int + cfg_scale: float + sampler: SamplerName + width: int + height: int + model: ModelRef + loras: list[LoraRef] = [] + clip_skip: int | None = None + +except ImportError: + from dataclasses import dataclass, field + + @dataclass(frozen=True) + class GenerationConfig: # type: ignore[no-redef] + """Contrato de parametros para generacion de imagenes (fallback dataclass). + + Usar la version pydantic cuando este disponible para validacion y + serializacion JSON canonica compartida con Go. + + Attributes: + prompt: Descripcion textual positiva de la imagen. + negative_prompt: Descripcion de lo que evitar. None = sin condicionamiento negativo. + seed: Semilla. -1 = aleatoria. + steps: Pasos de denoising (20-50 tipico, 4-8 para LCM). + cfg_scale: CFG scale (5.0-12.0 tipico, 1.0-2.0 para LCM). + sampler: Algoritmo de denoising (ver SamplerName). + width: Ancho en pixeles, multiplo de 8. + height: Alto en pixeles, multiplo de 8. + model: Referencia al modelo base (ModelRef). + loras: Lista de LoRAs a aplicar (LoraRef[]). + clip_skip: Capas CLIP a saltar desde el final. None = default del modelo. + """ + + prompt: str + seed: int + steps: int + cfg_scale: float + sampler: str + width: int + height: int + model: object # ModelRef + negative_prompt: str | None = None + loras: tuple = field(default_factory=tuple) + clip_skip: int | None = None diff --git a/python/functions/ml/gpu_info.md b/python/functions/ml/gpu_info.md new file mode 100644 index 00000000..685cb74a --- /dev/null +++ b/python/functions/ml/gpu_info.md @@ -0,0 +1,58 @@ +--- +name: gpu_info +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def gpu_info() -> list[dict]" +description: "Consulta nvidia-smi para obtener informacion de cada GPU NVIDIA: nombre, VRAM total y libre, version de driver y CUDA. Devuelve lista vacia si nvidia-smi no esta disponible, sin lanzar excepcion." +tags: [gpu, nvidia, cuda, vram, hardware, probe, ml, nvidia-smi] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: [] +output: "lista de dicts por GPU con claves: index (int), name (str), vram_total_mb (int), vram_free_mb (int), driver_version (str), cuda_version (str). Lista vacia si nvidia-smi no esta disponible." +tested: true +tests: + - "sin nvidia-smi devuelve lista vacia" + - "formato CSV correcto devuelve lista con un dict por GPU" + - "fila malformada en CSV se ignora sin excepcion" +test_file_path: "python/functions/ml/tests/test_gpu_info.py" +file_path: "python/functions/ml/gpu_info.py" +--- + +## Ejemplo + +```python +from ml.gpu_info import gpu_info + +gpus = gpu_info() +# Sin nvidia-smi: [] + +# Con una GPU: +# [ +# { +# "index": 0, +# "name": "NVIDIA GeForce RTX 4090", +# "vram_total_mb": 24564, +# "vram_free_mb": 22000, +# "driver_version": "535.183.01", +# "cuda_version": "8.9" +# } +# ] + +for gpu in gpus: + pct = 100 * (1 - gpu["vram_free_mb"] / gpu["vram_total_mb"]) + print(f"GPU {gpu['index']}: {gpu['name']} — VRAM {pct:.1f}% usada") +``` + +## Notas + +- Usa `--query-gpu=compute_cap` como aproximacion de la version CUDA soportada. El campo `cuda_version` del output es la compute capability (ej. "8.9"), no la version CUDA del driver. +- Robusto a `FileNotFoundError` (nvidia-smi no instalado), `TimeoutExpired` (driver colgado), y `OSError`. +- Para datos de torch (no nvidia-smi), usar `cuda_available`. +- impure: consulta hardware y estado del sistema en tiempo de ejecucion. diff --git a/python/functions/ml/gpu_info.py b/python/functions/ml/gpu_info.py new file mode 100644 index 00000000..13ba4754 --- /dev/null +++ b/python/functions/ml/gpu_info.py @@ -0,0 +1,73 @@ +"""Consulta informacion de GPUs NVIDIA via nvidia-smi.""" + +from __future__ import annotations + +import csv +import subprocess + + +def gpu_info() -> list[dict]: + """Devuelve informacion de todas las GPUs NVIDIA detectadas por nvidia-smi. + + Consulta nvidia-smi via subprocess. Si nvidia-smi no esta disponible o + falla, devuelve lista vacia sin lanzar excepcion. + + Returns: + Lista de dicts, uno por GPU, con claves: + index (int): indice de la GPU (0, 1, ...). + name (str): nombre del modelo (ej. "NVIDIA GeForce RTX 4090"). + vram_total_mb (int): memoria total en MB. + vram_free_mb (int): memoria libre en MB. + driver_version (str): version del driver NVIDIA. + cuda_version (str): version maxima de CUDA soportada por el driver. + Lista vacia si nvidia-smi no esta disponible. + """ + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=index,name,memory.total,memory.free,driver_version,compute_cap", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + except FileNotFoundError: + return [] + except subprocess.TimeoutExpired: + return [] + except OSError: + return [] + + if result.returncode != 0: + return [] + + gpus = [] + reader = csv.reader(result.stdout.strip().splitlines()) + for row in reader: + if len(row) < 5: + continue + try: + index = int(row[0].strip()) + name = row[1].strip() + vram_total_mb = int(row[2].strip()) + vram_free_mb = int(row[3].strip()) + driver_version = row[4].strip() + # compute_cap (ej. "8.9") como aproximacion de cuda_version soportada + cuda_version = row[5].strip() if len(row) > 5 else "" + except (ValueError, IndexError): + continue + + gpus.append( + { + "index": index, + "name": name, + "vram_total_mb": vram_total_mb, + "vram_free_mb": vram_free_mb, + "driver_version": driver_version, + "cuda_version": cuda_version, + } + ) + + return gpus diff --git a/python/functions/ml/hf_snapshot_download.md b/python/functions/ml/hf_snapshot_download.md new file mode 100644 index 00000000..4f494a51 --- /dev/null +++ b/python/functions/ml/hf_snapshot_download.md @@ -0,0 +1,82 @@ +--- +name: hf_snapshot_download +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def hf_snapshot_download(repo_id: str, allow_patterns: list[str] | None = None, ignore_patterns: list[str] | None = None, local_dir: str | None = None, token: str | None = None) -> str" +description: "Descarga un snapshot de un repo HuggingFace Hub (completo o filtrado por patrones glob). Wrapper de huggingface_hub.snapshot_download con ImportError descriptivo. Soporta repos privados/gated via token. Retorna path local del snapshot." +tags: [huggingface, hf, download, snapshot, model, weights, safetensors, ml, hub] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [huggingface_hub] +params: + - name: repo_id + desc: "identificador del repo en HuggingFace Hub en formato 'owner/name' (ej: 'runwayml/stable-diffusion-v1-5')" + - name: allow_patterns + desc: "lista opcional de patrones glob para incluir solo ciertos archivos (ej: ['*.safetensors', 'config.json']). None descarga todo." + - name: ignore_patterns + desc: "lista opcional de patrones glob para excluir archivos (ej: ['*.bin', 'flax_*', 'tf_*']). Util para descargar solo safetensors y evitar duplicados en otro formato." + - name: local_dir + desc: "directorio local de destino. Si None, usa el cache global de HuggingFace (~/.cache/huggingface/hub/)." + - name: token + desc: "token de acceso HuggingFace para repos privados o gated (Llama, Gemma, etc.). Si None, usa la variable de entorno HF_TOKEN." +output: "string: path absoluto al directorio local donde quedo almacenado el snapshot" +tested: true +tests: + - "repo_id se pasa correctamente a snapshot_download" + - "retorna string (la ruta local)" + - "allow_patterns se incluye en los kwargs si se especifica" + - "ignore_patterns se incluye en los kwargs si se especifica" + - "local_dir se incluye en los kwargs si se especifica" + - "token se incluye en los kwargs si se especifica" + - "args opcionales None no se incluyen en kwargs" + - "ImportError descriptivo si huggingface_hub no esta instalado" +test_file_path: "python/functions/ml/tests/test_hf_snapshot_download.py" +file_path: "python/functions/ml/hf_snapshot_download.py" +--- + +## Ejemplo + +```python +from ml.hf_snapshot_download import hf_snapshot_download + +# Descargar solo safetensors y JSONs de SD v1.5 (evita el .bin de 4 GB) +path = hf_snapshot_download( + repo_id="runwayml/stable-diffusion-v1-5", + allow_patterns=["*.safetensors", "*.json", "*.txt"], + ignore_patterns=["*.bin"], + local_dir=".local/models/sd-v1-5", +) +# path = "/home/lucas/fn_registry/.local/models/sd-v1-5" + +# Descargar un modelo gated (Llama) con token +path = hf_snapshot_download( + repo_id="meta-llama/Llama-2-7b-hf", + ignore_patterns=["*.bin"], + local_dir=".local/models/llama-2-7b", + token="hf_xxxxxxxxxxxxxxxxxxxxxxxx", +) + +# Descargar al cache global (sin local_dir) +path = hf_snapshot_download("BAAI/bge-m3") +# path = "/home/lucas/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/..." +``` + +## Notas + +- El wrapper es minimo: no reimplementa logica de descarga, solo asegura que + `huggingface_hub` no sea requerido en tiempo de indexacion del registry. +- `snapshot_download` es idempotente: si el snapshot ya existe en el cache/local_dir + con los mismos hashes, no vuelve a descargar. +- `allow_patterns` y `ignore_patterns` usan la semantica de `fnmatch`. + Tienen precedencia: si un archivo coincide con ambos, `ignore_patterns` gana. +- Para repos grandes (>10 GB), conviene usar `ignore_patterns=["*.bin"]` si el + repo ofrece safetensors (formato mas seguro, sin pickle, y soporta mmap). +- El token puede ponerse tambien en `~/.cache/huggingface/token` via + `huggingface-cli login` para no pasarlo inline. +- impure: hace I/O de red, escribe en disco, depende de disponibilidad del Hub. diff --git a/python/functions/ml/hf_snapshot_download.py b/python/functions/ml/hf_snapshot_download.py new file mode 100644 index 00000000..927e2e52 --- /dev/null +++ b/python/functions/ml/hf_snapshot_download.py @@ -0,0 +1,59 @@ +"""Wrapper de huggingface_hub.snapshot_download con manejo de ImportError descriptivo.""" + +from __future__ import annotations + + +def hf_snapshot_download( + repo_id: str, + allow_patterns: list[str] | None = None, + ignore_patterns: list[str] | None = None, + local_dir: str | None = None, + token: str | None = None, +) -> str: + """Descarga un snapshot completo (o filtrado) de un repo de HuggingFace Hub. + + Wrapper sobre `huggingface_hub.snapshot_download` con manejo de ImportError + descriptivo. Si `local_dir` se especifica, el snapshot se descarga alli en + lugar del cache global de HuggingFace (~/.cache/huggingface/). + + Args: + repo_id: identificador del repositorio en HuggingFace Hub + (ej: "runwayml/stable-diffusion-v1-5", "meta-llama/Llama-2-7b-hf"). + allow_patterns: lista opcional de patrones glob para incluir solo ciertos + archivos (ej: ["*.safetensors", "*.json"]). + Si None, se descargan todos los archivos. + ignore_patterns: lista opcional de patrones glob para excluir archivos + (ej: ["*.bin", "flax_*"]). Util para evitar descargar + pesos en formato pytorch si ya se tienen en safetensors. + local_dir: directorio local donde guardar el snapshot. Si None, usa + el cache global de HuggingFace Hub. + token: token de acceso a HuggingFace (para repos privados o con gated + access como Llama). Si None, usa HF_TOKEN del entorno. + + Returns: + Path local (str) donde quedo almacenado el snapshot. + + Raises: + ImportError: si huggingface_hub no esta instalado, con sugerencia de instalacion. + Exception: cualquier error de red o autenticacion propagado desde snapshot_download. + """ + try: + from huggingface_hub import snapshot_download + except ImportError as exc: + raise ImportError( + "huggingface_hub no esta instalado. " + "Instalar con: pip install huggingface_hub" + ) from exc + + kwargs: dict = {"repo_id": repo_id} + if allow_patterns is not None: + kwargs["allow_patterns"] = allow_patterns + if ignore_patterns is not None: + kwargs["ignore_patterns"] = ignore_patterns + if local_dir is not None: + kwargs["local_dir"] = local_dir + if token is not None: + kwargs["token"] = token + + result = snapshot_download(**kwargs) + return str(result) diff --git a/python/functions/ml/image_compare_side_by_side.md b/python/functions/ml/image_compare_side_by_side.md new file mode 100644 index 00000000..1c459cb5 --- /dev/null +++ b/python/functions/ml/image_compare_side_by_side.md @@ -0,0 +1,84 @@ +--- +name: image_compare_side_by_side +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def image_compare_side_by_side(a, b, label_a='A', label_b='B', gap_px=16, show_diff=True, show_phash=True) -> dict" +description: "Compara dos PIL Images lado a lado generando una imagen compuesta A | diff | B con gap configurable. Calcula MSE pixel-wise y perceptual hash (imagehash si disponible). Util para inspeccionar diferencias entre generaciones de imagen." +tags: [image, compare, diff, phash, mse, pil, pillow, visualization, ml, side-by-side] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [Pillow, numpy, imagehash] +tested: true +tests: + - "grid es PIL.Image con dimensiones correctas show_diff=True" + - "grid es PIL.Image sin diff show_diff=False" + - "pixel_mse positivo para imagenes distintas" + - "pixel_mse cero para imagen identica" + - "phash None si imagehash no disponible" +test_file_path: "python/functions/ml/tests/test_image_compare_side_by_side.py" +file_path: "python/functions/ml/image_compare_side_by_side.py" +params: + - name: a + desc: "Primera imagen PIL (referencia). Se convierte a RGB internamente si es RGBA/L/etc." + - name: b + desc: "Segunda imagen PIL (comparacion). Se redimensiona a size de a si difieren." + - name: label_a + desc: "Etiqueta de texto para el panel A (default 'A')." + - name: label_b + desc: "Etiqueta de texto para el panel B (default 'B')." + - name: gap_px + desc: "Espacio en pixeles entre paneles y en los bordes del canvas (default 16)." + - name: show_diff + desc: "Si True (default), inserta panel central con PIL.ImageChops.difference + autocontrast." + - name: show_phash + desc: "Si True (default), calcula perceptual hash con imagehash. Si el paquete no esta instalado, retorna None silenciosamente." +output: "dict con: 'grid' (PIL.Image compuesta), 'phash_a' (str|None), 'phash_b' (str|None), 'phash_distance' (int|None, Hamming), 'pixel_mse' (float)." +--- + +## Ejemplo + +```python +from PIL import Image +from ml.image_compare_side_by_side import image_compare_side_by_side + +img_a = Image.open("outputs/gen_v1.png") +img_b = Image.open("outputs/gen_v2.png") + +result = image_compare_side_by_side(img_a, img_b, label_a="v1", label_b="v2") + +result["grid"].save("compare.png") +print(f"MSE: {result['pixel_mse']:.2f}") +print(f"pHash distance: {result['phash_distance']}") +``` + +## Layout del grid + +Con `show_diff=True` (default): + +``` +[gap] [A] [gap] [diff] [gap] [B] [gap] +``` + +Canvas width = 3*w + 4*gap +Canvas height = h + 2*gap + +Con `show_diff=False`: + +``` +[gap] [A] [gap] [B] [gap] +``` + +Canvas width = 2*w + 3*gap + +## Notas + +- `pixel_mse` usa numpy si disponible; fallback a loop puro stdlib (mas lento). +- `phash_*` requiere `pip install imagehash`. Sin el paquete, los tres campos son `None`. +- Las imagenes se convierten a RGB antes de cualquier operacion para consistencia. +- Si `a` y `b` tienen distinto tamano, `b` se redimensiona con LANCZOS al tamano de `a`. diff --git a/python/functions/ml/image_compare_side_by_side.py b/python/functions/ml/image_compare_side_by_side.py new file mode 100644 index 00000000..5805446d --- /dev/null +++ b/python/functions/ml/image_compare_side_by_side.py @@ -0,0 +1,147 @@ +"""Compara dos PIL Images lado a lado con diff opcional y metricas de similitud.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import PIL.Image + + +def image_compare_side_by_side( + a: "PIL.Image.Image", + b: "PIL.Image.Image", + label_a: str = "A", + label_b: str = "B", + gap_px: int = 16, + show_diff: bool = True, + show_phash: bool = True, +) -> dict: + """Crea una imagen comparativa lado a lado con metricas opcionales. + + Construye una imagen compuesta A | [diff] | B con gap configurable. + Calcula MSE pixel-wise y opcionalmente perceptual hash (imagehash). + + Args: + a: Primera imagen PIL (imagen de referencia). + b: Segunda imagen PIL (imagen a comparar). + label_a: Etiqueta de texto para la imagen A (default "A"). + label_b: Etiqueta de texto para la imagen B (default "B"). + gap_px: Espacio en pixeles entre paneles (default 16). + show_diff: Si True, inserta panel de diferencia autocontrastada entre A y B. + show_phash: Si True, calcula perceptual hash con imagehash si disponible. + + Returns: + dict con: + - "grid": PIL.Image.Image — imagen compuesta lado a lado. + - "phash_a": str | None — 16 hex chars de perceptual hash de A (None si imagehash no disponible). + - "phash_b": str | None — 16 hex chars de perceptual hash de B. + - "phash_distance": int | None — Distancia de Hamming entre phash_a y phash_b. + - "pixel_mse": float — MSE pixel-wise sobre canales RGB. + + Raises: + ImportError: si Pillow no esta instalado. + """ + try: + from PIL import Image, ImageChops, ImageDraw, ImageFont, ImageOps + except ImportError as exc: + raise ImportError( + "Pillow no esta instalado. Instalar con: pip install Pillow" + ) from exc + + # Normalizar a RGB + img_a = a.convert("RGB") + img_b = b.convert("RGB") + + # Asegurar mismo tamano para comparacion (resize b a size de a si difieren) + if img_a.size != img_b.size: + img_b = img_b.resize(img_a.size, Image.LANCZOS) + + w, h = img_a.size + + # --- Construir panels --- + panels = [img_a] + if show_diff: + diff = ImageChops.difference(img_a, img_b) + diff_contrast = ImageOps.autocontrast(diff) + panels.append(diff_contrast) + panels.append(img_b) + + n = len(panels) + canvas_w = n * w + (n + 1) * gap_px + canvas_h = h + 2 * gap_px + canvas = Image.new("RGB", (canvas_w, canvas_h), color=(20, 20, 20)) + + # Pegar panels + labels_map = {0: label_a, n - 1: label_b} + if show_diff: + labels_map[1] = "diff" + + try: + draw = ImageDraw.Draw(canvas) + font = ImageFont.load_default() + except Exception: + draw = None + font = None + + for i, panel in enumerate(panels): + x = gap_px + i * (w + gap_px) + y = gap_px + canvas.paste(panel, (x, y)) + if draw and i in labels_map: + draw.text((x + 4, y + 4), labels_map[i], fill=(255, 255, 255), font=font) + + # --- MSE pixel-wise --- + pixel_mse = _compute_mse(img_a, img_b) + + # --- Perceptual hash --- + phash_a: str | None = None + phash_b: str | None = None + phash_distance: int | None = None + + if show_phash: + try: + import imagehash # type: ignore[import] + h_a = imagehash.phash(img_a) + h_b = imagehash.phash(img_b) + phash_a = str(h_a) + phash_b = str(h_b) + phash_distance = int(h_a - h_b) + except ImportError: + pass # imagehash not installed — leave None + + return { + "grid": canvas, + "phash_a": phash_a, + "phash_b": phash_b, + "phash_distance": phash_distance, + "pixel_mse": pixel_mse, + } + + +def _compute_mse(img_a: "PIL.Image.Image", img_b: "PIL.Image.Image") -> float: + """Calcula MSE pixel-wise sobre canales RGB.""" + try: + import numpy as np + arr_a = np.asarray(img_a, dtype=np.float64) + arr_b = np.asarray(img_b, dtype=np.float64) + return float(np.mean((arr_a - arr_b) ** 2)) + except ImportError: + pass + + # Fallback puro stdlib (lento para imagenes grandes) + pixels_a = list(img_a.getdata()) + pixels_b = list(img_b.getdata()) + n_pixels = len(pixels_a) + if n_pixels == 0: + return 0.0 + + total = 0.0 + for pa, pb in zip(pixels_a, pixels_b): + # Each pixel is (R, G, B) + for ca, cb in zip(pa, pb): + diff = float(ca) - float(cb) + total += diff * diff + + channels = 3 + return total / (n_pixels * channels) diff --git a/python/functions/ml/image_gen_result.py b/python/functions/ml/image_gen_result.py new file mode 100644 index 00000000..26957e7c --- /dev/null +++ b/python/functions/ml/image_gen_result.py @@ -0,0 +1,97 @@ +"""ImageGenResult — resultado de una operacion de generacion de imagen.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +# PIL.Image.Image se importa solo para type-checking estatico (mypy/pyright). +# En runtime NO se importa aqui — el consumidor ya tiene PIL instalado si +# trabaja con imagenes reales. Esto evita ImportError cuando el modulo se +# importa en contextos sin Pillow (ej. el orquestador Go via grpc/json). +if TYPE_CHECKING: + from PIL.Image import Image as PILImage + +try: + from pydantic import BaseModel, ConfigDict, field_validator + + class ImageGenResult(BaseModel): + """Resultado de una operacion de generacion de imagen con modelo de difusion. + + El campo `image` contiene el objeto PIL.Image.Image generado. No es + serializable a JSON — se accede directamente para guardar a disco o + pasar a pipelines de post-proceso. Para serializar el resultado, + usar solo el campo `meta` (que incluye la config usada) y guardar + la imagen por separado. + + Attributes: + image: Imagen generada. Tipo PIL.Image.Image en runtime. + No incluido en model_dump() ni model_dump_json(). + None si la generacion fallo (ver meta["error"]). + meta: Diccionario con metadata de la generacion. Debe incluir: + - "config": GenerationConfig.model_dump() con los params usados. + - "model": nombre del modelo. + - "seed_used": semilla real usada (util cuando seed=-1). + - "sampler": nombre del sampler. + Puede incluir campos adicionales del backend. + duration_ms: Tiempo total de generacion en milisegundos. + vram_peak_mb: Pico de VRAM consumida durante la generacion en MiB. + None si no se pudo medir (CPU inference o backend sin soporte). + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + image: Any # PIL.Image.Image en runtime; Any para evitar dep dura + meta: dict[str, Any] + duration_ms: int + vram_peak_mb: int | None = None + + @field_validator("image", mode="before") + @classmethod + def _validate_image(cls, v: Any) -> Any: + # Aceptar None (generacion fallida) o cualquier objeto imagen. + # No forzamos importar PIL aqui — la validacion real la hace el backend. + return v + + def model_dump(self, **kwargs: Any) -> dict[str, Any]: + """Serializa a dict excluyendo el campo image (no serializable a JSON). + + Returns: + dict con meta, duration_ms y vram_peak_mb. El campo image se omite. + """ + return { + "meta": self.meta, + "duration_ms": self.duration_ms, + "vram_peak_mb": self.vram_peak_mb, + } + + def model_dump_json(self, **kwargs: Any) -> str: + """Serializa a JSON excluyendo el campo image. + + Returns: + String JSON con meta, duration_ms y vram_peak_mb. + """ + import json + + return json.dumps(self.model_dump()) + +except ImportError: + from dataclasses import dataclass + + @dataclass + class ImageGenResult: # type: ignore[no-redef] + """Resultado de generacion de imagen (fallback dataclass). + + El campo `image` es PIL.Image.Image en runtime. No serializable a JSON. + Usar `meta` + guardar imagen por separado para persistencia. + + Attributes: + image: PIL.Image.Image generada. None si fallo. + meta: Metadata: config usada, modelo, seed_used, sampler, etc. + duration_ms: Duracion total de generacion en milisegundos. + vram_peak_mb: Pico de VRAM en MiB. None si no se pudo medir. + """ + + image: Any # PIL.Image.Image + meta: dict + duration_ms: int + vram_peak_mb: int | None = None diff --git a/python/functions/ml/image_generator.py b/python/functions/ml/image_generator.py new file mode 100644 index 00000000..98bc7b64 --- /dev/null +++ b/python/functions/ml/image_generator.py @@ -0,0 +1,46 @@ +"""ImageGenerator — Protocol para backends de generacion de imagenes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from generation_config import GenerationConfig + from image_gen_result import ImageGenResult + + +@runtime_checkable +class ImageGenerator(Protocol): + """Interfaz comun para backends de generacion de imagenes con difusion. + + Cualquier clase que implemente `generate(config) -> ImageGenResult` + satisface este Protocol sin herencia explicita (structural subtyping). + + Backends de ejemplo que satisfacen esta interfaz: + - DiffusersGenerator: usa HuggingFace diffusers + torch. + - SdCppGenerator: wrapper sobre stable-diffusion.cpp via ctypes/subprocess. + - ComfyUIGenerator: cliente HTTP a ComfyUI API. + - MockGenerator: implementacion de prueba sin GPU. + + El Protocol es `runtime_checkable`, por lo que se puede usar con isinstance(): + assert isinstance(my_backend, ImageGenerator) + + Nota: `isinstance()` con Protocol runtime_checkable solo verifica la presencia + del metodo `generate`, no la firma completa. Para verificacion estricta usar mypy. + """ + + def generate(self, config: "GenerationConfig") -> "ImageGenResult": + """Genera una imagen a partir de la configuracion de difusion. + + Args: + config: Parametros de generacion. Ver GenerationConfig. + + Returns: + Resultado con la imagen PIL, metadata de la generacion, + duracion total y pico de VRAM. Ver ImageGenResult. + + Raises: + Exception: El tipo concreto de error depende del backend. + Los backends deben documentar sus excepciones propias. + """ + ... diff --git a/python/functions/ml/image_grid.md b/python/functions/ml/image_grid.md new file mode 100644 index 00000000..b4caa4a4 --- /dev/null +++ b/python/functions/ml/image_grid.md @@ -0,0 +1,77 @@ +--- +name: image_grid +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def image_grid(images: list[PIL.Image.Image], cols: int = 4, labels: list[str] | None = None, gap_px: int = 8, bg_color: tuple = (20,20,20)) -> PIL.Image.Image" +description: "Combina una lista de PIL Images en un grid NxM con gap configurable, fondo oscuro y labels opcionales sobre cada celda. rows se calcula como ceil(n/cols). Retorna una sola PIL.Image RGB." +tags: [image, grid, pil, pillow, visualization, ml, montage, collage] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [Pillow] +params: + - name: images + desc: "lista de PIL.Image.Image a colocar en el grid (todas deben tener el mismo tamano o se usa el maximo)" + - name: cols + desc: "numero de columnas del grid (default 4)" + - name: labels + desc: "lista opcional de strings para etiquetar cada celda en la esquina superior izquierda" + - name: gap_px + desc: "espacio en pixeles entre celdas y en los bordes del canvas (default 8)" + - name: bg_color + desc: "color RGB de fondo del canvas como tupla (R, G, B), default (20,20,20) casi negro" +output: "PIL.Image.Image: imagen RGB con el grid montado. Lista con n imagenes en cols columnas y ceil(n/cols) filas." +tested: true +tests: + - "grid de 4 imagenes 16x16 cols=2 produce ancho/alto correcto" + - "grid de 4 imagenes cols=2 gap_px=8 tiene dimensiones correctas con gap" + - "grid de 1 imagen 1 col" + - "el resultado es una imagen RGB" + - "labels opcionales no lanza excepcion" + - "sin labels funciona correctamente" + - "lista vacia levanta ValueError" +test_file_path: "python/functions/ml/tests/test_image_grid.py" +file_path: "python/functions/ml/image_grid.py" +--- + +## Ejemplo + +```python +from PIL import Image +from ml.image_grid import image_grid +from ml.image_save_png import image_save_png + +# Generar 6 imagenes de prueba con colores distintos +colors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255),(255,0,255)] +imgs = [Image.new("RGB", (256, 256), c) for c in colors] + +grid = image_grid( + imgs, + cols=3, + labels=["rojo", "verde", "azul", "amarillo", "cyan", "magenta"], + gap_px=10, + bg_color=(30, 30, 30), +) +# grid.size == (3*256 + 4*10, 2*256 + 3*10) == (788, 542) + +image_save_png(grid, "outputs/preview_grid.png") +``` + +## Notas + +- Impure: asigna memoria para el canvas y ejecuta `ImageDraw` (efectos en + objetos PIL internos). Aunque no hace I/O de disco, las allocations PIL + y el draw tienen side-effects sobre objetos mutables. +- Las imagenes en modo no-RGB (RGBA, L, P, palette) se convierten a RGB + automaticamente con `.convert("RGB")` antes de pegar. +- Si la lista tiene menos imagenes que `cols * rows`, las celdas sobrantes + quedan en blanco (solo el color de fondo). +- El label usa `ImageFont.load_default()` (fuente bitmap monospace de PIL, + sin dependencias externas). Para fuentes TTF customizadas usar + `ImageFont.truetype(path, size)` externamente y pasar un `font` propio. +- Pillow se importa lazy para no bloquear `fn index`. diff --git a/python/functions/ml/image_grid.py b/python/functions/ml/image_grid.py new file mode 100644 index 00000000..13a4dbd9 --- /dev/null +++ b/python/functions/ml/image_grid.py @@ -0,0 +1,80 @@ +"""Combina una lista de PIL Images en un grid NxM con gap y labels opcionales.""" + +from __future__ import annotations + +import math + + +def image_grid( + images: list["PIL.Image.Image"], + cols: int = 4, + labels: list[str] | None = None, + gap_px: int = 8, + bg_color: tuple = (20, 20, 20), +) -> "PIL.Image.Image": + """Combina una lista de imagenes en un grid NxM. + + Asume que todas las imagenes tienen el mismo tamano (usa el maximo + ancho/alto detectado). Calcula rows = ceil(n / cols) automaticamente. + + Args: + images: lista de PIL.Image.Image a colocar en el grid. + cols: numero de columnas del grid (default 4). + labels: lista opcional de strings. Si se proporciona, se escribe + un label encima de cada celda usando la fuente default de PIL. + gap_px: espacio en pixeles entre celdas y en los bordes (default 8). + bg_color: color de fondo RGB del canvas (default casi negro (20,20,20)). + + Returns: + Una sola PIL.Image en modo RGB con el grid montado. + + Raises: + ImportError: si Pillow no esta instalado. + ValueError: si images esta vacio. + """ + try: + from PIL import Image, ImageDraw, ImageFont + except ImportError as exc: + raise ImportError( + "Pillow no esta instalado. Instalar con: pip install Pillow" + ) from exc + + if not images: + raise ValueError("image_grid: la lista de imagenes no puede estar vacia") + + n = len(images) + rows = math.ceil(n / cols) + + # Tamano de celda: max ancho y alto de todas las imagenes + cell_w = max(img.width for img in images) + cell_h = max(img.height for img in images) + + canvas_w = cols * cell_w + (cols + 1) * gap_px + canvas_h = rows * cell_h + (rows + 1) * gap_px + + canvas = Image.new("RGB", (canvas_w, canvas_h), color=bg_color) + + draw = ImageDraw.Draw(canvas) if labels else None + font = None + if draw: + try: + font = ImageFont.load_default() + except Exception: + font = None + + for idx, img in enumerate(images): + row = idx // cols + col = idx % cols + x = gap_px + col * (cell_w + gap_px) + y = gap_px + row * (cell_h + gap_px) + + # Convertir a RGB si hace falta (RGBA, L, P, etc.) + paste_img = img.convert("RGB") if img.mode != "RGB" else img + canvas.paste(paste_img, (x, y)) + + if draw and labels and idx < len(labels): + label = labels[idx] + # Texto en la esquina superior izquierda de la celda + draw.text((x + 2, y + 2), label, fill=(255, 255, 255), font=font) + + return canvas diff --git a/python/functions/ml/image_save_png.md b/python/functions/ml/image_save_png.md new file mode 100644 index 00000000..493ce247 --- /dev/null +++ b/python/functions/ml/image_save_png.md @@ -0,0 +1,69 @@ +--- +name: image_save_png +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def image_save_png(img: PIL.Image.Image, path: str, metadata: dict | None = None) -> str" +description: "Guarda una PIL Image como PNG con metadata embebida en chunks tEXt (prompt, seed, steps, sampler, model). Crea directorio padre si no existe. Retorna path absoluto escrito." +tags: [image, png, pil, pillow, metadata, save, ml, reproducibility] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [Pillow] +params: + - name: img + desc: "imagen PIL.Image.Image a guardar" + - name: path + desc: "ruta de destino del archivo PNG (absoluta o relativa)" + - name: metadata + desc: "dict opcional de pares clave/valor a embeber en chunks tEXt del PNG para reproducibilidad (prompt, seed, steps, etc.)" +output: "string: ruta absoluta del archivo PNG escrito" +tested: true +tests: + - "crea imagen 8x8, guarda y retorna ruta absoluta" + - "metadata se embebe en chunks tEXt y se puede releer con Image.text" + - "sin metadata el PNG se guarda igualmente" + - "crea directorio padre si no existe" + - "valores numericos en metadata se convierten a str automaticamente" +test_file_path: "python/functions/ml/tests/test_image_save_png.py" +file_path: "python/functions/ml/image_save_png.py" +--- + +## Ejemplo + +```python +from PIL import Image +from ml.image_save_png import image_save_png + +img = Image.new("RGB", (512, 512), color=(128, 64, 200)) +path = image_save_png( + img, + "outputs/gen_001.png", + metadata={ + "prompt": "a cat on a purple sofa", + "seed": 42, + "steps": 20, + "sampler": "euler_a", + "model": "sd-v1-5", + }, +) +# path = "/home/lucas/.../outputs/gen_001.png" +# Los metadatos quedan embebidos en el PNG y son legibles con exiftool o PIL. +``` + +## Notas + +- Usa `PngImagePlugin.PngInfo` para chunks `tEXt` (texto plano, no comprimido). + Para texto largo/comprimido existe `add_itxt`, pero `add_text` es compatible + con la mayoria de lectores (exiftool, A1111, ComfyUI, etc.). +- Los valores del dict se convierten a `str` automaticamente — se puede pasar + int, float o bool sin castear. +- Si `metadata` es `None` o `{}`, el PNG se guarda sin chunks extra (igual que + `img.save(path)`). +- Pillow no esta en los imports por defecto del registry para no bloquear + `fn index`. Se importa lazy dentro de la funcion. +- impure: escribe en disco y crea directorios. diff --git a/python/functions/ml/image_save_png.py b/python/functions/ml/image_save_png.py new file mode 100644 index 00000000..e22cd8e5 --- /dev/null +++ b/python/functions/ml/image_save_png.py @@ -0,0 +1,48 @@ +"""Guarda una PIL Image como PNG con metadata embebida en chunks tEXt.""" + +from __future__ import annotations + +import os + + +def image_save_png(img: "PIL.Image.Image", path: str, metadata: dict | None = None) -> str: + """Guarda una PIL Image como PNG en la ruta indicada. + + Embebe metadata arbitraria en chunks tEXt del PNG (clave/valor string). + Util para registrar prompt, seed, steps, sampler, model dentro del archivo + para reproducibilidad. + + Crea el directorio padre si no existe. + + Args: + img: imagen PIL a guardar. + path: ruta de destino (absoluta o relativa). Debe terminar en .png. + metadata: dict opcional de pares {clave: valor} a embeber en el PNG. + Los valores se convierten a str automaticamente. + + Returns: + Ruta absoluta del archivo PNG escrito. + + Raises: + ImportError: si Pillow no esta instalado. + OSError: si no se puede escribir en la ruta indicada. + """ + try: + from PIL import PngImagePlugin + except ImportError as exc: + raise ImportError( + "Pillow no esta instalado. Instalar con: pip install Pillow" + ) from exc + + abs_path = os.path.abspath(path) + parent = os.path.dirname(abs_path) + if parent: + os.makedirs(parent, exist_ok=True) + + png_info = PngImagePlugin.PngInfo() + if metadata: + for key, value in metadata.items(): + png_info.add_text(str(key), str(value)) + + img.save(abs_path, format="PNG", pnginfo=png_info) + return abs_path diff --git a/python/functions/ml/lora_ref.py b/python/functions/ml/lora_ref.py new file mode 100644 index 00000000..fd03628d --- /dev/null +++ b/python/functions/ml/lora_ref.py @@ -0,0 +1,47 @@ +"""LoraRef — referencia a un adaptador LoRA para generacion de imagenes.""" + +from __future__ import annotations + +try: + from pydantic import BaseModel, ConfigDict + + class LoraRef(BaseModel): + """Referencia a un adaptador LoRA (Low-Rank Adaptation). + + Un LoRA modifica el comportamiento de un modelo base sin cambiar + sus pesos originales. Se aplica multiplicando matrices de rango bajo + durante la inferencia. + + Attributes: + path: Ruta al archivo .safetensors o .bin del adaptador LoRA. + Puede ser absoluta o relativa al directorio de modelos. + weight: Factor de escala global del LoRA. 1.0 aplica el LoRA + con su fuerza original. 0.0 lo desactiva completamente. + Rango tipico: 0.0 a 1.5. + scale: Override del alpha del LoRA (escala de rango). None usa + el alpha del propio archivo. Util para ajuste fino sin + reentrenar. + """ + + model_config = ConfigDict(frozen=True) + + path: str + weight: float = 1.0 + scale: float | None = None + +except ImportError: + from dataclasses import dataclass + + @dataclass(frozen=True) + class LoraRef: # type: ignore[no-redef] + """Referencia a un adaptador LoRA (fallback dataclass). + + Attributes: + path: Ruta al archivo del adaptador LoRA (.safetensors o .bin). + weight: Factor de escala global. Rango tipico 0.0-1.5. Por defecto 1.0. + scale: Override del alpha. None usa el alpha del archivo. + """ + + path: str + weight: float = 1.0 + scale: float | None = None diff --git a/python/functions/ml/model_ref.py b/python/functions/ml/model_ref.py new file mode 100644 index 00000000..242e1806 --- /dev/null +++ b/python/functions/ml/model_ref.py @@ -0,0 +1,67 @@ +"""ModelRef — referencia a un modelo de generacion de imagenes.""" + +from __future__ import annotations + +from typing import Literal + +try: + from pydantic import BaseModel, ConfigDict + + class ModelRef(BaseModel): + """Referencia a un modelo de generacion de imagenes. + + Identifica el modelo por nombre (HuggingFace hub o ruta local), + tipo de arquitectura y cuantizacion. Serializable a JSON canonico + con model_dump() / model_dump_json() para el contrato compartido con Go. + + Attributes: + name: Nombre del modelo en HuggingFace Hub o identificador local. + Ejemplo: "stabilityai/stable-diffusion-xl-base-1.0". + model_type: Arquitectura del modelo. Uno de los literales definidos. + quantization: Precision numerica del checkpoint. Por defecto "fp16". + path: Ruta local al checkpoint si ya fue descargado. None si + se debe descargar del hub. + """ + + model_config = ConfigDict(frozen=True) + + name: str + model_type: Literal[ + "sd15", + "sd20", + "sdxl", + "sd3", + "flux_dev", + "flux_schnell", + "flux_kontext", + "qwen_image", + "chroma", + "z_image", + ] + quantization: Literal[ + "fp32", "fp16", "bf16", "q8_0", "q5_1", "q5_0", "q4_1", "q4_0" + ] = "fp16" + path: str | None = None + +except ImportError: + from dataclasses import dataclass + + @dataclass(frozen=True) + class ModelRef: # type: ignore[no-redef] + """Referencia a un modelo de generacion de imagenes (fallback dataclass). + + Usar la version pydantic cuando este disponible para validacion y + serializacion JSON canonica. Esta version no valida los literales en + tiempo de ejecucion. + + Attributes: + name: Nombre del modelo en HuggingFace Hub o ruta local. + model_type: Arquitectura del modelo (sd15|sd20|sdxl|sd3|flux_dev|...). + quantization: Precision numerica (fp32|fp16|bf16|q8_0|...). Por defecto "fp16". + path: Ruta local al checkpoint. None si no esta descargado. + """ + + name: str + model_type: str + quantization: str = "fp16" + path: str | None = None diff --git a/python/functions/ml/safetensors_inspect.md b/python/functions/ml/safetensors_inspect.md new file mode 100644 index 00000000..cd27f640 --- /dev/null +++ b/python/functions/ml/safetensors_inspect.md @@ -0,0 +1,99 @@ +--- +name: safetensors_inspect +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def safetensors_inspect(path: str) -> dict" +description: "Lee SOLO el header de un archivo .safetensors sin cargar los tensores en RAM. Retorna metadata del modelo, lista de tensores con dtype/shape/offsets, tamano total y conteo. Util para inspeccionar checkpoints de varios GB sin agotarlos en memoria." +tags: [safetensors, model, inspect, header, ml, huggingface, checkpoint, dtype, shape] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: path + desc: "ruta al archivo .safetensors a inspeccionar (absoluta o relativa)" +output: "dict con claves: path (str ruta absoluta), metadata (dict con __metadata__ del header), tensors (list[dict] con name/dtype/shape/data_offsets por tensor), total_size_bytes (int), n_tensors (int)" +tested: true +tests: + - "n_tensors refleja el numero de tensores en el header" + - "total_size_bytes refleja el tamano real del archivo" + - "metadata devuelve el contenido de __metadata__" + - "tensors es lista con una entrada por tensor del header" + - "cada tensor tiene dtype, shape y data_offsets" + - "result path es la ruta absoluta del archivo" + - "FileNotFoundError si el archivo no existe" + - "ValueError si el header no es JSON valido" + - "ValueError si el archivo esta vacio" + - "si no hay __metadata__ metadata retorna dict vacio" +test_file_path: "python/functions/ml/tests/test_safetensors_inspect.py" +file_path: "python/functions/ml/safetensors_inspect.py" +--- + +## Ejemplo + +```python +from ml.safetensors_inspect import safetensors_inspect + +info = safetensors_inspect("/models/sd-v1-5/model.safetensors") +print(info["n_tensors"]) # 1344 +print(info["total_size_bytes"]) # 3_975_733_952 (~3.7 GB) +print(info["metadata"]) # {"format": "pt", "model_type": "stable_diffusion"} + +# Ver los 5 primeros tensores +for t in info["tensors"][:5]: + print(t["name"], t["dtype"], t["shape"]) +# model.diffusion_model.input_blocks.0.0.weight F16 [320, 4, 3, 3] +# model.diffusion_model.input_blocks.0.0.bias F16 [320] +# ... +``` + +## Notas + +### Formato safetensors + +``` +[8 bytes: uint64 LE = N (longitud del header JSON)] +[N bytes: JSON con metadata y descriptores] +[datos binarios de los tensores (no se leen)] +``` + +El JSON tiene esta estructura: +```json +{ + "__metadata__": {"format": "pt", ...}, + "tensor_name": { + "dtype": "F32", + "shape": [1024, 768], + "data_offsets": [0, 3145728] + }, + ... +} +``` + +`data_offsets` son relativos al inicio del bloque de datos (despues del header), +no al inicio del archivo. Para acceso lazy a un tensor concreto: +`offset_en_archivo = 8 + header_len + data_offsets[0]`. + +### Por que no usar la libreria `safetensors` + +Esta funcion solo usa stdlib (`struct`, `json`, `os`) para no requerir +instalaciones adicionales y ser ejecutable durante `fn index`. La libreria +oficial `safetensors` de HuggingFace cargaria los tensores en RAM al usar +`safe_open` sin `framework=None`. Esta implementacion es read-only sobre +el header y garantiza que no se carga ningun dato de tensor. + +### Dtypes comunes + +| dtype | descripcion | +|-------|-------------| +| F32 | float32 (full precision) | +| BF16 | bfloat16 (training, ampere+) | +| F16 | float16 (inference) | +| I32 | int32 | +| I64 | int64 | +| U8 | uint8 | diff --git a/python/functions/ml/safetensors_inspect.py b/python/functions/ml/safetensors_inspect.py new file mode 100644 index 00000000..170d48ac --- /dev/null +++ b/python/functions/ml/safetensors_inspect.py @@ -0,0 +1,100 @@ +"""Lee solo el header de un archivo safetensors sin cargar los tensores en RAM.""" + +from __future__ import annotations + +import json +import struct + + +def safetensors_inspect(path: str) -> dict: + """Lee el header de un archivo safetensors sin cargar los tensores. + + El formato safetensors almacena al inicio del archivo: + - 8 bytes: uint64 little-endian con la longitud del header JSON (N). + - N bytes: JSON con metadata y descriptores de tensores. + + Este enfoque evita cargar gigabytes de pesos en RAM para inspeccionar + un checkpoint: solo se leen los primeros 8 + N bytes. + + Spec: https://github.com/huggingface/safetensors#format + + Args: + path: ruta al archivo .safetensors (absoluta o relativa). + + Returns: + dict con claves: + path (str): ruta absoluta del archivo. + metadata (dict): metadatos del modelo (campo "__metadata__" del header). + tensors (list[dict]): lista de tensores, cada uno con: + name (str): nombre del tensor. + dtype (str): tipo de dato (F32, BF16, F16, I32, etc.). + shape (list[int]): dimensiones del tensor. + data_offsets (list[int]): [inicio, fin] en bytes dentro del + bloque de datos (para acceso lazy si se necesita). + total_size_bytes (int): tamano total del archivo en bytes. + n_tensors (int): numero de tensores en el archivo. + + Raises: + FileNotFoundError: si el archivo no existe. + ValueError: si el archivo no es un safetensors valido. + OSError: si no se puede leer el archivo. + """ + import os + + abs_path = os.path.abspath(path) + + if not os.path.isfile(abs_path): + raise FileNotFoundError(f"safetensors_inspect: archivo no encontrado: {abs_path}") + + total_size = os.path.getsize(abs_path) + + with open(abs_path, "rb") as f: + # Leer los 8 bytes del tamano del header + raw_len = f.read(8) + if len(raw_len) < 8: + raise ValueError( + f"safetensors_inspect: archivo demasiado corto para ser safetensors: {abs_path}" + ) + + header_len = struct.unpack(" total_size: + raise ValueError( + f"safetensors_inspect: header_len invalido ({header_len}) en {abs_path}" + ) + + raw_header = f.read(header_len) + if len(raw_header) < header_len: + raise ValueError( + f"safetensors_inspect: no se pudo leer el header completo en {abs_path}" + ) + + try: + header = json.loads(raw_header.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + raise ValueError( + f"safetensors_inspect: header JSON invalido en {abs_path}: {exc}" + ) from exc + + metadata = header.get("__metadata__", {}) + + tensors = [] + for name, desc in header.items(): + if name == "__metadata__": + continue + tensors.append( + { + "name": name, + "dtype": desc.get("dtype", ""), + "shape": desc.get("shape", []), + "data_offsets": desc.get("data_offsets", []), + } + ) + + return { + "path": abs_path, + "metadata": metadata, + "tensors": tensors, + "total_size_bytes": total_size, + "n_tensors": len(tensors), + } diff --git a/python/functions/ml/sampler_name.py b/python/functions/ml/sampler_name.py new file mode 100644 index 00000000..8a6c8055 --- /dev/null +++ b/python/functions/ml/sampler_name.py @@ -0,0 +1,16 @@ +"""SamplerName — subset de samplers compartido entre diffusers y stable-diffusion.cpp.""" + +from typing import Literal + +# Sum type: valores validos de sampler para GenerationConfig. +# Subset estricto que tienen correspondencia directa en diffusers (schedulers) +# y en stable-diffusion.cpp (--sampling-method). +SamplerName = Literal[ + "euler", + "euler_a", + "dpm++2m", + "dpm++2m_v2", + "heun", + "dpm2", + "lcm", +] diff --git a/python/functions/ml/tests/test_cuda_available.py b/python/functions/ml/tests/test_cuda_available.py new file mode 100644 index 00000000..84907295 --- /dev/null +++ b/python/functions/ml/tests/test_cuda_available.py @@ -0,0 +1,54 @@ +"""Tests para cuda_available.""" + +import sys +import unittest +from unittest.mock import patch + + +# Asegurar que el modulo ml es importable desde el path del registry +sys.path.insert(0, "python/functions") + +from ml.cuda_available import cuda_available + + +class TestCudaAvailable(unittest.TestCase): + + def test_claves_del_dict_siempre_presentes(self): + """claves del dict siempre presentes""" + result = cuda_available() + for key in ("available", "device_count", "devices", "torch_version", "cuda_version"): + self.assertIn(key, result, f"Falta clave: {key}") + + def test_sin_torch_retorna_available_False_y_torch_version_not_installed(self): + """sin torch retorna available=False y torch_version=not_installed""" + with patch.dict(sys.modules, {"torch": None}): + result = cuda_available() + self.assertFalse(result["available"]) + self.assertEqual(result["torch_version"], "not_installed") + self.assertEqual(result["device_count"], 0) + self.assertEqual(result["devices"], []) + self.assertIsNone(result["cuda_version"]) + + def test_con_torch_sin_cuda_retorna_available_False_y_device_count_0(self): + """con torch sin cuda retorna available=False y device_count=0""" + import types + fake_torch = types.ModuleType("torch") + fake_torch.__version__ = "2.3.0" + fake_torch.cuda = types.SimpleNamespace( + is_available=lambda: False, + device_count=lambda: 0, + ) + fake_torch.version = types.SimpleNamespace(cuda=None) + + with patch.dict(sys.modules, {"torch": fake_torch}): + result = cuda_available() + + self.assertFalse(result["available"]) + self.assertEqual(result["device_count"], 0) + self.assertEqual(result["devices"], []) + self.assertEqual(result["torch_version"], "2.3.0") + self.assertIsNone(result["cuda_version"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/functions/ml/tests/test_diffusers_backend.py b/python/functions/ml/tests/test_diffusers_backend.py new file mode 100644 index 00000000..cf101f8c --- /dev/null +++ b/python/functions/ml/tests/test_diffusers_backend.py @@ -0,0 +1,212 @@ +"""Tests para el backend diffusers: load_pipeline, set_scheduler, generate, unload.""" + +from __future__ import annotations + +import sys +import os +import time + +import pytest + +# Ajustar path para importar desde python/functions/ml/ +_ML_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", +) +sys.path.insert(0, os.path.abspath(_ML_PATH)) + +# Importaciones lazy de torch y diffusers — las omitimos si no estan disponibles. +torch = pytest.importorskip("torch", reason="torch no instalado — skip tests diffusers") +pytest.importorskip("diffusers", reason="diffusers no instalado — skip tests diffusers") + +from ml.model_ref import ModelRef +from ml.generation_config import GenerationConfig +from ml.image_gen_result import ImageGenResult +from ml.diffusers_load_pipeline import diffusers_load_pipeline, _clear_pipeline_cache +from ml.diffusers_set_scheduler import diffusers_set_scheduler +from ml.diffusers_unload import diffusers_unload + +# diffusers_generate importa image_gen_result sin prefijo de paquete. +# Para evitar el double-import problem (ml.image_gen_result != image_gen_result), +# forzamos que sys.modules["image_gen_result"] apunte al modulo ya cargado +# como ml.image_gen_result antes de importar diffusers_generate. +import sys as _sys +import ml.image_gen_result as _igr_module +import ml.generation_config as _gcfg_module +import ml.genconfig_to_diffusers_kwargs as _gkwargs_module +for _alias, _mod in [ + ("image_gen_result", _igr_module), + ("generation_config", _gcfg_module), + ("genconfig_to_diffusers_kwargs", _gkwargs_module), +]: + if _alias not in _sys.modules: + _sys.modules[_alias] = _mod + +from ml.diffusers_generate import diffusers_generate + + +# --------------------------------------------------------------------------- +# Constantes +# --------------------------------------------------------------------------- +SD_TURBO_PATH = "/home/lucas/vaults/imagegen_models/diffusers/sd-turbo" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="session") +def sd_turbo_model() -> ModelRef: + """ModelRef apuntando a SD Turbo local.""" + if not os.path.isdir(SD_TURBO_PATH): + pytest.skip(f"SD Turbo no encontrado en {SD_TURBO_PATH}") + return ModelRef( + name="sd-turbo", + model_type="sd15", + quantization="fp16", + path=SD_TURBO_PATH, + ) + + +@pytest.fixture(scope="session") +def loaded_pipe(sd_turbo_model: ModelRef): + """Pipeline SD Turbo cargado una sola vez para toda la sesion de tests.""" + # Intentar fp16 primero; si falla (no hay variante fp16) usar fp32 + try: + pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16") + except Exception: + _clear_pipeline_cache() + pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp32") + yield pipe + # Teardown: liberar al final de la sesion + diffusers_unload(None) + + +@pytest.fixture(scope="session") +def sd_turbo_cfg(sd_turbo_model: ModelRef) -> GenerationConfig: + """GenerationConfig minimo para SD Turbo (1 step, 512x512).""" + return GenerationConfig( + prompt="a simple red circle on white background", + negative_prompt=None, + seed=42, + steps=1, + cfg_scale=0.0, + sampler="euler", + width=512, + height=512, + model=sd_turbo_model, + loras=[], + ) + + +# --------------------------------------------------------------------------- +# Test: carga pipeline y retorna callable +# --------------------------------------------------------------------------- + +def test_load_pipeline_returns_callable(sd_turbo_model: ModelRef) -> None: + """carga pipeline y retorna callable""" + _clear_pipeline_cache() + pipe = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16") + assert callable(pipe), "El pipeline debe ser callable" + assert hasattr(pipe, "scheduler"), "El pipeline debe tener atributo scheduler" + + +# --------------------------------------------------------------------------- +# Test: segunda carga usa cache (< 100ms) +# --------------------------------------------------------------------------- + +def test_load_pipeline_caches(sd_turbo_model: ModelRef) -> None: + """segunda carga usa cache (< 100ms)""" + # Primera carga (puede tardar varios segundos) + _clear_pipeline_cache() + _ = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16") + + # Segunda carga debe ser cache hit + t0 = time.perf_counter() + pipe2 = diffusers_load_pipeline(sd_turbo_model, device="auto", dtype="fp16") + elapsed_ms = (time.perf_counter() - t0) * 1000 + + assert elapsed_ms < 100, ( + f"Segunda carga tardo {elapsed_ms:.1f}ms (esperado < 100ms — debe ser cache hit)" + ) + assert pipe2 is not None + + +# --------------------------------------------------------------------------- +# Test: set_scheduler cambia la clase del scheduler +# --------------------------------------------------------------------------- + +def test_set_scheduler_changes_scheduler_class(loaded_pipe) -> None: + """euler cambia scheduler a EulerDiscreteScheduler""" + pipe = diffusers_set_scheduler(loaded_pipe, "euler") + scheduler_name = type(pipe.scheduler).__name__ + assert scheduler_name == "EulerDiscreteScheduler", ( + f"Esperado EulerDiscreteScheduler, obtenido {scheduler_name}" + ) + + +def test_set_scheduler_euler_a(loaded_pipe) -> None: + """euler_a cambia scheduler a EulerAncestralDiscreteScheduler""" + pipe = diffusers_set_scheduler(loaded_pipe, "euler_a") + scheduler_name = type(pipe.scheduler).__name__ + assert scheduler_name == "EulerAncestralDiscreteScheduler", ( + f"Esperado EulerAncestralDiscreteScheduler, obtenido {scheduler_name}" + ) + # Restaurar euler para no afectar otros tests + diffusers_set_scheduler(loaded_pipe, "euler") + + +def test_set_scheduler_invalid_raises_value_error(loaded_pipe) -> None: + """sampler invalido lanza ValueError""" + with pytest.raises(ValueError, match="no soportado"): + diffusers_set_scheduler(loaded_pipe, "nonexistent_sampler_xyz") + + +# --------------------------------------------------------------------------- +# Test: genera imagen retorna ImageGenResult +# --------------------------------------------------------------------------- + +def test_generate_returns_image_gen_result( + loaded_pipe, sd_turbo_cfg: GenerationConfig +) -> None: + """genera imagen retorna ImageGenResult""" + result = diffusers_generate(loaded_pipe, sd_turbo_cfg) + + assert isinstance(result, ImageGenResult), ( + f"Esperado ImageGenResult, obtenido {type(result)}" + ) + assert result.image is not None, "result.image no debe ser None" + assert result.duration_ms > 0, ( + f"duration_ms debe ser positivo, obtenido {result.duration_ms}" + ) + assert "backend" in result.meta, "meta debe tener key 'backend'" + assert result.meta["backend"] == "diffusers", ( + f"meta['backend'] debe ser 'diffusers', obtenido {result.meta['backend']}" + ) + assert "model" in result.meta, "meta debe tener key 'model'" + + # Verificar que la imagen tiene las dimensiones correctas + w, h = result.image.size + assert w == sd_turbo_cfg.width and h == sd_turbo_cfg.height, ( + f"Imagen esperada {sd_turbo_cfg.width}x{sd_turbo_cfg.height}, " + f"obtenida {w}x{h}" + ) + + +# --------------------------------------------------------------------------- +# Test: unload limpia cache cuda si disponible +# --------------------------------------------------------------------------- + +def test_unload_clears_cuda() -> None: + """unload None limpia cache cuda si disponible""" + cuda_available = torch.cuda.is_available() + + # Limpiar cache — no debe lanzar excepcion independientemente de si hay CUDA + diffusers_unload(None) + + if cuda_available: + # Despues de empty_cache, la memoria reservada por el allocator baja + # No podemos asumir que sea 0 (otros tensores pueden estar vivos), + # pero la llamada debe completarse sin error. + reserved = torch.cuda.memory_reserved() + # Solo verificamos que no lanza excepcion y que la llamada completo + assert reserved >= 0, "memory_reserved debe ser >= 0" diff --git a/python/functions/ml/tests/test_genconfig_json_roundtrip.py b/python/functions/ml/tests/test_genconfig_json_roundtrip.py new file mode 100644 index 00000000..fde5616c --- /dev/null +++ b/python/functions/ml/tests/test_genconfig_json_roundtrip.py @@ -0,0 +1,165 @@ +"""Tests de roundtrip JSON para genconfig_save_json y genconfig_load_json.""" + +import json +import os +import sys +import tempfile +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from ml.genconfig_save_json import genconfig_save_json +from ml.genconfig_load_json import genconfig_load_json +from ml.generation_config import GenerationConfig + + +def _make_cfg(**overrides): + """Construye un GenerationConfig sintetico usando model_validate para evitar + problemas de identidad de clase entre modulos pydantic separados.""" + defaults = dict( + prompt="a forest at dusk", + negative_prompt="blurry, low quality", + seed=123, + steps=25, + cfg_scale=7.5, + sampler="euler", + width=512, + height=512, + model={"name": "runwayml/stable-diffusion-v1-5", "model_type": "sd15"}, + loras=[{"path": "/loras/detail.safetensors", "weight": 0.7}], + ) + defaults.update(overrides) + try: + return GenerationConfig.model_validate(defaults) + except AttributeError: + from ml.model_ref import ModelRef + from ml.lora_ref import LoraRef + m = defaults.pop("model") + if isinstance(m, dict): + m = ModelRef(**m) + loras = defaults.pop("loras", []) + built = [LoraRef(**lr) if isinstance(lr, dict) else lr for lr in loras] + return GenerationConfig(model=m, loras=tuple(built), **defaults) + + +class TestGenconfigJsonRoundtrip(unittest.TestCase): + + def test_save_escribe_archivo_json_valido_en_la_ruta_indicada(self): + """save escribe archivo JSON valido en la ruta indicada""" + cfg = _make_cfg() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "config.json") + saved = genconfig_save_json(cfg, path) + self.assertTrue(os.path.isfile(saved)) + with open(saved, "r", encoding="utf-8") as f: + data = json.load(f) + self.assertIsInstance(data, dict) + self.assertEqual(data["prompt"], "a forest at dusk") + + def test_save_crea_directorios_padre_si_no_existen(self): + """save crea directorios padre si no existen""" + cfg = _make_cfg() + with tempfile.TemporaryDirectory() as tmpdir: + nested = os.path.join(tmpdir, "a", "b", "c", "config.json") + saved = genconfig_save_json(cfg, nested) + self.assertTrue(os.path.isfile(saved)) + + def test_json_contiene_claves_en_snake_case(self): + """json contiene claves en snake_case""" + cfg = _make_cfg() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "config.json") + genconfig_save_json(cfg, path) + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + # Claves deben ser snake_case (interoperabilidad con Go) + expected_keys = { + "prompt", "negative_prompt", "seed", "steps", + "cfg_scale", "sampler", "width", "height", "model", + } + for key in expected_keys: + self.assertIn(key, data, f"Clave snake_case faltante: {key}") + # No debe haber camelCase + self.assertNotIn("negativePrompt", data) + self.assertNotIn("cfgScale", data) + self.assertNotIn("numInferenceSteps", data) + + def test_roundtrip_preserva_campos_escalares(self): + """roundtrip save→load preserva campos escalares""" + cfg = _make_cfg() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "config.json") + genconfig_save_json(cfg, path) + loaded = genconfig_load_json(path) + self.assertEqual(loaded.prompt, cfg.prompt) + self.assertEqual(loaded.negative_prompt, cfg.negative_prompt) + self.assertEqual(loaded.seed, cfg.seed) + self.assertEqual(loaded.steps, cfg.steps) + self.assertAlmostEqual(loaded.cfg_scale, cfg.cfg_scale) + self.assertEqual(loaded.sampler, cfg.sampler) + self.assertEqual(loaded.width, cfg.width) + self.assertEqual(loaded.height, cfg.height) + + def test_roundtrip_preserva_model_ref(self): + """roundtrip preserva ModelRef""" + cfg = _make_cfg( + model={ + "name": "stabilityai/sdxl-base-1.0", + "model_type": "sdxl", + "quantization": "fp16", + "path": "/models/sdxl.safetensors", + } + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "config.json") + genconfig_save_json(cfg, path) + loaded = genconfig_load_json(path) + self.assertEqual(loaded.model.name, "stabilityai/sdxl-base-1.0") + self.assertEqual(loaded.model.model_type, "sdxl") + self.assertEqual(loaded.model.quantization, "fp16") + self.assertEqual(loaded.model.path, "/models/sdxl.safetensors") + + def test_roundtrip_preserva_loras(self): + """roundtrip preserva lista de LoraRef""" + cfg = _make_cfg( + loras=[ + {"path": "/loras/a.safetensors", "weight": 0.8}, + {"path": "/loras/b.safetensors", "weight": 0.5}, + ] + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "config.json") + genconfig_save_json(cfg, path) + loaded = genconfig_load_json(path) + loaded_loras = list(loaded.loras) + self.assertEqual(len(loaded_loras), 2) + paths = [lr.path for lr in loaded_loras] + self.assertIn("/loras/a.safetensors", paths) + self.assertIn("/loras/b.safetensors", paths) + + def test_roundtrip_negative_prompt_none(self): + """roundtrip con negative_prompt=None""" + cfg = _make_cfg(negative_prompt=None) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "config.json") + genconfig_save_json(cfg, path) + loaded = genconfig_load_json(path) + self.assertIsNone(loaded.negative_prompt) + + def test_load_falla_con_file_not_found(self): + """load lanza FileNotFoundError si el archivo no existe""" + with self.assertRaises(FileNotFoundError): + genconfig_load_json("/tmp/nonexistent_fn_registry_test_12345.json") + + def test_save_retorna_path_absoluto(self): + """save retorna path absoluto aunque se pase path relativo""" + cfg = _make_cfg() + with tempfile.TemporaryDirectory() as tmpdir: + abs_path = os.path.join(tmpdir, "cfg.json") + result = genconfig_save_json(cfg, abs_path) + self.assertTrue(os.path.isabs(result)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/functions/ml/tests/test_genconfig_to_diffusers_kwargs.py b/python/functions/ml/tests/test_genconfig_to_diffusers_kwargs.py new file mode 100644 index 00000000..fc887580 --- /dev/null +++ b/python/functions/ml/tests/test_genconfig_to_diffusers_kwargs.py @@ -0,0 +1,113 @@ +"""Tests para genconfig_to_diffusers_kwargs.""" + +import sys +import os +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from ml.genconfig_to_diffusers_kwargs import genconfig_to_diffusers_kwargs +from ml.generation_config import GenerationConfig + + +def _make_cfg(**overrides): + """Crea un GenerationConfig sintetico para tests via model_validate / constructor.""" + defaults = dict( + prompt="a dog in the park", + seed=42, + steps=30, + cfg_scale=7.5, + sampler="euler_a", + width=512, + height=768, + model={"name": "runwayml/stable-diffusion-v1-5", "model_type": "sd15"}, + ) + defaults.update(overrides) + try: + return GenerationConfig.model_validate(defaults) + except AttributeError: + # dataclass fallback: model y loras ya son dicts, construir manualmente + from ml.model_ref import ModelRef + from ml.lora_ref import LoraRef + m = defaults.pop("model") + if isinstance(m, dict): + m = ModelRef(**m) + loras = defaults.pop("loras", []) + built_loras = [] + for lr in loras: + if isinstance(lr, dict): + built_loras.append(LoraRef(**lr)) + else: + built_loras.append(lr) + return GenerationConfig(model=m, loras=tuple(built_loras), **defaults) + + +class TestGenconfigToDiffusersKwargs(unittest.TestCase): + + def test_kwargs_contiene_todas_las_claves_requeridas(self): + """kwargs contiene todas las claves requeridas""" + cfg = _make_cfg() + kwargs = genconfig_to_diffusers_kwargs(cfg) + required_keys = { + "prompt", + "negative_prompt", + "num_inference_steps", + "guidance_scale", + "width", + "height", + "generator", + } + self.assertEqual(set(kwargs.keys()), required_keys) + + def test_negative_prompt_none_se_pasa_tal_cual(self): + """negative_prompt None se pasa tal cual""" + cfg = _make_cfg(negative_prompt=None) + kwargs = genconfig_to_diffusers_kwargs(cfg) + self.assertIsNone(kwargs["negative_prompt"]) + + def test_steps_y_cfg_scale_se_mapean_a_num_inference_steps_y_guidance_scale(self): + """steps y cfg_scale se mapean a num_inference_steps y guidance_scale""" + cfg = _make_cfg(steps=20, cfg_scale=8.0) + kwargs = genconfig_to_diffusers_kwargs(cfg) + self.assertEqual(kwargs["num_inference_steps"], 20) + self.assertAlmostEqual(kwargs["guidance_scale"], 8.0) + + def test_generator_siempre_es_none(self): + """generator siempre es None""" + cfg = _make_cfg() + kwargs = genconfig_to_diffusers_kwargs(cfg) + self.assertIsNone(kwargs["generator"]) + + def test_prompt_se_copia_sin_modificar(self): + """prompt se copia sin modificar""" + cfg = _make_cfg(prompt="a cat on a roof") + kwargs = genconfig_to_diffusers_kwargs(cfg) + self.assertEqual(kwargs["prompt"], "a cat on a roof") + + def test_width_y_height_se_preservan(self): + """width y height se preservan""" + cfg = _make_cfg(width=1024, height=768) + kwargs = genconfig_to_diffusers_kwargs(cfg) + self.assertEqual(kwargs["width"], 1024) + self.assertEqual(kwargs["height"], 768) + + def test_negative_prompt_string_se_pasa_tal_cual(self): + """negative_prompt string se pasa tal cual""" + cfg = _make_cfg(negative_prompt="blurry, low quality") + kwargs = genconfig_to_diffusers_kwargs(cfg) + self.assertEqual(kwargs["negative_prompt"], "blurry, low quality") + + def test_no_incluye_seed_sampler_ni_loras(self): + """no incluye seed sampler ni loras en el dict""" + cfg = _make_cfg( + loras=[{"path": "/loras/detail.safetensors", "weight": 0.8}] + ) + kwargs = genconfig_to_diffusers_kwargs(cfg) + self.assertNotIn("seed", kwargs) + self.assertNotIn("sampler", kwargs) + self.assertNotIn("loras", kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/functions/ml/tests/test_genconfig_to_sdcpp_args.py b/python/functions/ml/tests/test_genconfig_to_sdcpp_args.py new file mode 100644 index 00000000..0f041bad --- /dev/null +++ b/python/functions/ml/tests/test_genconfig_to_sdcpp_args.py @@ -0,0 +1,150 @@ +"""Tests para genconfig_to_sdcpp_args.""" + +import sys +import os +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from ml.genconfig_to_sdcpp_args import genconfig_to_sdcpp_args, _SAMPLER_MAP +from ml.generation_config import GenerationConfig + + +def _make_cfg(**overrides): + """Crea un GenerationConfig sintetico para tests via model_validate / constructor.""" + defaults = dict( + prompt="a cat", + seed=1, + steps=20, + cfg_scale=7.0, + sampler="euler", + width=512, + height=512, + model={"name": "v1-5-pruned.ckpt", "model_type": "sd15", "path": "/models/v1-5.ckpt"}, + ) + defaults.update(overrides) + # Normalizar loras a dicts si fueron pasados como LoraRef + if "loras" in defaults: + normalized = [] + for lr in defaults["loras"]: + if hasattr(lr, "__dict__") and not isinstance(lr, dict): + normalized.append({"path": lr.path, "weight": lr.weight, "scale": lr.scale}) + else: + normalized.append(lr) + defaults["loras"] = normalized + try: + return GenerationConfig.model_validate(defaults) + except AttributeError: + from ml.model_ref import ModelRef + from ml.lora_ref import LoraRef + m = defaults.pop("model") + if isinstance(m, dict): + m = ModelRef(**m) + loras = defaults.pop("loras", []) + built = [LoraRef(**lr) if isinstance(lr, dict) else lr for lr in loras] + return GenerationConfig(model=m, loras=tuple(built), **defaults) + + +def _get_flag_value(args: list[str], flag: str) -> str | None: + """Extrae el valor de un flag en la lista de args.""" + try: + idx = args.index(flag) + return args[idx + 1] + except (ValueError, IndexError): + return None + + +def _get_all_flag_values(args: list[str], flag: str) -> list[str]: + """Extrae todos los valores de un flag repetido (ej. --lora).""" + values = [] + for i, arg in enumerate(args): + if arg == flag and i + 1 < len(args): + values.append(args[i + 1]) + return values + + +class TestGenconfigToSdcppArgs(unittest.TestCase): + + def test_sampler_euler_a_se_mapea_a_euler_a_en_el_flag_sampling_method(self): + """sampler euler_a se mapea a euler_a en el flag --sampling-method""" + cfg = _make_cfg(sampler="euler_a") + args = genconfig_to_sdcpp_args(cfg) + self.assertEqual(_get_flag_value(args, "--sampling-method"), "euler_a") + + def test_sampler_dpm_pp_2m_se_mapea_a_dpmpp2m(self): + """sampler dpm++2m se mapea a dpmpp2m""" + cfg = _make_cfg(sampler="dpm++2m") + args = genconfig_to_sdcpp_args(cfg) + self.assertEqual(_get_flag_value(args, "--sampling-method"), "dpmpp2m") + + def test_lora_con_path_y_weight_se_agrega_como_lora_path_weight(self): + """lora con path y weight se agrega como --lora path:weight""" + cfg = _make_cfg( + loras=[{"path": "/loras/detail.safetensors", "weight": 0.8}] + ) + args = genconfig_to_sdcpp_args(cfg) + lora_values = _get_all_flag_values(args, "--lora") + self.assertEqual(len(lora_values), 1) + self.assertEqual(lora_values[0], "/loras/detail.safetensors:0.8") + + def test_multiples_loras_generan_multiples_pares_lora(self): + """multiples loras generan multiples pares --lora""" + cfg = _make_cfg( + loras=[ + {"path": "/loras/a.safetensors", "weight": 0.5}, + {"path": "/loras/b.safetensors", "weight": 1.0}, + ] + ) + args = genconfig_to_sdcpp_args(cfg) + lora_values = _get_all_flag_values(args, "--lora") + self.assertEqual(len(lora_values), 2) + self.assertIn("/loras/a.safetensors:0.5", lora_values) + self.assertIn("/loras/b.safetensors:1.0", lora_values) + + def test_negative_prompt_none_produce_string_vacio_en_negative_prompt(self): + """negative_prompt None produce string vacio en --negative-prompt""" + cfg = _make_cfg(negative_prompt=None) + args = genconfig_to_sdcpp_args(cfg) + self.assertEqual(_get_flag_value(args, "--negative-prompt"), "") + + def test_model_path_tiene_prioridad_sobre_model_name_en_m(self): + """model.path tiene prioridad sobre model.name en -m""" + cfg = _make_cfg( + model={"name": "hub-name", "model_type": "sd15", "path": "/local/path/model.ckpt"} + ) + args = genconfig_to_sdcpp_args(cfg) + self.assertEqual(_get_flag_value(args, "-m"), "/local/path/model.ckpt") + + def test_sin_path_usa_model_name_en_m(self): + """sin path usa model.name en -m""" + cfg = _make_cfg( + model={"name": "runwayml/sd-v1-5", "model_type": "sd15", "path": None} + ) + args = genconfig_to_sdcpp_args(cfg) + self.assertEqual(_get_flag_value(args, "-m"), "runwayml/sd-v1-5") + + def test_args_contiene_flags_obligatorios(self): + """args contiene --prompt --seed --steps --cfg-scale --sampling-method -W -H -m""" + cfg = _make_cfg() + args = genconfig_to_sdcpp_args(cfg) + for flag in ["--prompt", "--seed", "--steps", "--cfg-scale", "--sampling-method", "-W", "-H", "-m"]: + self.assertIn(flag, args, f"Flag faltante: {flag}") + + def test_sampler_map_cubre_todos_los_samplers_canonicos(self): + """_SAMPLER_MAP cubre todos los samplers canonicos del dominio ml""" + canonical = {"euler", "euler_a", "dpm++2m", "dpm++2m_v2", "heun", "dpm2", "lcm"} + self.assertEqual(set(_SAMPLER_MAP.keys()), canonical) + + def test_seed_steps_width_height_se_convierten_a_string(self): + """seed steps width height se convierten a string en los args""" + cfg = _make_cfg(seed=42, steps=25, width=768, height=512) + args = genconfig_to_sdcpp_args(cfg) + self.assertEqual(_get_flag_value(args, "--seed"), "42") + self.assertEqual(_get_flag_value(args, "--steps"), "25") + self.assertEqual(_get_flag_value(args, "-W"), "768") + self.assertEqual(_get_flag_value(args, "-H"), "512") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/functions/ml/tests/test_generation_config_serialization.py b/python/functions/ml/tests/test_generation_config_serialization.py new file mode 100644 index 00000000..27585214 --- /dev/null +++ b/python/functions/ml/tests/test_generation_config_serialization.py @@ -0,0 +1,131 @@ +"""Tests para GenerationConfig — serialización, roundtrip y frozen.""" + +import json +import sys + +# Añadir python/functions/ml al path para que los imports internos del módulo +# (from lora_ref import LoraRef, from model_ref import ModelRef) funcionen. +# Los módulos se importan directamente desde el subdirectorio para evitar +# colisiones de tipos entre ml.generation_config.* y generation_config.*. +sys.path.insert(0, "python/functions/ml") + +import pytest + +from generation_config import GenerationConfig +from lora_ref import LoraRef +from model_ref import ModelRef + + +def _make_model() -> ModelRef: + return ModelRef(name="stabilityai/stable-diffusion-v1-5", model_type="sd15") + + +def _make_config() -> GenerationConfig: + return GenerationConfig( + prompt="a cat in the moonlight", + negative_prompt="blurry, low quality", + seed=42, + steps=30, + cfg_scale=7.5, + sampler="euler_a", + width=512, + height=512, + model=_make_model(), + loras=[], + clip_skip=1, + ) + + +def test_instancia_ok(): + """GenerationConfig crea instancia sin errores""" + cfg = _make_config() + assert cfg.prompt == "a cat in the moonlight" + assert cfg.seed == 42 + assert cfg.steps == 30 + assert cfg.cfg_scale == 7.5 + assert cfg.sampler == "euler_a" + assert cfg.width == 512 + assert cfg.height == 512 + assert cfg.clip_skip == 1 + + +def test_model_dump_keys_snake_case(): + """model_dump devuelve dict con keys snake_case incluyendo negative_prompt, cfg_scale, clip_skip""" + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic no disponible") + + cfg = _make_config() + d = cfg.model_dump() + assert isinstance(d, dict) + assert "negative_prompt" in d + assert "cfg_scale" in d + assert "clip_skip" in d + assert d["negative_prompt"] == "blurry, low quality" + assert d["cfg_scale"] == 7.5 + assert d["clip_skip"] == 1 + + +def test_model_dump_json_parseable(): + """model_dump_json retorna str JSON parseable""" + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic no disponible") + + cfg = _make_config() + raw = cfg.model_dump_json() + assert isinstance(raw, str) + parsed = json.loads(raw) + assert isinstance(parsed, dict) + assert parsed["prompt"] == "a cat in the moonlight" + assert parsed["seed"] == 42 + + +def test_roundtrip_model_validate(): + """GenerationConfig.model_validate(json.loads(...)) roundtrip ok""" + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic no disponible") + + cfg = _make_config() + raw_json = cfg.model_dump_json() + parsed = json.loads(raw_json) + cfg2 = GenerationConfig.model_validate(parsed) + assert cfg2.prompt == cfg.prompt + assert cfg2.seed == cfg.seed + assert cfg2.cfg_scale == cfg.cfg_scale + assert cfg2.sampler == cfg.sampler + assert cfg2.clip_skip == cfg.clip_skip + assert cfg2.model.name == cfg.model.name + assert cfg2.model.model_type == cfg.model.model_type + + +def test_frozen_levanta_error_al_mutar(): + """frozen: intentar mutar levanta AttributeError, ValidationError o FrozenInstanceError""" + cfg = _make_config() + raised = False + try: + # dataclass frozen y pydantic frozen levantan distintas excepciones + cfg.prompt = "mutated" # type: ignore[misc] + except Exception: + raised = True + + assert raised, "Se esperaba que mutar un campo frozen lanzara una excepcion" + + +def test_negative_prompt_opcional(): + """negative_prompt es opcional (default None)""" + cfg = GenerationConfig( + prompt="mountains", + seed=0, + steps=20, + cfg_scale=7.0, + sampler="euler", + width=512, + height=512, + model=_make_model(), + ) + assert cfg.negative_prompt is None diff --git a/python/functions/ml/tests/test_gpu_info.py b/python/functions/ml/tests/test_gpu_info.py new file mode 100644 index 00000000..bdc2a701 --- /dev/null +++ b/python/functions/ml/tests/test_gpu_info.py @@ -0,0 +1,48 @@ +"""Tests para gpu_info.""" + +import sys +import unittest +from unittest.mock import MagicMock, patch + +sys.path.insert(0, "python/functions") + +from ml.gpu_info import gpu_info + + +class TestGpuInfo(unittest.TestCase): + + def test_sin_nvidia_smi_devuelve_lista_vacia(self): + """sin nvidia-smi devuelve lista vacia""" + with patch("subprocess.run", side_effect=FileNotFoundError()): + result = gpu_info() + self.assertEqual(result, []) + + def test_formato_CSV_correcto_devuelve_lista_con_un_dict_por_GPU(self): + """formato CSV correcto devuelve lista con un dict por GPU""" + csv_output = " 0, NVIDIA RTX 4090, 24564, 22000, 535.183.01, 8.9\n" + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = csv_output + with patch("subprocess.run", return_value=mock_result): + result = gpu_info() + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["index"], 0) + self.assertEqual(result[0]["name"], "NVIDIA RTX 4090") + self.assertEqual(result[0]["vram_total_mb"], 24564) + self.assertEqual(result[0]["vram_free_mb"], 22000) + self.assertEqual(result[0]["driver_version"], "535.183.01") + self.assertEqual(result[0]["cuda_version"], "8.9") + + def test_fila_malformada_en_CSV_se_ignora_sin_excepcion(self): + """fila malformada en CSV se ignora sin excepcion""" + csv_output = " 0, RTX 4090, NONNUMERIC, 22000, 535.183.01, 8.9\n" + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = csv_output + with patch("subprocess.run", return_value=mock_result): + result = gpu_info() + self.assertEqual(result, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/functions/ml/tests/test_hf_snapshot_download.py b/python/functions/ml/tests/test_hf_snapshot_download.py new file mode 100644 index 00000000..86853014 --- /dev/null +++ b/python/functions/ml/tests/test_hf_snapshot_download.py @@ -0,0 +1,128 @@ +"""Tests para hf_snapshot_download — mockear snapshot_download y verificar args.""" + +import sys +import types + +sys.path.insert(0, "python/functions/ml") + +import pytest + +# Saltar si huggingface_hub no esta disponible Y no podemos mockearlo +# Usamos un mock inline para no requerir la lib real. +# Si la lib esta disponible, monkeypatch la reemplaza. Si no, la inyectamos manualmente. + + +def _inject_fake_hf_hub(monkeypatch, capture: list): + """Inyecta un modulo huggingface_hub falso con snapshot_download que captura kwargs.""" + + def fake_snapshot_download(**kwargs): + capture.append(kwargs) + return "/tmp/fake_snapshot" + + fake_module = types.ModuleType("huggingface_hub") + fake_module.snapshot_download = fake_snapshot_download + monkeypatch.setitem(sys.modules, "huggingface_hub", fake_module) + + +def test_args_minimos_repo_id(monkeypatch): + """repo_id se pasa correctamente a snapshot_download""" + capture = [] + _inject_fake_hf_hub(monkeypatch, capture) + + from hf_snapshot_download import hf_snapshot_download + result = hf_snapshot_download("runwayml/stable-diffusion-v1-5") + + assert len(capture) == 1 + assert capture[0]["repo_id"] == "runwayml/stable-diffusion-v1-5" + assert result == "/tmp/fake_snapshot" + + +def test_retorna_string(monkeypatch): + """hf_snapshot_download retorna un string (la ruta local)""" + capture = [] + _inject_fake_hf_hub(monkeypatch, capture) + + from hf_snapshot_download import hf_snapshot_download + result = hf_snapshot_download("some/repo") + + assert isinstance(result, str) + + +def test_allow_patterns_se_pasa(monkeypatch): + """allow_patterns se incluye en los kwargs si se especifica""" + capture = [] + _inject_fake_hf_hub(monkeypatch, capture) + + from hf_snapshot_download import hf_snapshot_download + hf_snapshot_download("some/repo", allow_patterns=["*.safetensors", "*.json"]) + + assert "allow_patterns" in capture[0] + assert capture[0]["allow_patterns"] == ["*.safetensors", "*.json"] + + +def test_ignore_patterns_se_pasa(monkeypatch): + """ignore_patterns se incluye en los kwargs si se especifica""" + capture = [] + _inject_fake_hf_hub(monkeypatch, capture) + + from hf_snapshot_download import hf_snapshot_download + hf_snapshot_download("some/repo", ignore_patterns=["*.bin", "flax_*"]) + + assert "ignore_patterns" in capture[0] + assert capture[0]["ignore_patterns"] == ["*.bin", "flax_*"] + + +def test_local_dir_se_pasa(monkeypatch): + """local_dir se incluye en los kwargs si se especifica""" + capture = [] + _inject_fake_hf_hub(monkeypatch, capture) + + from hf_snapshot_download import hf_snapshot_download + hf_snapshot_download("some/repo", local_dir="/models/sd15") + + assert "local_dir" in capture[0] + assert capture[0]["local_dir"] == "/models/sd15" + + +def test_token_se_pasa(monkeypatch): + """token se incluye en los kwargs si se especifica""" + capture = [] + _inject_fake_hf_hub(monkeypatch, capture) + + from hf_snapshot_download import hf_snapshot_download + hf_snapshot_download("private/model", token="hf_mytoken123") + + assert "token" in capture[0] + assert capture[0]["token"] == "hf_mytoken123" + + +def test_none_args_no_se_pasan(monkeypatch): + """args opcionales None no se incluyen en kwargs (no contaminar snapshot_download)""" + capture = [] + _inject_fake_hf_hub(monkeypatch, capture) + + from hf_snapshot_download import hf_snapshot_download + hf_snapshot_download("some/repo") + + kwargs = capture[0] + # Solo repo_id debe estar presente — los None no se incluyen + assert "allow_patterns" not in kwargs + assert "ignore_patterns" not in kwargs + assert "local_dir" not in kwargs + assert "token" not in kwargs + + +def test_import_error_sin_huggingface_hub(monkeypatch): + """ImportError descriptivo si huggingface_hub no esta instalado""" + import importlib + + # Inyectar None en sys.modules para simular libreria no instalada + monkeypatch.setitem(sys.modules, "huggingface_hub", None) + + # Recargar el modulo para que el try/except del top-level vea el None + import hf_snapshot_download as _mod + importlib.reload(_mod) + + from hf_snapshot_download import hf_snapshot_download + with pytest.raises(ImportError, match="huggingface_hub"): + hf_snapshot_download("any/repo") diff --git a/python/functions/ml/tests/test_image_compare_side_by_side.py b/python/functions/ml/tests/test_image_compare_side_by_side.py new file mode 100644 index 00000000..5270d022 --- /dev/null +++ b/python/functions/ml/tests/test_image_compare_side_by_side.py @@ -0,0 +1,128 @@ +"""Tests para image_compare_side_by_side.""" + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import pytest + +PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping") + +from PIL import Image +from image_compare_side_by_side import image_compare_side_by_side + + +def _black(w=16, h=16): + return Image.new("RGB", (w, h), color=(0, 0, 0)) + + +def _white(w=16, h=16): + return Image.new("RGB", (w, h), color=(255, 255, 255)) + + +# --------------------------------------------------------------------------- +# Grid shape +# --------------------------------------------------------------------------- + +def test_grid_es_pil_image_con_dimensiones_correctas_show_diff_True(): + """grid es PIL.Image con dimensiones correctas show_diff=True""" + w, h = 16, 16 + gap = 16 + result = image_compare_side_by_side(_black(w, h), _white(w, h), gap_px=gap, show_diff=True) + + grid = result["grid"] + assert isinstance(grid, Image.Image), "grid debe ser PIL.Image.Image" + + expected_w = 3 * w + 4 * gap # A + diff + B + 4 gaps + expected_h = h + 2 * gap + assert grid.size == (expected_w, expected_h), ( + f"Esperado ({expected_w}, {expected_h}), got {grid.size}" + ) + + +def test_grid_es_pil_image_sin_diff_show_diff_False(): + """grid es PIL.Image sin diff show_diff=False""" + w, h = 16, 16 + gap = 8 + result = image_compare_side_by_side(_black(w, h), _white(w, h), gap_px=gap, show_diff=False) + + grid = result["grid"] + assert isinstance(grid, Image.Image), "grid debe ser PIL.Image.Image" + + expected_w = 2 * w + 3 * gap # A + B + 3 gaps + expected_h = h + 2 * gap + assert grid.size == (expected_w, expected_h), ( + f"Esperado ({expected_w}, {expected_h}), got {grid.size}" + ) + + +# --------------------------------------------------------------------------- +# MSE +# --------------------------------------------------------------------------- + +def test_pixel_mse_positivo_para_imagenes_distintas(): + """pixel_mse positivo para imagenes distintas""" + result = image_compare_side_by_side(_black(), _white()) + mse = result["pixel_mse"] + assert isinstance(mse, float), f"pixel_mse debe ser float, got {type(mse)}" + assert mse > 0.0, f"pixel_mse debe ser > 0 para imagenes distintas, got {mse}" + + +def test_pixel_mse_cero_para_imagen_identica(): + """pixel_mse cero para imagen identica""" + img = _black() + result = image_compare_side_by_side(img, img.copy()) + mse = result["pixel_mse"] + assert mse == 0.0, f"pixel_mse debe ser 0.0 para imagenes identicas, got {mse}" + + +# --------------------------------------------------------------------------- +# pHash +# --------------------------------------------------------------------------- + +def test_phash_none_si_imagehash_no_disponible(): + """phash None si imagehash no disponible""" + try: + import imagehash # noqa: F401 + pytest.skip("imagehash esta instalado — test de fallback no aplica") + except ImportError: + pass + + result = image_compare_side_by_side(_black(), _white(), show_phash=True) + assert result["phash_a"] is None, "phash_a debe ser None si imagehash no instalado" + assert result["phash_b"] is None, "phash_b debe ser None si imagehash no instalado" + assert result["phash_distance"] is None, "phash_distance debe ser None si imagehash no instalado" + + +def test_phash_presente_si_imagehash_disponible(): + """phash presente si imagehash disponible""" + try: + import imagehash # noqa: F401 + except ImportError: + pytest.skip("imagehash no instalado") + + result = image_compare_side_by_side(_black(), _white(), show_phash=True) + assert isinstance(result["phash_a"], str), "phash_a debe ser str" + assert isinstance(result["phash_b"], str), "phash_b debe ser str" + assert isinstance(result["phash_distance"], int), "phash_distance debe ser int" + assert len(result["phash_a"]) == 16, f"phash_a debe tener 16 hex chars, got {len(result['phash_a'])}" + + +# --------------------------------------------------------------------------- +# Campos del resultado +# --------------------------------------------------------------------------- + +def test_resultado_tiene_todas_las_claves(): + """resultado tiene todas las claves esperadas""" + result = image_compare_side_by_side(_black(), _white()) + for key in ("grid", "phash_a", "phash_b", "phash_distance", "pixel_mse"): + assert key in result, f"Clave '{key}' faltante en resultado" + + +def test_show_phash_false_deja_phash_none(): + """show_phash=False deja phash* en None sin intentar import""" + result = image_compare_side_by_side(_black(), _white(), show_phash=False) + assert result["phash_a"] is None + assert result["phash_b"] is None + assert result["phash_distance"] is None diff --git a/python/functions/ml/tests/test_image_gen_result.py b/python/functions/ml/tests/test_image_gen_result.py new file mode 100644 index 00000000..8a4f6315 --- /dev/null +++ b/python/functions/ml/tests/test_image_gen_result.py @@ -0,0 +1,99 @@ +"""Tests para ImageGenResult — dump excluye image, meta viaja correctamente.""" + +import json +import sys + +sys.path.insert(0, "python/functions/ml") + +from image_gen_result import ImageGenResult + + +def _make_result(image=None, duration_ms=1234, vram_peak_mb=None, meta=None): + if meta is None: + meta = { + "model": "sd15", + "seed_used": 42, + "sampler": "euler_a", + "prompt": "a cat", + } + return ImageGenResult( + image=image, + meta=meta, + duration_ms=duration_ms, + vram_peak_mb=vram_peak_mb, + ) + + +def test_instancia_ok(): + """ImageGenResult crea instancia sin errores""" + r = _make_result(duration_ms=500) + assert r.duration_ms == 500 + assert isinstance(r.meta, dict) + + +def test_dump_excluye_image(): + """model_dump excluye el campo image automaticamente""" + + class FakeImage: + """Objeto imagen simulado (no PIL real).""" + pass + + r = _make_result(image=FakeImage(), duration_ms=800) + d = r.model_dump() + assert isinstance(d, dict) + assert "image" not in d, "image no debe aparecer en model_dump()" + + +def test_dump_incluye_meta_duration_vram(): + """model_dump incluye meta, duration_ms y vram_peak_mb""" + meta = {"model": "sdxl", "seed_used": 99, "sampler": "dpm++2m"} + r = _make_result(duration_ms=2000, vram_peak_mb=6144, meta=meta) + d = r.model_dump() + assert "meta" in d + assert "duration_ms" in d + assert "vram_peak_mb" in d + assert d["duration_ms"] == 2000 + assert d["vram_peak_mb"] == 6144 + + +def test_meta_dict_viaja_completo(): + """meta dict se conserva completo en model_dump""" + meta = { + "model": "flux_dev", + "seed_used": 777, + "sampler": "euler", + "custom_key": "custom_value", + "nested": {"a": 1}, + } + r = _make_result(meta=meta) + d = r.model_dump() + assert d["meta"] == meta + assert d["meta"]["custom_key"] == "custom_value" + assert d["meta"]["nested"] == {"a": 1} + + +def test_dump_json_parseable(): + """model_dump_json retorna string JSON parseable sin image""" + meta = {"model": "sd15", "seed_used": 1} + r = _make_result(duration_ms=100, meta=meta) + raw = r.model_dump_json() + assert isinstance(raw, str) + parsed = json.loads(raw) + assert "meta" in parsed + assert "duration_ms" in parsed + assert "image" not in parsed + + +def test_vram_peak_mb_none_serializa(): + """vram_peak_mb=None se serializa correctamente a null""" + r = _make_result(vram_peak_mb=None) + d = r.model_dump() + assert d["vram_peak_mb"] is None + + +def test_image_none_permitido(): + """image puede ser None (generacion fallida)""" + r = _make_result(image=None) + assert r.image is None + d = r.model_dump() + assert "image" not in d diff --git a/python/functions/ml/tests/test_image_generator_protocol.py b/python/functions/ml/tests/test_image_generator_protocol.py new file mode 100644 index 00000000..bf987dfc --- /dev/null +++ b/python/functions/ml/tests/test_image_generator_protocol.py @@ -0,0 +1,70 @@ +"""Tests para ImageGenerator Protocol — runtime_checkable y structural subtyping.""" + +import sys + +sys.path.insert(0, "python/functions/ml") + +from image_gen_result import ImageGenResult +from image_generator import ImageGenerator + + +class MockGenerator: + """Implementacion dummy que satisface ImageGenerator sin herencia explicita.""" + + def generate(self, config): + """Retorna un ImageGenResult sin imagen real.""" + return ImageGenResult( + image=None, + meta={"model": "mock", "seed_used": 0, "sampler": "euler"}, + duration_ms=1, + vram_peak_mb=None, + ) + + +class NotAGenerator: + """Clase que NO implementa generate — no satisface el Protocol.""" + + def predict(self, x): + return x + + +def test_dummy_satisface_protocol(): + """clase dummy que implementa generate satisface isinstance(x, ImageGenerator)""" + gen = MockGenerator() + assert isinstance(gen, ImageGenerator), ( + "MockGenerator debe satisfacer ImageGenerator Protocol (runtime_checkable)" + ) + + +def test_resultado_es_image_gen_result(): + """generate() retorna ImageGenResult""" + gen = MockGenerator() + result = gen.generate(config=None) + assert isinstance(result, ImageGenResult) + + +def test_clase_sin_generate_no_satisface_protocol(): + """clase sin metodo generate NO satisface isinstance check""" + not_gen = NotAGenerator() + assert not isinstance(not_gen, ImageGenerator), ( + "NotAGenerator no debe satisfacer ImageGenerator Protocol" + ) + + +def test_multiples_instancias_satisfacen_protocol(): + """multiples instancias del mismo dummy satisfacen el Protocol""" + for _ in range(3): + gen = MockGenerator() + assert isinstance(gen, ImageGenerator) + + +def test_lambda_con_callable_no_satisface_protocol(): + """un callable lambda no satisface el Protocol (no tiene metodo .generate)""" + + class LambdaLike: + def __call__(self, config): + return None + + obj = LambdaLike() + # __call__ no es lo mismo que .generate — no debe satisfacer el protocol + assert not isinstance(obj, ImageGenerator) diff --git a/python/functions/ml/tests/test_image_grid.py b/python/functions/ml/tests/test_image_grid.py new file mode 100644 index 00000000..f22cd34f --- /dev/null +++ b/python/functions/ml/tests/test_image_grid.py @@ -0,0 +1,85 @@ +"""Tests para image_grid — combina imagenes en grid NxM.""" + +import sys + +sys.path.insert(0, "python/functions/ml") + +import math +import pytest + +PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping") + +from image_grid import image_grid + + +def _make_images(n: int, w: int = 16, h: int = 16): + from PIL import Image + return [Image.new("RGB", (w, h), color=(i * 10, i * 10, i * 10)) for i in range(n)] + + +def test_grid_4_imagenes_2_cols_dimensiones_correctas(): + """grid de 4 imagenes 16x16 cols=2 produce ancho/alto correcto""" + images = _make_images(4, w=16, h=16) + result = image_grid(images, cols=2, gap_px=0) + + # rows = ceil(4/2) = 2 + # canvas_w = 2*16 + 3*0 = 32 (con gap_px=0: cols*w + (cols+1)*0) + # canvas_h = 2*16 + 3*0 = 32 + assert result.width == 32, f"Ancho esperado 32, got {result.width}" + assert result.height == 32, f"Alto esperado 32, got {result.height}" + + +def test_grid_4_imagenes_2_cols_con_gap(): + """grid de 4 imagenes cols=2 gap_px=8 tiene dimensiones correctas con gap""" + images = _make_images(4, w=16, h=16) + gap = 8 + cols = 2 + rows = math.ceil(4 / cols) + expected_w = cols * 16 + (cols + 1) * gap + expected_h = rows * 16 + (rows + 1) * gap + + result = image_grid(images, cols=cols, gap_px=gap) + assert result.width == expected_w, f"Ancho: expected {expected_w}, got {result.width}" + assert result.height == expected_h, f"Alto: expected {expected_h}, got {result.height}" + + +def test_grid_1_imagen_1_col(): + """grid de 1 imagen 1 col = imagen sola mas gaps""" + images = _make_images(1, w=32, h=32) + result = image_grid(images, cols=1, gap_px=4) + # rows=1, cols=1 → w = 1*32 + 2*4 = 40, h = 1*32 + 2*4 = 40 + assert result.width == 40 + assert result.height == 40 + + +def test_grid_retorna_imagen_rgb(): + """el resultado es una imagen RGB""" + from PIL import Image + images = _make_images(2, w=8, h=8) + result = image_grid(images, cols=2) + assert isinstance(result, Image.Image) + assert result.mode == "RGB" + + +def test_grid_con_labels_no_falla(): + """labels opcionales — no lanza excepcion""" + images = _make_images(4, w=16, h=16) + labels = ["a", "b", "c", "d"] + result = image_grid(images, cols=2, labels=labels, gap_px=0) + # Debe devolver imagen válida + assert result.width > 0 + assert result.height > 0 + + +def test_grid_sin_labels_no_falla(): + """sin labels funciona correctamente""" + images = _make_images(3, w=16, h=16) + result = image_grid(images, cols=3, labels=None, gap_px=0) + assert result.width == 3 * 16 + assert result.height == 16 # 1 row + + +def test_grid_lista_vacia_levanta_value_error(): + """lista vacia levanta ValueError""" + with pytest.raises(ValueError): + image_grid([], cols=2) diff --git a/python/functions/ml/tests/test_image_save_png.py b/python/functions/ml/tests/test_image_save_png.py new file mode 100644 index 00000000..c77e5672 --- /dev/null +++ b/python/functions/ml/tests/test_image_save_png.py @@ -0,0 +1,81 @@ +"""Tests para image_save_png — guarda PNG con metadata tEXt embebida.""" + +import sys + +sys.path.insert(0, "python/functions/ml") + +import pytest + +PIL = pytest.importorskip("PIL", reason="Pillow no instalado — skipping") + +from image_save_png import image_save_png + + +def test_guarda_archivo_y_retorna_ruta_absoluta(tmp_path): + """crea imagen 8x8, guarda y retorna ruta absoluta""" + from PIL import Image + + img = Image.new("RGB", (8, 8), color=(255, 0, 0)) + dest = str(tmp_path / "test.png") + result = image_save_png(img, dest) + + import os + assert os.path.isfile(result), f"El archivo no existe: {result}" + assert os.path.isabs(result), f"La ruta no es absoluta: {result}" + + +def test_metadata_embebida_en_chunks_text(tmp_path): + """metadata se embebe en chunks tEXt y se puede releer con Image.text""" + from PIL import Image + + img = Image.new("RGB", (8, 8), color=(0, 128, 0)) + dest = str(tmp_path / "with_meta.png") + meta = {"prompt": "hi", "seed": "42"} + image_save_png(img, dest, metadata=meta) + + reopened = Image.open(dest) + text_data = reopened.text # dict de chunks tEXt del PNG + assert "prompt" in text_data, f"Falta clave 'prompt' en PNG text chunks: {text_data}" + assert "seed" in text_data, f"Falta clave 'seed' en PNG text chunks: {text_data}" + assert text_data["prompt"] == "hi" + assert text_data["seed"] == "42" + + +def test_sin_metadata_no_falla(tmp_path): + """sin metadata el PNG se guarda igualmente""" + from PIL import Image + + img = Image.new("RGB", (8, 8)) + dest = str(tmp_path / "no_meta.png") + result = image_save_png(img, dest, metadata=None) + + import os + assert os.path.isfile(result) + + +def test_crea_directorio_padre(tmp_path): + """crea directorio padre si no existe""" + from PIL import Image + import os + + img = Image.new("RGB", (8, 8)) + dest = str(tmp_path / "subdir" / "deep" / "image.png") + result = image_save_png(img, dest) + assert os.path.isfile(result) + + +def test_metadata_valores_numericos_se_convierten_a_str(tmp_path): + """valores numericos en metadata se convierten a str automaticamente""" + from PIL import Image + + img = Image.new("RGB", (8, 8)) + dest = str(tmp_path / "numeric.png") + meta = {"steps": 30, "cfg_scale": 7.5} + image_save_png(img, dest, metadata=meta) + + reopened = Image.open(dest) + text_data = reopened.text + assert "steps" in text_data + assert "cfg_scale" in text_data + assert text_data["steps"] == "30" + assert text_data["cfg_scale"] == "7.5" diff --git a/python/functions/ml/tests/test_model_ref_lora_ref.py b/python/functions/ml/tests/test_model_ref_lora_ref.py new file mode 100644 index 00000000..47ccfbdb --- /dev/null +++ b/python/functions/ml/tests/test_model_ref_lora_ref.py @@ -0,0 +1,136 @@ +"""Tests para ModelRef y LoraRef — instanciación, dump y validación.""" + +import json +import sys + +# Importar desde el subdirectorio ml directamente para evitar colisiones de tipos +# entre ml.model_ref.ModelRef y model_ref.ModelRef en pydantic. +sys.path.insert(0, "python/functions/ml") + +import pytest + +from lora_ref import LoraRef +from model_ref import ModelRef + + +# --------------------------------------------------------------------------- +# ModelRef +# --------------------------------------------------------------------------- + + +def test_model_ref_instancia_ok(): + """ModelRef instancia sin errores""" + m = ModelRef(name="stabilityai/sdxl-base-1.0", model_type="sdxl") + assert m.name == "stabilityai/sdxl-base-1.0" + assert m.model_type == "sdxl" + + +def test_model_ref_quantization_default_fp16(): + """quantization default es fp16""" + m = ModelRef(name="runwayml/stable-diffusion-v1-5", model_type="sd15") + assert m.quantization == "fp16" + + +def test_model_ref_quantization_override(): + """quantization se puede cambiar a otro valor válido""" + m = ModelRef(name="some/model", model_type="flux_dev", quantization="bf16") + assert m.quantization == "bf16" + + +def test_model_ref_path_default_none(): + """path es None por defecto""" + m = ModelRef(name="some/model", model_type="sd15") + assert m.path is None + + +def test_model_ref_path_set(): + """path se puede especificar""" + m = ModelRef(name="some/model", model_type="sd15", path="/models/sd15.safetensors") + assert m.path == "/models/sd15.safetensors" + + +def test_model_ref_dump(): + """model_dump devuelve dict con las claves esperadas""" + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic no disponible") + + m = ModelRef(name="some/model", model_type="sdxl", quantization="q8_0") + d = m.model_dump() + assert isinstance(d, dict) + assert d["name"] == "some/model" + assert d["model_type"] == "sdxl" + assert d["quantization"] == "q8_0" + + +def test_model_ref_validate_roundtrip(): + """roundtrip model_dump_json / model_validate ok""" + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic no disponible") + + m = ModelRef(name="some/model", model_type="sd3", quantization="fp32") + raw = json.loads(m.model_dump_json()) + m2 = ModelRef.model_validate(raw) + assert m2.name == m.name + assert m2.model_type == m.model_type + assert m2.quantization == m.quantization + + +# --------------------------------------------------------------------------- +# LoraRef +# --------------------------------------------------------------------------- + + +def test_lora_ref_instancia_ok(): + """LoraRef instancia con path obligatorio""" + lr = LoraRef(path="/loras/anime.safetensors") + assert lr.path == "/loras/anime.safetensors" + + +def test_lora_ref_weight_default_1(): + """LoraRef weight default es 1.0""" + lr = LoraRef(path="/loras/style.safetensors") + assert lr.weight == 1.0 + + +def test_lora_ref_weight_override(): + """LoraRef weight se puede cambiar""" + lr = LoraRef(path="/loras/style.safetensors", weight=0.7) + assert lr.weight == 0.7 + + +def test_lora_ref_scale_default_none(): + """LoraRef scale default es None""" + lr = LoraRef(path="/loras/x.safetensors") + assert lr.scale is None + + +def test_lora_ref_dump(): + """LoraRef model_dump devuelve dict con las claves esperadas""" + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic no disponible") + + lr = LoraRef(path="/loras/x.safetensors", weight=0.8, scale=0.9) + d = lr.model_dump() + assert d["path"] == "/loras/x.safetensors" + assert d["weight"] == 0.8 + assert d["scale"] == 0.9 + + +def test_lora_ref_validate_roundtrip(): + """roundtrip dump / validate ok""" + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic no disponible") + + lr = LoraRef(path="/loras/x.safetensors", weight=0.5) + raw = json.loads(lr.model_dump_json()) + lr2 = LoraRef.model_validate(raw) + assert lr2.path == lr.path + assert lr2.weight == lr.weight diff --git a/python/functions/ml/tests/test_safetensors_inspect.py b/python/functions/ml/tests/test_safetensors_inspect.py new file mode 100644 index 00000000..96aaf4c0 --- /dev/null +++ b/python/functions/ml/tests/test_safetensors_inspect.py @@ -0,0 +1,160 @@ +"""Tests para safetensors_inspect — parseo de header sin dependencias externas.""" + +import json +import os +import struct +import sys +import tempfile + +sys.path.insert(0, "python/functions/ml") + +import pytest + +from safetensors_inspect import safetensors_inspect + + +def _write_safetensors(path: str, header: dict, data: bytes = b"") -> None: + """Escribe un archivo safetensors mínimo siguiendo la spec oficial. + + Spec: https://github.com/huggingface/safetensors#format + - 8 bytes: uint64 little-endian con la longitud N del header JSON + - N bytes: JSON del header + - (opcional) bytes de datos de tensores + """ + header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") + header_len = len(header_bytes) + with open(path, "wb") as f: + f.write(struct.pack(" dict: + """Genera un header con n_tensors tensores sintéticos.""" + header = { + "__metadata__": {"format": "pt", "creator": "test"}, + } + for i in range(n_tensors): + header[f"tensor_{i}"] = { + "dtype": "F32", + "shape": [4, 4], + "data_offsets": [i * 64, (i + 1) * 64], + } + return header + + +def test_n_tensors_correcto(tmp_path): + """n_tensors refleja el numero de tensores en el header""" + path = str(tmp_path / "model.safetensors") + _write_safetensors(path, _make_minimal_header(n_tensors=3)) + + result = safetensors_inspect(path) + assert result["n_tensors"] == 3 + + +def test_total_size_bytes_correcto(tmp_path): + """total_size_bytes refleja el tamaño real del archivo""" + path = str(tmp_path / "model.safetensors") + data = b"\x00" * 128 # 128 bytes de datos de tensor + _write_safetensors(path, _make_minimal_header(2), data=data) + + file_size = os.path.getsize(path) + result = safetensors_inspect(path) + assert result["total_size_bytes"] == file_size + + +def test_metadata_campo_dunder_presente(tmp_path): + """metadata devuelve el contenido de __metadata__""" + path = str(tmp_path / "model.safetensors") + header = { + "__metadata__": {"format": "pt", "model_name": "test_model"}, + "weight": {"dtype": "BF16", "shape": [8], "data_offsets": [0, 16]}, + } + _write_safetensors(path, header) + + result = safetensors_inspect(path) + assert result["metadata"] == {"format": "pt", "model_name": "test_model"} + + +def test_tensors_lista_correcta(tmp_path): + """tensors es lista con una entrada por tensor del header""" + path = str(tmp_path / "model.safetensors") + header = { + "__metadata__": {}, + "embed.weight": {"dtype": "F16", "shape": [128, 64], "data_offsets": [0, 16384]}, + "proj.bias": {"dtype": "F32", "shape": [64], "data_offsets": [16384, 16640]}, + } + _write_safetensors(path, header) + + result = safetensors_inspect(path) + assert result["n_tensors"] == 2 + names = {t["name"] for t in result["tensors"]} + assert "embed.weight" in names + assert "proj.bias" in names + + +def test_tensor_campos_dtype_shape_offsets(tmp_path): + """cada tensor tiene dtype, shape y data_offsets""" + path = str(tmp_path / "model.safetensors") + header = { + "__metadata__": {}, + "my_tensor": {"dtype": "I32", "shape": [2, 3], "data_offsets": [0, 24]}, + } + _write_safetensors(path, header) + + result = safetensors_inspect(path) + t = result["tensors"][0] + assert t["dtype"] == "I32" + assert t["shape"] == [2, 3] + assert t["data_offsets"] == [0, 24] + + +def test_path_absoluto_en_resultado(tmp_path): + """result['path'] es la ruta absoluta del archivo""" + path = str(tmp_path / "model.safetensors") + _write_safetensors(path, _make_minimal_header(1)) + + result = safetensors_inspect(path) + assert os.path.isabs(result["path"]) + assert result["path"].endswith("model.safetensors") + + +def test_archivo_no_encontrado_levanta_file_not_found(tmp_path): + """FileNotFoundError si el archivo no existe""" + with pytest.raises(FileNotFoundError): + safetensors_inspect(str(tmp_path / "nonexistent.safetensors")) + + +def test_header_invalido_levanta_value_error(tmp_path): + """ValueError si el header no es JSON válido""" + path = str(tmp_path / "bad.safetensors") + with open(path, "wb") as f: + bad_header = b"NOT JSON!!" + f.write(struct.pack(" 0) + finally: + for p in patches: + p.stop() + + def test_preference_desconocida_retorna_cpu_con_warning(self): + """preference desconocida retorna cpu con warning""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = torch_device_select("vulkan") + self.assertEqual(result, "cpu") + self.assertTrue(len(w) > 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/functions/ml/tests/test_vram_budget.py b/python/functions/ml/tests/test_vram_budget.py new file mode 100644 index 00000000..a3cd44af --- /dev/null +++ b/python/functions/ml/tests/test_vram_budget.py @@ -0,0 +1,78 @@ +"""Tests para vram_budget.""" +import sys +import os + +# Ajustar path para importar desde python/functions/ +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from ml.vram_budget import vram_budget + + +def test_sdxl_fp16_fits_24gb(): + """SDXL fp16 en una GPU de 24 GB deberia caber con headroom positivo.""" + result = vram_budget( + gpu_vram_total_mb=24576, # 24 GB + model_type="sdxl", + quantization="fp16", + n_loras=0, + width=1024, + height=1024, + batch_size=1, + ) + assert result["required_mb"] > 0, "required_mb debe ser positivo" + assert result["fits"] is True, f"SDXL fp16 debe caber en 24 GB, required={result['required_mb']} MB" + assert result["headroom_mb"] > 0, f"headroom debe ser positivo, got {result['headroom_mb']}" + assert result["warning"] is None, f"no debe haber warning, got: {result['warning']}" + + +def test_flux_fp16_no_fits_8gb(): + """Flux fp16 (~23 GB) no debe caber en una GPU de 8 GB.""" + result = vram_budget( + gpu_vram_total_mb=8192, # 8 GB + model_type="flux_dev", + quantization="fp16", + n_loras=0, + width=1024, + height=1024, + batch_size=1, + ) + assert result["required_mb"] > 8192, f"Flux fp16 debe requerir mas de 8 GB, got {result['required_mb']} MB" + assert result["fits"] is False, "Flux fp16 no debe caber en 8 GB" + assert result["headroom_mb"] < 0, f"headroom debe ser negativo, got {result['headroom_mb']}" + assert result["warning"] is not None, "debe haber warning con informacion de deficit" + assert "+N MB" in result["warning"] or "+" in result["warning"], \ + f"warning debe indicar cuantos MB extra se necesitan: {result['warning']}" + + +def test_lora_plus_quant_warning(): + """LoRA con quantization q8_0 debe emitir warning de incompatibilidad.""" + result = vram_budget( + gpu_vram_total_mb=24576, + model_type="sdxl", + quantization="q8_0", + n_loras=2, + width=1024, + height=1024, + batch_size=1, + ) + assert result["warning"] is not None, "debe haber warning por lora+quantization incompatible" + assert "incompatible" in result["warning"].lower(), \ + f"warning debe mencionar incompatibilidad: {result['warning']}" + assert "fp16" in result["warning"], \ + f"warning debe sugerir fp16: {result['warning']}" + + +def test_unknown_combo(): + """Combinacion (model_type, quant) desconocida debe retornar required_mb=-1 y warning.""" + result = vram_budget( + gpu_vram_total_mb=24576, + model_type="modelo_inventado", + quantization="q99_k", + n_loras=0, + ) + assert result["required_mb"] == -1, \ + f"required_mb debe ser -1 para combo desconocido, got {result['required_mb']}" + assert result["fits"] is False, "fits debe ser False para combo desconocido" + assert result["warning"] is not None, "debe haber warning para combo desconocido" + assert "unknown" in result["warning"].lower(), \ + f"warning debe mencionar 'unknown': {result['warning']}" diff --git a/python/functions/ml/torch_device_select.md b/python/functions/ml/torch_device_select.md new file mode 100644 index 00000000..e2f1f7b2 --- /dev/null +++ b/python/functions/ml/torch_device_select.md @@ -0,0 +1,67 @@ +--- +name: torch_device_select +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: impure +signature: "def torch_device_select(preference: str = 'auto') -> str" +description: "Selecciona el torch device optimo segun preferencia y disponibilidad real del hardware. 'auto' elige CUDA > MPS > CPU. Para preferencias explicitas valida disponibilidad y hace fallback a CPU con warnings.warn." +tags: [torch, pytorch, cuda, mps, device, hardware, probe, ml, apple-silicon] +uses_functions: [cuda_available_py_ml] +uses_types: [] +returns: [] +returns_optional: false +error_type: "error_go_core" +imports: [] +params: + - name: preference + desc: "'auto' detecta el mejor device disponible (CUDA > MPS > CPU). 'cuda' fuerza cuda:0. 'cuda:N' fuerza GPU N. 'mps' fuerza Apple Silicon. 'cpu' siempre retorna cpu." +output: "string de device listo para torch: 'cuda:0', 'cuda:N', 'mps' o 'cpu'. Nunca lanza excepcion — fallback a 'cpu' con warning si el device solicitado no esta disponible." +tested: true +tests: + - "preference=cpu siempre retorna cpu" + - "preference=auto sin cuda ni mps retorna cpu" + - "preference=cuda sin cuda disponible retorna cpu con warning" + - "preference=cuda:5 con solo 1 GPU retorna cpu con warning" + - "preference desconocida retorna cpu con warning" +test_file_path: "python/functions/ml/tests/test_torch_device_select.py" +file_path: "python/functions/ml/torch_device_select.py" +--- + +## Ejemplo + +```python +from ml.torch_device_select import torch_device_select + +# Deteccion automatica (recomendado) +device = torch_device_select() # "cuda:0" o "mps" o "cpu" + +# Forzar CPU para reproducibilidad +device = torch_device_select("cpu") # siempre "cpu" + +# Preferencia explicita con fallback automatico +device = torch_device_select("cuda") # "cuda:0" o "cpu" + warning + +# Uso tipico al cargar un modelo +import torch +device_str = torch_device_select("auto") +model = MyModel().to(torch.device(device_str)) +``` + +## Comparacion con gliner_load_model + +`gliner_load_model` usa internamente `_resolve_device` con la misma logica +CUDA/CPU. `torch_device_select` extiende ese patron con: +- Soporte MPS (Apple Silicon M1/M2/M3). +- Seleccion de GPU especifica (`cuda:N`). +- Fallback con `warnings.warn` en vez de silencio. + +## Notas + +- No levanta excepcion si torch no esta instalado: todos los helpers internos + capturan ImportError y tratan el device como no disponible. +- `warnings.warn` en vez de logging para no imponer dependencia de logging al caller. +- MPS requiere torch >= 1.12 y macOS 12.3+. En sistemas Linux/Windows + `torch.backends.mps` puede no existir — el helper lo maneja con `hasattr`. +- impure: depende del estado del hardware y de las librerias instaladas. diff --git a/python/functions/ml/torch_device_select.py b/python/functions/ml/torch_device_select.py new file mode 100644 index 00000000..aff8582a --- /dev/null +++ b/python/functions/ml/torch_device_select.py @@ -0,0 +1,108 @@ +"""Selecciona el mejor torch device disponible segun preferencia.""" + +from __future__ import annotations + +import warnings + + +def _cuda_available() -> bool: + """Retorna True si torch esta instalado y CUDA disponible.""" + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False + + +def _mps_available() -> bool: + """Retorna True si torch esta instalado y MPS (Apple Silicon) disponible.""" + try: + import torch + return ( + hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + ) + except ImportError: + return False + + +def _cuda_device_count() -> int: + """Retorna el numero de dispositivos CUDA disponibles.""" + try: + import torch + return torch.cuda.device_count() if torch.cuda.is_available() else 0 + except ImportError: + return 0 + + +def torch_device_select(preference: str = "auto") -> str: + """Selecciona el torch device optimo segun preferencia y disponibilidad. + + Con preference='auto': elige CUDA si disponible, luego MPS (Apple M1/M2), + luego CPU. Para preferencias explicitas, valida disponibilidad y hace + fallback a CPU con advertencia si el device solicitado no esta disponible. + + Args: + preference: 'auto' | 'cuda' | 'cuda:N' | 'mps' | 'cpu'. + 'auto': detecta automaticamente el mejor device. + 'cuda': usa cuda:0 si disponible, fallback a cpu. + 'cuda:N': usa el dispositivo N si existe, fallback a cpu. + 'mps': usa MPS si disponible (Mac Apple Silicon), fallback a cpu. + 'cpu': siempre retorna 'cpu'. + + Returns: + String de device para torch: 'cuda:0', 'cuda:N', 'mps' o 'cpu'. + """ + if preference == "cpu": + return "cpu" + + if preference == "auto": + if _cuda_available(): + return "cuda:0" + if _mps_available(): + return "mps" + return "cpu" + + if preference == "mps": + if _mps_available(): + return "mps" + warnings.warn( + "MPS no esta disponible en este sistema. Usando 'cpu'.", + stacklevel=2, + ) + return "cpu" + + if preference == "cuda": + if _cuda_available(): + return "cuda:0" + warnings.warn( + "CUDA no esta disponible en este sistema. Usando 'cpu'.", + stacklevel=2, + ) + return "cpu" + + if preference.startswith("cuda:"): + try: + device_idx = int(preference.split(":")[1]) + except (IndexError, ValueError): + warnings.warn( + f"Formato de device no valido: '{preference}'. Usando 'cpu'.", + stacklevel=2, + ) + return "cpu" + + count = _cuda_device_count() + if _cuda_available() and device_idx < count: + return preference + warnings.warn( + f"Device '{preference}' no disponible " + f"(cuda_count={count}). Usando 'cpu'.", + stacklevel=2, + ) + return "cpu" + + warnings.warn( + f"Preferencia desconocida: '{preference}'. Usando 'cpu'.", + stacklevel=2, + ) + return "cpu" diff --git a/python/functions/ml/vram_budget.md b/python/functions/ml/vram_budget.md new file mode 100644 index 00000000..43e09d4f --- /dev/null +++ b/python/functions/ml/vram_budget.md @@ -0,0 +1,73 @@ +--- +name: vram_budget +kind: function +lang: py +domain: ml +version: "1.0.0" +purity: pure +signature: "def vram_budget(gpu_vram_total_mb: int, model_type: str, quantization: str, n_loras: int = 0, width: int = 1024, height: int = 1024, batch_size: int = 1) -> dict" +description: "Estima la VRAM requerida para ejecutar un modelo de generacion de imagen via heuristicas tabuladas por (model_type, quantization). Retorna VRAM estimada, si cabe en la GPU indicada, headroom disponible, y warnings por incompatibilidades (lora+quant) o falta de VRAM. Funcion pura: solo lookup y aritmetica, sin GPU ni runtime." +tags: [ml, vram, gpu, budget, stable-diffusion, flux, sdxl, quantization, lora, estimation, pure] +uses_functions: [] +uses_types: [] +returns: [] +returns_optional: false +error_type: "" +imports: [] +params: + - name: gpu_vram_total_mb + desc: "VRAM total de la GPU objetivo en MB. Obtener con gpu_info() o torch.cuda.get_device_properties()." + - name: model_type + desc: "Tipo de modelo. Valores soportados: sd15, sdxl, flux_dev, flux_schnell, sd3, qwen_image. Combinaciones fuera de la tabla retornan required_mb=-1." + - name: quantization + desc: "Esquema de cuantizacion. Valores: fp16, q8_0, q4_0 (y variantes q4_k_m, q5_k_m, q6_k). Afecta tanto el tamano base como la compatibilidad con LoRAs." + - name: n_loras + desc: "Numero de LoRAs a cargar simultaneamente en VRAM. Cada LoRA suma ~300 MB. Con quantization != fp16 se emite warning de incompatibilidad." + - name: width + desc: "Ancho en pixeles de la imagen a generar. Afecta el overhead de latentes (mayor resolucion = mas VRAM para activaciones)." + - name: height + desc: "Alto en pixeles de la imagen a generar." + - name: batch_size + desc: "Numero de imagenes generadas en paralelo. El overhead de latentes escala linealmente con batch_size." +output: "dict con: required_mb (int, -1 si combo desconocido), fits (bool, True si cabe en gpu_vram_total_mb), headroom_mb (int, negativo si no cabe, 0 si combo desconocido), warning (str o None con aviso de incompatibilidad lora+quant o deficit de VRAM)." +tested: true +tests: + - "sdxl fp16 cabe en 24gb con headroom positivo" + - "flux fp16 no cabe en 8gb warning con deficit" + - "lora con quantization incompatible emite warning" + - "combo desconocido retorna required minus1 y warning" +test_file_path: "python/functions/ml/tests/test_vram_budget.py" +file_path: "python/functions/ml/vram_budget.py" +--- + +## Ejemplo + +```python +from ml.vram_budget import vram_budget + +# SDXL fp16 en 24 GB — cabe +r = vram_budget(24576, "sdxl", "fp16") +# {"required_mb": 6960, "fits": True, "headroom_mb": 17616, "warning": None} + +# Flux dev fp16 en 8 GB — no cabe +r = vram_budget(8192, "flux_dev", "fp16") +# {"required_mb": 23512, "fits": False, "headroom_mb": -15320, "warning": "needs +15320 MB ..."} + +# Flux dev q4_0 en 8 GB con 1 LoRA — incompatible +r = vram_budget(8192, "flux_dev", "q4_0", n_loras=1) +# {"required_mb": 7300, "fits": True, "headroom_mb": 892, +# "warning": "lora+quantization incompatible — usa fp16 para cargar LoRAs con flux_dev"} + +# Combo desconocido +r = vram_budget(24576, "mi_modelo", "q99_k") +# {"required_mb": -1, "fits": False, "headroom_mb": 0, +# "warning": "unknown model/quant combo: ('mi_modelo', 'q99_k')"} +``` + +## Notas + +- La tabla `_MODEL_VRAM_MB` es una estimacion inicial; el usuario debe calibrarla con mediciones reales (nvidia-smi durante inference). +- El overhead de latentes se calcula como `w*h/64 MB` para SD/SDXL/SD3 y `w*h/32 MB` para modelos Flux (espacio latente con mas canales). +- LoRA warning tiene prioridad sobre el warning de no-fits: si hay incompatibilidad lora+quant, ese warning se emite aunque el modelo no quepa. +- Para obtener gpu_vram_total_mb en tiempo real usar `gpu_info_py_ml` (impure). +- Funcion pura: misma entrada, misma salida. Sin I/O ni dependencias externas. diff --git a/python/functions/ml/vram_budget.py b/python/functions/ml/vram_budget.py new file mode 100644 index 00000000..c898f602 --- /dev/null +++ b/python/functions/ml/vram_budget.py @@ -0,0 +1,111 @@ +"""Estimador de VRAM requerida para modelos de generacion de imagen.""" + +# Base weights por (model_type, quantization) en MB. +# Incluye pesos del modelo + overhead tipico del contexto de inferencia. +_MODEL_VRAM_MB: dict[tuple[str, str], int] = { + ("sd15", "fp16"): 2100, + ("sd15", "q8_0"): 1200, + ("sd15", "q4_0"): 700, + ("sdxl", "fp16"): 6800, + ("sdxl", "q8_0"): 3800, + ("sdxl", "q4_0"): 2200, + ("flux_dev", "fp16"): 23000, + ("flux_dev", "q8_0"): 13000, + ("flux_dev", "q4_0"): 7000, + ("flux_schnell", "fp16"): 23000, + ("flux_schnell", "q8_0"): 12500, + ("flux_schnell", "q4_0"): 6500, + ("sd3", "fp16"): 8500, + ("sd3", "q8_0"): 4800, + ("sd3", "q4_0"): 2800, + ("qwen_image", "fp16"): 8000, + ("qwen_image", "q8_0"): 4500, + ("qwen_image", "q4_0"): 2600, +} + +# MB por LoRA adicional (estimacion conservadora en fp16). +_LORA_MB = 300 + +# Modelos que requieren overhead de latente mas alto (Flux usa bloques transformer mas grandes). +_FLUX_MODELS = {"flux_dev", "flux_schnell"} + +# Quantizaciones que son incompatibles con LoRA en la mayoria de runtimes. +_QUANT_LORA_INCOMPATIBLE = {"q8_0", "q4_0", "q4_k_m", "q5_k_m", "q6_k"} + + +def _latent_overhead_mb(model_type: str, width: int, height: int, batch_size: int) -> int: + """Estima el overhead de VRAM para activaciones y latentes en MB.""" + pixels = width * height + if model_type in _FLUX_MODELS: + # Flux usa un espacio latente 16x mas comprimido pero con mas canales. + overhead = pixels // 32 + else: + # SD 1.5 / SDXL / SD3: overhead aprox w*h/64 MB. + overhead = pixels // 64 + return overhead * batch_size + + +def vram_budget( + gpu_vram_total_mb: int, + model_type: str, + quantization: str, + n_loras: int = 0, + width: int = 1024, + height: int = 1024, + batch_size: int = 1, +) -> dict: + """Estima la VRAM requerida para ejecutar un modelo de generacion de imagen. + + Usa heuristicas tabuladas por (model_type, quantization) mas overhead de + latentes y LoRAs. No requiere GPU ni runtime — solo lookup y aritmetica. + + Args: + gpu_vram_total_mb: VRAM total de la GPU en MB. + model_type: Tipo de modelo. Valores: sd15, sdxl, flux_dev, flux_schnell, sd3, qwen_image. + quantization: Esquema de cuantizacion. Valores: fp16, q8_0, q4_0, etc. + n_loras: Numero de LoRAs a cargar simultaneamente (default 0). + width: Ancho de la imagen a generar en pixeles (default 1024). + height: Alto de la imagen a generar en pixeles (default 1024). + batch_size: Numero de imagenes en paralelo (default 1). + + Returns: + dict con: + - required_mb (int): VRAM estimada necesaria en MB. -1 si combo desconocido. + - fits (bool): True si required_mb <= gpu_vram_total_mb. + - headroom_mb (int): MB sobrantes (negativo si no cabe). 0 si combo desconocido. + - warning (str | None): Aviso sobre incompatibilidades o ajustes necesarios. + None si no hay advertencias. + """ + key = (model_type, quantization) + + if key not in _MODEL_VRAM_MB: + return { + "required_mb": -1, + "fits": False, + "headroom_mb": 0, + "warning": f"unknown model/quant combo: ({model_type!r}, {quantization!r})", + } + + base_mb = _MODEL_VRAM_MB[key] + latent_mb = _latent_overhead_mb(model_type, width, height, batch_size) + lora_mb = n_loras * _LORA_MB + + required_mb = base_mb + latent_mb + lora_mb + fits = required_mb <= gpu_vram_total_mb + headroom_mb = gpu_vram_total_mb - required_mb + + warning: str | None = None + + # LoRA + quantization incompatible en la mayoria de runtimes. + if n_loras > 0 and quantization in _QUANT_LORA_INCOMPATIBLE: + warning = f"lora+quantization incompatible — usa fp16 para cargar LoRAs con {model_type}" + elif not fits: + deficit = required_mb - gpu_vram_total_mb + warning = f"needs +{deficit} MB (required {required_mb} MB, available {gpu_vram_total_mb} MB)" + + return { + "required_mb": required_mb, + "fits": fits, + "headroom_mb": headroom_mb, + "warning": warning, + } diff --git a/python/pyproject.toml b/python/pyproject.toml index e400cfb5..cbb78f2f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -5,6 +5,7 @@ description = "Funciones Python del fn-registry: Metabase API, ML, utilidades" readme = "README.md" requires-python = ">=3.12" dependencies = [ + "chardet>=7.4.3", "contextily>=1.7.0", "cryptography>=46.0.6", "duckdb>=1.5.2", @@ -17,6 +18,7 @@ dependencies = [ "httpx", "matplotlib>=3.10.9", "openpyxl>=3.1.5", + "polars>=1.40.1", "pypdf>=6.10.0", "pyproj>=3.7.2", "python-docx>=1.2.0", diff --git a/python/types/ml/generation_config.md b/python/types/ml/generation_config.md new file mode 100644 index 00000000..636d3593 --- /dev/null +++ b/python/types/ml/generation_config.md @@ -0,0 +1,71 @@ +--- +name: generation_config +lang: py +domain: ml +version: "1.0.0" +algebraic: product +definition: | + class GenerationConfig(BaseModel): + prompt: str + negative_prompt: str | None = None + seed: int + steps: int + cfg_scale: float + sampler: SamplerName + width: int + height: int + model: ModelRef + loras: list[LoraRef] = [] + clip_skip: int | None = None +description: "Contrato de parametros para generacion de imagenes con modelos de difusion. Tipo central del dominio ml. Serializable a JSON canonico para intercambio con Go." +tags: [ml, diffusion, generation, config, stable-diffusion, flux, prompt, seed, sampler] +uses_types: [model_ref_py_ml, lora_ref_py_ml, sampler_name_py_ml] +file_path: "python/functions/ml/generation_config.py" +--- + +## Ejemplo + +```python +from model_ref import ModelRef +from lora_ref import LoraRef +from generation_config import GenerationConfig + +cfg = GenerationConfig( + prompt="a futuristic city at night, neon lights, photorealistic", + negative_prompt="blurry, low quality, watermark", + seed=42, + steps=30, + cfg_scale=7.5, + sampler="euler_a", + width=1024, + height=1024, + model=ModelRef( + name="stabilityai/stable-diffusion-xl-base-1.0", + model_type="sdxl", + ), + loras=[LoraRef(path="loras/detail_tweaker.safetensors", weight=0.8)], +) + +# Serializar a JSON canonico (para Go u otro servicio) +json_bytes = cfg.model_dump_json() + +# Deserializar desde JSON +cfg2 = GenerationConfig.model_validate_json(json_bytes) +assert cfg == cfg2 +``` + +## Notas + +Tipo producto inmutable (`frozen=True`). Fallback a `dataclass(frozen=True)` si pydantic no esta instalado (sin validacion de literales en runtime). + +Restricciones de dimensiones por arquitectura: + +| model_type | width/height recomendado | multiplo | +|---------------|--------------------------|----------| +| sd15, sd20 | 512 x 512 | 64 | +| sdxl, sd3 | 1024 x 1024 | 64 | +| flux_* | 1024 x 1024 (libre) | 16 | + +`seed=-1` indica semilla aleatoria — el backend genera una y la incluye en `ImageGenResult.meta["seed_used"]`. + +`loras=[]` es el default correcto (lista vacia). En el fallback dataclass se usa `tuple` para mantener la inmutabilidad (dataclass frozen + list = error). diff --git a/python/types/ml/image_gen_result.md b/python/types/ml/image_gen_result.md new file mode 100644 index 00000000..3401c5df --- /dev/null +++ b/python/types/ml/image_gen_result.md @@ -0,0 +1,46 @@ +--- +name: image_gen_result +lang: py +domain: ml +version: "1.0.0" +algebraic: product +definition: | + class ImageGenResult(BaseModel): + image: Any # PIL.Image.Image en runtime; no serializable + meta: dict[str, Any] # config, model, seed_used, sampler, ... + duration_ms: int + vram_peak_mb: int | None = None +description: "Resultado de una operacion de generacion de imagen con modelo de difusion. Contiene la imagen PIL, metadata de la generacion, duracion y pico de VRAM." +tags: [ml, diffusion, result, image, pil, vram, latency, stable-diffusion, flux] +uses_types: [generation_config_py_ml] +file_path: "python/functions/ml/image_gen_result.py" +--- + +## Notas + +El campo `image` es `PIL.Image.Image` en runtime pero se declara como `Any` para evitar un import duro de Pillow en este modulo. Esto permite importar `ImageGenResult` en contextos sin Pillow instalado (orquestadores, tests de tipos, servicios Go que parsean solo el JSON de `meta`). + +**Como guardar la imagen:** +```python +result: ImageGenResult = generator.generate(config) +result.image.save("output.png") # PIL API directa +``` + +**Como serializar para logging/Go:** +```python +# model_dump() excluye el campo image automaticamente +payload = result.model_dump() +# payload = {"meta": {...}, "duration_ms": 1234, "vram_peak_mb": 4096} +``` + +**Contenido esperado de `meta`:** + +| Clave | Tipo | Descripcion | +|----------------|--------|----------------------------------------------------| +| `config` | dict | `GenerationConfig.model_dump()` con params usados | +| `model` | str | Nombre del modelo (model_ref.name) | +| `seed_used` | int | Semilla real usada (importante cuando seed=-1) | +| `sampler` | str | Nombre del sampler ejecutado | +| `backend` | str | "diffusers", "sd_cpp", "comfy", etc. | + +El campo `vram_peak_mb` es `None` en inferencia CPU o cuando el backend no expone estadisticas de memoria GPU. En backends con soporte (diffusers + torch.cuda.max_memory_allocated) se rellena automaticamente. diff --git a/python/types/ml/image_generator.md b/python/types/ml/image_generator.md new file mode 100644 index 00000000..6df8b450 --- /dev/null +++ b/python/types/ml/image_generator.md @@ -0,0 +1,57 @@ +--- +name: image_generator +lang: py +domain: ml +version: "1.0.0" +algebraic: product +definition: | + @runtime_checkable + class ImageGenerator(Protocol): + def generate(self, config: GenerationConfig) -> ImageGenResult: ... +description: "Protocol para backends de generacion de imagenes con difusion. Cualquier clase con generate(config) -> ImageGenResult satisface la interfaz sin herencia explicita." +tags: [ml, diffusion, protocol, interface, generator, backend, stable-diffusion, flux] +uses_types: [generation_config_py_ml, image_gen_result_py_ml] +file_path: "python/functions/ml/image_generator.py" +--- + +## Notas + +`algebraic: product` porque un Protocol en Python es un tipo estructural que define un conjunto fijo de metodos — analogia mas cercana a product que a sum. No es un Literal/Union/Enum. + +El Protocol usa `@runtime_checkable` para permitir `isinstance(backend, ImageGenerator)` en runtime. La verificacion en runtime solo comprueba la presencia del metodo `generate`, no la firma completa. Para verificacion de tipos estatica estricta, usar mypy o pyright. + +**Structural subtyping** — no se necesita herencia: + +```python +from image_generator import ImageGenerator +from generation_config import GenerationConfig +from image_gen_result import ImageGenResult + +class DiffusersGenerator: + def generate(self, config: GenerationConfig) -> ImageGenResult: + # implementacion con diffusers + ... + +# Satisface ImageGenerator sin heredar de el +gen: ImageGenerator = DiffusersGenerator() +assert isinstance(gen, ImageGenerator) # True +``` + +**Uso tipico en funciones del registry:** + +```python +def batch_generate( + generator: ImageGenerator, + configs: list[GenerationConfig], +) -> list[ImageGenResult]: + return [generator.generate(cfg) for cfg in configs] +``` + +**Backends planificados:** + +| Backend | Descripcion | +|------------------|--------------------------------------------------| +| DiffusersGenerator | HuggingFace diffusers + torch (GPU/CPU) | +| SdCppGenerator | stable-diffusion.cpp via subprocess/ctypes | +| ComfyUIGenerator | Cliente HTTP a ComfyUI API local | +| MockGenerator | Genera imagen negra para tests sin GPU | diff --git a/python/types/ml/lora_ref.md b/python/types/ml/lora_ref.md new file mode 100644 index 00000000..c672db0d --- /dev/null +++ b/python/types/ml/lora_ref.md @@ -0,0 +1,29 @@ +--- +name: lora_ref +lang: py +domain: ml +version: "1.0.0" +algebraic: product +definition: | + class LoraRef(BaseModel): + path: str + weight: float = 1.0 + scale: float | None = None +description: "Referencia a un adaptador LoRA para generacion de imagenes. Especifica ruta, peso global y override de alpha. Inmutable." +tags: [ml, diffusion, lora, adapter, fine-tuning, stable-diffusion] +uses_types: [] +file_path: "python/functions/ml/lora_ref.py" +--- + +## Notas + +Tipo producto inmutable. Fallback a `dataclass(frozen=True)` si pydantic no esta instalado. + +`weight` controla cuanto influye el LoRA en la imagen generada: +- 0.0 — LoRA completamente desactivado (como si no estuviera). +- 1.0 — fuerza original del LoRA (lo que el autor recomienda tipicamente). +- >1.0 — sobre-aplicacion: puede producir artefactos o saturacion del estilo. + +`scale` es el override del alpha (escala del rango): si el archivo del LoRA tiene `alpha=16` y `rank=32`, el factor efectivo es `alpha/rank = 0.5`. Pasar `scale=1.0` fuerza factor 1.0 independientemente del alpha interno. + +Multiples LoRAs se listan en `GenerationConfig.loras` en orden de aplicacion. Los backends los aplican secuencialmente. diff --git a/python/types/ml/model_ref.md b/python/types/ml/model_ref.md new file mode 100644 index 00000000..08ce8422 --- /dev/null +++ b/python/types/ml/model_ref.md @@ -0,0 +1,38 @@ +--- +name: model_ref +lang: py +domain: ml +version: "1.0.0" +algebraic: product +definition: | + class ModelRef(BaseModel): + name: str + model_type: Literal["sd15","sd20","sdxl","sd3","flux_dev","flux_schnell","flux_kontext","qwen_image","chroma","z_image"] + quantization: Literal["fp32","fp16","bf16","q8_0","q5_1","q5_0","q4_1","q4_0"] = "fp16" + path: str | None = None +description: "Referencia a un modelo de generacion de imagenes. Identifica arquitectura, cuantizacion y ruta local opcional. Serializable a JSON canonico para contrato compartido con Go." +tags: [ml, diffusion, model, stable-diffusion, flux, quantization, huggingface] +uses_types: [] +file_path: "python/functions/ml/model_ref.py" +--- + +## Notas + +Tipo producto inmutable (`frozen=True` en pydantic). Fallback a `dataclass(frozen=True)` si pydantic no esta instalado. + +`model_type` es un tipo suma embebido que cubre las arquitecturas mas comunes: + +| Valor | Arquitectura | +|---------------|-------------------------------------| +| sd15 | Stable Diffusion 1.5 | +| sd20 | Stable Diffusion 2.0/2.1 | +| sdxl | Stable Diffusion XL | +| sd3 | Stable Diffusion 3 | +| flux_dev | FLUX.1-dev | +| flux_schnell | FLUX.1-schnell | +| flux_kontext | FLUX.1-Kontext | +| qwen_image | Qwen2-VL / Qwen2.5-VL con imagen | +| chroma | Chroma (variante flux controlada) | +| z_image | Z-Image (pipeline experimental) | + +`quantization` controla el formato del checkpoint: fp32/fp16/bf16 para pesos flotantes, q4_0-q8_0 para formatos GGUF (stable-diffusion.cpp). diff --git a/python/types/ml/sampler_name.md b/python/types/ml/sampler_name.md new file mode 100644 index 00000000..1a66bbc5 --- /dev/null +++ b/python/types/ml/sampler_name.md @@ -0,0 +1,39 @@ +--- +name: sampler_name +lang: py +domain: ml +version: "1.0.0" +algebraic: sum +definition: | + SamplerName = Literal[ + "euler", + "euler_a", + "dpm++2m", + "dpm++2m_v2", + "heun", + "dpm2", + "lcm", + ] +description: "Subset estricto de samplers compartido entre diffusers y stable-diffusion.cpp. Cada valor tiene correspondencia directa en ambos backends." +tags: [ml, diffusion, sampler, scheduler, stable-diffusion, diffusers] +uses_types: [] +file_path: "python/functions/ml/sampler_name.py" +--- + +## Notas + +Tipo suma — exactamente uno de los valores literales. No hay valor por defecto definido aqui; `GenerationConfig` lo establece. + +Correspondencia de backends: + +| SamplerName | diffusers scheduler | sd.cpp --sampling-method | +|--------------|-----------------------------|--------------------------| +| euler | EulerDiscreteScheduler | euler | +| euler_a | EulerAncestralDiscreteScheduler | euler_a | +| dpm++2m | DPMSolverMultistepScheduler | dpm++2m | +| dpm++2m_v2 | DPMSolverMultistepScheduler v2 | dpm++2m_v2 | +| heun | HeunDiscreteScheduler | heun | +| dpm2 | KDPM2DiscreteScheduler | dpm2 | +| lcm | LCMScheduler | lcm | + +Para agregar nuevos samplers: crear un nuevo tipo o extender con `Union[SamplerName, Literal["nuevo"]]` en el consumidor. No modificar este tipo para mantener el contrato con Go. diff --git a/python/uv.lock b/python/uv.lock index c98717b6..748b336e 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -226,6 +226,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/3a/dbeec9d1ee0844c679f6bb5d6ad4e9f198b1224f4e7a32825f47f6192b0c/cffi-2.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9", size = 184195, upload-time = "2025-09-08T23:23:43.004Z" }, ] +[[package]] +name = "chardet" +version = "7.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/b6/9df434a8eeba2e6628c465a1dfa31034228ef79b26f76f46278f4ef7e49d/chardet-7.4.3.tar.gz", hash = "sha256:cc1d4eb92a4ec1c2df3b490836ffa46922e599d34ce0bb75cf41fd2bf6303d56", size = 784800, upload-time = "2026-04-13T21:33:39.803Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/33/29de185079e6675c3f375546e30a559b7ddc75ce972f18d6e566cd9ea4eb/chardet-7.4.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:75d3c65cc16bddf40b8da1fd25ba84fca5f8070f2b14e86083653c1c85aee971", size = 874870, upload-time = "2026-04-13T21:33:05.977Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2f/4c5af01fd1a7506a1d5375403d68925eac70289229492db5aa68b58103d8/chardet-7.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:29af5999f654e8729d251f1724a62b538b1262d9292cccaefddf8a02aae1ef6a", size = 854859, upload-time = "2026-04-13T21:33:07.381Z" }, + { url = "https://files.pythonhosted.org/packages/36/21/edb36ad5dfa48d7f8eed97ab43931ecdaa8c15166c21b1d614967e49d681/chardet-7.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:626f00299ad62dfe937058a09572beed442ccc7b58f87aa667949b20fd3db235", size = 875032, upload-time = "2026-04-13T21:33:08.741Z" }, + { url = "https://files.pythonhosted.org/packages/e5/59/a32a241d861cf180853a11c8e5a67641cb1b2af13c3a5ccce83ec07e2c9f/chardet-7.4.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9a4904dd5f071b7a7d7f50b4a67a86db3c902d243bf31708f1d5cde2f68239cb", size = 888283, upload-time = "2026-04-13T21:33:10.213Z" }, + { url = "https://files.pythonhosted.org/packages/87/2e/e1ee6a77abf3782c00e05b89c4d4328c8353bf9500661c4348df1dd68614/chardet-7.4.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5d2879598bc220689e8ce509fe9c3f37ad2fca53a36be9c9bd91abdd91dd364f", size = 879974, upload-time = "2026-04-13T21:33:11.448Z" }, + { url = "https://files.pythonhosted.org/packages/32/60/fca69c534602a7ced04280c952a246ad1edde2a6ca3a164f65d32ac41fe7/chardet-7.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:4b2799bd58e7245cfa8d4ab2e8ad1d76a5c3a5b1f32318eb6acca4c69a3e7101", size = 943973, upload-time = "2026-04-13T21:33:12.756Z" }, + { url = "https://files.pythonhosted.org/packages/7c/43/79ac9b4db5bc87020c9dbc419125371d80882d1d197e9c4765ba8682b605/chardet-7.4.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a9e4486df251b8962e86ea9f139ca235aa6e0542a00f7844c9a04160afb99aa9", size = 873769, upload-time = "2026-04-13T21:33:14.002Z" }, + { url = "https://files.pythonhosted.org/packages/55/5f/25bdec773905bff0ff6cf35ca73b17bd05593b4f87bd8c5fa43705f7167d/chardet-7.4.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4fbff1907925b0c5a1064cffb5e040cd5e338585c9c552625f30de6bc2f3107a", size = 853991, upload-time = "2026-04-13T21:33:15.564Z" }, + { url = "https://files.pythonhosted.org/packages/b4/07/a29380ee0b215d23d77733b5ad60c5c0c7969650e080c667acdf9462040d/chardet-7.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:365135eaf37ba65a828f8e668eb0a8c38c479dcbec724dc25f4dfd781049c357", size = 874024, upload-time = "2026-04-13T21:33:16.915Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b1/3338e121cbd4c8a126b8ccb1061170c2ce51a53f678c502793ea49c6fd6d/chardet-7.4.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfc134b70c846c21ead8e43ada3ae1a805fff732f6922f8abcf2ff27b8f6493d", size = 887410, upload-time = "2026-04-13T21:33:18.368Z" }, + { url = "https://files.pythonhosted.org/packages/63/1c/44a9a9e0c59c185a5d307ceaeee8768afa1558f0a24f7a4b5fa11b67586b/chardet-7.4.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9acd9988a93e09390f3cd231201ea7166c415eb8da1b735928990ffc05cb9fbb", size = 879269, upload-time = "2026-04-13T21:33:20.377Z" }, + { url = "https://files.pythonhosted.org/packages/1b/b3/5d0e77ea774bd3224321c248880ea0c0379000ac5c2bb6d77609549de247/chardet-7.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:e1b98790c284ff813f18f7cf7de5f05ea2435a080030c7f1a8318f3a4f80b131", size = 944155, upload-time = "2026-04-13T21:33:21.694Z" }, + { url = "https://files.pythonhosted.org/packages/70/a8/bf0811d859e13801279a2ae64f37a408027b282f2047bc0001c75dd356ad/chardet-7.4.3-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:d892d3dcd652fdef53e3d6327d39b17c0df40a899dfc919abaeb64c974497531", size = 872887, upload-time = "2026-04-13T21:33:23.328Z" }, + { url = "https://files.pythonhosted.org/packages/51/ac/b9d68ebddfe1b02c77af5bf81120e12b036b4432dc6af7a303d90e2bc38b/chardet-7.4.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:acc46d1b8b7d5783216afe15db56d1c179b9a40e5a1558bc13164c4fd20674c4", size = 853964, upload-time = "2026-04-13T21:33:24.724Z" }, + { url = "https://files.pythonhosted.org/packages/2a/81/17fa103ea9caf5d325a5e4051ab2ba65996fd66baa60b81ee41af1f54e10/chardet-7.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ac3bf11c645734a1701a3804e43eabd98851838192267d08c353a834ab79fea", size = 876006, upload-time = "2026-04-13T21:33:26.098Z" }, + { url = "https://files.pythonhosted.org/packages/c2/20/193faab46a68ea550587331a698c3dca8099f8901d10937c4443135c7ed9/chardet-7.4.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e3bd9f936e04bae89c254262af08d9e5b98f805175ba1e29d454e6cba3107b7", size = 887680, upload-time = "2026-04-13T21:33:27.49Z" }, + { url = "https://files.pythonhosted.org/packages/40/c6/94a3c673327392652ee8bdea9a45bc8a5f5365197a7387d68f0eed007115/chardet-7.4.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:27cc23da03630cdecc9aa81a895aa86629c211f995cd57651f0fbc280717bf93", size = 879865, upload-time = "2026-04-13T21:33:29.052Z" }, + { url = "https://files.pythonhosted.org/packages/b1/2c/cad8b5e3623a987f3c930b68e2bdd06cfc388cd91cd42ed05f1227701b73/chardet-7.4.3-cp314-cp314-win_amd64.whl", hash = "sha256:b95c934b9ad59e2ba8abb9be49df70d3ad1b0d95d864b9fdb7588d4fa8bd921c", size = 939594, upload-time = "2026-04-13T21:33:31.391Z" }, + { url = "https://files.pythonhosted.org/packages/33/e0/d06e42fd6f02a58e5e227e5106587751cb38adcff0aaf949add744b78b6e/chardet-7.4.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:c77867f0c1cb8bd819502249fcdc500364aedb07881e11b743726fa2148e7b6e", size = 889714, upload-time = "2026-04-13T21:33:32.772Z" }, + { url = "https://files.pythonhosted.org/packages/d4/ed/40d091954d48abea037baae6be8fb79905e5f78d34d12ea955132c7d8011/chardet-7.4.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cf1efeaf65a6ef2f5b9cc3a1df6f08ba2831b369ccaa4c7018eaf90aa757bb11", size = 872319, upload-time = "2026-04-13T21:33:34.427Z" }, + { url = "https://files.pythonhosted.org/packages/bb/77/82a46821dbfbdfe062710d2bf2ede13426304e3567a23c57d919c0c31630/chardet-7.4.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f3504c139a2ad544077dd2d9e412cd08b01786843d76997cd43bb6de311723c", size = 892021, upload-time = "2026-04-13T21:33:35.766Z" }, + { url = "https://files.pythonhosted.org/packages/49/57/42d30c562bda5b4a839766c1aad8d5856b798ad2a1c3247b72a679afec94/chardet-7.4.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:457f619882ba66327d4d8d14c6c342269bdb1e4e1c38e8117df941d14d351b04", size = 902509, upload-time = "2026-04-13T21:33:37.096Z" }, + { url = "https://files.pythonhosted.org/packages/8c/6c/0a40afdb50a0fe041ab95553b835a8160b6cf0e81edf2ae2fe9f5224cbf9/chardet-7.4.3-py3-none-any.whl", hash = "sha256:1173b74051570cf08099d7429d92e4882d375ad4217f92a6e5240ccfb26f231e", size = 626562, upload-time = "2026-04-13T21:33:38.559Z" }, +] + [[package]] name = "charset-normalizer" version = "3.4.7" @@ -651,6 +682,7 @@ name = "fn-registry-python" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "chardet" }, { name = "contextily" }, { name = "cryptography" }, { name = "duckdb" }, @@ -663,6 +695,7 @@ dependencies = [ { name = "httpx" }, { name = "matplotlib" }, { name = "openpyxl" }, + { name = "polars" }, { name = "pypdf" }, { name = "pyproj" }, { name = "python-docx" }, @@ -687,6 +720,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "chardet", specifier = ">=7.4.3" }, { name = "contextily", specifier = ">=1.7.0" }, { name = "cryptography", specifier = ">=46.0.6" }, { name = "duckdb", specifier = ">=1.5.2" }, @@ -701,6 +735,7 @@ requires-dist = [ { name = "httpx" }, { name = "matplotlib", specifier = ">=3.10.9" }, { name = "openpyxl", specifier = ">=3.1.5" }, + { name = "polars", specifier = ">=1.40.1" }, { name = "pypdf", specifier = ">=6.10.0" }, { name = "pyproj", specifier = ">=3.7.2" }, { name = "python-docx", specifier = ">=1.2.0" }, @@ -2134,6 +2169,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "polars" +version = "1.40.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars-runtime-32" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/8c/bc9bc948058348ed43117cecc3007cd608f395915dae8a00974579a5dab1/polars-1.40.1.tar.gz", hash = "sha256:ab2694134b137596b5a59bfd7b4c54ebbc9b59f9403127f18e32d363777552e8", size = 733574, upload-time = "2026-04-22T19:15:55.507Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/91/74fc60d94488685a92ac9d49d7ec55f3e91fe9b77942a6235a5fa7f249c3/polars-1.40.1-py3-none-any.whl", hash = "sha256:c0f861219d1319cdea45c4ce4d30355a47176b8f98dcedf95ea8269f131b8abd", size = 828723, upload-time = "2026-04-22T19:14:25.452Z" }, +] + +[[package]] +name = "polars-runtime-32" +version = "1.40.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ba/26d40f039be9f552b5fd7365a621bdfc0f8e912ef77094ae4693491b0bae/polars_runtime_32-1.40.1.tar.gz", hash = "sha256:37f3065615d1bf90d03b5326222df4c5c1f8a5d33e50470aa588e3465e6eb814", size = 2935843, upload-time = "2026-04-22T19:15:57.26Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/46/22c8af5eed68ac2eeb556e0fa3ca8a7b798e984ceff4450888f3b5ac61fd/polars_runtime_32-1.40.1-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b748ef652270cc49e9e69f99a035e0eb4d5f856d42bcd6ac4d9d80a40142aa1e", size = 52098755, upload-time = "2026-04-22T19:14:28.555Z" }, + { url = "https://files.pythonhosted.org/packages/c6/3e/48599a38009ca60ff82a6f38c8a621ce3c0286aa7397c7d79e741bd9060e/polars_runtime_32-1.40.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:d249b3743e05986060cec0a7aaa542d020df6c6b876e556023a310efd581f9be", size = 46367542, upload-time = "2026-04-22T19:14:32.433Z" }, + { url = "https://files.pythonhosted.org/packages/43/e9/384bc069367a1a36ee31c13782c178dbd039b2b873b772d4a0fc23a2373d/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5987b30e7aa1059d069498496e8dda35afd592b0ac3d46ed87e3ff8df1ad652c", size = 50252104, upload-time = "2026-04-22T19:14:35.945Z" }, + { url = "https://files.pythonhosted.org/packages/15/ef/7d57ceb0651af74194e97ed6583e148d352f03d696090221b8059cdfc90b/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d7f42a8b3f16fc66002cc0f6516f7dd7653396886ae0ed362ab95c0b3408b59", size = 56250788, upload-time = "2026-04-22T19:14:39.743Z" }, + { url = "https://files.pythonhosted.org/packages/10/0f/e4b3ffc748827a14a474ec9c42e45c066050e440fec57e914091d9adda75/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e5f7becc237a7ec9d9a10878dc8e54b73bbf4e2d94a2991c37d7a0b38590d8f9", size = 50432590, upload-time = "2026-04-22T19:14:43.388Z" }, + { url = "https://files.pythonhosted.org/packages/d9/0b/b8d95fbed869fa4caabe9c400e4210374913b376e925e96fdcfa9be6416b/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:992d14cf191dde043d36fbdbc98a65e43fbc7e9a5024cecd45f838ac4988c1ee", size = 54155564, upload-time = "2026-04-22T19:14:47.239Z" }, + { url = "https://files.pythonhosted.org/packages/06/d9/d091d8fb5cbed5e9536adfed955c4c89987a4cc3b8e73ae4532402b91c74/polars_runtime_32-1.40.1-cp310-abi3-win_amd64.whl", hash = "sha256:f78bb2abd00101cbb23cc0cb068f7e36e081057a15d2ec2dde3dda280709f030", size = 51829755, upload-time = "2026-04-22T19:14:50.85Z" }, + { url = "https://files.pythonhosted.org/packages/65/ad/b33c3022a394f3eb55c3310597cec615412a8a33880055eee191d154a628/polars_runtime_32-1.40.1-cp310-abi3-win_arm64.whl", hash = "sha256:b5cbfaf6b085b420b4bfcbe24e8f665076d1cccfdb80c0484c02a023ce205537", size = 45822104, upload-time = "2026-04-22T19:14:54.192Z" }, +] + [[package]] name = "propcache" version = "0.4.1" diff --git a/registry/migrations/012_vault_files.sql b/registry/migrations/012_vault_files.sql new file mode 100644 index 00000000..8e91072e --- /dev/null +++ b/registry/migrations/012_vault_files.sql @@ -0,0 +1,16 @@ +CREATE TABLE IF NOT EXISTS vault_files ( + vault_id TEXT NOT NULL, + vault_name TEXT NOT NULL, + rel_path TEXT NOT NULL, + size INTEGER NOT NULL, + mtime INTEGER NOT NULL, + sha256 TEXT NOT NULL, + mime TEXT NOT NULL DEFAULT '', + ext TEXT NOT NULL DEFAULT '', + bucket TEXT NOT NULL DEFAULT '', + sub_bucket TEXT NOT NULL DEFAULT '', + indexed_at INTEGER NOT NULL, + PRIMARY KEY (vault_id, rel_path) +); +CREATE INDEX IF NOT EXISTS idx_vault_files_sha256 ON vault_files(sha256); +CREATE INDEX IF NOT EXISTS idx_vault_files_vault ON vault_files(vault_id); diff --git a/types/infra/gpu_info.md b/types/infra/gpu_info.md new file mode 100644 index 00000000..784cba80 --- /dev/null +++ b/types/infra/gpu_info.md @@ -0,0 +1,39 @@ +--- +name: gpu_info +lang: go +domain: infra +version: "1.0.0" +algebraic: product +definition: | + type GpuInfo struct { + Index int `json:"index"` + Name string `json:"name"` + VramTotalMb int `json:"vram_total_mb"` + VramFreeMb int `json:"vram_free_mb"` + DriverVersion string `json:"driver_version"` + CudaVersion string `json:"cuda_version,omitempty"` + } +description: "Describe una GPU detectada en el sistema con capacidad de VRAM total y libre, version de driver y version de CUDA (opcional, solo NVIDIA)." +tags: [gpu, cuda, hardware, infra, ml] +uses_types: [] +file_path: "functions/infra/gpu_info.go" +--- + +## Ejemplo + +```go +gpu := GpuInfo{ + Index: 0, + Name: "NVIDIA GeForce RTX 4090", + VramTotalMb: 24576, + VramFreeMb: 20000, + DriverVersion: "545.23.08", + CudaVersion: "12.3", +} +``` + +## Notas + +`CudaVersion` es opcional y solo se rellena en GPUs NVIDIA con driver CUDA instalado. +Los valores de VRAM estan en megabytes enteros para evitar ambiguedad con unidades. +Espejo JSON-compatible de `GpuInfo_py_ml` (si existe) — tags `json:` en snake_case. diff --git a/types/infra/vault_file.md b/types/infra/vault_file.md new file mode 100644 index 00000000..ad3f3241 --- /dev/null +++ b/types/infra/vault_file.md @@ -0,0 +1,51 @@ +--- +name: vault_file +lang: go +domain: infra +version: "1.0.0" +algebraic: product +definition: | + type VaultFile struct { + VaultID string `json:"vault_id"` + VaultName string `json:"vault_name"` + RelPath string `json:"rel_path"` + Size int64 `json:"size"` + Mtime int64 `json:"mtime"` + Sha256 string `json:"sha256"` + Mime string `json:"mime"` + Ext string `json:"ext"` + Bucket string `json:"bucket"` + SubBucket string `json:"sub_bucket"` + } +description: "Describe un fichero individual dentro de un vault: identidad (vault + ruta relativa), metadatos de contenido (tamano, mtime, sha256, mime) y clasificacion estructural (bucket, sub-bucket)." +tags: [vault, file, metadata, infra, storage] +uses_types: [] +file_path: "functions/infra/vault_file.go" +--- + +## Ejemplo + +```go +vf := VaultFile{ + VaultID: "turismo_spain_app_turismo", + VaultName: "turismo_spain", + RelPath: "data/raw/hoteles_2024.csv", + Size: 142340, + Mtime: 1715000000, + Sha256: "e3b0c44298fc1c149afb...", + Mime: "text/csv", + Ext: ".csv", + Bucket: "data", + SubBucket: "raw", +} +``` + +## Notas + +`Bucket` acepta exactamente "data" o "knowledge". `SubBucket` puede ser vacio si el fichero +esta directamente en la raiz del bucket. Los valores conocidos de SubBucket son: +- data: raw, processed, exports +- knowledge: decisions, domains, models, benchmarks, test_documents + +`Mtime` en unix seconds UTC. `Sha256` hex lowercase sin prefijo. `Size` en bytes. +JSON tags en snake_case para serializar directamente a la tabla `vault_files` de operations.db. diff --git a/types/ml/generation_config.md b/types/ml/generation_config.md new file mode 100644 index 00000000..1951239d --- /dev/null +++ b/types/ml/generation_config.md @@ -0,0 +1,53 @@ +--- +name: generation_config +lang: go +domain: ml +version: "1.0.0" +algebraic: product +definition: | + type GenerationConfig struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Seed int64 `json:"seed"` + Steps int `json:"steps"` + CfgScale float64 `json:"cfg_scale"` + Sampler string `json:"sampler"` + Width int `json:"width"` + Height int `json:"height"` + Model ModelRef `json:"model"` + Loras []LoraRef `json:"loras,omitempty"` + ClipSkip *int `json:"clip_skip,omitempty"` + } +description: "Parametros de una solicitud de generacion de imagen. Espejo JSON-compatible de GenerationConfig_py_ml: roundtrip JSON bytes <-> Python sin perdida." +tags: [ml, image-gen, diffusion, config] +uses_types: [model_ref_go_ml, lora_ref_go_ml] +file_path: "functions/ml/generation_config.go" +--- + +## Ejemplo + +```go +clipSkip := 2 +cfg := GenerationConfig{ + Prompt: "a photo of a cat in space, 8k, highly detailed", + Seed: 42, + Steps: 20, + CfgScale: 7.0, + Sampler: "euler_a", + Width: 512, + Height: 512, + Model: ModelRef{ + Name: "dreamshaper_8", + ModelType: "sd15", + Quantization: "fp16", + }, + ClipSkip: &clipSkip, +} +``` + +## Notas + +Todos los campos usan tags `json:` snake_case para coincidir con el dataclass Python homólogo. +`NegativePrompt`, `Loras` y `ClipSkip` son opcionales (`omitempty` / puntero). +`Sampler` es string libre — ver `SamplerName_go_ml` para los valores documentados. +`Seed` es `int64` para compatibilidad con seeds grandes de algunos backends. diff --git a/types/ml/image_gen_result.md b/types/ml/image_gen_result.md new file mode 100644 index 00000000..66ec40fb --- /dev/null +++ b/types/ml/image_gen_result.md @@ -0,0 +1,39 @@ +--- +name: image_gen_result +lang: go +domain: ml +version: "1.0.0" +algebraic: product +definition: | + type ImageGenResult struct { + ImageBytes []byte `json:"-"` + Format string `json:"format"` + Meta map[string]any `json:"meta"` + DurationMs int64 `json:"duration_ms"` + VramPeakMb *int `json:"vram_peak_mb,omitempty"` + } +description: "Resultado de una solicitud de generacion de imagen. ImageBytes contiene los bytes raw del PNG (excluido del JSON). Meta transporta metadata del backend (seed efectivo, steps, modelo, etc.)." +tags: [ml, image-gen, diffusion, result] +uses_types: [] +file_path: "functions/ml/image_gen_result.go" +--- + +## Ejemplo + +```go +result, err := gen.Generate(ctx, cfg) +if err != nil { + log.Fatal(err) +} +// Bytes raw del PNG — escribir a disco o enviar como multipart +os.WriteFile("output.png", result.ImageBytes, 0644) +// Metadata serializable a JSON +fmt.Println(result.DurationMs, result.Format) +``` + +## Notas + +`ImageBytes` usa `json:"-"` porque viaja por canal binario separado (multipart, gRPC bytes, archivo). +El JSON de este tipo solo transporta metadata, no los bytes de la imagen. +`Meta` es un mapa libre — los backends lo usan para devolver seed efectivo, steps realizados, +nombre del modelo, etc. `VramPeakMb` es opcional y solo se rellena si el backend lo reporta. diff --git a/types/ml/image_generator.md b/types/ml/image_generator.md new file mode 100644 index 00000000..bb7a5573 --- /dev/null +++ b/types/ml/image_generator.md @@ -0,0 +1,33 @@ +--- +name: image_generator +lang: go +domain: ml +version: "1.0.0" +algebraic: product +definition: | + type ImageGenerator interface { + Generate(ctx context.Context, cfg GenerationConfig) (ImageGenResult, error) + } +description: "Interface para backends de generacion de imagenes. Implementaciones pueden ser locales (ComfyUI, diffusers via Python subprocess) o remotas (API HTTP)." +tags: [ml, image-gen, diffusion, interface] +uses_types: [generation_config_go_ml, image_gen_result_go_ml] +file_path: "functions/ml/image_generator.go" +--- + +## Ejemplo + +```go +// Cualquier backend que implemente Generate puede usarse de forma intercambiable. +var gen ImageGenerator = NewComfyUIBackend(cfg) +result, err := gen.Generate(ctx, GenerationConfig{ + Prompt: "a serene mountain lake at dawn", + Steps: 25, + // ... +}) +``` + +## Notas + +Interface minima de un metodo. Los backends locales (ComfyUI, diffusers) y remotos (APIs) +implementan el mismo contrato. El caller no necesita saber el backend concreto. +`ctx` permite cancelacion y timeout. Los errores de backend se propagan directamente. diff --git a/types/ml/lora_ref.md b/types/ml/lora_ref.md new file mode 100644 index 00000000..a181dabb --- /dev/null +++ b/types/ml/lora_ref.md @@ -0,0 +1,39 @@ +--- +name: lora_ref +lang: go +domain: ml +version: "1.0.0" +algebraic: product +definition: | + type LoraRef struct { + Path string `json:"path"` + Weight float64 `json:"weight"` + Scale *float64 `json:"scale,omitempty"` + } +description: "Referencia a un adaptador LoRA con su peso de fusion (0.0-1.0) y escala opcional de activacion." +tags: [ml, lora, image-gen, diffusion] +uses_types: [] +file_path: "functions/ml/lora_ref.go" +--- + +## Ejemplo + +```go +lora := LoraRef{ + Path: "/loras/detail_tweaker.safetensors", + Weight: 0.7, +} + +scale := 0.5 +lora2 := LoraRef{ + Path: "/loras/lighting.safetensors", + Weight: 0.8, + Scale: &scale, +} +``` + +## Notas + +Espejo JSON-compatible de `LoraRef_py_ml`. +`Weight` es el multiplicador de fusion aplicado al modelo base (0.0 = sin efecto, 1.0 = efecto completo). +`Scale` es un override de escala de activacion — omitir si se usa el default del backend. diff --git a/types/ml/model_ref.md b/types/ml/model_ref.md new file mode 100644 index 00000000..e217e4b0 --- /dev/null +++ b/types/ml/model_ref.md @@ -0,0 +1,36 @@ +--- +name: model_ref +lang: go +domain: ml +version: "1.0.0" +algebraic: product +definition: | + type ModelRef struct { + Name string `json:"name"` + ModelType string `json:"model_type"` + Quantization string `json:"quantization"` + Path string `json:"path,omitempty"` + } +description: "Referencia a un modelo de generacion de imagenes: nombre, arquitectura (sd15|sdxl|flux_dev|...), cuantizacion (fp16|q8_0|...) y path opcional en disco." +tags: [ml, model, image-gen, diffusion] +uses_types: [] +file_path: "functions/ml/model_ref.go" +--- + +## Ejemplo + +```go +m := ModelRef{ + Name: "dreamshaper_8", + ModelType: "sd15", + Quantization: "fp16", + Path: "/models/dreamshaper_8.safetensors", +} +``` + +## Notas + +Espejo JSON-compatible de `ModelRef_py_ml`. Los tags `json:` coinciden con los +campos snake_case del dataclass Python para roundtrip sin perdida. +`ModelType` es string libre — los valores documentados son `sd15`, `sdxl`, +`flux_dev`, `flux_schnell`. `Quantization` habitual: `fp16`, `bf16`, `q8_0`, `q4_K_M`.