updated encryption interceptor to meet good programming principle: added modular key exchange mechanism, updated interface design, restructured interfaces to support interface segregation

This commit is contained in:
harshabose
2025-05-01 01:09:43 +05:30
parent 8f381b77f4
commit aed3b8c928
16 changed files with 465 additions and 205 deletions

View File

@@ -39,6 +39,9 @@ type Message interface {
//
// Returns an error if processing fails
ReadProcess(Interceptor, Connection) error
SetReceiver(message.Receiver)
SetSender(message.Sender)
}
// BaseMessage provides a default implementation of the Message interface.
@@ -94,3 +97,11 @@ func (m *BaseMessage) ReadProcess(_ Interceptor, _ Connection) error {
// Derived-types should override this method with specific processing logic
return nil
}
func (m *BaseMessage) SetSender(sender message.Sender) {
m.CurrentHeader.Sender = sender
}
func (m *BaseMessage) SetReceiver(receiver message.Receiver) {
m.CurrentHeader.Receiver = receiver
}

View File

@@ -2,21 +2,24 @@ package encrypt
import (
"context"
"time"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type Interceptor struct {
interceptor.NoOpInterceptor
nonceValidator NonceValidator
keyExchangeManager keyexchange.Manager
keyExchangeManager interfaces.KeyExchangeManager
keyProvider keyprovider.KeyProvider
stateManager state.Manager
stateManager interfaces.StateManager
config config.Config
}
@@ -36,14 +39,21 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
}
func (i *Interceptor) Init(connection interceptor.Connection) error {
_state, err := i.stateManager.GetState(connection)
s, err := i.stateManager.GetState(connection)
if err != nil {
return err
}
if err := i.keyExchangeManager.Init(_state, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
if err := i.keyExchangeManager.Init(s, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
return err
}
ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second)
defer cancel()
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
return i.Process(waiter, s)
}
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
@@ -62,16 +72,15 @@ func (i *Interceptor) Close() error {
}
func GetState(_i interceptor.Interceptor, connection interceptor.Connection) (*state.State, error) {
i, ok := _i.(*Interceptor)
if !ok {
return nil, encryptionerr.ErrInvalidInterceptor
}
s, err := i.stateManager.GetState(connection)
if err != nil {
return nil, err
}
return s, nil
func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) {
return i.stateManager.GetState(connection)
}
func (i *Interceptor) Process(msg interfaces.CanProcess, state interfaces.State) error {
processor, ok := i.keyExchangeManager.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
return processor.Process(msg, state)
}

View File

