package infra import ( "bytes" "fmt" "os" "path/filepath" "regexp" "strings" "time" "gopkg.in/yaml.v3" ) // SynapseMsc3861Config holds parameters for enabling MSC3861 (MAS) in homeserver.yaml. type SynapseMsc3861Config struct { // HomeserverYamlPath is the absolute path to the homeserver.yaml file. HomeserverYamlPath string // MasEndpoint is the internal MAS URL (e.g. http://mas:8080/). MasEndpoint string // MasSecret is the shared_secret hex (64 hex chars, 32 bytes) matching mas/config.yaml::matrix.secret. MasSecret string // BackupDir is the directory where the original file backup is stored. BackupDir string // DryRun: if true, compute diff only without writing files. DryRun bool } // SynapseMsc3861Result holds the output of SynapseMsc3861Enable. type SynapseMsc3861Result struct { // BackupPath is the path of the backup file created (empty if DryRun=true). BackupPath string // LinesAdded is the number of added lines in the diff. LinesAdded int // LinesRemoved is the number of removed lines in the diff. LinesRemoved int // Diff is the unified diff string between original and modified content. Diff string } // hexPattern matches exactly 64 lowercase hex characters. var hexPattern = regexp.MustCompile(`^[0-9a-f]{64}$`) // SynapseMsc3861Enable edits a Synapse homeserver.yaml to enable MSC3861 (Matrix Authentication Service). // // Steps: // 1. Validate inputs. // 2. Backup the original file to BackupDir. // 3. Parse the YAML using the yaml.v3 Node API (preserves comments). // 4. Uncomment / add the matrix_authentication_service block. // 5. Ensure experimental_features.msc3861.enabled = true. // 6. Ensure password_config.enabled = false. // 7. Compute a unified diff. // 8. Write the result unless DryRun=true. func SynapseMsc3861Enable(cfg SynapseMsc3861Config) (SynapseMsc3861Result, error) { var result SynapseMsc3861Result // --- 1. Validate inputs --- if cfg.HomeserverYamlPath == "" { return result, fmt.Errorf("HomeserverYamlPath is required") } if _, err := os.Stat(cfg.HomeserverYamlPath); err != nil { return result, fmt.Errorf("HomeserverYamlPath %q not found: %w", cfg.HomeserverYamlPath, err) } if cfg.MasEndpoint == "" { return result, fmt.Errorf("MasEndpoint is required") } if !strings.HasPrefix(cfg.MasEndpoint, "http://") && !strings.HasPrefix(cfg.MasEndpoint, "https://") { return result, fmt.Errorf("MasEndpoint must start with http:// or https://") } if !hexPattern.MatchString(cfg.MasSecret) { return result, fmt.Errorf("MasSecret must be exactly 64 lowercase hex characters (32 bytes)") } if cfg.BackupDir == "" { return result, fmt.Errorf("BackupDir is required") } // --- Read original file --- originalBytes, err := os.ReadFile(cfg.HomeserverYamlPath) if err != nil { return result, fmt.Errorf("reading homeserver.yaml: %w", err) } originalContent := string(originalBytes) // --- 2. Backup --- if !cfg.DryRun { if err := os.MkdirAll(cfg.BackupDir, 0o755); err != nil { return result, fmt.Errorf("creating backup dir %q: %w", cfg.BackupDir, err) } ts := time.Now().Unix() backupName := fmt.Sprintf("homeserver_%d.yaml", ts) backupPath := filepath.Join(cfg.BackupDir, backupName) if err := os.WriteFile(backupPath, originalBytes, 0o644); err != nil { return result, fmt.Errorf("writing backup: %w", err) } result.BackupPath = backupPath } // --- 3–6. Modify content using line-level and YAML node processing --- modifiedContent, err := applyMsc3861Edits(originalContent, cfg.MasEndpoint, cfg.MasSecret) if err != nil { return result, fmt.Errorf("applying MSC3861 edits: %w", err) } // --- 7. Compute diff --- diff := unifiedDiff("homeserver.yaml (original)", "homeserver.yaml (modified)", originalContent, modifiedContent) result.Diff = diff added, removed := countDiffLines(diff) result.LinesAdded = added result.LinesRemoved = removed // --- 8. Write if not DryRun --- if !cfg.DryRun { if err := os.WriteFile(cfg.HomeserverYamlPath, []byte(modifiedContent), 0o644); err != nil { return result, fmt.Errorf("writing modified homeserver.yaml: %w", err) } } return result, nil } // applyMsc3861Edits performs all required YAML edits on the raw content string. // It uses a line-based approach so that comments are preserved exactly. func applyMsc3861Edits(content, masEndpoint, masSecret string) (string, error) { // We work line-by-line for the commented-block replacement and password_config, // then use yaml.v3 Node API for experimental_features.msc3861. lines := strings.Split(content, "\n") lines = enableMasBlock(lines, masEndpoint, masSecret) lines = setPasswordConfigDisabled(lines) modified := strings.Join(lines, "\n") // Now handle experimental_features.msc3861 via yaml.v3 Node API. modified, err := ensureExperimentalMsc3861(modified) if err != nil { return "", fmt.Errorf("updating experimental_features: %w", err) } return modified, nil } // masBlockTemplate is the YAML block we want active in the file. func masBlockLines(endpoint, secret string) []string { return []string{ "matrix_authentication_service:", " enabled: true", fmt.Sprintf(" endpoint: %q", endpoint), fmt.Sprintf(" secret: %q", secret), } } // enableMasBlock finds the commented-out matrix_authentication_service block // (lines starting with "# matrix_authentication_service:") or an existing active // block, and replaces/inserts the correct active block. func enableMasBlock(lines []string, endpoint, secret string) []string { // Patterns to detect the section. commentedHeader := regexp.MustCompile(`^#\s*matrix_authentication_service:`) activeHeader := regexp.MustCompile(`^matrix_authentication_service:`) commentedSubkey := regexp.MustCompile(`^#\s+\w`) newBlock := masBlockLines(endpoint, secret) var result []string i := 0 injected := false for i < len(lines) { line := lines[i] if commentedHeader.MatchString(line) && !injected { // Replace the commented block (consume commented sub-lines too). result = append(result, newBlock...) injected = true i++ // Skip subsequent commented sub-lines belonging to this block. for i < len(lines) && commentedSubkey.MatchString(lines[i]) { i++ } continue } if activeHeader.MatchString(line) && !injected { // Already active — replace it to ensure correct values. result = append(result, newBlock...) injected = true i++ // Skip existing sub-lines (indented). for i < len(lines) && (strings.HasPrefix(lines[i], " ") || lines[i] == "") { // Stop at the next top-level key. if lines[i] != "" && !strings.HasPrefix(lines[i], " ") { break } if strings.HasPrefix(lines[i], " ") { i++ continue } break } continue } result = append(result, line) i++ } if !injected { // Block not found anywhere — append at end (before trailing blank lines). result = append(result, "") result = append(result, newBlock...) } return result } // setPasswordConfigDisabled ensures `password_config:\n enabled: false` in the file. func setPasswordConfigDisabled(lines []string) []string { headerRe := regexp.MustCompile(`^password_config:`) commentedRe := regexp.MustCompile(`^#\s*password_config:`) var result []string i := 0 injected := false for i < len(lines) { line := lines[i] if commentedRe.MatchString(line) && !injected { // Replace commented block. result = append(result, "password_config:") result = append(result, " enabled: false") injected = true i++ for i < len(lines) && regexp.MustCompile(`^#\s+\w`).MatchString(lines[i]) { i++ } continue } if headerRe.MatchString(line) && !injected { // Active block — update or add enabled: false sub-key. result = append(result, line) injected = true i++ foundEnabled := false var subLines []string for i < len(lines) && strings.HasPrefix(lines[i], " ") { sl := lines[i] if regexp.MustCompile(`^\s+enabled:`).MatchString(sl) { subLines = append(subLines, " enabled: false") foundEnabled = true } else { subLines = append(subLines, sl) } i++ } if !foundEnabled { subLines = append([]string{" enabled: false"}, subLines...) } result = append(result, subLines...) continue } result = append(result, line) i++ } if !injected { result = append(result, "") result = append(result, "password_config:") result = append(result, " enabled: false") } return result } // ensureExperimentalMsc3861 uses yaml.v3 Node API to set // experimental_features.msc3861.enabled = true preserving other keys. func ensureExperimentalMsc3861(content string) (string, error) { var doc yaml.Node if err := yaml.Unmarshal([]byte(content), &doc); err != nil { return content, fmt.Errorf("yaml unmarshal: %w", err) } if doc.Kind == 0 { // Empty document — append the block. return content + "\nexperimental_features:\n msc3861:\n enabled: true\n", nil } root := &doc if root.Kind == yaml.DocumentNode && len(root.Content) > 0 { root = root.Content[0] } if root.Kind != yaml.MappingNode { return content, fmt.Errorf("unexpected root YAML node kind %v", root.Kind) } // Find or create experimental_features. expNode := findMappingValue(root, "experimental_features") if expNode == nil { // Append experimental_features block. keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "experimental_features"} valNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} root.Content = append(root.Content, keyNode, valNode) expNode = valNode } // Find or create msc3861 under experimental_features. mscNode := findMappingValue(expNode, "msc3861") if mscNode == nil { keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "msc3861"} valNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} expNode.Content = append(expNode.Content, keyNode, valNode) mscNode = valNode } // Set enabled: true inside msc3861. enabledNode := findMappingValue(mscNode, "enabled") if enabledNode == nil { keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "enabled"} valNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"} mscNode.Content = append(mscNode.Content, keyNode, valNode) } else { enabledNode.Value = "true" enabledNode.Tag = "!!bool" } var buf bytes.Buffer enc := yaml.NewEncoder(&buf) enc.SetIndent(2) if err := enc.Encode(&doc); err != nil { return content, fmt.Errorf("yaml marshal: %w", err) } if err := enc.Close(); err != nil { return content, fmt.Errorf("yaml encoder close: %w", err) } return buf.String(), nil } // findMappingValue returns the value node for the given key in a mapping node, or nil. func findMappingValue(node *yaml.Node, key string) *yaml.Node { if node.Kind != yaml.MappingNode { return nil } for i := 0; i+1 < len(node.Content); i += 2 { if node.Content[i].Value == key { return node.Content[i+1] } } return nil } // unifiedDiff produces a simple unified diff between two texts. func unifiedDiff(fromLabel, toLabel, original, modified string) string { if original == modified { return "" } origLines := strings.Split(original, "\n") modLines := strings.Split(modified, "\n") var sb strings.Builder fmt.Fprintf(&sb, "--- %s\n", fromLabel) fmt.Fprintf(&sb, "+++ %s\n", toLabel) // Simple LCS-based diff using a greedy approach (good enough for YAML files). lcs := computeLCS(origLines, modLines) formatDiff(&sb, origLines, modLines, lcs) return sb.String() } // computeLCS computes the longest common subsequence indices for two string slices. // Returns a slice of (origIdx, modIdx) pairs. type lcsEntry struct{ o, m int } func computeLCS(a, b []string) []lcsEntry { la, lb := len(a), len(b) // dp[i][j] = LCS length for a[:i], b[:j] dp := make([][]int, la+1) for i := range dp { dp[i] = make([]int, lb+1) } for i := 1; i <= la; i++ { for j := 1; j <= lb; j++ { if a[i-1] == b[j-1] { dp[i][j] = dp[i-1][j-1] + 1 } else if dp[i-1][j] >= dp[i][j-1] { dp[i][j] = dp[i-1][j] } else { dp[i][j] = dp[i][j-1] } } } // Backtrack. var result []lcsEntry i, j := la, lb for i > 0 && j > 0 { if a[i-1] == b[j-1] { result = append([]lcsEntry{{i - 1, j - 1}}, result...) i-- j-- } else if dp[i-1][j] >= dp[i][j-1] { i-- } else { j-- } } return result } // formatDiff writes unified diff hunks. func formatDiff(sb *strings.Builder, orig, mod []string, lcs []lcsEntry) { const ctx = 3 // Build change regions. var hunks []diffHunk lcsIdx := 0 oi, mi := 0, 0 flushHunk := func(ho1, ho2, hm1, hm2 int) { // Add context lines. ctxStart := ho1 - ctx if ctxStart < 0 { ctxStart = 0 } ctxEnd := ho2 + ctx if ctxEnd > len(orig) { ctxEnd = len(orig) } ctxMStart := hm1 - ctx if ctxMStart < 0 { ctxMStart = 0 } ctxMEnd := hm2 + ctx if ctxMEnd > len(mod) { ctxMEnd = len(mod) } var lines []string // Leading context. for k := ctxStart; k < ho1; k++ { lines = append(lines, " "+orig[k]) } // Removals. for k := ho1; k < ho2; k++ { lines = append(lines, "-"+orig[k]) } // Additions. for k := hm1; k < hm2; k++ { lines = append(lines, "+"+mod[k]) } // Trailing context. for k := ho2; k < ctxEnd; k++ { lines = append(lines, " "+orig[k]) } _ = ctxMStart _ = ctxMEnd hunks = append(hunks, diffHunk{ctxStart, ctxEnd, ctxMStart, ctxMEnd, lines}) } for lcsIdx <= len(lcs) { var lo, lm int if lcsIdx < len(lcs) { lo = lcs[lcsIdx].o lm = lcs[lcsIdx].m } else { lo = len(orig) lm = len(mod) } if oi < lo || mi < lm { flushHunk(oi, lo, mi, lm) } if lcsIdx < len(lcs) { oi = lcs[lcsIdx].o + 1 mi = lcs[lcsIdx].m + 1 } lcsIdx++ } // Merge overlapping hunks and print. merged := mergeHunks(hunks) for _, h := range merged { fmt.Fprintf(sb, "@@ -%d,%d +%d,%d @@\n", h.o1+1, h.o2-h.o1, h.m1+1, h.m2-h.m1) for _, l := range h.lines { sb.WriteString(l) sb.WriteByte('\n') } } } type diffHunk struct { o1, o2, m1, m2 int lines []string } func mergeHunks(hunks []diffHunk) []diffHunk { var result []diffHunk for _, dh := range hunks { if len(result) > 0 && dh.o1 <= result[len(result)-1].o2 { prev := &result[len(result)-1] if dh.o2 > prev.o2 { prev.o2 = dh.o2 } if dh.m2 > prev.m2 { prev.m2 = dh.m2 } prev.lines = append(prev.lines, dh.lines...) } else { result = append(result, dh) } } return result } // countDiffLines counts added (+) and removed (-) lines in a unified diff. func countDiffLines(diff string) (added, removed int) { for _, line := range strings.Split(diff, "\n") { if strings.HasPrefix(line, "+") && !strings.HasPrefix(line, "+++") { added++ } else if strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "---") { removed++ } } return }