package infra import ( "encoding/json" "fmt" "io" "net/http" "strings" "time" ) // SynapseLoginFlowsCheckConfig holds the parameters for polling the Synapse // login-flows endpoint and verifying that the MAS (Matrix Authentication // Service) SSO flow is active. type SynapseLoginFlowsCheckConfig struct { HomeserverURL string // Public URL of the homeserver (e.g. https://matrix.example.com) ExpectedSsoIdpID string // IdP id to find in m.login.sso.identity_providers[].id (empty = only check SSO presence) MaxRetries int // Number of attempts before giving up (default: 10) RetryDelaySeconds int // Seconds to wait between attempts (default: 3) HttpTimeoutSeconds int // Per-request HTTP timeout in seconds (default: 5) } // SynapseLoginFlowsCheckResult contains the parsed state of the login-flows // endpoint after the last successful (or final failed) attempt. type SynapseLoginFlowsCheckResult struct { Flows []string // All flow types returned (e.g. ["m.login.sso"]) SsoPresent bool // true if "m.login.sso" is in Flows IdpFound bool // true if ExpectedSsoIdpID was found (or ExpectedSsoIdpID is empty and SsoPresent) PasswordEnabled bool // true if "m.login.password" is in Flows LastResponseJSON string // Raw JSON body from the last HTTP response AttemptsUsed int // Number of HTTP attempts made } // loginFlowsResponse is the structure returned by // GET /_matrix/client/v3/login type loginFlowsResponse struct { Flows []loginFlow `json:"flows"` } type loginFlow struct { Type string `json:"type"` IdentityProviders []idpProvider `json:"identity_providers,omitempty"` } type idpProvider struct { ID string `json:"id"` Name string `json:"name"` } // SynapseLoginFlowsCheck polls GET {HomeserverURL}/_matrix/client/v3/login // and checks that the SSO/MAS flow is present and password login is disabled. // It retries up to MaxRetries times with RetryDelaySeconds delay between each. // // Success condition: // - "m.login.sso" is present in flows // - ExpectedSsoIdpID found in identity_providers (skipped when empty) // - "m.login.password" is NOT present // // Returns the result from the last attempt. On convergence failure it also // returns a non-nil error describing the final state. func SynapseLoginFlowsCheck(cfg SynapseLoginFlowsCheckConfig) (SynapseLoginFlowsCheckResult, error) { if cfg.HomeserverURL == "" { return SynapseLoginFlowsCheckResult{}, fmt.Errorf("synapse_login_flows_check: HomeserverURL must not be empty") } cfg.HomeserverURL = strings.TrimRight(cfg.HomeserverURL, "/") if cfg.MaxRetries <= 0 { cfg.MaxRetries = 10 } if cfg.RetryDelaySeconds < 0 { cfg.RetryDelaySeconds = 3 } if cfg.HttpTimeoutSeconds <= 0 { cfg.HttpTimeoutSeconds = 5 } endpoint := cfg.HomeserverURL + "/_matrix/client/v3/login" httpClient := &http.Client{ Timeout: time.Duration(cfg.HttpTimeoutSeconds) * time.Second, } var result SynapseLoginFlowsCheckResult for attempt := 1; attempt <= cfg.MaxRetries; attempt++ { result.AttemptsUsed = attempt resp, body, parseErr := fetchAndParse(httpClient, endpoint) result.LastResponseJSON = body if parseErr != nil { // On the last attempt, surface the parse/network error if attempt == cfg.MaxRetries { return result, fmt.Errorf("synapse_login_flows_check: attempt %d/%d: %w", attempt, cfg.MaxRetries, parseErr) } sleepSeconds(cfg.RetryDelaySeconds) continue } // Build result from parsed response result.Flows = extractFlowTypes(resp.Flows) result.SsoPresent = containsFlow(resp.Flows, "m.login.sso") result.PasswordEnabled = containsFlow(resp.Flows, "m.login.password") if result.SsoPresent { if cfg.ExpectedSsoIdpID == "" { result.IdpFound = true } else { result.IdpFound = findIdp(resp.Flows, cfg.ExpectedSsoIdpID) } } else { result.IdpFound = false } // Check success condition if result.SsoPresent && result.IdpFound && !result.PasswordEnabled { return result, nil } if attempt < cfg.MaxRetries { sleepSeconds(cfg.RetryDelaySeconds) } } // Exhausted retries — build a descriptive error msg := buildConvergenceError(result, cfg) return result, fmt.Errorf("synapse_login_flows_check: %s", msg) } // fetchAndParse performs one HTTP GET and returns the parsed response plus the // raw body. On any error (network, status, JSON) the raw body may be partial. func fetchAndParse(client *http.Client, url string) (*loginFlowsResponse, string, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, "", fmt.Errorf("build request: %w", err) } req.Header.Set("Accept", "application/json") httpResp, err := client.Do(req) if err != nil { return nil, "", fmt.Errorf("http get: %w", err) } defer httpResp.Body.Close() raw, err := io.ReadAll(httpResp.Body) if err != nil { return nil, "", fmt.Errorf("read body: %w", err) } body := string(raw) if httpResp.StatusCode != http.StatusOK { return nil, body, fmt.Errorf("unexpected status %d: %s", httpResp.StatusCode, body) } var parsed loginFlowsResponse if err := json.Unmarshal(raw, &parsed); err != nil { return nil, body, fmt.Errorf("json unmarshal: %w", err) } return &parsed, body, nil } // extractFlowTypes returns the "type" field of each flow entry. func extractFlowTypes(flows []loginFlow) []string { types := make([]string, 0, len(flows)) for _, f := range flows { types = append(types, f.Type) } return types } // containsFlow reports whether any flow entry has the given type. func containsFlow(flows []loginFlow, flowType string) bool { for _, f := range flows { if f.Type == flowType { return true } } return false } // findIdp reports whether any identity_provider in a "m.login.sso" flow has // the given id. func findIdp(flows []loginFlow, idpID string) bool { for _, f := range flows { if f.Type != "m.login.sso" { continue } for _, idp := range f.IdentityProviders { if idp.ID == idpID { return true } } } return false } // buildConvergenceError assembles a human-readable error message describing // why the final state is not the expected post-migration state. func buildConvergenceError(r SynapseLoginFlowsCheckResult, cfg SynapseLoginFlowsCheckConfig) string { var parts []string if !r.SsoPresent { parts = append(parts, "m.login.sso not present") } if cfg.ExpectedSsoIdpID != "" && !r.IdpFound { parts = append(parts, fmt.Sprintf("IdP %q not found in identity_providers", cfg.ExpectedSsoIdpID)) } if r.PasswordEnabled { parts = append(parts, "m.login.password still enabled (MSC3861 not fully applied)") } reason := strings.Join(parts, "; ") return fmt.Sprintf("MAS migration not confirmed after %d attempt(s): %s", r.AttemptsUsed, reason) } // sleepSeconds sleeps for n seconds. Extracted for test patching via a // package-level variable. var sleepSeconds = func(n int) { if n > 0 { time.Sleep(time.Duration(n) * time.Second) } }