mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-05 15:46:52 +08:00
added InterceptSocketWriter and InterceptSocketReader functions on encryption interceptor
This commit is contained in:
@@ -55,7 +55,7 @@ type BaseMessage struct {
|
||||
message.BaseMessage
|
||||
}
|
||||
|
||||
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
|
||||
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Marshallable, msg Message) (BaseMessage, error) {
|
||||
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
|
||||
if err != nil {
|
||||
return BaseMessage{}, nil
|
||||
|
@@ -41,6 +41,20 @@ const (
|
||||
UnknownReceiver Receiver = "unknown.receiver"
|
||||
)
|
||||
|
||||
type Marshallable interface {
|
||||
// Marshal serializes the message to JSON format
|
||||
Marshal() ([]byte, error)
|
||||
}
|
||||
|
||||
type Unmarshallable interface {
|
||||
// Unmarshal deserializes the message from JSON format
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
|
||||
func (p Payload) Marshal() ([]byte, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Message defines the interface that all message types must implement.
|
||||
// It provides methods for protocol identification, serialization, and
|
||||
// message nesting/unwrapping.
|
||||
@@ -56,11 +70,9 @@ type Message interface {
|
||||
// Returns nil, nil if there is no next message
|
||||
GetNext(Registry) (Message, error)
|
||||
|
||||
// Marshal serializes the message to JSON format
|
||||
Marshal() ([]byte, error)
|
||||
Marshallable
|
||||
|
||||
// Unmarshal deserializes the message from JSON format
|
||||
Unmarshal([]byte) error
|
||||
Unmarshallable
|
||||
}
|
||||
|
||||
// Header contains common metadata for all messages
|
||||
@@ -127,7 +139,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
|
||||
return nil, ErrNoPayload
|
||||
}
|
||||
|
||||
return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload))
|
||||
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
|
||||
}
|
||||
|
||||
// Marshal serializes the message to JSON format
|
||||
@@ -140,7 +152,7 @@ func (m *BaseMessage) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, m)
|
||||
}
|
||||
|
||||
func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (BaseMessage, error) {
|
||||
func NewBaseMessage(nextProtocol Protocol, nextPayload Marshallable, msg Message) (BaseMessage, error) {
|
||||
var inner json.RawMessage = nil
|
||||
if nextPayload != nil {
|
||||
if nextProtocol == NoneProtocol {
|
||||
|
@@ -21,6 +21,8 @@ type Registry interface {
|
||||
// Returns an error if the protocol is already registered
|
||||
Register(Protocol, Factory) error
|
||||
|
||||
Check(protocol Protocol) bool
|
||||
|
||||
// Create instantiates a new message for the given protocol
|
||||
// Returns an error if the protocol is not registered
|
||||
Create(Protocol) (Message, error)
|
||||
@@ -65,6 +67,14 @@ func NewRegistry() *DefaultRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultRegistry) Check(protocol Protocol) bool {
|
||||
r.mux.RLock()
|
||||
defer r.mux.RUnlock()
|
||||
|
||||
_, exists := r.factories[protocol]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Register adds a message factory for a protocol
|
||||
// Returns an error if the protocol is already registered
|
||||
// This method is thread-safe
|
||||
|
@@ -56,25 +56,28 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
|
||||
m, ok := msg.(*EncryptedMessage)
|
||||
if !ok {
|
||||
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||
}
|
||||
|
||||
nonce := types.Nonce{}
|
||||
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
protocol := msg.GetProtocol()
|
||||
encryptedData := a.encryptor.Seal(nil, nonce[:], m.NextPayload, a.sessionID[:])
|
||||
|
||||
data, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
m.NextPayload = encryptedData
|
||||
|
||||
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:])
|
||||
|
||||
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
|
||||
m, ok := msg.(*EncryptedMessage)
|
||||
if !ok {
|
||||
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||
|
@@ -17,21 +17,50 @@ type EncryptedMessage struct {
|
||||
SessionID types.EncryptionSessionID
|
||||
}
|
||||
|
||||
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) {
|
||||
msg := &EncryptedMessage{
|
||||
Nonce: nonce,
|
||||
func NewEncryptedMessage(msg message.Message) (*EncryptedMessage, error) {
|
||||
em := &EncryptedMessage{
|
||||
Timestamp: time.Now(),
|
||||
SessionID: id,
|
||||
}
|
||||
|
||||
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg)
|
||||
bmsg, err := interceptor.NewBaseMessage(msg.GetProtocol(), msg, em)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg.BaseMessage = bmsg
|
||||
em.BaseMessage = bmsg
|
||||
|
||||
return msg, nil
|
||||
return em, nil
|
||||
}
|
||||
|
||||
func (m *EncryptedMessage) WriteProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
i, ok := _i.(interfaces.CanGetState)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
s, err := i.GetState(connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ss, ok := s.(interfaces.CanEncrypt)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
encmsg, err := ss.Encrypt(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, ok := encmsg.(*EncryptedMessage)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
m.NextPayload = msg.NextPayload
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||
@@ -17,15 +18,16 @@ import (
|
||||
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
nonceValidator NonceValidator
|
||||
keyExchangeManager interfaces.KeyExchangeManager
|
||||
keyProvider keyprovider.KeyProvider
|
||||
stateManager interfaces.StateManager
|
||||
config config.Config
|
||||
localMessageRegistry message.Registry
|
||||
nonceValidator NonceValidator
|
||||
keyExchangeManager interfaces.KeyExchangeManager
|
||||
keyProvider keyprovider.KeyProvider
|
||||
stateManager interfaces.StateManager
|
||||
config config.Config
|
||||
}
|
||||
|
||||
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
|
||||
ctx, cancel := context.WithCancel(i.Ctx)
|
||||
ctx, cancel := context.WithCancel(i.Ctx())
|
||||
|
||||
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
|
||||
if err != nil {
|
||||
@@ -49,7 +51,7 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(i.Ctx(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
|
||||
@@ -63,30 +65,20 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
|
||||
|
||||
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error {
|
||||
s, err := i.GetState(conn)
|
||||
if i.localMessageRegistry.Check(msg.GetProtocol()) {
|
||||
return writer.Write(conn, msg)
|
||||
}
|
||||
|
||||
m, err := encryptor.NewEncryptedMessage(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ss, ok := s.(interfaces.CanEncrypt)
|
||||
if !ok {
|
||||
if err := m.WriteProcess(i, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encrypted, err := ss.Encrypt(msg)
|
||||
if err != nil {
|
||||
if !s.GetConfig().RequireEncryption {
|
||||
return writer.Write(conn, msg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
iMessage, ok := encrypted.(interceptor.Message)
|
||||
if !ok {
|
||||
return writer.Write(conn)
|
||||
}
|
||||
|
||||
return writer.Write(conn, encrypted)
|
||||
return writer.Write(conn, m)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,22 +89,23 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
|
||||
return msg, err
|
||||
}
|
||||
|
||||
s, err := i.GetState(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !i.localMessageRegistry.Check(msg.GetProtocol()) {
|
||||
if !i.config.RequireEncryption {
|
||||
return msg, nil
|
||||
}
|
||||
return nil, encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
ss, ok := s.(interfaces.CanDecrypt)
|
||||
m, ok := msg.(interceptor.Message)
|
||||
if !ok {
|
||||
return msg, err
|
||||
return nil, encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
m, err := ss.Decrypt(msg)
|
||||
if err != nil {
|
||||
if err := m.ReadProcess(i, conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
return m.GetNext(i.GetMessageRegistry())
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user