feat: add testing support for crypto initialization and process management, including auto-recovery and filtering of go wrapper processes
This commit is contained in:
+89
-19
@@ -7,8 +7,10 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto/cryptohelper"
|
||||
@@ -44,6 +46,37 @@ func New(cfg config.MatrixCfg) (*Client, error) {
|
||||
return &Client{raw: raw, cfg: cfg}, nil
|
||||
}
|
||||
|
||||
// cryptoIniter abstracts crypto helper creation for testing.
|
||||
type cryptoIniter interface {
|
||||
newHelper(pickleKey []byte, storePath string) (cryptoHelper, error)
|
||||
}
|
||||
|
||||
// cryptoHelper abstracts the mautrix CryptoHelper for testing.
|
||||
type cryptoHelper interface {
|
||||
io.Closer
|
||||
Init(ctx context.Context) error
|
||||
SetAccountID(id string)
|
||||
}
|
||||
|
||||
// mautrixCryptoIniter is the real implementation using mautrix.
|
||||
type mautrixCryptoIniter struct {
|
||||
raw *mautrix.Client
|
||||
}
|
||||
|
||||
func (m *mautrixCryptoIniter) newHelper(pickleKey []byte, storePath string) (cryptoHelper, error) {
|
||||
h, err := cryptohelper.NewCryptoHelper(m.raw, pickleKey, storePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mautrixCryptoWrapper{h}, nil
|
||||
}
|
||||
|
||||
type mautrixCryptoWrapper struct {
|
||||
*cryptohelper.CryptoHelper
|
||||
}
|
||||
|
||||
func (w *mautrixCryptoWrapper) SetAccountID(id string) { w.DBAccountID = id }
|
||||
|
||||
// InitCrypto sets up end-to-end encryption using the mautrix cryptohelper.
|
||||
// storePath is the SQLite file path for crypto material (e.g. "./agents/<id>/data/crypto/crypto.db").
|
||||
// pickleKeyHex is a hex-encoded key for encrypting crypto material at rest. If empty,
|
||||
@@ -51,44 +84,81 @@ func New(cfg config.MatrixCfg) (*Client, error) {
|
||||
// agentID namespaces the crypto state within the database.
|
||||
// Returns an io.Closer that must be called on agent shutdown to flush the crypto store.
|
||||
func (c *Client) InitCrypto(ctx context.Context, storePath, pickleKeyHex, agentID string) (io.Closer, error) {
|
||||
// Resolve the actual device ID from the server — the value in config may differ
|
||||
// from what the registration process assigned.
|
||||
// Resolve the actual device ID from the server.
|
||||
whoami, err := c.raw.Whoami(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("whoami for crypto init: %w", err)
|
||||
}
|
||||
c.raw.DeviceID = whoami.DeviceID
|
||||
|
||||
// Use explicit pickle key if provided, otherwise derive from access token.
|
||||
var pickleKey []byte
|
||||
if pickleKeyHex != "" {
|
||||
pickleKey, err = hex.DecodeString(pickleKeyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode pickle_key_env: %w", err)
|
||||
}
|
||||
} else {
|
||||
sum := sha256.Sum256([]byte(c.raw.AccessToken))
|
||||
pickleKey = sum[:]
|
||||
initer := &mautrixCryptoIniter{raw: c.raw}
|
||||
closer, helper, err := initCryptoCore(ctx, storePath, pickleKeyHex, c.raw.AccessToken, agentID, initer, slog.Default())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Assign the real mautrix crypto helper — this satisfies mautrix.CryptoHelper.
|
||||
c.raw.Crypto = helper.(*mautrixCryptoWrapper)
|
||||
return closer, nil
|
||||
}
|
||||
|
||||
// initCryptoCore contains the testable logic: pickle key resolution, store
|
||||
// creation, and auto-recovery on stale crypto.db. Returns (closer, helper, err).
|
||||
func initCryptoCore(ctx context.Context, storePath, pickleKeyHex, accessToken, agentID string, initer cryptoIniter, logger *slog.Logger) (io.Closer, cryptoHelper, error) {
|
||||
pickleKey, err := resolvePickleKey(pickleKeyHex, accessToken)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(storePath), 0700); err != nil {
|
||||
return nil, fmt.Errorf("create crypto store dir: %w", err)
|
||||
return nil, nil, fmt.Errorf("create crypto store dir: %w", err)
|
||||
}
|
||||
|
||||
helper, err := cryptohelper.NewCryptoHelper(c.raw, pickleKey, storePath)
|
||||
helper, err := initHelper(ctx, initer, pickleKey, storePath, agentID)
|
||||
if err != nil && strings.Contains(err.Error(), "not marked as shared") {
|
||||
logger.Warn("crypto store inconsistent, attempting auto-recovery",
|
||||
"store", storePath,
|
||||
)
|
||||
if removeErr := os.Remove(storePath); removeErr != nil && !os.IsNotExist(removeErr) {
|
||||
return nil, nil, fmt.Errorf("auto-recovery: remove stale crypto.db: %w (original: %w)", removeErr, err)
|
||||
}
|
||||
helper, err = initHelper(ctx, initer, pickleKey, storePath, agentID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("e2ee init after auto-recovery: %w", err)
|
||||
}
|
||||
logger.Info("e2ee auto-recovery succeeded")
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("init e2ee: %w", err)
|
||||
}
|
||||
|
||||
return helper, helper, nil
|
||||
}
|
||||
|
||||
func initHelper(ctx context.Context, initer cryptoIniter, pickleKey []byte, storePath, agentID string) (cryptoHelper, error) {
|
||||
helper, err := initer.newHelper(pickleKey, storePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create crypto helper: %w", err)
|
||||
}
|
||||
helper.DBAccountID = agentID
|
||||
|
||||
helper.SetAccountID(agentID)
|
||||
if err := helper.Init(ctx); err != nil {
|
||||
return nil, fmt.Errorf("init e2ee: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.raw.Crypto = helper
|
||||
return helper, nil
|
||||
}
|
||||
|
||||
// resolvePickleKey decodes a hex key or derives one from the access token.
|
||||
func resolvePickleKey(pickleKeyHex, accessToken string) ([]byte, error) {
|
||||
if pickleKeyHex != "" {
|
||||
key, err := hex.DecodeString(pickleKeyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode pickle_key_env: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
sum := sha256.Sum256([]byte(accessToken))
|
||||
return sum[:], nil
|
||||
}
|
||||
|
||||
// SendText sends a plain-text message to a room.
|
||||
// If the room has E2EE enabled and crypto is initialized, the message is encrypted automatically.
|
||||
func (c *Client) SendText(ctx context.Context, roomID, text string) error {
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fakeCryptoHelper implements cryptoHelper for testing.
|
||||
type fakeCryptoHelper struct {
|
||||
initErr error
|
||||
closed bool
|
||||
accountID string
|
||||
}
|
||||
|
||||
func (f *fakeCryptoHelper) Init(ctx context.Context) error { return f.initErr }
|
||||
func (f *fakeCryptoHelper) Close() error { f.closed = true; return nil }
|
||||
func (f *fakeCryptoHelper) SetAccountID(id string) { f.accountID = id }
|
||||
|
||||
// fakeCryptoIniter implements cryptoIniter for testing.
|
||||
type fakeCryptoIniter struct {
|
||||
calls int
|
||||
helpers []*fakeCryptoHelper // one per call
|
||||
}
|
||||
|
||||
func (f *fakeCryptoIniter) newHelper(pickleKey []byte, storePath string) (cryptoHelper, error) {
|
||||
idx := f.calls
|
||||
f.calls++
|
||||
if idx < len(f.helpers) {
|
||||
return f.helpers[idx], nil
|
||||
}
|
||||
return &fakeCryptoHelper{}, nil
|
||||
}
|
||||
|
||||
func TestInitCryptoCore_Success(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
||||
|
||||
initer := &fakeCryptoIniter{
|
||||
helpers: []*fakeCryptoHelper{{initErr: nil}},
|
||||
}
|
||||
|
||||
closer, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if closer == nil {
|
||||
t.Error("expected non-nil closer")
|
||||
}
|
||||
if initer.calls != 1 {
|
||||
t.Errorf("expected 1 call to newHelper, got %d", initer.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitCryptoCore_AutoRecoveryOnStaleStore(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
||||
|
||||
// Create a stale crypto.db file.
|
||||
_ = os.MkdirAll(filepath.Dir(storePath), 0700)
|
||||
_ = os.WriteFile(storePath, []byte("stale"), 0o644)
|
||||
|
||||
// First call fails with "not marked as shared", second succeeds.
|
||||
initer := &fakeCryptoIniter{
|
||||
helpers: []*fakeCryptoHelper{
|
||||
{initErr: errors.New("device keys not marked as shared")},
|
||||
{initErr: nil},
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
||||
if err != nil {
|
||||
t.Fatalf("expected auto-recovery to succeed, got: %v", err)
|
||||
}
|
||||
if initer.calls != 2 {
|
||||
t.Errorf("expected 2 calls (fail + retry), got %d", initer.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitCryptoCore_AutoRecoveryFailsTwice(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
||||
|
||||
_ = os.MkdirAll(filepath.Dir(storePath), 0700)
|
||||
_ = os.WriteFile(storePath, []byte("stale"), 0o644)
|
||||
|
||||
initer := &fakeCryptoIniter{
|
||||
helpers: []*fakeCryptoHelper{
|
||||
{initErr: errors.New("not marked as shared")},
|
||||
{initErr: errors.New("still broken after recovery")},
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
||||
if err == nil {
|
||||
t.Fatal("expected error when recovery also fails")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "after auto-recovery") {
|
||||
t.Errorf("expected 'after auto-recovery' in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitCryptoCore_NonRecoverableError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
storePath := filepath.Join(dir, "crypto", "crypto.db")
|
||||
|
||||
initer := &fakeCryptoIniter{
|
||||
helpers: []*fakeCryptoHelper{
|
||||
{initErr: errors.New("connection refused")},
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := initCryptoCore(context.Background(), storePath, "", "fake-token", "test-agent", initer, slog.Default())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-recoverable failure")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "init e2ee") {
|
||||
t.Errorf("expected 'init e2ee' in error, got: %v", err)
|
||||
}
|
||||
// Should NOT have retried.
|
||||
if initer.calls != 1 {
|
||||
t.Errorf("expected 1 call (no retry for non-stale error), got %d", initer.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePickleKey_BadHex(t *testing.T) {
|
||||
_, err := resolvePickleKey("not-hex!", "token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid hex pickle key")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "decode pickle_key_env") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePickleKey_DeriveFromToken(t *testing.T) {
|
||||
key, err := resolvePickleKey("", "my-access-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
t.Errorf("expected 32-byte sha256 key, got %d bytes", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePickleKey_Explicit(t *testing.T) {
|
||||
hexKey := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
key, err := resolvePickleKey(hexKey, "ignored")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
t.Errorf("expected 32 bytes, got %d", len(key))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitHelper_SetsAccountID(t *testing.T) {
|
||||
helper := &fakeCryptoHelper{}
|
||||
initer := &fakeCryptoIniter{helpers: []*fakeCryptoHelper{helper}}
|
||||
|
||||
_, err := initHelper(context.Background(), initer, []byte("key"), "/fake", "my-agent")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if helper.accountID != "my-agent" {
|
||||
t.Errorf("expected accountID='my-agent', got '%s'", helper.accountID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user