From 29d5ca8ade200386ea457ecd435b5b37b6460eac Mon Sep 17 00:00:00 2001 From: harshabose Date: Fri, 2 May 2025 02:52:05 +0530 Subject: [PATCH] added NewBaseMessage function in message.go --- pkg/interceptor/interceptor.go | 25 +++-- pkg/middleware/encrypt/encryptor/aes256.go | 93 +++++++++++++++++++ .../encrypt/encryptor/encryptor_messages.go | 42 +++++++++ pkg/middleware/encrypt/encryptor/skipper.go | 64 +++++++++++++ pkg/middleware/encrypt/interceptor.go | 55 ++++++++++- .../encrypt/interfaces/encryptor.go | 18 ++-- .../encrypt/interfaces/keyexchange.go | 1 + .../encrypt/keyexchange/curve25519messages.go | 43 +++++++-- .../keyexchange/keyexchange_manager.go | 25 +++++ pkg/middleware/encrypt/state/state.go | 4 + 10 files changed, 345 insertions(+), 25 deletions(-) create mode 100644 pkg/middleware/encrypt/encryptor/aes256.go create mode 100644 pkg/middleware/encrypt/encryptor/encryptor_messages.go create mode 100644 pkg/middleware/encrypt/encryptor/skipper.go diff --git a/pkg/interceptor/interceptor.go b/pkg/interceptor/interceptor.go index 0a7c589..0189366 100644 --- a/pkg/interceptor/interceptor.go +++ b/pkg/interceptor/interceptor.go @@ -4,6 +4,8 @@ import ( "context" "io" "sync" + + "github.com/harshabose/socket-comm/pkg/message" ) type Registry struct { @@ -56,29 +58,34 @@ type Interceptor interface { } type Writer interface { - Write(conn Connection, message Message) error + Write(conn Connection, message message.Message) error } type Reader interface { - Read(conn Connection) (Message, error) + Read(conn Connection) (message.Message, error) } -type ReaderFunc func(conn Connection) (Message, error) +type ReaderFunc func(conn Connection) (message.Message, error) -func (f ReaderFunc) Read(conn Connection) (Message, error) { +func (f ReaderFunc) Read(conn Connection) (message.Message, error) { return f(conn) } -type WriterFunc func(conn Connection, message Message) error +type WriterFunc func(conn Connection, message message.Message) error -func (f WriterFunc) Write(conn Connection, message Message) error { +func (f WriterFunc) Write(conn Connection, message message.Message) error { return f(conn, message) } type NoOpInterceptor struct { - ID string - Mutex sync.RWMutex - Ctx context.Context + ID string + messageRegistry message.Registry + Mutex sync.RWMutex + Ctx context.Context +} + +func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry { + return interceptor.messageRegistry } func (interceptor *NoOpInterceptor) BindSocketConnection(_ Connection, _ Writer, _ Reader) (Writer, Reader, error) { diff --git a/pkg/middleware/encrypt/encryptor/aes256.go b/pkg/middleware/encrypt/encryptor/aes256.go new file mode 100644 index 0000000..ac8b4a2 --- /dev/null +++ b/pkg/middleware/encrypt/encryptor/aes256.go @@ -0,0 +1,93 @@ +package encryptor + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" + "sync" + + "github.com/harshabose/socket-comm/pkg/message" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" +) + +type AES256Encryptor struct { + encryptor cipher.AEAD + decryptor cipher.AEAD + sessionID types.EncryptionSessionID + mux sync.RWMutex +} + +func (a *AES256Encryptor) SetKeys(encryptorKey, decryptorKey types.Key) error { + a.mux.Lock() + defer a.mux.Unlock() + + // Setup encryption AEAD + encBlock, err := aes.NewCipher(encryptorKey[:]) + if err != nil { + return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err) + } + + encGCM, err := cipher.NewGCM(encBlock) + if err != nil { + return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err) + } + + // Setup decryption AEAD + decBlock, err := aes.NewCipher(decryptorKey[:]) + if err != nil { + return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err) + } + + decGCM, err := cipher.NewGCM(decBlock) + if err != nil { + return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err) + } + + a.encryptor = encGCM + a.decryptor = decGCM + + return nil +} + +func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error) { + a.mux.Lock() + defer a.mux.Unlock() + + nonce := types.Nonce{} + + if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + + data, err := msg.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal message: %w", err) + } +} + +func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) { + // TODO implement me + panic("implement me") +} + +func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) { + a.mux.Lock() + defer a.mux.Unlock() + + a.sessionID = id +} + +func (a *AES256Encryptor) Ready() bool { + a.mux.RLock() + defer a.mux.RUnlock() + + return a.encryptor != nil && a.decryptor != nil +} + +func (a *AES256Encryptor) Close() error { + // TODO implement me + panic("implement me") +} diff --git a/pkg/middleware/encrypt/encryptor/encryptor_messages.go b/pkg/middleware/encrypt/encryptor/encryptor_messages.go new file mode 100644 index 0000000..195349f --- /dev/null +++ b/pkg/middleware/encrypt/encryptor/encryptor_messages.go @@ -0,0 +1,42 @@ +package encryptor + +import ( + "github.com/harshabose/socket-comm/pkg/interceptor" + "github.com/harshabose/socket-comm/pkg/message" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" +) + +type EncryptedMessage struct { + interceptor.BaseMessage +} + +func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error { + i, ok := _i.(interfaces.CanGetState) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + s, err := i.GetState(conn) + if err != nil { + return err + } + + ss, ok := s.(interfaces.CanDecrypt) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + msg, err := ss.Decrypt(m) + if err != nil { + return err + } + + // TODO: message.Registry is not implemented yet + decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload) + if err != nil { + return err + } + + return nil +} diff --git a/pkg/middleware/encrypt/encryptor/skipper.go b/pkg/middleware/encrypt/encryptor/skipper.go new file mode 100644 index 0000000..ee7a045 --- /dev/null +++ b/pkg/middleware/encrypt/encryptor/skipper.go @@ -0,0 +1,64 @@ +package encryptor + +import ( + "github.com/harshabose/socket-comm/pkg/message" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" +) + +type SkipEncryptionChecker func(message.Message) bool + +// NewProtocolSkipChecker creates a checker that skips encryption for specific protocols +func NewProtocolSkipChecker(protocolsToSkip ...message.Protocol) SkipEncryptionChecker { + skipMap := make(map[message.Protocol]bool) + for _, protocol := range protocolsToSkip { + skipMap[protocol] = true + } + + return func(msg message.Message) bool { + return skipMap[msg.GetProtocol()] + } +} + +type SkipperEncryptor struct { + wrapped interfaces.Encryptor + skip SkipEncryptionChecker +} + +func NewSkipperEncryptor(wrapped interfaces.Encryptor, skip SkipEncryptionChecker) *SkipperEncryptor { + return &SkipperEncryptor{ + wrapped: wrapped, + skip: skip, + } +} + +func (e *SkipperEncryptor) SetSessionID(id types.EncryptionSessionID) { + e.wrapped.SetSessionID(id) +} + +func (e *SkipperEncryptor) SetKeys(encryptorKey, decryptorKey types.Key) error { + return e.wrapped.SetKeys(encryptorKey, decryptorKey) +} + +func (e *SkipperEncryptor) Encrypt(msg message.Message) (message.Message, error) { + if e.skip(msg) { + return msg, nil + } + + return e.wrapped.Encrypt(msg) +} + +func (e *SkipperEncryptor) Decrypt(msg message.Message) (message.Message, error) { + if !e.skip(msg) { + return e.wrapped.Decrypt(msg) + } + return msg, nil +} + +func (e *SkipperEncryptor) Ready() bool { + return e.wrapped.Ready() +} + +func (e *SkipperEncryptor) Close() error { + return e.wrapped.Close() +} diff --git a/pkg/middleware/encrypt/interceptor.go b/pkg/middleware/encrypt/interceptor.go index 4bd441e..a36bdc7 100644 --- a/pkg/middleware/encrypt/interceptor.go +++ b/pkg/middleware/encrypt/interceptor.go @@ -5,6 +5,7 @@ import ( "time" "github.com/harshabose/socket-comm/pkg/interceptor" + "github.com/harshabose/socket-comm/pkg/message" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" @@ -53,23 +54,71 @@ func (i *Interceptor) Init(connection interceptor.Connection) error { waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted) - return i.Process(waiter, s) + if err := i.Process(waiter, s); err != nil { + return err + } + + return i.keyExchangeManager.Finalise(s) } func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { + return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error { + s, err := i.GetState(conn) + if err != nil { + return err + } + ss, ok := s.(interfaces.CanEncrypt) + if !ok { + return err + } + + encrypted, err := ss.Encrypt(msg) + if err != nil { + if !s.GetConfig().RequireEncryption { + return writer.Write(conn, msg) + } + return err + } + + return writer.Write(conn, encrypted) + }) } func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader { + return interceptor.ReaderFunc(func(conn interceptor.Connection) (message.Message, error) { + msg, err := reader.Read(conn) + if err != nil { + return msg, err + } + s, err := i.GetState(conn) + if err != nil { + return nil, err + } + + ss, ok := s.(interfaces.CanDecrypt) + if !ok { + return msg, err + } + + m, err := ss.Decrypt(msg) + if err != nil { + return nil, err + } + + return m, nil + }) } func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) { - + // TODO: Implement full closing } func (i *Interceptor) Close() error { - + // TODO: Use UnBindSocketConnection to close all + // TODO: Close interceptor + return nil } func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) { diff --git a/pkg/middleware/encrypt/interfaces/encryptor.go b/pkg/middleware/encrypt/interfaces/encryptor.go index ffbc7a2..f193d98 100644 --- a/pkg/middleware/encrypt/interfaces/encryptor.go +++ b/pkg/middleware/encrypt/interfaces/encryptor.go @@ -16,19 +16,25 @@ type KeyGetter interface { GetKeys() (encKey, decKey types.Key, err error) } +type CanEncrypt interface { + // Encrypt encrypts a message between sender and receiver + Encrypt(message message.Message) (message.Message, error) +} + +type CanDecrypt interface { + // Decrypt decrypts an encrypted message in-place + Decrypt(message message.Message) (message.Message, error) +} + // Encryptor defines the interface for message encryption and decryption type Encryptor interface { KeySetter + CanEncrypt + CanDecrypt // SetSessionID sets the session identifier for this encryption session SetSessionID(id types.EncryptionSessionID) - // Encrypt encrypts a message between sender and receiver - Encrypt(senderID, receiverID string, message message.Message) (message.Message, error) - - // Decrypt decrypts an encrypted message in-place - Decrypt(message message.Message) error - // Ready checks if the encryptor is properly initialized and ready to use Ready() bool diff --git a/pkg/middleware/encrypt/interfaces/keyexchange.go b/pkg/middleware/encrypt/interfaces/keyexchange.go index 64ff9f2..55591e2 100644 --- a/pkg/middleware/encrypt/interfaces/keyexchange.go +++ b/pkg/middleware/encrypt/interfaces/keyexchange.go @@ -10,6 +10,7 @@ type ProtocolProcessor interface { type KeyExchangeManager interface { Init(state State, options ...ProtocolFactoryOption) error + Finalise(state State) error } type CanGetSessionState interface { diff --git a/pkg/middleware/encrypt/keyexchange/curve25519messages.go b/pkg/middleware/encrypt/keyexchange/curve25519messages.go index 5f6b9a4..b840d74 100644 --- a/pkg/middleware/encrypt/keyexchange/curve25519messages.go +++ b/pkg/middleware/encrypt/keyexchange/curve25519messages.go @@ -225,21 +225,50 @@ func (m *Done) ReadProcess(_i interceptor.Interceptor, connection interceptor.Co return pp.Process(m, s) } -func (m *Done) Process(protocol interfaces.Protocol, _s interfaces.State) error { - ss, ok := _s.(interfaces.KeySetter) - if !ok { - return encryptionerr.ErrInvalidInterceptor - } - +func (m *Done) Process(protocol interfaces.Protocol, s interfaces.State) error { p, ok := protocol.(*Curve25519Protocol) if !ok { return encryptionerr.ErrInvalidMessageType } - if err := ss.SetKeys(p.encKey, p.decKey); err != nil { + msg, err := NewDoneResponse() + if err != nil { + return err + } + + if err := s.WriteMessage(msg); err != nil { return err } p.state = types.SessionStateCompleted return nil } + +type DoneResponse struct { + Done +} + +func NewDoneResponse() (*DoneResponse, error) { + msg := &DoneResponse{ + Done: Done{ + Timestamp: time.Now(), + }, + } + bmsg, err := interceptor.NewBaseMessage(message.NoneProtocol, nil, msg) + if err != nil { + return nil, err + } + msg.BaseMessage = bmsg + + return msg, nil +} + +func (m *DoneResponse) Process(protocol interfaces.Protocol, s interfaces.State) error { + p, ok := protocol.(*Curve25519Protocol) + if !ok { + return encryptionerr.ErrInvalidMessageType + } + + p.state = types.SessionStateCompleted + return nil +} diff --git a/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go b/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go index fe8c3d8..24d4c5d 100644 --- a/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go +++ b/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go @@ -55,6 +55,31 @@ func (m *Manager) Init(s interfaces.State, options ...interfaces.ProtocolFactory return nil } +func (m *Manager) Finalise(s interfaces.State) error { + sessionID := s.GetKeyExchangeSessionID() + session, exists := m.sessions[sessionID] + if !exists { + return encryptionerr.ErrExchangeNotComplete + } + + ss, ok := s.(interfaces.KeySetter) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + p, ok := (session.protocol).(interfaces.KeyGetter) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + encKey, decKey, err := p.GetKeys() + if err != nil { + return err + } + + return ss.SetKeys(encKey, decKey) +} + func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error { session, exists := m.sessions[s.GetKeyExchangeSessionID()] if !exists { diff --git a/pkg/middleware/encrypt/state/state.go b/pkg/middleware/encrypt/state/state.go index cb79682..d938809 100644 --- a/pkg/middleware/encrypt/state/state.go +++ b/pkg/middleware/encrypt/state/state.go @@ -91,3 +91,7 @@ func (s *State) SetKeys(encKey, decKey types.Key) error { return s.encryptor.SetKeys(encKey, decKey) } + +func (s *State) Decrypt(msg message.Message) (message.Message, error) { + return s.encryptor.Decrypt(msg) +}