mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-05 23:56:54 +08:00
added InterceptSocketWriter and InterceptSocketReader functions on encryption interceptor
This commit is contained in:
@@ -55,7 +55,7 @@ type BaseMessage struct {
|
|||||||
message.BaseMessage
|
message.BaseMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
|
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Marshallable, msg Message) (BaseMessage, error) {
|
||||||
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
|
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return BaseMessage{}, nil
|
return BaseMessage{}, nil
|
||||||
|
@@ -41,6 +41,20 @@ const (
|
|||||||
UnknownReceiver Receiver = "unknown.receiver"
|
UnknownReceiver Receiver = "unknown.receiver"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Marshallable interface {
|
||||||
|
// Marshal serializes the message to JSON format
|
||||||
|
Marshal() ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Unmarshallable interface {
|
||||||
|
// Unmarshal deserializes the message from JSON format
|
||||||
|
Unmarshal([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Payload) Marshal() ([]byte, error) {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Message defines the interface that all message types must implement.
|
// Message defines the interface that all message types must implement.
|
||||||
// It provides methods for protocol identification, serialization, and
|
// It provides methods for protocol identification, serialization, and
|
||||||
// message nesting/unwrapping.
|
// message nesting/unwrapping.
|
||||||
@@ -56,11 +70,9 @@ type Message interface {
|
|||||||
// Returns nil, nil if there is no next message
|
// Returns nil, nil if there is no next message
|
||||||
GetNext(Registry) (Message, error)
|
GetNext(Registry) (Message, error)
|
||||||
|
|
||||||
// Marshal serializes the message to JSON format
|
Marshallable
|
||||||
Marshal() ([]byte, error)
|
|
||||||
|
|
||||||
// Unmarshal deserializes the message from JSON format
|
Unmarshallable
|
||||||
Unmarshal([]byte) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header contains common metadata for all messages
|
// Header contains common metadata for all messages
|
||||||
@@ -127,7 +139,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
|
|||||||
return nil, ErrNoPayload
|
return nil, ErrNoPayload
|
||||||
}
|
}
|
||||||
|
|
||||||
return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload))
|
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal serializes the message to JSON format
|
// Marshal serializes the message to JSON format
|
||||||
@@ -140,7 +152,7 @@ func (m *BaseMessage) Unmarshal(data []byte) error {
|
|||||||
return json.Unmarshal(data, m)
|
return json.Unmarshal(data, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (BaseMessage, error) {
|
func NewBaseMessage(nextProtocol Protocol, nextPayload Marshallable, msg Message) (BaseMessage, error) {
|
||||||
var inner json.RawMessage = nil
|
var inner json.RawMessage = nil
|
||||||
if nextPayload != nil {
|
if nextPayload != nil {
|
||||||
if nextProtocol == NoneProtocol {
|
if nextProtocol == NoneProtocol {
|
||||||
|
@@ -21,6 +21,8 @@ type Registry interface {
|
|||||||
// Returns an error if the protocol is already registered
|
// Returns an error if the protocol is already registered
|
||||||
Register(Protocol, Factory) error
|
Register(Protocol, Factory) error
|
||||||
|
|
||||||
|
Check(protocol Protocol) bool
|
||||||
|
|
||||||
// Create instantiates a new message for the given protocol
|
// Create instantiates a new message for the given protocol
|
||||||
// Returns an error if the protocol is not registered
|
// Returns an error if the protocol is not registered
|
||||||
Create(Protocol) (Message, error)
|
Create(Protocol) (Message, error)
|
||||||
@@ -65,6 +67,14 @@ func NewRegistry() *DefaultRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *DefaultRegistry) Check(protocol Protocol) bool {
|
||||||
|
r.mux.RLock()
|
||||||
|
defer r.mux.RUnlock()
|
||||||
|
|
||||||
|
_, exists := r.factories[protocol]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
// Register adds a message factory for a protocol
|
// Register adds a message factory for a protocol
|
||||||
// Returns an error if the protocol is already registered
|
// Returns an error if the protocol is already registered
|
||||||
// This method is thread-safe
|
// This method is thread-safe
|
||||||
|
@@ -56,25 +56,28 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
|
|||||||
a.mux.Lock()
|
a.mux.Lock()
|
||||||
defer a.mux.Unlock()
|
defer a.mux.Unlock()
|
||||||
|
|
||||||
|
m, ok := msg.(*EncryptedMessage)
|
||||||
|
if !ok {
|
||||||
|
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||||
|
}
|
||||||
|
|
||||||
nonce := types.Nonce{}
|
nonce := types.Nonce{}
|
||||||
|
|
||||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol := msg.GetProtocol()
|
encryptedData := a.encryptor.Seal(nil, nonce[:], m.NextPayload, a.sessionID[:])
|
||||||
|
|
||||||
data, err := msg.Marshal()
|
m.NextPayload = encryptedData
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:])
|
return m, nil
|
||||||
|
|
||||||
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
|
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
|
||||||
|
a.mux.Lock()
|
||||||
|
defer a.mux.Unlock()
|
||||||
|
|
||||||
m, ok := msg.(*EncryptedMessage)
|
m, ok := msg.(*EncryptedMessage)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||||
|
@@ -17,21 +17,50 @@ type EncryptedMessage struct {
|
|||||||
SessionID types.EncryptionSessionID
|
SessionID types.EncryptionSessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) {
|
func NewEncryptedMessage(msg message.Message) (*EncryptedMessage, error) {
|
||||||
msg := &EncryptedMessage{
|
em := &EncryptedMessage{
|
||||||
Nonce: nonce,
|
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
SessionID: id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg)
|
bmsg, err := interceptor.NewBaseMessage(msg.GetProtocol(), msg, em)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg.BaseMessage = bmsg
|
em.BaseMessage = bmsg
|
||||||
|
|
||||||
return msg, nil
|
return em, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *EncryptedMessage) WriteProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
|
||||||
|
i, ok := _i.(interfaces.CanGetState)
|
||||||
|
if !ok {
|
||||||
|
return encryptionerr.ErrInvalidInterceptor
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := i.GetState(connection)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ss, ok := s.(interfaces.CanEncrypt)
|
||||||
|
if !ok {
|
||||||
|
return encryptionerr.ErrInvalidInterceptor
|
||||||
|
}
|
||||||
|
|
||||||
|
encmsg, err := ss.Encrypt(m)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := encmsg.(*EncryptedMessage)
|
||||||
|
if !ok {
|
||||||
|
return encryptionerr.ErrInvalidInterceptor
|
||||||
|
}
|
||||||
|
|
||||||
|
m.NextPayload = msg.NextPayload
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
|
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
|
||||||
|
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/harshabose/socket-comm/pkg/message"
|
"github.com/harshabose/socket-comm/pkg/message"
|
||||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
|
||||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||||
@@ -17,6 +18,7 @@ import (
|
|||||||
|
|
||||||
type Interceptor struct {
|
type Interceptor struct {
|
||||||
interceptor.NoOpInterceptor
|
interceptor.NoOpInterceptor
|
||||||
|
localMessageRegistry message.Registry
|
||||||
nonceValidator NonceValidator
|
nonceValidator NonceValidator
|
||||||
keyExchangeManager interfaces.KeyExchangeManager
|
keyExchangeManager interfaces.KeyExchangeManager
|
||||||
keyProvider keyprovider.KeyProvider
|
keyProvider keyprovider.KeyProvider
|
||||||
@@ -25,7 +27,7 @@ type Interceptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
|
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
|
||||||
ctx, cancel := context.WithCancel(i.Ctx)
|
ctx, cancel := context.WithCancel(i.Ctx())
|
||||||
|
|
||||||
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
|
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -49,7 +51,7 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second)
|
ctx, cancel := context.WithTimeout(i.Ctx(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
|
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
|
||||||
@@ -63,30 +65,20 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
|
|||||||
|
|
||||||
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||||
return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error {
|
return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error {
|
||||||
s, err := i.GetState(conn)
|
if i.localMessageRegistry.Check(msg.GetProtocol()) {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ss, ok := s.(interfaces.CanEncrypt)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
encrypted, err := ss.Encrypt(msg)
|
|
||||||
if err != nil {
|
|
||||||
if !s.GetConfig().RequireEncryption {
|
|
||||||
return writer.Write(conn, msg)
|
return writer.Write(conn, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m, err := encryptor.NewEncryptedMessage(msg)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
iMessage, ok := encrypted.(interceptor.Message)
|
if err := m.WriteProcess(i, conn); err != nil {
|
||||||
if !ok {
|
return err
|
||||||
return writer.Write(conn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return writer.Write(conn, encrypted)
|
return writer.Write(conn, m)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,22 +89,23 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
|
|||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := i.GetState(conn)
|
if !i.localMessageRegistry.Check(msg.GetProtocol()) {
|
||||||
if err != nil {
|
if !i.config.RequireEncryption {
|
||||||
return nil, err
|
return msg, nil
|
||||||
|
}
|
||||||
|
return nil, encryptionerr.ErrInvalidInterceptor
|
||||||
}
|
}
|
||||||
|
|
||||||
ss, ok := s.(interfaces.CanDecrypt)
|
m, ok := msg.(interceptor.Message)
|
||||||
if !ok {
|
if !ok {
|
||||||
return msg, err
|
return nil, encryptionerr.ErrInvalidInterceptor
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := ss.Decrypt(msg)
|
if err := m.ReadProcess(i, conn); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m.GetNext(i.GetMessageRegistry())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user