@@ -9,7 +9,7 @@ import (
// NonceValidator provides protection against replay attacks
type NonceValidator interface {
// Validate checks if a nonce is valid and hasn't been seen before
Validate(nonce []byte, sessionID types.SessionID) error
Validate(nonce []byte, sessionID types.EncryptionSessionID) error
// Cleanup removes expired nonces
Cleanup(before time.Time)

View File

@@ -1,4 +1,4 @@
package encryptor
package interfaces
import (
"io"
@@ -7,13 +7,21 @@ import (
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type KeySetter interface {
// SetKeys configures the encryption and decryption keys
SetKeys(encryptorKey, decryptorKey types.Key) error
}
type KeyGetter interface {
GetKeys() (encKey, decKey types.Key, err error)
}
// Encryptor defines the interface for message encryption and decryption
type Encryptor interface {
// SetKeys configures the encryption and decryption keys
SetKeys(encryptKey, decryptKey types.Key) error
KeySetter
// SetSessionID sets the session identifier for this encryption session
SetSessionID(id types.SessionID)
SetSessionID(id types.EncryptionSessionID)
// Encrypt encrypts a message between sender and receiver
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)

View File

@@ -0,0 +1,26 @@
package interfaces
import (
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type ProtocolProcessor interface {
Process(CanProcess, State) error
}
type KeyExchangeManager interface {
Init(state State, options ...ProtocolFactoryOption) error
}
type CanGetSessionState interface {
GetState() types.SessionState
}
type Protocol interface {
Init(s State) error
IsComplete() bool
}
type CanProcess interface {
Process(Protocol, State) error
}

View File

@@ -0,0 +1,35 @@
package interfaces
import (
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type ProtocolFactoryOption func(Protocol) error
type State interface {
GenerateKeyExchangeSessionID() types.KeyExchangeSessionID
GetKeyExchangeSessionID() types.KeyExchangeSessionID
WriteMessage(interceptor.Message) error
GetConfig() config.Config
}
type CanGetState interface {
GetState(interceptor.Connection) (State, error)
}
type CanSetState interface {
SetState(interceptor.Connection, State) error
}
type CanRemoveState interface {
RemoveState(interceptor.Connection) error
}
type StateManager interface {
CanGetState
CanSetState
CanRemoveState
ForEach(func(interceptor.Connection, State) error) error
}

View File

@@ -0,0 +1,190 @@
package keyexchange
import (
"fmt"
"time"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type Init struct {
// TODO: MANAGE STATE USING KEY EXCHANGE SESSION ID
interceptor.BaseMessage
PublicKey types.PublicKey `json:"public_key"`
Signature []byte `json:"signature"`
SessionID types.EncryptionSessionID `json:"session_id"`
Salt types.Salt `json:"salt"`
}
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
ss, ok := _interceptor.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := ss.GetState(connection)
if err != nil {
return err
}
pp, ok := _interceptor.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
return pp.Process(m, s)
}
func (m *Init) Process(protocol interfaces.Protocol, s interfaces.State) error {
p, ok := protocol.(*Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if p.state != types.SessionStateInitial {
return encryptionerr.ErrInvalidSessionState
}
sign := append(m.PublicKey[:], m.Salt[:]...)
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
return encryptionerr.ErrInvalidSignature
}
p.salt = m.Salt
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
if err != nil {
return fmt.Errorf("failed to compute shared secret: %w", err)
}
encKey, decKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING
if err != nil {
return fmt.Errorf("key derivation failed: %w", err)
}
p.encKey = encKey
p.decKey = decKey
p.sessionID = m.SessionID
if err := s.WriteMessage(nil); err != nil {
} // TODO: ADD RESPONSE MESSAGE
p.state = types.SessionStateInProgress
return nil
}
type Response struct {
interceptor.BaseMessage
PublicKey types.PublicKey `json:"public_key"`
}
func (m *Response) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Response) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
ss, ok := _i.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := ss.GetState(connection)
if err != nil {
return err
}
pp, ok := _i.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
return pp.Process(m, s)
}
func (m *Response) Process(protocol interfaces.Protocol, s interfaces.State) error {
p, ok := protocol.(*Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if p.state != types.SessionStateInitial {
return encryptionerr.ErrInvalidSessionState
}
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
if err != nil {
return fmt.Errorf("failed to compute shared secret: %w", err)
}
decKey, encKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING
if err != nil {
return fmt.Errorf("key derivation failed: %w", err)
}
p.encKey = encKey
p.decKey = decKey
if err := s.WriteMessage(nil); err != nil {
return err
} // TODO: Send Done message
p.state = types.SessionStateInProgress
return nil
}
type Done struct {
interceptor.BaseMessage
Timestamp time.Time `json:"timestamp"`
}
func (m *Done) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Done) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
ss, ok := _i.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := ss.GetState(connection)
if err != nil {
return err
}
pp, ok := _i.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
return pp.Process(m, s)
}
func (m *Done) Process(protocol interfaces.Protocol, _s interfaces.State) error {
ss, ok := _s.(interfaces.KeySetter)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
p, ok := protocol.(*Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if err := ss.SetKeys(p.encKey, p.decKey); err != nil {
return err
}
p.state = types.SessionStateCompleted
return nil
}

View File

@@ -0,0 +1,47 @@
package keyexchange
import (
"context"
"fmt"
"time"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type SessionStateTargetWaiter struct {
target types.SessionState
ctx context.Context
}
func NewSessionStateTargetWaiter(ctx context.Context, target types.SessionState) SessionStateTargetWaiter {
return SessionStateTargetWaiter{
target: target,
ctx: ctx,
}
}
func (w SessionStateTargetWaiter) Process(protocol interfaces.Protocol, _ interfaces.State) error {
p, ok := protocol.(interfaces.CanGetSessionState)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
// TODO: Implement BLOCKing (ideally using sync.Cond, or simple ticker)
select {
case <-ticker.C:
if p.GetState() == w.target {
return nil
}
case <-w.ctx.Done():
return fmt.Errorf("timeout waiting for state %v: %w", w.target, w.ctx.Err())
}
}
}
// TODO: WRITE MORE KEY EXCHANGE PROCESSES HERE

View File

@@ -8,8 +8,7 @@ import (
"golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/messages"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
@@ -18,11 +17,11 @@ type Curve25519Protocol struct {
pubKey types.PublicKey
peerPubKey types.PublicKey
salt types.Salt // also in protocol curve25519protocol.go
sessionID types.SessionID
sessionID types.EncryptionSessionID
sharedSecret []byte
encKey types.Key
decKey types.Key
state SessionState
state types.SessionState
options Curve25519Options
// TODO: mutex is needed here for SessionState and Keys
@@ -34,55 +33,55 @@ type Curve25519Options struct {
RequireSignature bool
}
func (p *Curve25519Protocol) Init(s *state.State) error {
func (p *Curve25519Protocol) Init(s interfaces.State) error {
if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil {
return err
}
curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey))
if s.InterceptorConfig.IsServer && p.options.RequireSignature {
if s.GetConfig().IsServer && p.options.RequireSignature {
if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil {
p.state = SessionStateError
p.state = types.SessionStateError
return err
}
if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil {
p.state = SessionStateError
p.state = types.SessionStateError
return err
}
sign := ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
_ = ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
if err := s.Writer.Write(s.Connection, messages.NewInit("", s.PeerID, p.pubKey, sign, p.sessionID, p.salt)); err != nil {
p.state = SessionStateError
if err := s.WriteMessage(nil); err != nil { // TODO: SEND INIT MESSAGE
p.state = types.SessionStateError
return err
}
}
p.state = SessionStateInProgress
p.state = types.SessionStateInitial
return nil
}
func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) {
if p.state != SessionStateCompleted {
if p.state != types.SessionStateCompleted {
return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete
}
return p.encKey, p.decKey, nil
}
func (p *Curve25519Protocol) GetState() SessionState {
func (p *Curve25519Protocol) GetState() types.SessionState {
return p.state
}
func (p *Curve25519Protocol) IsComplete() bool {
return p.state == SessionStateCompleted
return p.state == types.SessionStateCompleted
}
func (p *Curve25519Protocol) Process(msg MessageProcessor, s *state.State) error {
func (p *Curve25519Protocol) Process(msg interfaces.CanProcess, s interfaces.State) error {
if err := msg.Process(p, s); err != nil {
p.state = SessionStateError
p.state = types.SessionStateError
return err
}

View File

@@ -9,13 +9,13 @@ import (
"golang.org/x/crypto/hkdf"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type Session struct {
protocol Protocol
state *state.State
protocol interfaces.Protocol
state interfaces.State
createdAt time.Time
completedAt time.Time
}
@@ -25,16 +25,16 @@ type Manager struct {
sessions map[types.KeyExchangeSessionID]*Session
}
func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
func (m *Manager) Init(s interfaces.State, options ...interfaces.ProtocolFactoryOption) error {
sessionID := s.GenerateKeyExchangeSessionID()
_, exists := m.sessions[sessionID]
if exists {
return encryptionerr.ErrExchangeInProgress
}
factory, exists := m.registry[s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol]
factory, exists := m.registry[s.GetConfig().EncryptionProtocol.KeyExchangeProtocol]
if !exists {
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol)
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.GetConfig().EncryptionProtocol.KeyExchangeProtocol)
}
p, err := factory(options...)
@@ -55,13 +55,18 @@ func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
return nil
}
func (m *Manager) Process(s *state.State, msg MessageProcessor) error {
session, exists := m.sessions[s.KeyExchangeSessionID]
func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error {
session, exists := m.sessions[s.GetKeyExchangeSessionID()]
if !exists {
return encryptionerr.ErrSessionNotFound
}
return session.protocol.Process(msg, s)
processor, ok := session.protocol.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
return processor.Process(msg, s)
}
// Derive generates encryption keys from shared secret

View File

@@ -3,36 +3,12 @@ package keyexchange
import (
"fmt"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type SessionState int
const (
SessionStateInitial SessionState = iota
SessionStateInProgress
SessionStateCompleted
SessionStateError
)
type MessageProcessor interface {
Process(Protocol, *state.State) error
}
type Protocol interface {
Init(s *state.State) error
GetKeys() (encKey types.Key, decKey types.Key, err error)
Process(MessageProcessor, *state.State) error
GetState() SessionState
IsComplete() bool
}
type ProtocolFactoryOption func(Protocol) error
func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption {
return func(protocol Protocol) error {
func WithKeySignature(keyProvider keyprovider.KeyProvider) interfaces.ProtocolFactoryOption {
return func(protocol interfaces.Protocol) error {
curveProtocol, ok := protocol.(*Curve25519Protocol)
if !ok {
return fmt.Errorf("WithKeySignature only supports Curve25519Protocol")
@@ -46,4 +22,4 @@ func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption
}
}
type ProtocolFactory func(options ...ProtocolFactoryOption) (Protocol, error)
type ProtocolFactory func(options ...interfaces.ProtocolFactoryOption) (interfaces.Protocol, error)

View File

@@ -1,92 +0,0 @@
package messages
import (
"fmt"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
// Protocol constants
const (
InitProtocol message.Protocol = "curve25519.init"
ResponseProtocol message.Protocol = "curve25519.response"
ConfirmProtocol message.Protocol = "curve25519.confirm"
)
type Init struct {
interceptor.BaseMessage
PublicKey types.PublicKey `json:"public_key"`
Signature []byte `json:"signature"`
SessionID types.SessionID `json:"session_id"`
Salt types.Salt `json:"salt"`
}
func NewInit(sender message.Sender, receiver message.Receiver, key types.PublicKey, sign []byte, sessionID types.SessionID, salt types.Salt) *Init {
return &Init{
BaseMessage: interceptor.NewBaseMessage(InitProtocol, sender, receiver),
PublicKey: key,
Signature: sign,
SessionID: sessionID,
Salt: salt,
}
}
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
s, err := encrypt.GetState(_interceptor, connection)
if err != nil {
return err
}
return i.keyExchangeManager.Process(s, m)
}
func (m *Init) Process(protocol keyexchange.Protocol, s *state.State) error {
p, ok := protocol.(*keyexchange.Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if p.GetState() != keyexchange.SessionStateInitial {
return encryptionerr.ErrInvalidSessionState
}
sign := append(m.PublicKey[:], m.Salt[:]...)
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
return encryptionerr.ErrInvalidSignature
}
p.salt = m.Salt
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
if err != nil {
return fmt.Errorf("failed to compute shared secret: %w", err)
}
encKey, decKey, err := keyexchange.Derive(shared, p.salt, "") // TODO: ADD INFO STRING
if err != nil {
return fmt.Errorf("key derivation failed: %w", err)
}
p.encKey = encKey
p.decKey = decKey
p.sessionID = m.SessionID
if err := s.Writer.Write(s.Connection, nil); err != nil {
return err
} // TODO: ADD RESPONSE MESSAGE
p.state = SessionStateCompleted
return nil
}

View File

@@ -11,22 +11,23 @@ import (
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type State struct {
InterceptorConfig config.Config // copy of interceptor config
PeerID message.Receiver
privKey types.PrivateKey // also in protocol curve25519protocol.go
salt types.Salt // also in protocol curve25519protocol.go
sessionID types.SessionID // encryption sessionID
encryptor encryptor.Encryptor
Connection interceptor.Connection
Writer interceptor.Writer
Reader interceptor.Reader
currentConfig config.Config // copy of interceptor config
peerID message.Receiver
privKey types.PrivateKey // also in protocol curve25519protocol.go
salt types.Salt // also in protocol curve25519protocol.go
encryptSessionID types.EncryptionSessionID // encryption encryptSessionID
encryptor interfaces.Encryptor
connection interceptor.Connection
writer interceptor.Writer
reader interceptor.Reader
cancel context.CancelFunc
ctx context.Context
KeyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
keyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
mux sync.RWMutex
}
@@ -37,24 +38,56 @@ func NewState(ctx context.Context, cancel context.CancelFunc, config config.Conf
}
return &State{
InterceptorConfig: config,
peerID: message.UnknownReceiver,
privKey: types.PrivateKey{},
salt: types.Salt{},
encryptor: newEncryptor,
Connection: connection,
Writer: writer,
Reader: reader,
cancel: cancel,
ctx: ctx,
currentConfig: config,
peerID: message.UnknownReceiver,
privKey: types.PrivateKey{},
salt: types.Salt{},
encryptor: newEncryptor,
connection: connection,
writer: writer,
reader: reader,
cancel: cancel,
ctx: ctx,
}, nil
}
func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
if s.KeyExchangeSessionID != "" {
fmt.Println("KeyExchangeSessionID already exists; creating new")
}
s.KeyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
s.mux.Lock()
defer s.mux.Unlock()
return s.KeyExchangeSessionID
if s.keyExchangeSessionID != "" {
fmt.Println("keyExchangeSessionID already exists; creating new")
}
s.keyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
return s.keyExchangeSessionID
}
func (s *State) WriteMessage(msg interceptor.Message) error {
s.mux.Lock()
defer s.mux.Unlock()
msg.SetReceiver(s.peerID)
return s.writer.Write(s.connection, msg)
}
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
s.mux.Lock()
defer s.mux.Unlock()
return s.keyExchangeSessionID
}
func (s *State) GetConfig() config.Config {
s.mux.Lock()
defer s.mux.Unlock()
return s.currentConfig
}
func (s *State) SetKeys(encKey, decKey types.Key) error {
s.mux.Lock()
defer s.mux.Unlock()
return s.encryptor.SetKeys(encKey, decKey)
}

View File

@@ -3,16 +3,18 @@ package state
import (
"sync"
"github.com/harshabose/socket-comm/internal/util"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
)
type Manager struct {
states map[interceptor.Connection]*State
states map[interceptor.Connection]interfaces.State
mux sync.RWMutex
}
func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
func (m *Manager) GetState(connection interceptor.Connection) (interfaces.State, error) {
m.mux.RLock()
defer m.mux.RUnlock()
@@ -24,7 +26,7 @@ func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
return state, nil
}
func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
func (m *Manager) SetState(connection interceptor.Connection, s interfaces.State) error {
m.mux.Lock()
defer m.mux.Unlock()
@@ -36,31 +38,31 @@ func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
return nil
}
// RemoveState removes a Connection's state
func (m *Manager) RemoveState(connection interceptor.Connection) (*State, error) {
// RemoveState removes a connection's state
func (m *Manager) RemoveState(connection interceptor.Connection) error {
m.mux.Lock()
defer m.mux.Unlock()
state, exists := m.states[connection]
_, exists := m.states[connection]
if !exists {
return nil, encryptionerr.ErrConnectionNotFound
return encryptionerr.ErrConnectionNotFound
}
delete(m.states, connection)
return state, nil
return nil
}
// ForEach executes the provided function for each state in the manager
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state *State) error) []error {
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state interfaces.State) error) error {
m.mux.RLock()
defer m.mux.RUnlock()
var errs []error
var errs util.MultiError
for conn, state := range m.states {
if err := fn(conn, state); err != nil {
errs = append(errs, err)
errs.Add(err)
}
}
return errs
return errs.ErrorOrNil()
}

View File

@@ -5,7 +5,7 @@ type (
PrivateKey [32]byte
PublicKey [32]byte
Salt [16]byte
SessionID [16]byte
EncryptionSessionID [16]byte
Nonce [12]byte
Key [32]byte
KeyExchangeProtocol string

View File

@@ -0,0 +1,11 @@
package types
type SessionState int
const (
SessionStateNotStart SessionState = iota
SessionStateInitial
SessionStateInProgress
SessionStateCompleted
SessionStateError
)