added aes256.go

This commit is contained in:
harshabose
2025-05-02 12:53:30 +05:30
parent 29d5ca8ade
commit 9398d0bd5d
9 changed files with 110 additions and 20 deletions

View File

@@ -3,6 +3,7 @@ package interceptor
import "github.com/harshabose/socket-comm/internal/util" import "github.com/harshabose/socket-comm/internal/util"
type Chain struct { type Chain struct {
NoOpInterceptor
interceptors []Interceptor interceptors []Interceptor
} }

View File

@@ -44,6 +44,12 @@ type Connection interface {
} }
type Interceptor interface { type Interceptor interface {
ID() string
Ctx() context.Context
GetMessageRegistry() message.Registry
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error) BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
Init(Connection) error Init(Connection) error
@@ -78,10 +84,18 @@ func (f WriterFunc) Write(conn Connection, message message.Message) error {
} }
type NoOpInterceptor struct { type NoOpInterceptor struct {
ID string iD string
messageRegistry message.Registry messageRegistry message.Registry
Mutex sync.RWMutex Mutex sync.RWMutex
Ctx context.Context ctx context.Context
}
func (interceptor *NoOpInterceptor) Ctx() context.Context {
return interceptor.ctx
}
func (interceptor *NoOpInterceptor) ID() string {
return interceptor.iD
} }
func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry { func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry {

View File

@@ -55,7 +55,6 @@ type BaseMessage struct {
message.BaseMessage message.BaseMessage
} }
// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) { func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg) bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
if err != nil { if err != nil {

View File

@@ -48,6 +48,10 @@ type Message interface {
// GetProtocol returns the protocol identifier for this message // GetProtocol returns the protocol identifier for this message
GetProtocol() Protocol GetProtocol() Protocol
GetNextPayload() (Payload, error)
GetNextProtocol() Protocol
// GetNext retrieves the next message in the chain, if one exists // GetNext retrieves the next message in the chain, if one exists
// Returns nil, nil if there is no next message // Returns nil, nil if there is no next message
GetNext(Registry) (Message, error) GetNext(Registry) (Message, error)
@@ -86,8 +90,8 @@ type BaseMessage struct {
// CURRENT OTHER FIELDS... // CURRENT OTHER FIELDS...
// NEXT MESSAGE PROCESSOR // NEXT MESSAGE PROCESSOR
NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain NextPayload Payload `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
} }
// GetProtocol returns this message's protocol identifier // GetProtocol returns this message's protocol identifier
@@ -95,6 +99,22 @@ func (m *BaseMessage) GetProtocol() Protocol {
return m.CurrentProtocol return m.CurrentProtocol
} }
func (m *BaseMessage) GetNextPayload() (Payload, error) {
if m.NextProtocol == NoneProtocol {
return nil, nil
}
if m.NextPayload == nil {
return nil, ErrNoPayload
}
return m.NextPayload, nil
}
func (m *BaseMessage) GetNextProtocol() Protocol {
return m.NextProtocol
}
// GetNext retrieves the next message in the chain, if one exists. // GetNext retrieves the next message in the chain, if one exists.
// Returns nil, nil if NextProtocol is NoneProtocol. // Returns nil, nil if NextProtocol is NoneProtocol.
// Uses the provided Registry to create and unmarshal the next message. // Uses the provided Registry to create and unmarshal the next message.
@@ -107,7 +127,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
return nil, ErrNoPayload return nil, ErrNoPayload
} }
return registry.Unmarshal(m.NextProtocol, m.NextPayload) return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload))
} }
// Marshal serializes the message to JSON format // Marshal serializes the message to JSON format
@@ -136,7 +156,7 @@ func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (Ba
return BaseMessage{ return BaseMessage{
CurrentProtocol: msg.GetProtocol(), CurrentProtocol: msg.GetProtocol(),
CurrentHeader: NewV1Header(UnknownSender, UnknownReceiver), CurrentHeader: NewV1Header(UnknownSender, UnknownReceiver),
NextPayload: inner, NextPayload: Payload(inner),
NextProtocol: nextProtocol, NextProtocol: nextProtocol,
}, nil }, nil
} }

View File

@@ -27,7 +27,7 @@ type Registry interface {
// Unmarshal creates and deserializes a message for a protocol // Unmarshal creates and deserializes a message for a protocol
// The provided data is parsed into the appropriate message type // The provided data is parsed into the appropriate message type
Unmarshal(Protocol, json.RawMessage) (Message, error) Unmarshal(Protocol, Payload) (Message, error)
// UnmarshalRaw deserializes a message when the protocol is unknown // UnmarshalRaw deserializes a message when the protocol is unknown
// It first inspects the envelope to determine the protocol, then unmarshals accordingly // It first inspects the envelope to determine the protocol, then unmarshals accordingly
@@ -103,7 +103,7 @@ func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) {
// Unmarshal creates and deserializes a message for a protocol // Unmarshal creates and deserializes a message for a protocol
// The provided data is parsed into the appropriate message type // The provided data is parsed into the appropriate message type
// This method is thread-safe // This method is thread-safe
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) { func (r *DefaultRegistry) Unmarshal(protocol Protocol, data Payload) (Message, error) {
msg, err := r.Create(protocol) msg, err := r.Create(protocol)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -120,7 +120,7 @@ func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Me
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type // It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
// This method is particularly useful for handling incoming WebSocket messages // This method is particularly useful for handling incoming WebSocket messages
// This method is thread-safe // This method is thread-safe
func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) { func (r *DefaultRegistry) UnmarshalRaw(data Payload) (Message, error) {
var envelope Envelope var envelope Envelope
if err := json.Unmarshal(data, &envelope); err != nil { if err := json.Unmarshal(data, &envelope); err != nil {
return nil, fmt.Errorf("failed to extract protocol: %w", err) return nil, fmt.Errorf("failed to extract protocol: %w", err)

View File

@@ -62,15 +62,32 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
return nil, fmt.Errorf("failed to generate nonce: %w", err) return nil, fmt.Errorf("failed to generate nonce: %w", err)
} }
protocol := msg.GetProtocol()
data, err := msg.Marshal() data, err := msg.Marshal()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal message: %w", err) return nil, fmt.Errorf("failed to marshal message: %w", err)
} }
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:])
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
} }
func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) { func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
// TODO implement me m, ok := msg.(*EncryptedMessage)
panic("implement me") if !ok {
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
}
data, err := a.decryptor.Open(nil, m.Nonce[:], m.NextPayload, a.sessionID[:])
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}
m.NextPayload = data
return m, nil
} }
func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) { func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) {

View File

@@ -1,14 +1,37 @@
package encryptor package encryptor
import ( import (
"time"
"github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message" "github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
) )
type EncryptedMessage struct { type EncryptedMessage struct {
interceptor.BaseMessage interceptor.BaseMessage
Nonce types.Nonce
Timestamp time.Time
SessionID types.EncryptionSessionID
}
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) {
msg := &EncryptedMessage{
Nonce: nonce,
Timestamp: time.Now(),
SessionID: id,
}
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg)
if err != nil {
return nil, err
}
msg.BaseMessage = bmsg
return msg, nil
} }
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error { func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
@@ -27,16 +50,17 @@ func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn intercep
return encryptionerr.ErrInvalidInterceptor return encryptionerr.ErrInvalidInterceptor
} }
msg, err := ss.Decrypt(m) decmsg, err := ss.Decrypt(m)
if err != nil { if err != nil {
return err return err
} }
// TODO: message.Registry is not implemented yet msg, ok := decmsg.(*EncryptedMessage)
decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload) if !ok {
if err != nil { return encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
return err
} }
m.NextPayload = msg.NextPayload // JUST MAKING SURE
return nil return nil
} }

View File

@@ -81,6 +81,11 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
return err return err
} }
iMessage, ok := encrypted.(interceptor.Message)
if !ok {
return writer.Write(conn)
}
return writer.Write(conn, encrypted) return writer.Write(conn, encrypted)
}) })
} }

View File

@@ -63,14 +63,24 @@ func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
return s.keyExchangeSessionID return s.keyExchangeSessionID
} }
func (s *State) GetConnection() interceptor.Connection {
return s.connection
}
func (s *State) WriteMessage(msg interceptor.Message) error { func (s *State) WriteMessage(msg interceptor.Message) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
// TODO: MANAGE CLIENT DISCOVERY
msg.SetReceiver(s.peerID)
return s.writer.Write(s.connection, msg) return s.writer.Write(s.connection, msg)
} }
func (s *State) ReadMessage(msg interceptor.Message) error {
s.mux.Lock()
defer s.mux.Unlock()
return s.reader.Read()
}
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID { func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()