added InterceptSocketWriter and InterceptSocketReader functions on encryption interceptor

This commit is contained in:
harshabose
2025-05-02 19:31:13 +05:30
parent 9398d0bd5d
commit 162c4108c2
6 changed files with 101 additions and 54 deletions

View File

@@ -55,7 +55,7 @@ type BaseMessage struct {
message.BaseMessage message.BaseMessage
} }
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) { func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Marshallable, msg Message) (BaseMessage, error) {
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg) bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
if err != nil { if err != nil {
return BaseMessage{}, nil return BaseMessage{}, nil

View File

@@ -41,6 +41,20 @@ const (
UnknownReceiver Receiver = "unknown.receiver" UnknownReceiver Receiver = "unknown.receiver"
) )
type Marshallable interface {
// Marshal serializes the message to JSON format
Marshal() ([]byte, error)
}
type Unmarshallable interface {
// Unmarshal deserializes the message from JSON format
Unmarshal([]byte) error
}
func (p Payload) Marshal() ([]byte, error) {
return p, nil
}
// Message defines the interface that all message types must implement. // Message defines the interface that all message types must implement.
// It provides methods for protocol identification, serialization, and // It provides methods for protocol identification, serialization, and
// message nesting/unwrapping. // message nesting/unwrapping.
@@ -56,11 +70,9 @@ type Message interface {
// Returns nil, nil if there is no next message // Returns nil, nil if there is no next message
GetNext(Registry) (Message, error) GetNext(Registry) (Message, error)
// Marshal serializes the message to JSON format Marshallable
Marshal() ([]byte, error)
// Unmarshal deserializes the message from JSON format Unmarshallable
Unmarshal([]byte) error
} }
// Header contains common metadata for all messages // Header contains common metadata for all messages
@@ -127,7 +139,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
return nil, ErrNoPayload return nil, ErrNoPayload
} }
return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload)) return registry.Unmarshal(m.NextProtocol, m.NextPayload)
} }
// Marshal serializes the message to JSON format // Marshal serializes the message to JSON format
@@ -140,7 +152,7 @@ func (m *BaseMessage) Unmarshal(data []byte) error {
return json.Unmarshal(data, m) return json.Unmarshal(data, m)
} }
func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (BaseMessage, error) { func NewBaseMessage(nextProtocol Protocol, nextPayload Marshallable, msg Message) (BaseMessage, error) {
var inner json.RawMessage = nil var inner json.RawMessage = nil
if nextPayload != nil { if nextPayload != nil {
if nextProtocol == NoneProtocol { if nextProtocol == NoneProtocol {

View File

@@ -21,6 +21,8 @@ type Registry interface {
// Returns an error if the protocol is already registered // Returns an error if the protocol is already registered
Register(Protocol, Factory) error Register(Protocol, Factory) error
Check(protocol Protocol) bool
// Create instantiates a new message for the given protocol // Create instantiates a new message for the given protocol
// Returns an error if the protocol is not registered // Returns an error if the protocol is not registered
Create(Protocol) (Message, error) Create(Protocol) (Message, error)
@@ -65,6 +67,14 @@ func NewRegistry() *DefaultRegistry {
} }
} }
func (r *DefaultRegistry) Check(protocol Protocol) bool {
r.mux.RLock()
defer r.mux.RUnlock()
_, exists := r.factories[protocol]
return exists
}
// Register adds a message factory for a protocol // Register adds a message factory for a protocol
// Returns an error if the protocol is already registered // Returns an error if the protocol is already registered
// This method is thread-safe // This method is thread-safe

View File

@@ -56,25 +56,28 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
a.mux.Lock() a.mux.Lock()
defer a.mux.Unlock() defer a.mux.Unlock()
m, ok := msg.(*EncryptedMessage)
if !ok {
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
}
nonce := types.Nonce{} nonce := types.Nonce{}
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err) return nil, fmt.Errorf("failed to generate nonce: %w", err)
} }
protocol := msg.GetProtocol() encryptedData := a.encryptor.Seal(nil, nonce[:], m.NextPayload, a.sessionID[:])
data, err := msg.Marshal() m.NextPayload = encryptedData
if err != nil {
return nil, fmt.Errorf("failed to marshal message: %w", err)
}
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:]) return m, nil
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
} }
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) { func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
a.mux.Lock()
defer a.mux.Unlock()
m, ok := msg.(*EncryptedMessage) m, ok := msg.(*EncryptedMessage)
if !ok { if !ok {
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE

View File

@@ -17,21 +17,50 @@ type EncryptedMessage struct {
SessionID types.EncryptionSessionID SessionID types.EncryptionSessionID
} }
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) { func NewEncryptedMessage(msg message.Message) (*EncryptedMessage, error) {
msg := &EncryptedMessage{ em := &EncryptedMessage{
Nonce: nonce,
Timestamp: time.Now(), Timestamp: time.Now(),
SessionID: id,
} }
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg) bmsg, err := interceptor.NewBaseMessage(msg.GetProtocol(), msg, em)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg.BaseMessage = bmsg em.BaseMessage = bmsg
return msg, nil return em, nil
}
func (m *EncryptedMessage) WriteProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
i, ok := _i.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := i.GetState(connection)
if err != nil {
return err
}
ss, ok := s.(interfaces.CanEncrypt)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
encmsg, err := ss.Encrypt(m)
if err != nil {
return err
}
msg, ok := encmsg.(*EncryptedMessage)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
m.NextPayload = msg.NextPayload
return nil
} }
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error { func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {

View File

@@ -8,6 +8,7 @@ import (
"github.com/harshabose/socket-comm/pkg/message" "github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
@@ -17,6 +18,7 @@ import (
type Interceptor struct { type Interceptor struct {
interceptor.NoOpInterceptor interceptor.NoOpInterceptor
localMessageRegistry message.Registry
nonceValidator NonceValidator nonceValidator NonceValidator
keyExchangeManager interfaces.KeyExchangeManager keyExchangeManager interfaces.KeyExchangeManager
keyProvider keyprovider.KeyProvider keyProvider keyprovider.KeyProvider
@@ -25,7 +27,7 @@ type Interceptor struct {
} }
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) { func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
ctx, cancel := context.WithCancel(i.Ctx) ctx, cancel := context.WithCancel(i.Ctx())
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader) newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
if err != nil { if err != nil {
@@ -49,7 +51,7 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
return err return err
} }
ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second) ctx, cancel := context.WithTimeout(i.Ctx(), 10*time.Second)
defer cancel() defer cancel()
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted) waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
@@ -63,30 +65,20 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error { return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error {
s, err := i.GetState(conn) if i.localMessageRegistry.Check(msg.GetProtocol()) {
if err != nil {
return err
}
ss, ok := s.(interfaces.CanEncrypt)
if !ok {
return err
}
encrypted, err := ss.Encrypt(msg)
if err != nil {
if !s.GetConfig().RequireEncryption {
return writer.Write(conn, msg) return writer.Write(conn, msg)
} }
m, err := encryptor.NewEncryptedMessage(msg)
if err != nil {
return err return err
} }
iMessage, ok := encrypted.(interceptor.Message) if err := m.WriteProcess(i, conn); err != nil {
if !ok { return err
return writer.Write(conn)
} }
return writer.Write(conn, encrypted) return writer.Write(conn, m)
}) })
} }
@@ -97,22 +89,23 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
return msg, err return msg, err
} }
s, err := i.GetState(conn) if !i.localMessageRegistry.Check(msg.GetProtocol()) {
if err != nil { if !i.config.RequireEncryption {
return nil, err return msg, nil
}
return nil, encryptionerr.ErrInvalidInterceptor
} }
ss, ok := s.(interfaces.CanDecrypt) m, ok := msg.(interceptor.Message)
if !ok { if !ok {
return msg, err return nil, encryptionerr.ErrInvalidInterceptor
} }
m, err := ss.Decrypt(msg) if err := m.ReadProcess(i, conn); err != nil {
if err != nil {
return nil, err return nil, err
} }
return m, nil return m.GetNext(i.GetMessageRegistry())
}) })
} }