package embeddednats_test import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" "net" "os" "path/filepath" "testing" "time" "github.com/enmanuel/unibus/pkg/busauth" "github.com/enmanuel/unibus/pkg/embeddednats" "github.com/nats-io/nats.go" server "github.com/nats-io/nats-server/v2/server" ) // freePort returns an OS-assigned free TCP port on loopback. func freePort(t *testing.T) int { t.Helper() l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("free port: %v", err) } defer l.Close() return l.Addr().(*net.TCPAddr).Port } // startNode boots a clustered embedded NATS node. peerRoutePorts are the route // ports of the OTHER nodes; user/pass gate the route layer (empty disables it); // routeTLS, when non-nil, secures the routes with mutual TLS. func startNode(t *testing.T, name string, clientPort, routePort int, peerRoutePorts []int, user, pass string, routeTLS *clusterTLS) *server.Server { t.Helper() routes := make([]string, 0, len(peerRoutePorts)) for _, p := range peerRoutePorts { // Carry the cluster credentials in the route URL so this node // authenticates outbound to its peers' route listeners. if user != "" { routes = append(routes, fmt.Sprintf("nats://%s:%s@127.0.0.1:%d", user, pass, p)) } else { routes = append(routes, fmt.Sprintf("nats://127.0.0.1:%d", p)) } } cc := &embeddednats.ClusterConfig{ Name: "unibus-test", Host: "127.0.0.1", Port: routePort, Routes: routes, Username: user, Password: pass, } if routeTLS != nil { cfg, err := busauth.RouteTLSConfig(routeTLS.cert, routeTLS.key, routeTLS.ca) if err != nil { t.Fatalf("route TLS for %s: %v", name, err) } cc.TLS = cfg } ns, err := embeddednats.StartServer(embeddednats.ServerConfig{ StoreDir: t.TempDir(), Host: "127.0.0.1", Port: clientPort, ServerName: name, Cluster: cc, }) if err != nil { t.Fatalf("start node %s: %v", name, err) } t.Cleanup(func() { ns.Shutdown(); ns.WaitForShutdown() }) return ns } // waitRoutes waits until ns has at least want established routes, or fails. func waitRoutes(t *testing.T, ns *server.Server, want int) { t.Helper() deadline := time.Now().Add(8 * time.Second) for time.Now().Before(deadline) { if ns.NumRoutes() >= want { return } time.Sleep(50 * time.Millisecond) } t.Fatalf("node %q never reached %d routes (have %d)", ns.Name(), want, ns.NumRoutes()) } // stableRouteCount waits for ns's route count to stop changing (the NATS route // pool opens several connections per peer asynchronously) and returns it, so a // test can use it as a baseline that an impostor must not increase. func stableRouteCount(t *testing.T, ns *server.Server) int { t.Helper() prev := -1 stableSince := time.Now() deadline := time.Now().Add(5 * time.Second) for time.Now().Before(deadline) { n := ns.NumRoutes() if n != prev { prev = n stableSince = time.Now() } else if time.Since(stableSince) >= 750*time.Millisecond { return n } time.Sleep(50 * time.Millisecond) } return prev } // pubSubAcrossNodes connects a subscriber to subURL and a publisher to pubURL, // publishes one message on subject, and reports whether it arrived within 3s. // This proves the cluster forwards client subjects between nodes. func pubSubAcrossNodes(t *testing.T, subURL, pubURL, subject, payload string) bool { t.Helper() subConn, err := nats.Connect(subURL) if err != nil { t.Fatalf("subscriber connect %s: %v", subURL, err) } defer subConn.Close() got := make(chan string, 1) if _, err := subConn.Subscribe(subject, func(m *nats.Msg) { select { case got <- string(m.Data): default: } }); err != nil { t.Fatalf("subscribe: %v", err) } if err := subConn.Flush(); err != nil { t.Fatalf("flush sub: %v", err) } pubConn, err := nats.Connect(pubURL) if err != nil { t.Fatalf("publisher connect %s: %v", pubURL, err) } defer pubConn.Close() // Retry the publish for a moment: route interest propagation across the // cluster is asynchronous, so the very first publish can race the gossip. deadline := time.Now().Add(3 * time.Second) for time.Now().Before(deadline) { if err := pubConn.Publish(subject, []byte(payload)); err != nil { t.Fatalf("publish: %v", err) } _ = pubConn.Flush() select { case v := <-got: return v == payload case <-time.After(100 * time.Millisecond): } } return false } // --- golden: two-node cluster forwards client subjects across nodes ---------- func TestClusterForwardsAcrossNodes(t *testing.T) { rp0, rp1 := freePort(t), freePort(t) n0 := startNode(t, "n0", freePort(t), rp0, []int{rp1}, "clusteruser", "clusterpass", nil) n1 := startNode(t, "n1", freePort(t), rp1, []int{rp0}, "clusteruser", "clusterpass", nil) waitRoutes(t, n0, 1) waitRoutes(t, n1, 1) if !pubSubAcrossNodes(t, n0.ClientURL(), n1.ClientURL(), "test.cross", "hello-cluster") { t.Fatalf("subject published on n1 did not reach subscriber on n0") } } // --- edge: three-node cluster (HA shape) forwards between non-adjacent nodes -- func TestClusterThreeNodesForward(t *testing.T) { rp0, rp1, rp2 := freePort(t), freePort(t), freePort(t) n0 := startNode(t, "n0", freePort(t), rp0, []int{rp1, rp2}, "u", "p", nil) n1 := startNode(t, "n1", freePort(t), rp1, []int{rp0, rp2}, "u", "p", nil) n2 := startNode(t, "n2", freePort(t), rp2, []int{rp0, rp1}, "u", "p", nil) waitRoutes(t, n0, 2) waitRoutes(t, n1, 2) waitRoutes(t, n2, 2) // Publish on n2, subscribe on n0: a message must traverse the cluster. if !pubSubAcrossNodes(t, n0.ClientURL(), n2.ClientURL(), "test.ha", "three-node") { t.Fatalf("subject published on n2 did not reach subscriber on n0") } } // --- error: a node with the wrong cluster password is rejected as a route ----- func TestClusterRejectsBadRouteAuth(t *testing.T) { rp0, rp1 := freePort(t), freePort(t) good := startNode(t, "good", freePort(t), rp0, []int{rp1}, "secret", "right", nil) _ = startNode(t, "peer", freePort(t), rp1, []int{rp0}, "secret", "right", nil) waitRoutes(t, good, 1) // Let the route pool settle so the baseline count is stable (NATS opens a // pool of route connections per peer, so NumRoutes counts connections, not // distinct peers). base := stableRouteCount(t, good) // Impostor knows the addresses but not the cluster password. It tries to // route to `good`; the route handshake must be rejected, so the impostor // never establishes a route. impostor := startNode(t, "impostor", freePort(t), freePort(t), []int{rp0}, "secret", "WRONG", nil) // Give the route layer ample time to (fail to) connect, then assert it never // formed: the impostor has zero routes, and `good`'s route count is unchanged // (it did not accept a route from the impostor). time.Sleep(2 * time.Second) if n := impostor.NumRoutes(); n != 0 { t.Fatalf("impostor with wrong cluster password formed %d routes, want 0", n) } if n := good.NumRoutes(); n != base { t.Fatalf("legit node route count changed from %d to %d after impostor attempt (it accepted the impostor)", base, n) } } // --- golden (TLS): mutual-TLS routes forward across nodes --------------------- func TestClusterMutualTLSForwards(t *testing.T) { ca, caKey := genCA(t) dir := t.TempDir() tlsA := writeNodeCert(t, dir, "a", ca, caKey) tlsB := writeNodeCert(t, dir, "b", ca, caKey) rp0, rp1 := freePort(t), freePort(t) n0 := startNode(t, "n0", freePort(t), rp0, []int{rp1}, "u", "p", tlsA) n1 := startNode(t, "n1", freePort(t), rp1, []int{rp0}, "u", "p", tlsB) waitRoutes(t, n0, 1) waitRoutes(t, n1, 1) if !pubSubAcrossNodes(t, n0.ClientURL(), n1.ClientURL(), "test.tls", "mtls-ok") { t.Fatalf("subject did not cross the mutual-TLS cluster") } } // --- error (TLS): a node whose cert is not signed by the bus CA cannot join --- func TestClusterRejectsUnsignedNode(t *testing.T) { ca, caKey := genCA(t) dir := t.TempDir() tlsGood := writeNodeCert(t, dir, "good", ca, caKey) tlsPeer := writeNodeCert(t, dir, "peer", ca, caKey) // The impostor signs its node cert with a DIFFERENT CA, and pins only that // CA. The legit nodes' RequireAndVerifyClientCert against the bus CA rejects // it; the impostor likewise rejects the legit node's cert. No route forms. otherCA, otherKey := genCA(t) tlsImpostor := writeNodeCert(t, dir, "impostor", otherCA, otherKey) rp0, rp1 := freePort(t), freePort(t) good := startNode(t, "good", freePort(t), rp0, []int{rp1}, "u", "p", tlsGood) _ = startNode(t, "peer", freePort(t), rp1, []int{rp0}, "u", "p", tlsPeer) waitRoutes(t, good, 1) base := stableRouteCount(t, good) impostor := startNode(t, "impostor", freePort(t), freePort(t), []int{rp0}, "u", "p", tlsImpostor) time.Sleep(2 * time.Second) if n := impostor.NumRoutes(); n != 0 { t.Fatalf("impostor with unsigned cert formed %d routes, want 0", n) } if n := good.NumRoutes(); n != base { t.Fatalf("legit node route count changed from %d to %d after unsigned impostor attempt (it accepted the impostor)", base, n) } } // --- cert helpers ------------------------------------------------------------ type clusterTLS struct{ cert, key, ca string } // PEM file paths // genCA creates a self-signed ECDSA CA certificate and its key. func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("gen CA key: %v", err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "unibus-test-CA"}, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, IsCA: true, } der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) if err != nil { t.Fatalf("create CA cert: %v", err) } caCert, err := x509.ParseCertificate(der) if err != nil { t.Fatalf("parse CA cert: %v", err) } return caCert, key } // writeNodeCert issues a node certificate signed by ca (SAN 127.0.0.1/::1, // usable as both server and client) and writes cert/key/ca PEM files, returning // their paths for RouteTLSConfig. func writeNodeCert(t *testing.T, dir, name string, ca *x509.Certificate, caKey *ecdsa.PrivateKey) *clusterTLS { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("gen node key: %v", err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().UnixNano()), Subject: pkix.Name{CommonName: name}, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, DNSNames: []string{"localhost"}, } der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey) if err != nil { t.Fatalf("create node cert: %v", err) } certPath := filepath.Join(dir, name+".crt") keyPath := filepath.Join(dir, name+".key") caPath := filepath.Join(dir, name+"-ca.crt") writePEM(t, certPath, "CERTIFICATE", der) keyDER, err := x509.MarshalECPrivateKey(key) if err != nil { t.Fatalf("marshal node key: %v", err) } writePEM(t, keyPath, "EC PRIVATE KEY", keyDER) writePEM(t, caPath, "CERTIFICATE", ca.Raw) return &clusterTLS{cert: certPath, key: keyPath, ca: caPath} } func writePEM(t *testing.T, path, blockType string, der []byte) { t.Helper() b := pem.EncodeToMemory(&pem.Block{Type: blockType, Bytes: der}) if err := os.WriteFile(path, b, 0o600); err != nil { t.Fatalf("write %s: %v", path, err) } }