diff --git a/golibdave/golibdave.go b/golibdave/golibdave.go index 857625d..4de12bb 100644 --- a/golibdave/golibdave.go +++ b/golibdave/golibdave.go @@ -192,6 +192,10 @@ func (s *session) prepareEpoch(epoch int, protocolVersion uint16) { return } + s.logger.Warn("prepareEpoch: resetting MLS session via Init", + slog.Int("epoch", epoch), + slog.Int("protocol_version", int(protocolVersion)), + ) s.session.Init(protocolVersion, uint64(s.channelID), string(s.selfUserID)) } @@ -228,17 +232,33 @@ func (s *session) setupKeyRatchetForUser(userID godave.UserID, protocolVersion u disabled := protocolVersion == disabledProtocolVersion if userID == s.selfUserID { + kr := s.session.GetKeyRatchet(string(userID)) + if !disabled && kr == nil { + s.logger.Warn("nil key ratchet for self after GetKeyRatchet", + slog.String("user_id", string(userID)), + slog.Int("protocol_version", int(protocolVersion)), + ) + return + } s.encryptor.SetPassthroughMode(disabled) if !disabled { - s.encryptor.SetKeyRatchet(s.session.GetKeyRatchet(string(userID))) + s.encryptor.SetKeyRatchet(kr) } return } decryptor := s.decryptors[userID] + kr := s.session.GetKeyRatchet(string(userID)) + if !disabled && kr == nil { + s.logger.Warn("nil key ratchet for user after GetKeyRatchet", + slog.String("user_id", string(userID)), + slog.Int("protocol_version", int(protocolVersion)), + ) + return + } decryptor.TransitionToPassthroughMode(disabled) if !disabled { - decryptor.TransitionToKeyRatchet(s.session.GetKeyRatchet(string(userID))) + decryptor.TransitionToKeyRatchet(kr) } } diff --git a/libdave/key_ratchet.go b/libdave/key_ratchet.go index c8d3b5f..2cb5e42 100644 --- a/libdave/key_ratchet.go +++ b/libdave/key_ratchet.go @@ -11,6 +11,10 @@ type KeyRatchet struct { } func newKeyRatchet(handle keyRatchetHandle) *KeyRatchet { + if handle == nil { + return nil + } + keyRatchet := &KeyRatchet{handle: handle} runtime.SetFinalizer(keyRatchet, func(k *KeyRatchet) {