#include "viz/scatter_plot.h" #include "viz/plot_static.h" #include "implot.h" namespace { template void draw_scatter(const char* title, const T* xs, const T* ys, int count, float height) { if (count <= 0) return; T x_min = xs[0], x_max = xs[0]; T y_min = ys[0], y_max = ys[0]; for (int i = 1; i < count; i++) { if (xs[i] < x_min) x_min = xs[i]; if (xs[i] > x_max) x_max = xs[i]; if (ys[i] < y_min) y_min = ys[i]; if (ys[i] > y_max) y_max = ys[i]; } double dx = static_cast(x_max) - static_cast(x_min); double dy = static_cast(y_max) - static_cast(y_min); if (dx < 1e-9) dx = 1.0; if (dy < 1e-9) dy = 1.0; const ImVec2 plot_size(-1.0f, height > 0.0f ? height : 200.0f); if (ImPlot::BeginPlot(title, plot_size, plot_static::kPlotFlags)) { ImPlot::SetupAxes(nullptr, nullptr, plot_static::kAxisFlags, plot_static::kAxisFlags); ImPlot::SetupAxisLimits(ImAxis_X1, static_cast(x_min) - dx * 0.05, static_cast(x_max) + dx * 0.05, ImPlotCond_Always); ImPlot::SetupAxisLimits(ImAxis_Y1, static_cast(y_min) - dy * 0.05, static_cast(y_max) + dy * 0.05, ImPlotCond_Always); ImPlot::PlotScatter("##data", xs, ys, count); ImPlot::EndPlot(); } } } // namespace void scatter_plot(const char* title, const float* xs, const float* ys, int count, float height) { draw_scatter(title, xs, ys, count, height); } void scatter_plot(const char* title, const double* xs, const double* ys, int count, float height) { draw_scatter(title, xs, ys, count, height); }