From 9398d0bd5de265b942613adb241b94dc77e0a13e Mon Sep 17 00:00:00 2001 From: harshabose Date: Fri, 2 May 2025 12:53:30 +0530 Subject: [PATCH] added aes256.go --- pkg/interceptor/chain.go | 1 + pkg/interceptor/interceptor.go | 18 ++++++++-- pkg/interceptor/message.go | 1 - pkg/message/message.go | 28 ++++++++++++--- pkg/message/registry.go | 6 ++-- pkg/middleware/encrypt/encryptor/aes256.go | 23 +++++++++++-- .../encrypt/encryptor/encryptor_messages.go | 34 ++++++++++++++++--- pkg/middleware/encrypt/interceptor.go | 5 +++ pkg/middleware/encrypt/state/state.go | 14 ++++++-- 9 files changed, 110 insertions(+), 20 deletions(-) diff --git a/pkg/interceptor/chain.go b/pkg/interceptor/chain.go index 1f501fc..9360a53 100644 --- a/pkg/interceptor/chain.go +++ b/pkg/interceptor/chain.go @@ -3,6 +3,7 @@ package interceptor import "github.com/harshabose/socket-comm/internal/util" type Chain struct { + NoOpInterceptor interceptors []Interceptor } diff --git a/pkg/interceptor/interceptor.go b/pkg/interceptor/interceptor.go index 0189366..00e9467 100644 --- a/pkg/interceptor/interceptor.go +++ b/pkg/interceptor/interceptor.go @@ -44,6 +44,12 @@ type Connection interface { } type Interceptor interface { + ID() string + + Ctx() context.Context + + GetMessageRegistry() message.Registry + BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error) Init(Connection) error @@ -78,10 +84,18 @@ func (f WriterFunc) Write(conn Connection, message message.Message) error { } type NoOpInterceptor struct { - ID string + iD string messageRegistry message.Registry Mutex sync.RWMutex - Ctx context.Context + ctx context.Context +} + +func (interceptor *NoOpInterceptor) Ctx() context.Context { + return interceptor.ctx +} + +func (interceptor *NoOpInterceptor) ID() string { + return interceptor.iD } func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry { diff --git a/pkg/interceptor/message.go b/pkg/interceptor/message.go index b0af086..5c84f24 100644 --- a/pkg/interceptor/message.go +++ b/pkg/interceptor/message.go @@ -55,7 +55,6 @@ type BaseMessage struct { message.BaseMessage } -// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) { bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg) if err != nil { diff --git a/pkg/message/message.go b/pkg/message/message.go index f2d7b84..a841e64 100644 --- a/pkg/message/message.go +++ b/pkg/message/message.go @@ -48,6 +48,10 @@ type Message interface { // GetProtocol returns the protocol identifier for this message GetProtocol() Protocol + GetNextPayload() (Payload, error) + + GetNextProtocol() Protocol + // GetNext retrieves the next message in the chain, if one exists // Returns nil, nil if there is no next message GetNext(Registry) (Message, error) @@ -86,8 +90,8 @@ type BaseMessage struct { // CURRENT OTHER FIELDS... // NEXT MESSAGE PROCESSOR - NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain - NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain + NextPayload Payload `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain + NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain } // GetProtocol returns this message's protocol identifier @@ -95,6 +99,22 @@ func (m *BaseMessage) GetProtocol() Protocol { return m.CurrentProtocol } +func (m *BaseMessage) GetNextPayload() (Payload, error) { + if m.NextProtocol == NoneProtocol { + return nil, nil + } + + if m.NextPayload == nil { + return nil, ErrNoPayload + } + + return m.NextPayload, nil +} + +func (m *BaseMessage) GetNextProtocol() Protocol { + return m.NextProtocol +} + // GetNext retrieves the next message in the chain, if one exists. // Returns nil, nil if NextProtocol is NoneProtocol. // Uses the provided Registry to create and unmarshal the next message. @@ -107,7 +127,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) { return nil, ErrNoPayload } - return registry.Unmarshal(m.NextProtocol, m.NextPayload) + return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload)) } // Marshal serializes the message to JSON format @@ -136,7 +156,7 @@ func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (Ba return BaseMessage{ CurrentProtocol: msg.GetProtocol(), CurrentHeader: NewV1Header(UnknownSender, UnknownReceiver), - NextPayload: inner, + NextPayload: Payload(inner), NextProtocol: nextProtocol, }, nil } diff --git a/pkg/message/registry.go b/pkg/message/registry.go index 6b5615d..9570553 100644 --- a/pkg/message/registry.go +++ b/pkg/message/registry.go @@ -27,7 +27,7 @@ type Registry interface { // Unmarshal creates and deserializes a message for a protocol // The provided data is parsed into the appropriate message type - Unmarshal(Protocol, json.RawMessage) (Message, error) + Unmarshal(Protocol, Payload) (Message, error) // UnmarshalRaw deserializes a message when the protocol is unknown // It first inspects the envelope to determine the protocol, then unmarshals accordingly @@ -103,7 +103,7 @@ func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) { // Unmarshal creates and deserializes a message for a protocol // The provided data is parsed into the appropriate message type // This method is thread-safe -func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) { +func (r *DefaultRegistry) Unmarshal(protocol Protocol, data Payload) (Message, error) { msg, err := r.Create(protocol) if err != nil { return nil, err @@ -120,7 +120,7 @@ func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Me // It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type // This method is particularly useful for handling incoming WebSocket messages // This method is thread-safe -func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) { +func (r *DefaultRegistry) UnmarshalRaw(data Payload) (Message, error) { var envelope Envelope if err := json.Unmarshal(data, &envelope); err != nil { return nil, fmt.Errorf("failed to extract protocol: %w", err) diff --git a/pkg/middleware/encrypt/encryptor/aes256.go b/pkg/middleware/encrypt/encryptor/aes256.go index ac8b4a2..44da387 100644 --- a/pkg/middleware/encrypt/encryptor/aes256.go +++ b/pkg/middleware/encrypt/encryptor/aes256.go @@ -62,15 +62,32 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error) return nil, fmt.Errorf("failed to generate nonce: %w", err) } + protocol := msg.GetProtocol() + data, err := msg.Marshal() if err != nil { return nil, fmt.Errorf("failed to marshal message: %w", err) } + + encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:]) + + return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID) } -func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) { - // TODO implement me - panic("implement me") +func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) { + m, ok := msg.(*EncryptedMessage) + if !ok { + return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE + } + + data, err := a.decryptor.Open(nil, m.Nonce[:], m.NextPayload, a.sessionID[:]) + if err != nil { + return nil, fmt.Errorf("decryption failed: %w", err) + } + + m.NextPayload = data + + return m, nil } func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) { diff --git a/pkg/middleware/encrypt/encryptor/encryptor_messages.go b/pkg/middleware/encrypt/encryptor/encryptor_messages.go index 195349f..b1d7f4b 100644 --- a/pkg/middleware/encrypt/encryptor/encryptor_messages.go +++ b/pkg/middleware/encrypt/encryptor/encryptor_messages.go @@ -1,14 +1,37 @@ package encryptor import ( + "time" + "github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/message" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" ) type EncryptedMessage struct { interceptor.BaseMessage + Nonce types.Nonce + Timestamp time.Time + SessionID types.EncryptionSessionID +} + +func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) { + msg := &EncryptedMessage{ + Nonce: nonce, + Timestamp: time.Now(), + SessionID: id, + } + + bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg) + if err != nil { + return nil, err + } + + msg.BaseMessage = bmsg + + return msg, nil } func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error { @@ -27,16 +50,17 @@ func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn intercep return encryptionerr.ErrInvalidInterceptor } - msg, err := ss.Decrypt(m) + decmsg, err := ss.Decrypt(m) if err != nil { return err } - // TODO: message.Registry is not implemented yet - decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload) - if err != nil { - return err + msg, ok := decmsg.(*EncryptedMessage) + if !ok { + return encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE } + m.NextPayload = msg.NextPayload // JUST MAKING SURE + return nil } diff --git a/pkg/middleware/encrypt/interceptor.go b/pkg/middleware/encrypt/interceptor.go index a36bdc7..a4b596d 100644 --- a/pkg/middleware/encrypt/interceptor.go +++ b/pkg/middleware/encrypt/interceptor.go @@ -81,6 +81,11 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept return err } + iMessage, ok := encrypted.(interceptor.Message) + if !ok { + return writer.Write(conn) + } + return writer.Write(conn, encrypted) }) } diff --git a/pkg/middleware/encrypt/state/state.go b/pkg/middleware/encrypt/state/state.go index d938809..7dbc441 100644 --- a/pkg/middleware/encrypt/state/state.go +++ b/pkg/middleware/encrypt/state/state.go @@ -63,14 +63,24 @@ func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID { return s.keyExchangeSessionID } +func (s *State) GetConnection() interceptor.Connection { + return s.connection +} + func (s *State) WriteMessage(msg interceptor.Message) error { s.mux.Lock() defer s.mux.Unlock() - - msg.SetReceiver(s.peerID) + // TODO: MANAGE CLIENT DISCOVERY return s.writer.Write(s.connection, msg) } +func (s *State) ReadMessage(msg interceptor.Message) error { + s.mux.Lock() + defer s.mux.Unlock() + + return s.reader.Read() +} + func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID { s.mux.Lock() defer s.mux.Unlock()