added InterceptSocketWriter and InterceptSocketReader functions on encryption interceptor

This commit is contained in:
harshabose
2025-05-02 19:31:13 +05:30
parent 9398d0bd5d
commit 162c4108c2
6 changed files with 101 additions and 54 deletions

View File

@@ -55,7 +55,7 @@ type BaseMessage struct {
message.BaseMessage
}
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Marshallable, msg Message) (BaseMessage, error) {
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
if err != nil {
return BaseMessage{}, nil

View File

@@ -41,6 +41,20 @@ const (
UnknownReceiver Receiver = "unknown.receiver"
)
type Marshallable interface {
// Marshal serializes the message to JSON format
Marshal() ([]byte, error)
}
type Unmarshallable interface {
// Unmarshal deserializes the message from JSON format
Unmarshal([]byte) error
}
func (p Payload) Marshal() ([]byte, error) {
return p, nil
}
// Message defines the interface that all message types must implement.
// It provides methods for protocol identification, serialization, and
// message nesting/unwrapping.
@@ -56,11 +70,9 @@ type Message interface {
// Returns nil, nil if there is no next message
GetNext(Registry) (Message, error)
// Marshal serializes the message to JSON format
Marshal() ([]byte, error)
Marshallable
// Unmarshal deserializes the message from JSON format
Unmarshal([]byte) error
Unmarshallable
}
// Header contains common metadata for all messages
@@ -127,7 +139,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
return nil, ErrNoPayload
}
return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload))
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
}
// Marshal serializes the message to JSON format
@@ -140,7 +152,7 @@ func (m *BaseMessage) Unmarshal(data []byte) error {
return json.Unmarshal(data, m)
}
func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (BaseMessage, error) {
func NewBaseMessage(nextProtocol Protocol, nextPayload Marshallable, msg Message) (BaseMessage, error) {
var inner json.RawMessage = nil
if nextPayload != nil {
if nextProtocol == NoneProtocol {

View File

@@ -21,6 +21,8 @@ type Registry interface {
// Returns an error if the protocol is already registered
Register(Protocol, Factory) error
Check(protocol Protocol) bool
// Create instantiates a new message for the given protocol
// Returns an error if the protocol is not registered
Create(Protocol) (Message, error)
@@ -65,6 +67,14 @@ func NewRegistry() *DefaultRegistry {
}
}
func (r *DefaultRegistry) Check(protocol Protocol) bool {
r.mux.RLock()
defer r.mux.RUnlock()
_, exists := r.factories[protocol]
return exists
}
// Register adds a message factory for a protocol
// Returns an error if the protocol is already registered
// This method is thread-safe

View File

@@ -56,25 +56,28 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
a.mux.Lock()
defer a.mux.Unlock()
m, ok := msg.(*EncryptedMessage)
if !ok {
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
}
nonce := types.Nonce{}
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
protocol := msg.GetProtocol()
encryptedData := a.encryptor.Seal(nil, nonce[:], m.NextPayload, a.sessionID[:])
data, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal message: %w", err)
}
m.NextPayload = encryptedData
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:])
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
return m, nil
}
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
a.mux.Lock()
defer a.mux.Unlock()
m, ok := msg.(*EncryptedMessage)
if !ok {
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE

View File

@@ -17,21 +17,50 @@ type EncryptedMessage struct {
SessionID types.EncryptionSessionID
}
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) {
msg := &EncryptedMessage{
Nonce: nonce,
func NewEncryptedMessage(msg message.Message) (*EncryptedMessage, error) {
em := &EncryptedMessage{
Timestamp: time.Now(),
SessionID: id,
}
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg)
bmsg, err := interceptor.NewBaseMessage(msg.GetProtocol(), msg, em)
if err != nil {
return nil, err
}
msg.BaseMessage = bmsg
em.BaseMessage = bmsg
return msg, nil
return em, nil
}
func (m *EncryptedMessage) WriteProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
i, ok := _i.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := i.GetState(connection)
if err != nil {
return err
}
ss, ok := s.(interfaces.CanEncrypt)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
encmsg, err := ss.Encrypt(m)
if err != nil {
return err
}
msg, ok := encmsg.(*EncryptedMessage)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
m.NextPayload = msg.NextPayload
return nil
}
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {

View File

@@ -8,6 +8,7 @@ import (
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
@@ -17,6 +18,7 @@ import (
type Interceptor struct {
interceptor.NoOpInterceptor
localMessageRegistry message.Registry
nonceValidator NonceValidator
keyExchangeManager interfaces.KeyExchangeManager
keyProvider keyprovider.KeyProvider
@@ -25,7 +27,7 @@ type Interceptor struct {
}
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
ctx, cancel := context.WithCancel(i.Ctx)
ctx, cancel := context.WithCancel(i.Ctx())
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
if err != nil {
@@ -49,7 +51,7 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
return err
}
ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second)
ctx, cancel := context.WithTimeout(i.Ctx(), 10*time.Second)
defer cancel()
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
@@ -63,30 +65,20 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error {
s, err := i.GetState(conn)
if err != nil {
return err
}
ss, ok := s.(interfaces.CanEncrypt)
if !ok {
return err
}
encrypted, err := ss.Encrypt(msg)
if err != nil {
if !s.GetConfig().RequireEncryption {
if i.localMessageRegistry.Check(msg.GetProtocol()) {
return writer.Write(conn, msg)
}
m, err := encryptor.NewEncryptedMessage(msg)
if err != nil {
return err
}
iMessage, ok := encrypted.(interceptor.Message)
if !ok {
return writer.Write(conn)
if err := m.WriteProcess(i, conn); err != nil {
return err
}
return writer.Write(conn, encrypted)
return writer.Write(conn, m)
})
}
@@ -97,22 +89,23 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
return msg, err
}
s, err := i.GetState(conn)
if err != nil {
return nil, err
if !i.localMessageRegistry.Check(msg.GetProtocol()) {
if !i.config.RequireEncryption {
return msg, nil
}
return nil, encryptionerr.ErrInvalidInterceptor
}
ss, ok := s.(interfaces.CanDecrypt)
m, ok := msg.(interceptor.Message)
if !ok {
return msg, err
return nil, encryptionerr.ErrInvalidInterceptor
}
m, err := ss.Decrypt(msg)
if err != nil {
if err := m.ReadProcess(i, conn); err != nil {
return nil, err
}
return m, nil
return m.GetNext(i.GetMessageRegistry())
})
}