// Tests for data_table::compute_pipeline (cpp/functions/core/compute_pipeline.cpp). // Pure logic: no ImGui, no I/O. Exercises stage chaining. #define CATCH_CONFIG_MAIN #include "catch_amalgamated.hpp" #include "core/compute_pipeline.h" #include #include using namespace data_table; // --------------------------------------------------------------------------- // Helper // --------------------------------------------------------------------------- namespace { struct Table { std::vector backing; std::vector cells; std::vector headers; std::vector types; int rows = 0, cols = 0; void add_row(std::initializer_list row) { for (const char* s : row) backing.emplace_back(s ? s : ""); ++rows; } void finalize() { cols = headers.empty() ? 0 : (int)headers.size(); cells.clear(); cells.reserve(backing.size()); for (const auto& s : backing) cells.push_back(s.c_str()); } }; } // anon // --------------------------------------------------------------------------- // Test: empty stages returns passthrough // --------------------------------------------------------------------------- TEST_CASE("compute_pipeline empty stages returns passthrough") { Table t; t.headers = {"x", "y"}; t.types = {ColumnType::Int, ColumnType::Int}; t.add_row({"1", "2"}); t.add_row({"3", "4"}); t.finalize(); StageOutput out = compute_pipeline(t.cells.data(), t.rows, t.cols, t.headers, t.types, {}); REQUIRE(out.rows == 2); REQUIRE(out.cols == 2); } // --------------------------------------------------------------------------- // Test: single stage pipeline equals compute_stage directly // --------------------------------------------------------------------------- TEST_CASE("compute_pipeline single stage equals compute_stage") { Table t; t.headers = {"dept", "amount"}; t.types = {ColumnType::String, ColumnType::Float}; t.add_row({"eng", "100"}); t.add_row({"mktg", "200"}); t.add_row({"eng", "150"}); t.finalize(); Stage stage; Filter f; f.col = 0; f.op = Op::Eq; f.value = "eng"; stage.filters.push_back(f); StageOutput direct = compute_stage(t.cells.data(), t.rows, t.cols, t.headers, t.types, stage); StageOutput via_pipe = compute_pipeline(t.cells.data(), t.rows, t.cols, t.headers, t.types, {stage}); REQUIRE(direct.rows == via_pipe.rows); REQUIRE(direct.cols == via_pipe.cols); } // --------------------------------------------------------------------------- // Test: two-stage chain — filter then group+sum // --------------------------------------------------------------------------- TEST_CASE("compute_pipeline two stages chain filter then group") { Table t; t.headers = {"region", "type", "revenue"}; t.types = {ColumnType::String, ColumnType::String, ColumnType::Float}; t.add_row({"EU", "A", "100"}); t.add_row({"US", "A", "200"}); t.add_row({"EU", "B", "300"}); t.add_row({"EU", "A", "50"}); t.finalize(); // Stage 0: filter EU only. Stage s0; Filter f; f.col = 0; f.op = Op::Eq; f.value = "EU"; s0.filters.push_back(f); // Stage 1: group by type + sum revenue, sort desc. Stage s1; s1.breakouts = {"type"}; Aggregation agg; agg.fn = AggFn::Sum; agg.col = "revenue"; s1.aggregations.push_back(agg); SortClause sc; sc.col = "sum_revenue"; sc.desc = true; s1.sorts.push_back(sc); StageOutput out = compute_pipeline(t.cells.data(), t.rows, t.cols, t.headers, t.types, {s0, s1}); // EU rows: A=100+50=150, B=300. Sorted desc -> B first. REQUIRE(out.rows == 2); REQUIRE(std::string(out.cells[0 * out.cols + 0]) == "B"); REQUIRE(std::string(out.cells[1 * out.cols + 0]) == "A"); double rev_b = std::stod(out.cells[0 * out.cols + 1]); double rev_a = std::stod(out.cells[1 * out.cols + 1]); REQUIRE(rev_b == 300.0); REQUIRE(rev_a == 150.0); } // --------------------------------------------------------------------------- // Test: three-stage chain — group, then filter on aggregated column, then sort // --------------------------------------------------------------------------- TEST_CASE("compute_pipeline three stage chain") { Table t; t.headers = {"cat", "val"}; t.types = {ColumnType::String, ColumnType::Int}; t.add_row({"A", "10"}); t.add_row({"A", "20"}); t.add_row({"B", "5"}); t.add_row({"C", "100"}); t.add_row({"C", "50"}); t.finalize(); // Stage 0: group by cat, sum val. Stage s0; s0.breakouts = {"cat"}; Aggregation agg; agg.fn = AggFn::Sum; agg.col = "val"; s0.aggregations.push_back(agg); // Stage 1: filter where sum_val > 10. Stage s1; Filter f; f.col = 1; f.op = Op::Gt; f.value = "10"; s1.filters.push_back(f); // Stage 2: sort asc by sum_val. Stage s2; SortClause sc; sc.col = "sum_val"; sc.desc = false; s2.sorts.push_back(sc); StageOutput out = compute_pipeline(t.cells.data(), t.rows, t.cols, t.headers, t.types, {s0, s1, s2}); // A=30, B=5 (filtered out), C=150. Sorted asc -> A(30), C(150). REQUIRE(out.rows == 2); REQUIRE(std::string(out.cells[0 * out.cols + 0]) == "A"); REQUIRE(std::string(out.cells[1 * out.cols + 0]) == "C"); } // --------------------------------------------------------------------------- // Test: pipeline with empty input table // --------------------------------------------------------------------------- TEST_CASE("compute_pipeline empty table") { Stage s0; Filter f; f.col = 0; f.op = Op::Eq; f.value = "x"; s0.filters.push_back(f); StageOutput out = compute_pipeline(nullptr, 0, 0, {}, {}, {s0}); REQUIRE(out.rows == 0); }