added NewBaseMessage function in message.go

This commit is contained in:
harshabose
2025-05-02 02:52:05 +05:30
parent 83414564ec
commit 29d5ca8ade
10 changed files with 345 additions and 25 deletions

View File

@@ -4,6 +4,8 @@ import (
"context" "context"
"io" "io"
"sync" "sync"
"github.com/harshabose/socket-comm/pkg/message"
) )
type Registry struct { type Registry struct {
@@ -56,29 +58,34 @@ type Interceptor interface {
} }
type Writer interface { type Writer interface {
Write(conn Connection, message Message) error Write(conn Connection, message message.Message) error
} }
type Reader interface { type Reader interface {
Read(conn Connection) (Message, error) Read(conn Connection) (message.Message, error)
} }
type ReaderFunc func(conn Connection) (Message, error) type ReaderFunc func(conn Connection) (message.Message, error)
func (f ReaderFunc) Read(conn Connection) (Message, error) { func (f ReaderFunc) Read(conn Connection) (message.Message, error) {
return f(conn) return f(conn)
} }
type WriterFunc func(conn Connection, message Message) error type WriterFunc func(conn Connection, message message.Message) error
func (f WriterFunc) Write(conn Connection, message Message) error { func (f WriterFunc) Write(conn Connection, message message.Message) error {
return f(conn, message) return f(conn, message)
} }
type NoOpInterceptor struct { type NoOpInterceptor struct {
ID string ID string
Mutex sync.RWMutex messageRegistry message.Registry
Ctx context.Context Mutex sync.RWMutex
Ctx context.Context
}
func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry {
return interceptor.messageRegistry
} }
func (interceptor *NoOpInterceptor) BindSocketConnection(_ Connection, _ Writer, _ Reader) (Writer, Reader, error) { func (interceptor *NoOpInterceptor) BindSocketConnection(_ Connection, _ Writer, _ Reader) (Writer, Reader, error) {

View File

@@ -0,0 +1,93 @@
package encryptor
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
"sync"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type AES256Encryptor struct {
encryptor cipher.AEAD
decryptor cipher.AEAD
sessionID types.EncryptionSessionID
mux sync.RWMutex
}
func (a *AES256Encryptor) SetKeys(encryptorKey, decryptorKey types.Key) error {
a.mux.Lock()
defer a.mux.Unlock()
// Setup encryption AEAD
encBlock, err := aes.NewCipher(encryptorKey[:])
if err != nil {
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
}
encGCM, err := cipher.NewGCM(encBlock)
if err != nil {
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
}
// Setup decryption AEAD
decBlock, err := aes.NewCipher(decryptorKey[:])
if err != nil {
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
}
decGCM, err := cipher.NewGCM(decBlock)
if err != nil {
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
}
a.encryptor = encGCM
a.decryptor = decGCM
return nil
}
func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error) {
a.mux.Lock()
defer a.mux.Unlock()
nonce := types.Nonce{}
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
data, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal message: %w", err)
}
}
func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) {
// TODO implement me
panic("implement me")
}
func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) {
a.mux.Lock()
defer a.mux.Unlock()
a.sessionID = id
}
func (a *AES256Encryptor) Ready() bool {
a.mux.RLock()
defer a.mux.RUnlock()
return a.encryptor != nil && a.decryptor != nil
}
func (a *AES256Encryptor) Close() error {
// TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,42 @@
package encryptor
import (
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
)
type EncryptedMessage struct {
interceptor.BaseMessage
}
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
i, ok := _i.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := i.GetState(conn)
if err != nil {
return err
}
ss, ok := s.(interfaces.CanDecrypt)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
msg, err := ss.Decrypt(m)
if err != nil {
return err
}
// TODO: message.Registry is not implemented yet
decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,64 @@
package encryptor
import (
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type SkipEncryptionChecker func(message.Message) bool
// NewProtocolSkipChecker creates a checker that skips encryption for specific protocols
func NewProtocolSkipChecker(protocolsToSkip ...message.Protocol) SkipEncryptionChecker {
skipMap := make(map[message.Protocol]bool)
for _, protocol := range protocolsToSkip {
skipMap[protocol] = true
}
return func(msg message.Message) bool {
return skipMap[msg.GetProtocol()]
}
}
type SkipperEncryptor struct {
wrapped interfaces.Encryptor
skip SkipEncryptionChecker
}
func NewSkipperEncryptor(wrapped interfaces.Encryptor, skip SkipEncryptionChecker) *SkipperEncryptor {
return &SkipperEncryptor{
wrapped: wrapped,
skip: skip,
}
}
func (e *SkipperEncryptor) SetSessionID(id types.EncryptionSessionID) {
e.wrapped.SetSessionID(id)
}
func (e *SkipperEncryptor) SetKeys(encryptorKey, decryptorKey types.Key) error {
return e.wrapped.SetKeys(encryptorKey, decryptorKey)
}
func (e *SkipperEncryptor) Encrypt(msg message.Message) (message.Message, error) {
if e.skip(msg) {
return msg, nil
}
return e.wrapped.Encrypt(msg)
}
func (e *SkipperEncryptor) Decrypt(msg message.Message) (message.Message, error) {
if !e.skip(msg) {
return e.wrapped.Decrypt(msg)
}
return msg, nil
}
func (e *SkipperEncryptor) Ready() bool {
return e.wrapped.Ready()
}
func (e *SkipperEncryptor) Close() error {
return e.wrapped.Close()
}

View File

@@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
@@ -53,23 +54,71 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted) waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
return i.Process(waiter, s) if err := i.Process(waiter, s); err != nil {
return err
}
return i.keyExchangeManager.Finalise(s)
} }
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error {
s, err := i.GetState(conn)
if err != nil {
return err
}
ss, ok := s.(interfaces.CanEncrypt)
if !ok {
return err
}
encrypted, err := ss.Encrypt(msg)
if err != nil {
if !s.GetConfig().RequireEncryption {
return writer.Write(conn, msg)
}
return err
}
return writer.Write(conn, encrypted)
})
} }
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader { func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn interceptor.Connection) (message.Message, error) {
msg, err := reader.Read(conn)
if err != nil {
return msg, err
}
s, err := i.GetState(conn)
if err != nil {
return nil, err
}
ss, ok := s.(interfaces.CanDecrypt)
if !ok {
return msg, err
}
m, err := ss.Decrypt(msg)
if err != nil {
return nil, err
}
return m, nil
})
} }
func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) { func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) {
// TODO: Implement full closing
} }
func (i *Interceptor) Close() error { func (i *Interceptor) Close() error {
// TODO: Use UnBindSocketConnection to close all
// TODO: Close interceptor
return nil
} }
func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) { func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) {

View File

@@ -16,19 +16,25 @@ type KeyGetter interface {
GetKeys() (encKey, decKey types.Key, err error) GetKeys() (encKey, decKey types.Key, err error)
} }
type CanEncrypt interface {
// Encrypt encrypts a message between sender and receiver
Encrypt(message message.Message) (message.Message, error)
}
type CanDecrypt interface {
// Decrypt decrypts an encrypted message in-place
Decrypt(message message.Message) (message.Message, error)
}
// Encryptor defines the interface for message encryption and decryption // Encryptor defines the interface for message encryption and decryption
type Encryptor interface { type Encryptor interface {
KeySetter KeySetter
CanEncrypt
CanDecrypt
// SetSessionID sets the session identifier for this encryption session // SetSessionID sets the session identifier for this encryption session
SetSessionID(id types.EncryptionSessionID) SetSessionID(id types.EncryptionSessionID)
// Encrypt encrypts a message between sender and receiver
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)
// Decrypt decrypts an encrypted message in-place
Decrypt(message message.Message) error
// Ready checks if the encryptor is properly initialized and ready to use // Ready checks if the encryptor is properly initialized and ready to use
Ready() bool Ready() bool

View File

@@ -10,6 +10,7 @@ type ProtocolProcessor interface {
type KeyExchangeManager interface { type KeyExchangeManager interface {
Init(state State, options ...ProtocolFactoryOption) error Init(state State, options ...ProtocolFactoryOption) error
Finalise(state State) error
} }
type CanGetSessionState interface { type CanGetSessionState interface {

View File

@@ -225,21 +225,50 @@ func (m *Done) ReadProcess(_i interceptor.Interceptor, connection interceptor.Co
return pp.Process(m, s) return pp.Process(m, s)
} }
func (m *Done) Process(protocol interfaces.Protocol, _s interfaces.State) error { func (m *Done) Process(protocol interfaces.Protocol, s interfaces.State) error {
ss, ok := _s.(interfaces.KeySetter)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
p, ok := protocol.(*Curve25519Protocol) p, ok := protocol.(*Curve25519Protocol)
if !ok { if !ok {
return encryptionerr.ErrInvalidMessageType return encryptionerr.ErrInvalidMessageType
} }
if err := ss.SetKeys(p.encKey, p.decKey); err != nil { msg, err := NewDoneResponse()
if err != nil {
return err
}
if err := s.WriteMessage(msg); err != nil {
return err return err
} }
p.state = types.SessionStateCompleted p.state = types.SessionStateCompleted
return nil return nil
} }
type DoneResponse struct {
Done
}
func NewDoneResponse() (*DoneResponse, error) {
msg := &DoneResponse{
Done: Done{
Timestamp: time.Now(),
},
}
bmsg, err := interceptor.NewBaseMessage(message.NoneProtocol, nil, msg)
if err != nil {
return nil, err
}
msg.BaseMessage = bmsg
return msg, nil
}
func (m *DoneResponse) Process(protocol interfaces.Protocol, s interfaces.State) error {
p, ok := protocol.(*Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
p.state = types.SessionStateCompleted
return nil
}

View File

@@ -55,6 +55,31 @@ func (m *Manager) Init(s interfaces.State, options ...interfaces.ProtocolFactory
return nil return nil
} }
func (m *Manager) Finalise(s interfaces.State) error {
sessionID := s.GetKeyExchangeSessionID()
session, exists := m.sessions[sessionID]
if !exists {
return encryptionerr.ErrExchangeNotComplete
}
ss, ok := s.(interfaces.KeySetter)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
p, ok := (session.protocol).(interfaces.KeyGetter)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
encKey, decKey, err := p.GetKeys()
if err != nil {
return err
}
return ss.SetKeys(encKey, decKey)
}
func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error { func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error {
session, exists := m.sessions[s.GetKeyExchangeSessionID()] session, exists := m.sessions[s.GetKeyExchangeSessionID()]
if !exists { if !exists {

View File

@@ -91,3 +91,7 @@ func (s *State) SetKeys(encKey, decKey types.Key) error {
return s.encryptor.SetKeys(encKey, decKey) return s.encryptor.SetKeys(encKey, decKey)
} }
func (s *State) Decrypt(msg message.Message) (message.Message, error) {
return s.encryptor.Decrypt(msg)
}