487 lines
13 KiB
Go
487 lines
13 KiB
Go
package matrix
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/crypto/ssss"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
// fakeCryptoHelper implements cryptoHelper for testing.
|
|
type fakeCryptoHelper struct {
|
|
initErr error
|
|
closed bool
|
|
accountID string
|
|
}
|
|
|
|
func (f *fakeCryptoHelper) Init(ctx context.Context) error { return f.initErr }
|
|
func (f *fakeCryptoHelper) Close() error { f.closed = true; return nil }
|
|
func (f *fakeCryptoHelper) SetAccountID(id string) { f.accountID = id }
|
|
|
|
// fakeCryptoIniter implements cryptoIniter for testing.
|
|
type fakeCryptoIniter struct {
|
|
calls int
|
|
helpers []*fakeCryptoHelper // one per call
|
|
}
|
|
|
|
func (f *fakeCryptoIniter) newHelper(pickleKey []byte, storePath string) (cryptoHelper, error) {
|
|
idx := f.calls
|
|
f.calls++
|
|
if idx < len(f.helpers) {
|
|
return f.helpers[idx], nil
|
|
}
|
|
return &fakeCryptoHelper{}, nil
|
|
}
|
|
|
|
func TestInitCryptoCore_Success(t *testing.T) {
|
|
dir := t.TempDir()
|
|
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
|
|
|
initer := &fakeCryptoIniter{
|
|
helpers: []*fakeCryptoHelper{{initErr: nil}},
|
|
}
|
|
|
|
closer, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if closer == nil {
|
|
t.Error("expected non-nil closer")
|
|
}
|
|
if initer.calls != 1 {
|
|
t.Errorf("expected 1 call to newHelper, got %d", initer.calls)
|
|
}
|
|
}
|
|
|
|
func TestInitCryptoCore_AutoRecoveryOnStaleStore(t *testing.T) {
|
|
dir := t.TempDir()
|
|
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
|
|
|
// Create a stale crypto.db file.
|
|
_ = os.MkdirAll(filepath.Dir(storePath), 0700)
|
|
_ = os.WriteFile(storePath, []byte("stale"), 0o644)
|
|
|
|
// First call fails with "not marked as shared", second succeeds.
|
|
initer := &fakeCryptoIniter{
|
|
helpers: []*fakeCryptoHelper{
|
|
{initErr: errors.New("device keys not marked as shared")},
|
|
{initErr: nil},
|
|
},
|
|
}
|
|
|
|
_, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
|
if err != nil {
|
|
t.Fatalf("expected auto-recovery to succeed, got: %v", err)
|
|
}
|
|
if initer.calls != 2 {
|
|
t.Errorf("expected 2 calls (fail + retry), got %d", initer.calls)
|
|
}
|
|
}
|
|
|
|
func TestInitCryptoCore_AutoRecoveryFailsTwice(t *testing.T) {
|
|
dir := t.TempDir()
|
|
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
|
|
|
_ = os.MkdirAll(filepath.Dir(storePath), 0700)
|
|
_ = os.WriteFile(storePath, []byte("stale"), 0o644)
|
|
|
|
initer := &fakeCryptoIniter{
|
|
helpers: []*fakeCryptoHelper{
|
|
{initErr: errors.New("not marked as shared")},
|
|
{initErr: errors.New("still broken after recovery")},
|
|
},
|
|
}
|
|
|
|
_, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
|
if err == nil {
|
|
t.Fatal("expected error when recovery also fails")
|
|
}
|
|
if !strings.Contains(err.Error(), "after auto-recovery") {
|
|
t.Errorf("expected 'after auto-recovery' in error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestInitCryptoCore_NonRecoverableError(t *testing.T) {
|
|
dir := t.TempDir()
|
|
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
|
|
|
initer := &fakeCryptoIniter{
|
|
helpers: []*fakeCryptoHelper{
|
|
{initErr: errors.New("connection refused")},
|
|
},
|
|
}
|
|
|
|
_, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
|
if err == nil {
|
|
t.Fatal("expected error for non-recoverable failure")
|
|
}
|
|
if !strings.Contains(err.Error(), "init e2ee") {
|
|
t.Errorf("expected 'init e2ee' in error, got: %v", err)
|
|
}
|
|
// Should NOT have retried.
|
|
if initer.calls != 1 {
|
|
t.Errorf("expected 1 call (no retry for non-stale error), got %d", initer.calls)
|
|
}
|
|
}
|
|
|
|
func TestResolvePickleKey_BadHex(t *testing.T) {
|
|
_, err := resolvePickleKey("not-hex!", "token")
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid hex pickle key")
|
|
}
|
|
if !strings.Contains(err.Error(), "decode pickle_key_env") {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestResolvePickleKey_DeriveFromToken(t *testing.T) {
|
|
key, err := resolvePickleKey("", "my-access-token")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(key) != 32 {
|
|
t.Errorf("expected 32-byte sha256 key, got %d bytes", len(key))
|
|
}
|
|
}
|
|
|
|
func TestResolvePickleKey_Explicit(t *testing.T) {
|
|
hexKey := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
key, err := resolvePickleKey(hexKey, "ignored")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(key) != 32 {
|
|
t.Errorf("expected 32 bytes, got %d", len(key))
|
|
}
|
|
}
|
|
|
|
func TestInitHelper_SetsAccountID(t *testing.T) {
|
|
helper := &fakeCryptoHelper{}
|
|
initer := &fakeCryptoIniter{helpers: []*fakeCryptoHelper{helper}}
|
|
|
|
_, err := initHelper(context.Background(), initer, []byte("key"), "/fake", "my-agent")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if helper.accountID != "my-agent" {
|
|
t.Errorf("expected accountID='my-agent', got '%s'", helper.accountID)
|
|
}
|
|
}
|
|
|
|
// --- diagMachine fake for testing diagnostics ---
|
|
|
|
type fakeDiagMachine struct {
|
|
pubKeys *crypto.CrossSigningPublicKeysCache
|
|
ownDevice *id.Device
|
|
seeds crypto.CrossSigningSeeds
|
|
seedsPanic bool // simulate ExportCrossSigningKeys panic
|
|
trustState id.TrustState
|
|
trustErr error
|
|
deviceTrusted bool
|
|
}
|
|
|
|
func (f *fakeDiagMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *crypto.CrossSigningPublicKeysCache {
|
|
return f.pubKeys
|
|
}
|
|
|
|
func (f *fakeDiagMachine) OwnIdentity() *id.Device {
|
|
return f.ownDevice
|
|
}
|
|
|
|
func (f *fakeDiagMachine) ExportCrossSigningKeys() crypto.CrossSigningSeeds {
|
|
if f.seedsPanic {
|
|
panic("nil pointer dereference")
|
|
}
|
|
return f.seeds
|
|
}
|
|
|
|
func (f *fakeDiagMachine) ResolveTrustContext(ctx context.Context, device *id.Device) (id.TrustState, error) {
|
|
return f.trustState, f.trustErr
|
|
}
|
|
|
|
func (f *fakeDiagMachine) IsDeviceTrusted(device *id.Device) bool {
|
|
return f.deviceTrusted
|
|
}
|
|
|
|
// testLogger returns a logger that writes to a buffer for assertions.
|
|
func testLogger(buf *bytes.Buffer) *slog.Logger {
|
|
return slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
}
|
|
|
|
func TestLogCryptoDiagnosticsCore_NilOwnDevice(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
machine := &fakeDiagMachine{
|
|
pubKeys: &crypto.CrossSigningPublicKeysCache{MasterKey: "abc", SelfSigningKey: "def", UserSigningKey: "ghi"},
|
|
ownDevice: nil, // nil device — was causing panic before the fix
|
|
seeds: crypto.CrossSigningSeeds{MasterKey: []byte("x"), SelfSigningKey: []byte("y"), UserSigningKey: []byte("z")},
|
|
}
|
|
|
|
logCryptoDiagnosticsCore(context.Background(), machine, "@bot:test", "DEVICE1", logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "own device identity is nil") {
|
|
t.Errorf("expected warning about nil device identity, got:\n%s", out)
|
|
}
|
|
}
|
|
|
|
func TestLogCryptoDiagnosticsCore_NilPublicKeys(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
machine := &fakeDiagMachine{
|
|
pubKeys: nil, // no cross-signing public keys
|
|
ownDevice: nil,
|
|
seeds: crypto.CrossSigningSeeds{},
|
|
}
|
|
|
|
logCryptoDiagnosticsCore(context.Background(), machine, "@bot:test", "DEVICE1", logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "NO cross-signing public keys found") {
|
|
t.Errorf("expected warning about missing public keys, got:\n%s", out)
|
|
}
|
|
}
|
|
|
|
func TestLogCrossSigningSeeds_PanicRecovery(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
machine := &fakeDiagMachine{seedsPanic: true}
|
|
|
|
// Must not panic — should recover gracefully.
|
|
logCrossSigningSeeds(machine, logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "cross-signing private keys not available") {
|
|
t.Errorf("expected recovery warning, got:\n%s", out)
|
|
}
|
|
}
|
|
|
|
func TestLogCrossSigningSeeds_AllPresent(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
machine := &fakeDiagMachine{
|
|
seeds: crypto.CrossSigningSeeds{
|
|
MasterKey: []byte("master"),
|
|
SelfSigningKey: []byte("self"),
|
|
UserSigningKey: []byte("user"),
|
|
},
|
|
}
|
|
|
|
logCrossSigningSeeds(machine, logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "cross-signing private keys in store") {
|
|
t.Errorf("expected info about keys in store, got:\n%s", out)
|
|
}
|
|
if strings.Contains(out, "self-signing private key NOT in store") {
|
|
t.Error("should not warn when self-signing key is present")
|
|
}
|
|
}
|
|
|
|
func TestLogCrossSigningSeeds_MissingSelfSigning(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
machine := &fakeDiagMachine{
|
|
seeds: crypto.CrossSigningSeeds{
|
|
MasterKey: []byte("master"),
|
|
// SelfSigningKey intentionally missing
|
|
},
|
|
}
|
|
|
|
logCrossSigningSeeds(machine, logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "self-signing private key NOT in store") {
|
|
t.Errorf("expected warning about missing self-signing key, got:\n%s", out)
|
|
}
|
|
}
|
|
|
|
func TestLogDeviceTrust_Trusted(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
device := &id.Device{DeviceID: "DEV1"}
|
|
machine := &fakeDiagMachine{
|
|
trustState: id.TrustStateCrossSignedTOFU,
|
|
deviceTrusted: true,
|
|
}
|
|
|
|
logDeviceTrust(context.Background(), machine, device, logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "own device trust state") {
|
|
t.Errorf("expected trust state log, got:\n%s", out)
|
|
}
|
|
if strings.Contains(out, "device is NOT cross-signed") {
|
|
t.Error("should not warn for trusted device")
|
|
}
|
|
}
|
|
|
|
func TestLogDeviceTrust_Untrusted(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
device := &id.Device{DeviceID: "DEV2"}
|
|
machine := &fakeDiagMachine{
|
|
trustState: id.TrustStateUnset,
|
|
deviceTrusted: false,
|
|
}
|
|
|
|
logDeviceTrust(context.Background(), machine, device, logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "device is NOT cross-signed") {
|
|
t.Errorf("expected cross-sign warning, got:\n%s", out)
|
|
}
|
|
}
|
|
|
|
func TestLogDeviceTrust_ResolveTrustError(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
device := &id.Device{DeviceID: "DEV3"}
|
|
machine := &fakeDiagMachine{
|
|
trustErr: errors.New("crypto store unavailable"),
|
|
}
|
|
|
|
logDeviceTrust(context.Background(), machine, device, logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "failed to resolve device trust") {
|
|
t.Errorf("expected trust resolve error, got:\n%s", out)
|
|
}
|
|
}
|
|
|
|
func TestLogCryptoDiagnosticsCore_FullHappyPath(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := testLogger(&buf)
|
|
|
|
machine := &fakeDiagMachine{
|
|
pubKeys: &crypto.CrossSigningPublicKeysCache{
|
|
MasterKey: "masterkey123",
|
|
SelfSigningKey: "selfkey456",
|
|
UserSigningKey: "userkey789",
|
|
},
|
|
ownDevice: &id.Device{DeviceID: "MYDEV"},
|
|
trustState: id.TrustStateCrossSignedTOFU,
|
|
deviceTrusted: true,
|
|
seeds: crypto.CrossSigningSeeds{
|
|
MasterKey: []byte("m"),
|
|
SelfSigningKey: []byte("s"),
|
|
UserSigningKey: []byte("u"),
|
|
},
|
|
}
|
|
|
|
logCryptoDiagnosticsCore(context.Background(), machine, "@bot:hs", "MYDEV", logger)
|
|
|
|
out := buf.String()
|
|
if !strings.Contains(out, "device info") {
|
|
t.Error("expected device info log")
|
|
}
|
|
if !strings.Contains(out, "cross-signing public keys found") {
|
|
t.Error("expected public keys log")
|
|
}
|
|
if !strings.Contains(out, "own device trust state") {
|
|
t.Error("expected trust state log")
|
|
}
|
|
if !strings.Contains(out, "cross-signing private keys in store") {
|
|
t.Error("expected private keys log")
|
|
}
|
|
}
|
|
|
|
// --- SSSS key fetcher fakes for testing fetchCrossSigningKeysCore ---
|
|
|
|
type fakeSSSSKeyVerifier struct {
|
|
key *ssss.Key
|
|
err error
|
|
}
|
|
|
|
func (f *fakeSSSSKeyVerifier) VerifyRecoveryKey(keyID, recoveryKey string) (*ssss.Key, error) {
|
|
return f.key, f.err
|
|
}
|
|
|
|
type fakeSSSSKeyFetcher struct {
|
|
keyID string
|
|
verifier ssssKeyVerifier
|
|
getErr error
|
|
fetchErr error
|
|
}
|
|
|
|
func (f *fakeSSSSKeyFetcher) GetDefaultKeyData(ctx context.Context) (string, ssssKeyVerifier, error) {
|
|
return f.keyID, f.verifier, f.getErr
|
|
}
|
|
|
|
func (f *fakeSSSSKeyFetcher) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error {
|
|
return f.fetchErr
|
|
}
|
|
|
|
func TestFetchCrossSigningKeysCore_Success(t *testing.T) {
|
|
fetcher := &fakeSSSSKeyFetcher{
|
|
keyID: "key1",
|
|
verifier: &fakeSSSSKeyVerifier{key: &ssss.Key{ID: "key1"}},
|
|
}
|
|
|
|
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "valid-recovery-key")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestFetchCrossSigningKeysCore_GetDefaultKeyFails(t *testing.T) {
|
|
fetcher := &fakeSSSSKeyFetcher{
|
|
getErr: errors.New("no default key"),
|
|
}
|
|
|
|
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "any-key")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "get SSSS default key") {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestFetchCrossSigningKeysCore_VerifyRecoveryKeyFails(t *testing.T) {
|
|
fetcher := &fakeSSSSKeyFetcher{
|
|
keyID: "key1",
|
|
verifier: &fakeSSSSKeyVerifier{err: errors.New("invalid recovery key")},
|
|
}
|
|
|
|
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "bad-key")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "verify recovery key") {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestFetchCrossSigningKeysCore_FetchFromSSSSFails(t *testing.T) {
|
|
fetcher := &fakeSSSSKeyFetcher{
|
|
keyID: "key1",
|
|
verifier: &fakeSSSSKeyVerifier{key: &ssss.Key{ID: "key1"}},
|
|
fetchErr: errors.New("decryption failed"),
|
|
}
|
|
|
|
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "valid-key")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "fetch cross-signing keys from SSSS") {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|