From 162c4108c2c0b5579462ef98b1ab4ed17554aa9d Mon Sep 17 00:00:00 2001 From: harshabose Date: Fri, 2 May 2025 19:31:13 +0530 Subject: [PATCH] added InterceptSocketWriter and InterceptSocketReader functions on encryption interceptor --- pkg/interceptor/message.go | 2 +- pkg/message/message.go | 24 ++++++-- pkg/message/registry.go | 10 ++++ pkg/middleware/encrypt/encryptor/aes256.go | 19 ++++--- .../encrypt/encryptor/encryptor_messages.go | 43 +++++++++++--- pkg/middleware/encrypt/interceptor.go | 57 ++++++++----------- 6 files changed, 101 insertions(+), 54 deletions(-) diff --git a/pkg/interceptor/message.go b/pkg/interceptor/message.go index 5c84f24..9fb480a 100644 --- a/pkg/interceptor/message.go +++ b/pkg/interceptor/message.go @@ -55,7 +55,7 @@ type BaseMessage struct { message.BaseMessage } -func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) { +func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Marshallable, msg Message) (BaseMessage, error) { bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg) if err != nil { return BaseMessage{}, nil diff --git a/pkg/message/message.go b/pkg/message/message.go index a841e64..d71a6c1 100644 --- a/pkg/message/message.go +++ b/pkg/message/message.go @@ -41,6 +41,20 @@ const ( UnknownReceiver Receiver = "unknown.receiver" ) +type Marshallable interface { + // Marshal serializes the message to JSON format + Marshal() ([]byte, error) +} + +type Unmarshallable interface { + // Unmarshal deserializes the message from JSON format + Unmarshal([]byte) error +} + +func (p Payload) Marshal() ([]byte, error) { + return p, nil +} + // Message defines the interface that all message types must implement. // It provides methods for protocol identification, serialization, and // message nesting/unwrapping. @@ -56,11 +70,9 @@ type Message interface { // Returns nil, nil if there is no next message GetNext(Registry) (Message, error) - // Marshal serializes the message to JSON format - Marshal() ([]byte, error) + Marshallable - // Unmarshal deserializes the message from JSON format - Unmarshal([]byte) error + Unmarshallable } // Header contains common metadata for all messages @@ -127,7 +139,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) { return nil, ErrNoPayload } - return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload)) + return registry.Unmarshal(m.NextProtocol, m.NextPayload) } // Marshal serializes the message to JSON format @@ -140,7 +152,7 @@ func (m *BaseMessage) Unmarshal(data []byte) error { return json.Unmarshal(data, m) } -func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (BaseMessage, error) { +func NewBaseMessage(nextProtocol Protocol, nextPayload Marshallable, msg Message) (BaseMessage, error) { var inner json.RawMessage = nil if nextPayload != nil { if nextProtocol == NoneProtocol { diff --git a/pkg/message/registry.go b/pkg/message/registry.go index 9570553..3455dab 100644 --- a/pkg/message/registry.go +++ b/pkg/message/registry.go @@ -21,6 +21,8 @@ type Registry interface { // Returns an error if the protocol is already registered Register(Protocol, Factory) error + Check(protocol Protocol) bool + // Create instantiates a new message for the given protocol // Returns an error if the protocol is not registered Create(Protocol) (Message, error) @@ -65,6 +67,14 @@ func NewRegistry() *DefaultRegistry { } } +func (r *DefaultRegistry) Check(protocol Protocol) bool { + r.mux.RLock() + defer r.mux.RUnlock() + + _, exists := r.factories[protocol] + return exists +} + // Register adds a message factory for a protocol // Returns an error if the protocol is already registered // This method is thread-safe diff --git a/pkg/middleware/encrypt/encryptor/aes256.go b/pkg/middleware/encrypt/encryptor/aes256.go index 44da387..d13d573 100644 --- a/pkg/middleware/encrypt/encryptor/aes256.go +++ b/pkg/middleware/encrypt/encryptor/aes256.go @@ -56,25 +56,28 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error) a.mux.Lock() defer a.mux.Unlock() + m, ok := msg.(*EncryptedMessage) + if !ok { + return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE + } + nonce := types.Nonce{} if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { return nil, fmt.Errorf("failed to generate nonce: %w", err) } - protocol := msg.GetProtocol() + encryptedData := a.encryptor.Seal(nil, nonce[:], m.NextPayload, a.sessionID[:]) - data, err := msg.Marshal() - if err != nil { - return nil, fmt.Errorf("failed to marshal message: %w", err) - } + m.NextPayload = encryptedData - encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:]) - - return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID) + return m, nil } func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) { + a.mux.Lock() + defer a.mux.Unlock() + m, ok := msg.(*EncryptedMessage) if !ok { return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE diff --git a/pkg/middleware/encrypt/encryptor/encryptor_messages.go b/pkg/middleware/encrypt/encryptor/encryptor_messages.go index b1d7f4b..8daf617 100644 --- a/pkg/middleware/encrypt/encryptor/encryptor_messages.go +++ b/pkg/middleware/encrypt/encryptor/encryptor_messages.go @@ -17,21 +17,50 @@ type EncryptedMessage struct { SessionID types.EncryptionSessionID } -func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) { - msg := &EncryptedMessage{ - Nonce: nonce, +func NewEncryptedMessage(msg message.Message) (*EncryptedMessage, error) { + em := &EncryptedMessage{ Timestamp: time.Now(), - SessionID: id, } - bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg) + bmsg, err := interceptor.NewBaseMessage(msg.GetProtocol(), msg, em) if err != nil { return nil, err } - msg.BaseMessage = bmsg + em.BaseMessage = bmsg - return msg, nil + return em, nil +} + +func (m *EncryptedMessage) WriteProcess(_i interceptor.Interceptor, connection interceptor.Connection) error { + i, ok := _i.(interfaces.CanGetState) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + s, err := i.GetState(connection) + if err != nil { + return err + } + + ss, ok := s.(interfaces.CanEncrypt) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + encmsg, err := ss.Encrypt(m) + if err != nil { + return err + } + + msg, ok := encmsg.(*EncryptedMessage) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + m.NextPayload = msg.NextPayload + + return nil } func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error { diff --git a/pkg/middleware/encrypt/interceptor.go b/pkg/middleware/encrypt/interceptor.go index a4b596d..aae8caa 100644 --- a/pkg/middleware/encrypt/interceptor.go +++ b/pkg/middleware/encrypt/interceptor.go @@ -8,6 +8,7 @@ import ( "github.com/harshabose/socket-comm/pkg/message" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider" @@ -17,15 +18,16 @@ import ( type Interceptor struct { interceptor.NoOpInterceptor - nonceValidator NonceValidator - keyExchangeManager interfaces.KeyExchangeManager - keyProvider keyprovider.KeyProvider - stateManager interfaces.StateManager - config config.Config + localMessageRegistry message.Registry + nonceValidator NonceValidator + keyExchangeManager interfaces.KeyExchangeManager + keyProvider keyprovider.KeyProvider + stateManager interfaces.StateManager + config config.Config } func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) { - ctx, cancel := context.WithCancel(i.Ctx) + ctx, cancel := context.WithCancel(i.Ctx()) newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader) if err != nil { @@ -49,7 +51,7 @@ func (i *Interceptor) Init(connection interceptor.Connection) error { return err } - ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second) + ctx, cancel := context.WithTimeout(i.Ctx(), 10*time.Second) defer cancel() waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted) @@ -63,30 +65,20 @@ func (i *Interceptor) Init(connection interceptor.Connection) error { func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error { - s, err := i.GetState(conn) + if i.localMessageRegistry.Check(msg.GetProtocol()) { + return writer.Write(conn, msg) + } + + m, err := encryptor.NewEncryptedMessage(msg) if err != nil { return err } - ss, ok := s.(interfaces.CanEncrypt) - if !ok { + if err := m.WriteProcess(i, conn); err != nil { return err } - encrypted, err := ss.Encrypt(msg) - if err != nil { - if !s.GetConfig().RequireEncryption { - return writer.Write(conn, msg) - } - return err - } - - iMessage, ok := encrypted.(interceptor.Message) - if !ok { - return writer.Write(conn) - } - - return writer.Write(conn, encrypted) + return writer.Write(conn, m) }) } @@ -97,22 +89,23 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept return msg, err } - s, err := i.GetState(conn) - if err != nil { - return nil, err + if !i.localMessageRegistry.Check(msg.GetProtocol()) { + if !i.config.RequireEncryption { + return msg, nil + } + return nil, encryptionerr.ErrInvalidInterceptor } - ss, ok := s.(interfaces.CanDecrypt) + m, ok := msg.(interceptor.Message) if !ok { - return msg, err + return nil, encryptionerr.ErrInvalidInterceptor } - m, err := ss.Decrypt(msg) - if err != nil { + if err := m.ReadProcess(i, conn); err != nil { return nil, err } - return m, nil + return m.GetNext(i.GetMessageRegistry()) }) }