package infra import ( "encoding/json" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" ) // loginFlowsJSON builds a minimal /_matrix/client/v3/login response body. func loginFlowsJSON(flows []loginFlow) string { b, _ := json.Marshal(loginFlowsResponse{Flows: flows}) return string(b) } // masFlows returns a typical post-migration response: only SSO with one IdP. func masFlows(idpID string) []loginFlow { return []loginFlow{ { Type: "m.login.sso", IdentityProviders: []idpProvider{ {ID: idpID, Name: "MAS"}, }, }, } } // legacyFlows returns a pre-migration response: password + application_service. func legacyFlows() []loginFlow { return []loginFlow{ {Type: "m.login.password"}, {Type: "m.login.application_service"}, } } func TestSynapseLoginFlowsCheck(t *testing.T) { // Disable real sleep during tests origSleep := sleepSeconds sleepSeconds = func(int) {} t.Cleanup(func() { sleepSeconds = origSleep }) t.Run("SSO + IdP expected -> success on first attempt", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(loginFlowsJSON(masFlows("oidc-mas")))) })) defer srv.Close() cfg := SynapseLoginFlowsCheckConfig{ HomeserverURL: srv.URL, ExpectedSsoIdpID: "oidc-mas", MaxRetries: 5, RetryDelaySeconds: 0, HttpTimeoutSeconds: 5, } res, err := SynapseLoginFlowsCheck(cfg) if err != nil { t.Fatalf("expected no error, got: %v", err) } if !res.SsoPresent { t.Error("SsoPresent should be true") } if !res.IdpFound { t.Error("IdpFound should be true") } if res.PasswordEnabled { t.Error("PasswordEnabled should be false") } if res.AttemptsUsed != 1 { t.Errorf("expected 1 attempt, got %d", res.AttemptsUsed) } if len(res.LastResponseJSON) == 0 { t.Error("LastResponseJSON should not be empty") } }) t.Run("legacy response then SSO on 3rd attempt -> success after retries", func(t *testing.T) { var callCount int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := int(atomic.AddInt32(&callCount, 1)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if n < 3 { w.Write([]byte(loginFlowsJSON(legacyFlows()))) } else { w.Write([]byte(loginFlowsJSON(masFlows("oidc-mas")))) } })) defer srv.Close() cfg := SynapseLoginFlowsCheckConfig{ HomeserverURL: srv.URL, ExpectedSsoIdpID: "oidc-mas", MaxRetries: 10, RetryDelaySeconds: 0, HttpTimeoutSeconds: 5, } res, err := SynapseLoginFlowsCheck(cfg) if err != nil { t.Fatalf("expected no error, got: %v", err) } if res.AttemptsUsed != 3 { t.Errorf("expected 3 attempts, got %d", res.AttemptsUsed) } if !res.SsoPresent { t.Error("SsoPresent should be true") } if res.PasswordEnabled { t.Error("PasswordEnabled should be false") } }) t.Run("response never changes -> error after maxRetries", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(loginFlowsJSON(legacyFlows()))) })) defer srv.Close() cfg := SynapseLoginFlowsCheckConfig{ HomeserverURL: srv.URL, ExpectedSsoIdpID: "oidc-mas", MaxRetries: 3, RetryDelaySeconds: 0, HttpTimeoutSeconds: 5, } res, err := SynapseLoginFlowsCheck(cfg) if err == nil { t.Fatal("expected error after max retries, got nil") } if !strings.Contains(err.Error(), "MAS migration not confirmed") { t.Errorf("expected 'MAS migration not confirmed' in error message, got: %v", err) } if res.AttemptsUsed != 3 { t.Errorf("expected 3 attempts used, got %d", res.AttemptsUsed) } if !res.PasswordEnabled { t.Error("PasswordEnabled should be true (legacy still active)") } }) t.Run("HTTP timeout -> error", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Deliberately hang longer than the 1s timeout <-r.Context().Done() })) defer srv.Close() cfg := SynapseLoginFlowsCheckConfig{ HomeserverURL: srv.URL, ExpectedSsoIdpID: "oidc-mas", MaxRetries: 1, RetryDelaySeconds: 0, HttpTimeoutSeconds: 1, } _, err := SynapseLoginFlowsCheck(cfg) if err == nil { t.Fatal("expected error on timeout, got nil") } if !strings.Contains(err.Error(), "synapse_login_flows_check") { t.Errorf("expected error to contain function name, got: %v", err) } }) t.Run("malformed JSON -> error", func(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{not valid json`)) })) defer srv.Close() cfg := SynapseLoginFlowsCheckConfig{ HomeserverURL: srv.URL, MaxRetries: 1, RetryDelaySeconds: 0, HttpTimeoutSeconds: 5, } _, err := SynapseLoginFlowsCheck(cfg) if err == nil { t.Fatal("expected error on malformed JSON, got nil") } if !strings.Contains(err.Error(), "json unmarshal") { t.Errorf("expected json unmarshal error, got: %v", err) } }) }