package matrix import ( "bytes" "context" "errors" "log/slog" "os" "path/filepath" "strings" "testing" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/id" ) // fakeCryptoHelper implements cryptoHelper for testing. type fakeCryptoHelper struct { initErr error closed bool accountID string } func (f *fakeCryptoHelper) Init(ctx context.Context) error { return f.initErr } func (f *fakeCryptoHelper) Close() error { f.closed = true; return nil } func (f *fakeCryptoHelper) SetAccountID(id string) { f.accountID = id } // fakeCryptoIniter implements cryptoIniter for testing. type fakeCryptoIniter struct { calls int helpers []*fakeCryptoHelper // one per call } func (f *fakeCryptoIniter) newHelper(pickleKey []byte, storePath string) (cryptoHelper, error) { idx := f.calls f.calls++ if idx < len(f.helpers) { return f.helpers[idx], nil } return &fakeCryptoHelper{}, nil } func TestInitCryptoCore_Success(t *testing.T) { dir := t.TempDir() storePath := filepath.Join(dir, "crypto", "crypto.db") initer := &fakeCryptoIniter{ helpers: []*fakeCryptoHelper{{initErr: nil}}, } closer, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default()) if err != nil { t.Fatalf("unexpected error: %v", err) } if closer == nil { t.Error("expected non-nil closer") } if initer.calls != 1 { t.Errorf("expected 1 call to newHelper, got %d", initer.calls) } } func TestInitCryptoCore_AutoRecoveryOnStaleStore(t *testing.T) { dir := t.TempDir() storePath := filepath.Join(dir, "crypto", "crypto.db") // Create a stale crypto.db file. _ = os.MkdirAll(filepath.Dir(storePath), 0700) _ = os.WriteFile(storePath, []byte("stale"), 0o644) // First call fails with "not marked as shared", second succeeds. initer := &fakeCryptoIniter{ helpers: []*fakeCryptoHelper{ {initErr: errors.New("device keys not marked as shared")}, {initErr: nil}, }, } _, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default()) if err != nil { t.Fatalf("expected auto-recovery to succeed, got: %v", err) } if initer.calls != 2 { t.Errorf("expected 2 calls (fail + retry), got %d", initer.calls) } } func TestInitCryptoCore_AutoRecoveryFailsTwice(t *testing.T) { dir := t.TempDir() storePath := filepath.Join(dir, "crypto", "crypto.db") _ = os.MkdirAll(filepath.Dir(storePath), 0700) _ = os.WriteFile(storePath, []byte("stale"), 0o644) initer := &fakeCryptoIniter{ helpers: []*fakeCryptoHelper{ {initErr: errors.New("not marked as shared")}, {initErr: errors.New("still broken after recovery")}, }, } _, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default()) if err == nil { t.Fatal("expected error when recovery also fails") } if !strings.Contains(err.Error(), "after auto-recovery") { t.Errorf("expected 'after auto-recovery' in error, got: %v", err) } } func TestInitCryptoCore_NonRecoverableError(t *testing.T) { dir := t.TempDir() storePath := filepath.Join(dir, "crypto", "crypto.db") initer := &fakeCryptoIniter{ helpers: []*fakeCryptoHelper{ {initErr: errors.New("connection refused")}, }, } _, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default()) if err == nil { t.Fatal("expected error for non-recoverable failure") } if !strings.Contains(err.Error(), "init e2ee") { t.Errorf("expected 'init e2ee' in error, got: %v", err) } // Should NOT have retried. if initer.calls != 1 { t.Errorf("expected 1 call (no retry for non-stale error), got %d", initer.calls) } } func TestResolvePickleKey_BadHex(t *testing.T) { _, err := resolvePickleKey("not-hex!", "token") if err == nil { t.Fatal("expected error for invalid hex pickle key") } if !strings.Contains(err.Error(), "decode pickle_key_env") { t.Errorf("unexpected error: %v", err) } } func TestResolvePickleKey_DeriveFromToken(t *testing.T) { key, err := resolvePickleKey("", "my-access-token") if err != nil { t.Fatalf("unexpected error: %v", err) } if len(key) != 32 { t.Errorf("expected 32-byte sha256 key, got %d bytes", len(key)) } } func TestResolvePickleKey_Explicit(t *testing.T) { hexKey := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" key, err := resolvePickleKey(hexKey, "ignored") if err != nil { t.Fatalf("unexpected error: %v", err) } if len(key) != 32 { t.Errorf("expected 32 bytes, got %d", len(key)) } } func TestInitHelper_SetsAccountID(t *testing.T) { helper := &fakeCryptoHelper{} initer := &fakeCryptoIniter{helpers: []*fakeCryptoHelper{helper}} _, err := initHelper(context.Background(), initer, []byte("key"), "/fake", "my-agent") if err != nil { t.Fatalf("unexpected error: %v", err) } if helper.accountID != "my-agent" { t.Errorf("expected accountID='my-agent', got '%s'", helper.accountID) } } // --- diagMachine fake for testing diagnostics --- type fakeDiagMachine struct { pubKeys *crypto.CrossSigningPublicKeysCache ownDevice *id.Device seeds crypto.CrossSigningSeeds seedsPanic bool // simulate ExportCrossSigningKeys panic trustState id.TrustState trustErr error deviceTrusted bool } func (f *fakeDiagMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *crypto.CrossSigningPublicKeysCache { return f.pubKeys } func (f *fakeDiagMachine) OwnIdentity() *id.Device { return f.ownDevice } func (f *fakeDiagMachine) ExportCrossSigningKeys() crypto.CrossSigningSeeds { if f.seedsPanic { panic("nil pointer dereference") } return f.seeds } func (f *fakeDiagMachine) ResolveTrustContext(ctx context.Context, device *id.Device) (id.TrustState, error) { return f.trustState, f.trustErr } func (f *fakeDiagMachine) IsDeviceTrusted(device *id.Device) bool { return f.deviceTrusted } // testLogger returns a logger that writes to a buffer for assertions. func testLogger(buf *bytes.Buffer) *slog.Logger { return slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) } func TestLogCryptoDiagnosticsCore_NilOwnDevice(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) machine := &fakeDiagMachine{ pubKeys: &crypto.CrossSigningPublicKeysCache{MasterKey: "abc", SelfSigningKey: "def", UserSigningKey: "ghi"}, ownDevice: nil, // nil device — was causing panic before the fix seeds: crypto.CrossSigningSeeds{MasterKey: []byte("x"), SelfSigningKey: []byte("y"), UserSigningKey: []byte("z")}, } logCryptoDiagnosticsCore(context.Background(), machine, "@bot:test", "DEVICE1", logger) out := buf.String() if !strings.Contains(out, "own device identity is nil") { t.Errorf("expected warning about nil device identity, got:\n%s", out) } } func TestLogCryptoDiagnosticsCore_NilPublicKeys(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) machine := &fakeDiagMachine{ pubKeys: nil, // no cross-signing public keys ownDevice: nil, seeds: crypto.CrossSigningSeeds{}, } logCryptoDiagnosticsCore(context.Background(), machine, "@bot:test", "DEVICE1", logger) out := buf.String() if !strings.Contains(out, "NO cross-signing public keys found") { t.Errorf("expected warning about missing public keys, got:\n%s", out) } } func TestLogCrossSigningSeeds_PanicRecovery(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) machine := &fakeDiagMachine{seedsPanic: true} // Must not panic — should recover gracefully. logCrossSigningSeeds(machine, logger) out := buf.String() if !strings.Contains(out, "cross-signing private keys not available") { t.Errorf("expected recovery warning, got:\n%s", out) } } func TestLogCrossSigningSeeds_AllPresent(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) machine := &fakeDiagMachine{ seeds: crypto.CrossSigningSeeds{ MasterKey: []byte("master"), SelfSigningKey: []byte("self"), UserSigningKey: []byte("user"), }, } logCrossSigningSeeds(machine, logger) out := buf.String() if !strings.Contains(out, "cross-signing private keys in store") { t.Errorf("expected info about keys in store, got:\n%s", out) } if strings.Contains(out, "self-signing private key NOT in store") { t.Error("should not warn when self-signing key is present") } } func TestLogCrossSigningSeeds_MissingSelfSigning(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) machine := &fakeDiagMachine{ seeds: crypto.CrossSigningSeeds{ MasterKey: []byte("master"), // SelfSigningKey intentionally missing }, } logCrossSigningSeeds(machine, logger) out := buf.String() if !strings.Contains(out, "self-signing private key NOT in store") { t.Errorf("expected warning about missing self-signing key, got:\n%s", out) } } func TestLogDeviceTrust_Trusted(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) device := &id.Device{DeviceID: "DEV1"} machine := &fakeDiagMachine{ trustState: id.TrustStateCrossSignedTOFU, deviceTrusted: true, } logDeviceTrust(context.Background(), machine, device, logger) out := buf.String() if !strings.Contains(out, "own device trust state") { t.Errorf("expected trust state log, got:\n%s", out) } if strings.Contains(out, "device is NOT cross-signed") { t.Error("should not warn for trusted device") } } func TestLogDeviceTrust_Untrusted(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) device := &id.Device{DeviceID: "DEV2"} machine := &fakeDiagMachine{ trustState: id.TrustStateUnset, deviceTrusted: false, } logDeviceTrust(context.Background(), machine, device, logger) out := buf.String() if !strings.Contains(out, "device is NOT cross-signed") { t.Errorf("expected cross-sign warning, got:\n%s", out) } } func TestLogDeviceTrust_ResolveTrustError(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) device := &id.Device{DeviceID: "DEV3"} machine := &fakeDiagMachine{ trustErr: errors.New("crypto store unavailable"), } logDeviceTrust(context.Background(), machine, device, logger) out := buf.String() if !strings.Contains(out, "failed to resolve device trust") { t.Errorf("expected trust resolve error, got:\n%s", out) } } func TestLogCryptoDiagnosticsCore_FullHappyPath(t *testing.T) { var buf bytes.Buffer logger := testLogger(&buf) machine := &fakeDiagMachine{ pubKeys: &crypto.CrossSigningPublicKeysCache{ MasterKey: "masterkey123", SelfSigningKey: "selfkey456", UserSigningKey: "userkey789", }, ownDevice: &id.Device{DeviceID: "MYDEV"}, trustState: id.TrustStateCrossSignedTOFU, deviceTrusted: true, seeds: crypto.CrossSigningSeeds{ MasterKey: []byte("m"), SelfSigningKey: []byte("s"), UserSigningKey: []byte("u"), }, } logCryptoDiagnosticsCore(context.Background(), machine, "@bot:hs", "MYDEV", logger) out := buf.String() if !strings.Contains(out, "device info") { t.Error("expected device info log") } if !strings.Contains(out, "cross-signing public keys found") { t.Error("expected public keys log") } if !strings.Contains(out, "own device trust state") { t.Error("expected trust state log") } if !strings.Contains(out, "cross-signing private keys in store") { t.Error("expected private keys log") } } // --- SSSS key fetcher fakes for testing fetchCrossSigningKeysCore --- type fakeSSSSKeyVerifier struct { key *ssss.Key err error } func (f *fakeSSSSKeyVerifier) VerifyRecoveryKey(keyID, recoveryKey string) (*ssss.Key, error) { return f.key, f.err } type fakeSSSSKeyFetcher struct { keyID string verifier ssssKeyVerifier getErr error fetchErr error } func (f *fakeSSSSKeyFetcher) GetDefaultKeyData(ctx context.Context) (string, ssssKeyVerifier, error) { return f.keyID, f.verifier, f.getErr } func (f *fakeSSSSKeyFetcher) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error { return f.fetchErr } func TestFetchCrossSigningKeysCore_Success(t *testing.T) { fetcher := &fakeSSSSKeyFetcher{ keyID: "key1", verifier: &fakeSSSSKeyVerifier{key: &ssss.Key{ID: "key1"}}, } err := fetchCrossSigningKeysCore(context.Background(), fetcher, "valid-recovery-key") if err != nil { t.Fatalf("unexpected error: %v", err) } } func TestFetchCrossSigningKeysCore_GetDefaultKeyFails(t *testing.T) { fetcher := &fakeSSSSKeyFetcher{ getErr: errors.New("no default key"), } err := fetchCrossSigningKeysCore(context.Background(), fetcher, "any-key") if err == nil { t.Fatal("expected error") } if !strings.Contains(err.Error(), "get SSSS default key") { t.Errorf("unexpected error: %v", err) } } func TestFetchCrossSigningKeysCore_VerifyRecoveryKeyFails(t *testing.T) { fetcher := &fakeSSSSKeyFetcher{ keyID: "key1", verifier: &fakeSSSSKeyVerifier{err: errors.New("invalid recovery key")}, } err := fetchCrossSigningKeysCore(context.Background(), fetcher, "bad-key") if err == nil { t.Fatal("expected error") } if !strings.Contains(err.Error(), "verify recovery key") { t.Errorf("unexpected error: %v", err) } } func TestFetchCrossSigningKeysCore_FetchFromSSSSFails(t *testing.T) { fetcher := &fakeSSSSKeyFetcher{ keyID: "key1", verifier: &fakeSSSSKeyVerifier{key: &ssss.Key{ID: "key1"}}, fetchErr: errors.New("decryption failed"), } err := fetchCrossSigningKeysCore(context.Background(), fetcher, "valid-key") if err == nil { t.Fatal("expected error") } if !strings.Contains(err.Error(), "fetch cross-signing keys from SSSS") { t.Errorf("unexpected error: %v", err) } }