#include "ws_client.h" #include #include #include #include #include #include #include #ifdef _WIN32 #include #include #pragma comment(lib, "ws2_32.lib") typedef SOCKET sock_t; #define SOCK_INVALID INVALID_SOCKET #define SOCK_CLOSE closesocket #define SOCK_ERR WSAGetLastError() #define FN_SOCK_NONBLOCK(s) do { u_long m = 1; ioctlsocket((s), FIONBIO, &m); } while (0) #else #include #include #include #include #include #include typedef int sock_t; #define SOCK_INVALID (-1) #define SOCK_CLOSE close #define SOCK_ERR errno #define FN_SOCK_NONBLOCK(s) do { int f = fcntl((s), F_GETFL, 0); fcntl((s), F_SETFL, f | O_NONBLOCK); } while (0) #endif #ifdef _WIN32 static bool wsa_init_ws() { static bool inited = false; static std::once_flag flag; std::call_once(flag, []() { WSADATA wsa; WSAStartup(MAKEWORD(2, 2), &wsa); }); return true; } #endif namespace { // ----- Base64 (small, sufficient for 16-byte WS key) ----- const char* kBase64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; std::string base64_encode(const uint8_t* in, size_t len) { std::string out; out.reserve(((len + 2) / 3) * 4); size_t i = 0; for (; i + 3 <= len; i += 3) { uint32_t v = (uint32_t(in[i]) << 16) | (uint32_t(in[i + 1]) << 8) | uint32_t(in[i + 2]); out.push_back(kBase64[(v >> 18) & 63]); out.push_back(kBase64[(v >> 12) & 63]); out.push_back(kBase64[(v >> 6) & 63]); out.push_back(kBase64[v & 63]); } if (i < len) { uint32_t v = uint32_t(in[i]) << 16; if (i + 1 < len) v |= uint32_t(in[i + 1]) << 8; out.push_back(kBase64[(v >> 18) & 63]); out.push_back(kBase64[(v >> 12) & 63]); out.push_back(i + 1 < len ? kBase64[(v >> 6) & 63] : '='); out.push_back('='); } return out; } // Send exactly n bytes (blocking). bool send_all(sock_t sock, const char* data, size_t n) { size_t sent = 0; while (sent < n) { int k = send(sock, data + sent, static_cast(n - sent), 0); if (k <= 0) return false; sent += k; } return true; } // Receive exactly n bytes (blocking on a non-non-blocking socket). bool recv_all(sock_t sock, char* data, size_t n) { size_t got = 0; while (got < n) { int k = recv(sock, data + got, static_cast(n - got), 0); if (k <= 0) return false; got += k; } return true; } // Receive up to n bytes; returns count, or -1 on error / -2 on would-block. int recv_some(sock_t sock, char* data, size_t n) { int k = recv(sock, data, static_cast(n), 0); if (k > 0) return k; if (k == 0) return -1; #ifdef _WIN32 if (WSAGetLastError() == WSAEWOULDBLOCK) return -2; #else if (errno == EAGAIN || errno == EWOULDBLOCK) return -2; #endif return -1; } } // namespace WsClient::WsClient() = default; WsClient::~WsClient() { stop(); } void WsClient::start(const std::string& host, int port, const std::string& path) { State expected = State::Idle; if (!state_.compare_exchange_strong(expected, State::Connecting)) return; host_ = host; port_ = port; path_ = path; stop_flag_.store(false); worker_ = std::thread([this]() { this->run(); }); } void WsClient::stop() { stop_flag_.store(true); int s = sock_.exchange(-1); if (s != -1) SOCK_CLOSE(static_cast(s)); out_cv_.notify_all(); if (worker_.joinable()) worker_.join(); state_.store(State::Stopped); } int WsClient::drain(std::vector& out, int max) { std::lock_guard g(in_mu_); int n = 0; while (!in_queue_.empty() && n < max) { out.emplace_back(std::move(in_queue_.front())); in_queue_.pop_front(); n++; } return n; } bool WsClient::send_text(const std::string& payload) { if (state_.load() != State::Connected) return false; { std::lock_guard g(out_mu_); out_queue_.push_back(payload); } out_cv_.notify_one(); return true; } void WsClient::run() { int backoff_ms = 500; while (!stop_flag_.load()) { state_.store(State::Connecting); if (connect_once()) { backoff_ms = 500; // reset on successful connect state_.store(State::Connected); read_loop(); } state_.store(State::Backoff); if (stop_flag_.load()) break; // Exponential backoff: 0.5s → 1s → 2s → 4s → 8s (cap). for (int slept = 0; slept < backoff_ms && !stop_flag_.load(); slept += 100) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } backoff_ms = std::min(backoff_ms * 2, 8000); } state_.store(State::Stopped); } bool WsClient::connect_once() { #ifdef _WIN32 wsa_init_ws(); #endif sock_t sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (sock == SOCK_INVALID) return false; // 5s connect timeout via SO_*TIMEO. Stays blocking afterwards for the // handshake; read_loop switches to non-blocking with select(). #ifdef _WIN32 DWORD timeout_ms = 5000; setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout_ms, sizeof(timeout_ms)); setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout_ms, sizeof(timeout_ms)); #else struct timeval tv; tv.tv_sec = 5; tv.tv_usec = 0; setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); #endif struct sockaddr_in addr; std::memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = htons(static_cast(port_)); addr.sin_addr.s_addr = inet_addr(host_.c_str()); if (connect(sock, (struct sockaddr*)&addr, sizeof(addr)) != 0) { SOCK_CLOSE(sock); return false; } sock_.store(static_cast(sock)); if (!handshake()) { int s = sock_.exchange(-1); if (s != -1) SOCK_CLOSE(static_cast(s)); return false; } // Non-blocking for the read loop. FN_SOCK_NONBLOCK(sock); return true; } bool WsClient::handshake() { sock_t sock = static_cast(sock_.load()); // 16 random bytes → base64 → Sec-WebSocket-Key. uint8_t key_raw[16]; std::random_device rd; for (auto& b : key_raw) b = static_cast(rd() & 0xff); std::string key_b64 = base64_encode(key_raw, sizeof(key_raw)); std::ostringstream req; req << "GET " << path_ << " HTTP/1.1\r\n"; req << "Host: " << host_ << ":" << port_ << "\r\n"; req << "Upgrade: websocket\r\n"; req << "Connection: Upgrade\r\n"; req << "Sec-WebSocket-Key: " << key_b64 << "\r\n"; req << "Sec-WebSocket-Version: 13\r\n"; req << "Origin: http://" << host_ << ":" << port_ << "\r\n"; req << "\r\n"; std::string raw = req.str(); if (!send_all(sock, raw.c_str(), raw.size())) return false; // Read response headers (up to 4KB). std::string resp; char buf[1024]; while (resp.find("\r\n\r\n") == std::string::npos) { int k = recv(sock, buf, sizeof(buf), 0); if (k <= 0) return false; resp.append(buf, k); if (resp.size() > 4096) return false; } // Expect "HTTP/1.1 101". if (resp.compare(0, 12, "HTTP/1.1 101") != 0 && resp.compare(0, 12, "HTTP/1.0 101") != 0) { fprintf(stderr, "[ws] handshake failed: %.*s\n", (int)std::min(resp.size(), 120), resp.c_str()); return false; } // We intentionally skip Sec-WebSocket-Accept verification — controlled // server, localhost-only, 101 status is enough for this app. return true; } bool WsClient::read_loop() { sock_t sock = static_cast(sock_.load()); std::vector rb; // accumulated read buffer rb.reserve(64 * 1024); while (!stop_flag_.load()) { // Block on select() for up to 100ms so we can both read and check // the outgoing queue. fd_set rfds; FD_ZERO(&rfds); FD_SET(sock, &rfds); struct timeval tv; tv.tv_sec = 0; tv.tv_usec = 100 * 1000; int sel = select(static_cast(sock) + 1, &rfds, nullptr, nullptr, &tv); if (sel < 0) return false; if (sel > 0 && FD_ISSET(sock, &rfds)) { uint8_t tmp[8192]; int k = recv_some(sock, reinterpret_cast(tmp), sizeof(tmp)); if (k == -1) return false; if (k > 0) rb.insert(rb.end(), tmp, tmp + k); } // Drain outgoing queue (text frames, masked). for (;;) { std::string payload; { std::lock_guard g(out_mu_); if (out_queue_.empty()) break; payload = std::move(out_queue_.front()); out_queue_.pop_front(); } if (!send_frame(0x1, payload)) return false; } // Parse frames. RFC6455 minimal: assume server never masks, no // continuation, opcodes: 0x1 text, 0x8 close, 0x9 ping, 0xA pong. while (rb.size() >= 2) { uint8_t b0 = rb[0]; uint8_t b1 = rb[1]; bool fin = (b0 & 0x80) != 0; (void)fin; int opcode = b0 & 0x0F; bool mask = (b1 & 0x80) != 0; uint64_t len = b1 & 0x7F; size_t pos = 2; if (len == 126) { if (rb.size() < pos + 2) break; len = (uint64_t(rb[pos]) << 8) | uint64_t(rb[pos + 1]); pos += 2; } else if (len == 127) { if (rb.size() < pos + 8) break; len = 0; for (int i = 0; i < 8; i++) len = (len << 8) | rb[pos + i]; pos += 8; } uint8_t mkey[4] = {0, 0, 0, 0}; if (mask) { if (rb.size() < pos + 4) break; for (int i = 0; i < 4; i++) mkey[i] = rb[pos + i]; pos += 4; } if (rb.size() < pos + len) break; std::string payload; payload.resize(static_cast(len)); for (size_t i = 0; i < len; i++) { uint8_t c = rb[pos + i]; if (mask) c ^= mkey[i & 3]; payload[i] = static_cast(c); } rb.erase(rb.begin(), rb.begin() + pos + len); switch (opcode) { case 0x1: { // text last_event_ts_.store(std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()); std::lock_guard g(in_mu_); in_queue_.emplace_back(std::move(payload)); // Bound the queue. Drop oldest if too big — UI consumes // each frame so this should only kick in if the dashboard // is paused (window minimized, etc.). while (in_queue_.size() > 512) in_queue_.pop_front(); break; } case 0x8: // close return false; case 0x9: // ping → reply with pong (same payload) if (!send_frame(0xA, payload)) return false; break; case 0xA: // pong, ignore break; default: // 0x0 continuation or unexpected opcode: bail (server // controlled by us, shouldn't happen). return false; } } } return true; } bool WsClient::send_frame(int opcode, const std::string& payload) { sock_t sock = static_cast(sock_.load()); if (sock == SOCK_INVALID || static_cast(sock) < 0) return false; std::vector frame; frame.reserve(payload.size() + 16); frame.push_back(static_cast(0x80 | (opcode & 0x0F))); // FIN + opcode uint64_t len = payload.size(); if (len < 126) { frame.push_back(static_cast(0x80 | len)); // mask bit set } else if (len <= 0xFFFF) { frame.push_back(static_cast(0x80 | 126)); frame.push_back(static_cast((len >> 8) & 0xFF)); frame.push_back(static_cast(len & 0xFF)); } else { frame.push_back(static_cast(0x80 | 127)); for (int i = 7; i >= 0; i--) frame.push_back(static_cast((len >> (8 * i)) & 0xFF)); } // Mask key (RFC requires mask for client → server frames). std::random_device rd; uint8_t mkey[4]; for (auto& b : mkey) b = static_cast(rd() & 0xff); for (int i = 0; i < 4; i++) frame.push_back(mkey[i]); for (size_t i = 0; i < payload.size(); i++) { frame.push_back(static_cast(payload[i]) ^ mkey[i & 3]); } return send_all(sock, reinterpret_cast(frame.data()), frame.size()); }