added aes256.go

This commit is contained in:
harshabose
2025-05-02 12:53:30 +05:30
parent 29d5ca8ade
commit 9398d0bd5d
9 changed files with 110 additions and 20 deletions

View File

@@ -3,6 +3,7 @@ package interceptor
import "github.com/harshabose/socket-comm/internal/util"
type Chain struct {
NoOpInterceptor
interceptors []Interceptor
}

View File

@@ -44,6 +44,12 @@ type Connection interface {
}
type Interceptor interface {
ID() string
Ctx() context.Context
GetMessageRegistry() message.Registry
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
Init(Connection) error
@@ -78,10 +84,18 @@ func (f WriterFunc) Write(conn Connection, message message.Message) error {
}
type NoOpInterceptor struct {
ID string
iD string
messageRegistry message.Registry
Mutex sync.RWMutex
Ctx context.Context
ctx context.Context
}
func (interceptor *NoOpInterceptor) Ctx() context.Context {
return interceptor.ctx
}
func (interceptor *NoOpInterceptor) ID() string {
return interceptor.iD
}
func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry {

View File

@@ -55,7 +55,6 @@ type BaseMessage struct {
message.BaseMessage
}
// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
if err != nil {

View File

@@ -48,6 +48,10 @@ type Message interface {
// GetProtocol returns the protocol identifier for this message
GetProtocol() Protocol
GetNextPayload() (Payload, error)
GetNextProtocol() Protocol
// GetNext retrieves the next message in the chain, if one exists
// Returns nil, nil if there is no next message
GetNext(Registry) (Message, error)
@@ -86,8 +90,8 @@ type BaseMessage struct {
// CURRENT OTHER FIELDS...
// NEXT MESSAGE PROCESSOR
NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
NextPayload Payload `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
}
// GetProtocol returns this message's protocol identifier
@@ -95,6 +99,22 @@ func (m *BaseMessage) GetProtocol() Protocol {
return m.CurrentProtocol
}
func (m *BaseMessage) GetNextPayload() (Payload, error) {
if m.NextProtocol == NoneProtocol {
return nil, nil
}
if m.NextPayload == nil {
return nil, ErrNoPayload
}
return m.NextPayload, nil
}
func (m *BaseMessage) GetNextProtocol() Protocol {
return m.NextProtocol
}
// GetNext retrieves the next message in the chain, if one exists.
// Returns nil, nil if NextProtocol is NoneProtocol.
// Uses the provided Registry to create and unmarshal the next message.
@@ -107,7 +127,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
return nil, ErrNoPayload
}
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload))
}
// Marshal serializes the message to JSON format
@@ -136,7 +156,7 @@ func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (Ba
return BaseMessage{
CurrentProtocol: msg.GetProtocol(),
CurrentHeader: NewV1Header(UnknownSender, UnknownReceiver),
NextPayload: inner,
NextPayload: Payload(inner),
NextProtocol: nextProtocol,
}, nil
}

View File

@@ -27,7 +27,7 @@ type Registry interface {
// Unmarshal creates and deserializes a message for a protocol
// The provided data is parsed into the appropriate message type
Unmarshal(Protocol, json.RawMessage) (Message, error)
Unmarshal(Protocol, Payload) (Message, error)
// UnmarshalRaw deserializes a message when the protocol is unknown
// It first inspects the envelope to determine the protocol, then unmarshals accordingly
@@ -103,7 +103,7 @@ func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) {
// Unmarshal creates and deserializes a message for a protocol
// The provided data is parsed into the appropriate message type
// This method is thread-safe
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) {
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data Payload) (Message, error) {
msg, err := r.Create(protocol)
if err != nil {
return nil, err
@@ -120,7 +120,7 @@ func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Me
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
// This method is particularly useful for handling incoming WebSocket messages
// This method is thread-safe
func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) {
func (r *DefaultRegistry) UnmarshalRaw(data Payload) (Message, error) {
var envelope Envelope
if err := json.Unmarshal(data, &envelope); err != nil {
return nil, fmt.Errorf("failed to extract protocol: %w", err)

View File

@@ -62,15 +62,32 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
protocol := msg.GetProtocol()
data, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal message: %w", err)
}
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:])
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
}
func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) {
// TODO implement me
panic("implement me")
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
m, ok := msg.(*EncryptedMessage)
if !ok {
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
}
data, err := a.decryptor.Open(nil, m.Nonce[:], m.NextPayload, a.sessionID[:])
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}
m.NextPayload = data
return m, nil
}
func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) {

View File

@@ -1,14 +1,37 @@
package encryptor
import (
"time"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type EncryptedMessage struct {
interceptor.BaseMessage
Nonce types.Nonce
Timestamp time.Time
SessionID types.EncryptionSessionID
}
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) {
msg := &EncryptedMessage{
Nonce: nonce,
Timestamp: time.Now(),
SessionID: id,
}
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg)
if err != nil {
return nil, err
}
msg.BaseMessage = bmsg
return msg, nil
}
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
@@ -27,16 +50,17 @@ func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn intercep
return encryptionerr.ErrInvalidInterceptor
}
msg, err := ss.Decrypt(m)
decmsg, err := ss.Decrypt(m)
if err != nil {
return err
}
// TODO: message.Registry is not implemented yet
decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload)
if err != nil {
return err
msg, ok := decmsg.(*EncryptedMessage)
if !ok {
return encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
}
m.NextPayload = msg.NextPayload // JUST MAKING SURE
return nil
}

View File

@@ -81,6 +81,11 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
return err
}
iMessage, ok := encrypted.(interceptor.Message)
if !ok {
return writer.Write(conn)
}
return writer.Write(conn, encrypted)
})
}

View File

@@ -63,14 +63,24 @@ func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
return s.keyExchangeSessionID
}
func (s *State) GetConnection() interceptor.Connection {
return s.connection
}
func (s *State) WriteMessage(msg interceptor.Message) error {
s.mux.Lock()
defer s.mux.Unlock()
msg.SetReceiver(s.peerID)
// TODO: MANAGE CLIENT DISCOVERY
return s.writer.Write(s.connection, msg)
}
func (s *State) ReadMessage(msg interceptor.Message) error {
s.mux.Lock()
defer s.mux.Unlock()
return s.reader.Read()
}
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
s.mux.Lock()
defer s.mux.Unlock()