feat: add recovery key support for E2EE agents, including configuration and documentation updates
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/cryptohelper"
|
||||
"maunium.net/go/mautrix/crypto/ssss"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
@@ -108,6 +109,66 @@ func (c *Client) InitCrypto(ctx context.Context, storePath, pickleKeyHex, agentI
|
||||
return closer, nil
|
||||
}
|
||||
|
||||
// ssssKeyFetcher abstracts the SSSS + cross-signing key retrieval for testing.
|
||||
type ssssKeyFetcher interface {
|
||||
GetDefaultKeyData(ctx context.Context) (string, ssssKeyVerifier, error)
|
||||
FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error
|
||||
}
|
||||
|
||||
// ssssKeyVerifier abstracts the SSSS key metadata verification.
|
||||
type ssssKeyVerifier interface {
|
||||
VerifyRecoveryKey(keyID, recoveryKey string) (*ssss.Key, error)
|
||||
}
|
||||
|
||||
// olmSSSSFetcher adapts *crypto.OlmMachine to the ssssKeyFetcher interface.
|
||||
type olmSSSSFetcher struct {
|
||||
machine *crypto.OlmMachine
|
||||
}
|
||||
|
||||
func (o *olmSSSSFetcher) GetDefaultKeyData(ctx context.Context) (string, ssssKeyVerifier, error) {
|
||||
keyID, keyData, err := o.machine.SSSS.GetDefaultKeyData(ctx)
|
||||
return keyID, keyData, err
|
||||
}
|
||||
|
||||
func (o *olmSSSSFetcher) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error {
|
||||
return o.machine.FetchCrossSigningKeysFromSSSS(ctx, key)
|
||||
}
|
||||
|
||||
// FetchCrossSigningKeys retrieves cross-signing private keys from SSSS
|
||||
// (server-side secret storage) using the given base58 recovery key.
|
||||
// This allows the agent to sign its own device, eliminating the
|
||||
// "Encrypted by a device not verified by its owner" warning.
|
||||
func (c *Client) FetchCrossSigningKeys(ctx context.Context, recoveryKey string) error {
|
||||
wrapper, ok := c.raw.Crypto.(*mautrixCryptoWrapper)
|
||||
if !ok || wrapper == nil {
|
||||
return fmt.Errorf("crypto not initialized")
|
||||
}
|
||||
machine := wrapper.Machine()
|
||||
if machine == nil {
|
||||
return fmt.Errorf("olm machine not available")
|
||||
}
|
||||
return fetchCrossSigningKeysCore(ctx, &olmSSSSFetcher{machine}, recoveryKey)
|
||||
}
|
||||
|
||||
// fetchCrossSigningKeysCore contains the testable logic for SSSS key retrieval.
|
||||
func fetchCrossSigningKeysCore(ctx context.Context, fetcher ssssKeyFetcher, recoveryKey string) error {
|
||||
keyID, keyData, err := fetcher.GetDefaultKeyData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get SSSS default key: %w", err)
|
||||
}
|
||||
|
||||
key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify recovery key: %w", err)
|
||||
}
|
||||
|
||||
if err := fetcher.FetchCrossSigningKeysFromSSSS(ctx, key); err != nil {
|
||||
return fmt.Errorf("fetch cross-signing keys from SSSS: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initCryptoCore contains the testable logic: pickle key resolution, store
|
||||
// creation, and auto-recovery on stale crypto.db. Returns (closer, helper, err).
|
||||
func initCryptoCore(ctx context.Context, storePath, pickleKeyHex, accessToken, agentID string, initer cryptoIniter, logger *slog.Logger) (io.Closer, cryptoHelper, error) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/ssss"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
@@ -400,3 +401,86 @@ func TestLogCryptoDiagnosticsCore_FullHappyPath(t *testing.T) {
|
||||
t.Error("expected private keys log")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SSSS key fetcher fakes for testing fetchCrossSigningKeysCore ---
|
||||
|
||||
type fakeSSSSKeyVerifier struct {
|
||||
key *ssss.Key
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeSSSSKeyVerifier) VerifyRecoveryKey(keyID, recoveryKey string) (*ssss.Key, error) {
|
||||
return f.key, f.err
|
||||
}
|
||||
|
||||
type fakeSSSSKeyFetcher struct {
|
||||
keyID string
|
||||
verifier ssssKeyVerifier
|
||||
getErr error
|
||||
fetchErr error
|
||||
}
|
||||
|
||||
func (f *fakeSSSSKeyFetcher) GetDefaultKeyData(ctx context.Context) (string, ssssKeyVerifier, error) {
|
||||
return f.keyID, f.verifier, f.getErr
|
||||
}
|
||||
|
||||
func (f *fakeSSSSKeyFetcher) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error {
|
||||
return f.fetchErr
|
||||
}
|
||||
|
||||
func TestFetchCrossSigningKeysCore_Success(t *testing.T) {
|
||||
fetcher := &fakeSSSSKeyFetcher{
|
||||
keyID: "key1",
|
||||
verifier: &fakeSSSSKeyVerifier{key: &ssss.Key{ID: "key1"}},
|
||||
}
|
||||
|
||||
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "valid-recovery-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchCrossSigningKeysCore_GetDefaultKeyFails(t *testing.T) {
|
||||
fetcher := &fakeSSSSKeyFetcher{
|
||||
getErr: errors.New("no default key"),
|
||||
}
|
||||
|
||||
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "any-key")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "get SSSS default key") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchCrossSigningKeysCore_VerifyRecoveryKeyFails(t *testing.T) {
|
||||
fetcher := &fakeSSSSKeyFetcher{
|
||||
keyID: "key1",
|
||||
verifier: &fakeSSSSKeyVerifier{err: errors.New("invalid recovery key")},
|
||||
}
|
||||
|
||||
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "bad-key")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "verify recovery key") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchCrossSigningKeysCore_FetchFromSSSSFails(t *testing.T) {
|
||||
fetcher := &fakeSSSSKeyFetcher{
|
||||
keyID: "key1",
|
||||
verifier: &fakeSSSSKeyVerifier{key: &ssss.Key{ID: "key1"}},
|
||||
fetchErr: errors.New("decryption failed"),
|
||||
}
|
||||
|
||||
err := fetchCrossSigningKeysCore(context.Background(), fetcher, "valid-key")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "fetch cross-signing keys from SSSS") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user