diff --git a/pkg/interceptor/message.go b/pkg/interceptor/message.go index 4c7cd55..eabe1c7 100644 --- a/pkg/interceptor/message.go +++ b/pkg/interceptor/message.go @@ -39,6 +39,9 @@ type Message interface { // // Returns an error if processing fails ReadProcess(Interceptor, Connection) error + + SetReceiver(message.Receiver) + SetSender(message.Sender) } // BaseMessage provides a default implementation of the Message interface. @@ -94,3 +97,11 @@ func (m *BaseMessage) ReadProcess(_ Interceptor, _ Connection) error { // Derived-types should override this method with specific processing logic return nil } + +func (m *BaseMessage) SetSender(sender message.Sender) { + m.CurrentHeader.Sender = sender +} + +func (m *BaseMessage) SetReceiver(receiver message.Receiver) { + m.CurrentHeader.Receiver = receiver +} diff --git a/pkg/middleware/encrypt/interceptor.go b/pkg/middleware/encrypt/interceptor.go index 20f22ba..4bd441e 100644 --- a/pkg/middleware/encrypt/interceptor.go +++ b/pkg/middleware/encrypt/interceptor.go @@ -2,21 +2,24 @@ package encrypt import ( "context" + "time" "github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/state" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" ) type Interceptor struct { interceptor.NoOpInterceptor nonceValidator NonceValidator - keyExchangeManager keyexchange.Manager + keyExchangeManager interfaces.KeyExchangeManager keyProvider keyprovider.KeyProvider - stateManager state.Manager + stateManager interfaces.StateManager config config.Config } @@ -36,14 +39,21 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr } func (i *Interceptor) Init(connection interceptor.Connection) error { - _state, err := i.stateManager.GetState(connection) + s, err := i.stateManager.GetState(connection) if err != nil { return err } - if err := i.keyExchangeManager.Init(_state, keyexchange.WithKeySignature(i.keyProvider)); err != nil { + if err := i.keyExchangeManager.Init(s, keyexchange.WithKeySignature(i.keyProvider)); err != nil { return err } + + ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second) + defer cancel() + + waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted) + + return i.Process(waiter, s) } func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { @@ -62,16 +72,15 @@ func (i *Interceptor) Close() error { } -func GetState(_i interceptor.Interceptor, connection interceptor.Connection) (*state.State, error) { - i, ok := _i.(*Interceptor) - if !ok { - return nil, encryptionerr.ErrInvalidInterceptor - } - - s, err := i.stateManager.GetState(connection) - if err != nil { - return nil, err - } - - return s, nil +func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) { + return i.stateManager.GetState(connection) +} + +func (i *Interceptor) Process(msg interfaces.CanProcess, state interfaces.State) error { + processor, ok := i.keyExchangeManager.(interfaces.ProtocolProcessor) + if !ok { + return encryptionerr.ErrInvalidMessageType + } + + return processor.Process(msg, state) } diff --git a/pkg/middleware/encrypt/interfaces.go b/pkg/middleware/encrypt/interfaces.go index 412d34d..6596987 100644 --- a/pkg/middleware/encrypt/interfaces.go +++ b/pkg/middleware/encrypt/interfaces.go @@ -9,7 +9,7 @@ import ( // NonceValidator provides protection against replay attacks type NonceValidator interface { // Validate checks if a nonce is valid and hasn't been seen before - Validate(nonce []byte, sessionID types.SessionID) error + Validate(nonce []byte, sessionID types.EncryptionSessionID) error // Cleanup removes expired nonces Cleanup(before time.Time) diff --git a/pkg/middleware/encrypt/encryptor/encryptor.go b/pkg/middleware/encrypt/interfaces/encryptor.go similarity index 75% rename from pkg/middleware/encrypt/encryptor/encryptor.go rename to pkg/middleware/encrypt/interfaces/encryptor.go index ed67f2e..ffbc7a2 100644 --- a/pkg/middleware/encrypt/encryptor/encryptor.go +++ b/pkg/middleware/encrypt/interfaces/encryptor.go @@ -1,4 +1,4 @@ -package encryptor +package interfaces import ( "io" @@ -7,13 +7,21 @@ import ( "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" ) +type KeySetter interface { + // SetKeys configures the encryption and decryption keys + SetKeys(encryptorKey, decryptorKey types.Key) error +} + +type KeyGetter interface { + GetKeys() (encKey, decKey types.Key, err error) +} + // Encryptor defines the interface for message encryption and decryption type Encryptor interface { - // SetKeys configures the encryption and decryption keys - SetKeys(encryptKey, decryptKey types.Key) error + KeySetter // SetSessionID sets the session identifier for this encryption session - SetSessionID(id types.SessionID) + SetSessionID(id types.EncryptionSessionID) // Encrypt encrypts a message between sender and receiver Encrypt(senderID, receiverID string, message message.Message) (message.Message, error) diff --git a/pkg/middleware/encrypt/interfaces/keyexchange.go b/pkg/middleware/encrypt/interfaces/keyexchange.go new file mode 100644 index 0000000..64ff9f2 --- /dev/null +++ b/pkg/middleware/encrypt/interfaces/keyexchange.go @@ -0,0 +1,26 @@ +package interfaces + +import ( + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" +) + +type ProtocolProcessor interface { + Process(CanProcess, State) error +} + +type KeyExchangeManager interface { + Init(state State, options ...ProtocolFactoryOption) error +} + +type CanGetSessionState interface { + GetState() types.SessionState +} + +type Protocol interface { + Init(s State) error + IsComplete() bool +} + +type CanProcess interface { + Process(Protocol, State) error +} diff --git a/pkg/middleware/encrypt/interfaces/state.go b/pkg/middleware/encrypt/interfaces/state.go new file mode 100644 index 0000000..2573f27 --- /dev/null +++ b/pkg/middleware/encrypt/interfaces/state.go @@ -0,0 +1,35 @@ +package interfaces + +import ( + "github.com/harshabose/socket-comm/pkg/interceptor" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" +) + +type ProtocolFactoryOption func(Protocol) error + +type State interface { + GenerateKeyExchangeSessionID() types.KeyExchangeSessionID + GetKeyExchangeSessionID() types.KeyExchangeSessionID + WriteMessage(interceptor.Message) error + GetConfig() config.Config +} + +type CanGetState interface { + GetState(interceptor.Connection) (State, error) +} + +type CanSetState interface { + SetState(interceptor.Connection, State) error +} + +type CanRemoveState interface { + RemoveState(interceptor.Connection) error +} + +type StateManager interface { + CanGetState + CanSetState + CanRemoveState + ForEach(func(interceptor.Connection, State) error) error +} diff --git a/pkg/middleware/encrypt/keyexchange/curve25519messages.go b/pkg/middleware/encrypt/keyexchange/curve25519messages.go new file mode 100644 index 0000000..9f89af5 --- /dev/null +++ b/pkg/middleware/encrypt/keyexchange/curve25519messages.go @@ -0,0 +1,190 @@ +package keyexchange + +import ( + "fmt" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" + + "github.com/harshabose/socket-comm/pkg/interceptor" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" +) + +type Init struct { + // TODO: MANAGE STATE USING KEY EXCHANGE SESSION ID + interceptor.BaseMessage + PublicKey types.PublicKey `json:"public_key"` + Signature []byte `json:"signature"` + SessionID types.EncryptionSessionID `json:"session_id"` + Salt types.Salt `json:"salt"` +} + +func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error { + return nil +} + +func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error { + ss, ok := _interceptor.(interfaces.CanGetState) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + s, err := ss.GetState(connection) + if err != nil { + return err + } + + pp, ok := _interceptor.(interfaces.ProtocolProcessor) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + return pp.Process(m, s) +} + +func (m *Init) Process(protocol interfaces.Protocol, s interfaces.State) error { + p, ok := protocol.(*Curve25519Protocol) + if !ok { + return encryptionerr.ErrInvalidMessageType + } + + if p.state != types.SessionStateInitial { + return encryptionerr.ErrInvalidSessionState + } + + sign := append(m.PublicKey[:], m.Salt[:]...) + if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) { + return encryptionerr.ErrInvalidSignature + } + + p.salt = m.Salt + shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:]) + if err != nil { + return fmt.Errorf("failed to compute shared secret: %w", err) + } + + encKey, decKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING + if err != nil { + return fmt.Errorf("key derivation failed: %w", err) + } + + p.encKey = encKey + p.decKey = decKey + p.sessionID = m.SessionID + + if err := s.WriteMessage(nil); err != nil { + + } // TODO: ADD RESPONSE MESSAGE + + p.state = types.SessionStateInProgress + return nil +} + +type Response struct { + interceptor.BaseMessage + PublicKey types.PublicKey `json:"public_key"` +} + +func (m *Response) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error { + return nil +} + +func (m *Response) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error { + ss, ok := _i.(interfaces.CanGetState) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + s, err := ss.GetState(connection) + if err != nil { + return err + } + + pp, ok := _i.(interfaces.ProtocolProcessor) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + return pp.Process(m, s) +} + +func (m *Response) Process(protocol interfaces.Protocol, s interfaces.State) error { + p, ok := protocol.(*Curve25519Protocol) + if !ok { + return encryptionerr.ErrInvalidMessageType + } + + if p.state != types.SessionStateInitial { + return encryptionerr.ErrInvalidSessionState + } + + shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:]) + if err != nil { + return fmt.Errorf("failed to compute shared secret: %w", err) + } + + decKey, encKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING + if err != nil { + return fmt.Errorf("key derivation failed: %w", err) + } + + p.encKey = encKey + p.decKey = decKey + + if err := s.WriteMessage(nil); err != nil { + return err + } // TODO: Send Done message + + p.state = types.SessionStateInProgress + return nil +} + +type Done struct { + interceptor.BaseMessage + Timestamp time.Time `json:"timestamp"` +} + +func (m *Done) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error { + return nil +} + +func (m *Done) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error { + ss, ok := _i.(interfaces.CanGetState) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + s, err := ss.GetState(connection) + if err != nil { + return err + } + + pp, ok := _i.(interfaces.ProtocolProcessor) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + return pp.Process(m, s) +} + +func (m *Done) Process(protocol interfaces.Protocol, _s interfaces.State) error { + ss, ok := _s.(interfaces.KeySetter) + if !ok { + return encryptionerr.ErrInvalidInterceptor + } + + p, ok := protocol.(*Curve25519Protocol) + if !ok { + return encryptionerr.ErrInvalidMessageType + } + + if err := ss.SetKeys(p.encKey, p.decKey); err != nil { + return err + } + + p.state = types.SessionStateCompleted + return nil +} diff --git a/pkg/middleware/encrypt/keyexchange/curve25519process.go b/pkg/middleware/encrypt/keyexchange/curve25519process.go new file mode 100644 index 0000000..7a9674c --- /dev/null +++ b/pkg/middleware/encrypt/keyexchange/curve25519process.go @@ -0,0 +1,47 @@ +package keyexchange + +import ( + "context" + "fmt" + "time" + + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" +) + +type SessionStateTargetWaiter struct { + target types.SessionState + ctx context.Context +} + +func NewSessionStateTargetWaiter(ctx context.Context, target types.SessionState) SessionStateTargetWaiter { + return SessionStateTargetWaiter{ + target: target, + ctx: ctx, + } +} + +func (w SessionStateTargetWaiter) Process(protocol interfaces.Protocol, _ interfaces.State) error { + p, ok := protocol.(interfaces.CanGetSessionState) + if !ok { + return encryptionerr.ErrInvalidMessageType + } + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + // TODO: Implement BLOCKing (ideally using sync.Cond, or simple ticker) + select { + case <-ticker.C: + if p.GetState() == w.target { + return nil + } + case <-w.ctx.Done(): + return fmt.Errorf("timeout waiting for state %v: %w", w.target, w.ctx.Err()) + } + } +} + +// TODO: WRITE MORE KEY EXCHANGE PROCESSES HERE diff --git a/pkg/middleware/encrypt/keyexchange/curve25519protocol.go b/pkg/middleware/encrypt/keyexchange/curve25519protocol.go index 6fdfa44..a320b45 100644 --- a/pkg/middleware/encrypt/keyexchange/curve25519protocol.go +++ b/pkg/middleware/encrypt/keyexchange/curve25519protocol.go @@ -8,8 +8,7 @@ import ( "golang.org/x/crypto/ed25519" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/messages" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/state" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" ) @@ -18,11 +17,11 @@ type Curve25519Protocol struct { pubKey types.PublicKey peerPubKey types.PublicKey salt types.Salt // also in protocol curve25519protocol.go - sessionID types.SessionID + sessionID types.EncryptionSessionID sharedSecret []byte encKey types.Key decKey types.Key - state SessionState + state types.SessionState options Curve25519Options // TODO: mutex is needed here for SessionState and Keys @@ -34,55 +33,55 @@ type Curve25519Options struct { RequireSignature bool } -func (p *Curve25519Protocol) Init(s *state.State) error { +func (p *Curve25519Protocol) Init(s interfaces.State) error { if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil { return err } curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey)) - if s.InterceptorConfig.IsServer && p.options.RequireSignature { + if s.GetConfig().IsServer && p.options.RequireSignature { if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil { - p.state = SessionStateError + p.state = types.SessionStateError return err } if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil { - p.state = SessionStateError + p.state = types.SessionStateError return err } - sign := ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...)) + _ = ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...)) - if err := s.Writer.Write(s.Connection, messages.NewInit("", s.PeerID, p.pubKey, sign, p.sessionID, p.salt)); err != nil { - p.state = SessionStateError + if err := s.WriteMessage(nil); err != nil { // TODO: SEND INIT MESSAGE + p.state = types.SessionStateError return err } } - p.state = SessionStateInProgress + p.state = types.SessionStateInitial return nil } func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) { - if p.state != SessionStateCompleted { + if p.state != types.SessionStateCompleted { return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete } return p.encKey, p.decKey, nil } -func (p *Curve25519Protocol) GetState() SessionState { +func (p *Curve25519Protocol) GetState() types.SessionState { return p.state } func (p *Curve25519Protocol) IsComplete() bool { - return p.state == SessionStateCompleted + return p.state == types.SessionStateCompleted } -func (p *Curve25519Protocol) Process(msg MessageProcessor, s *state.State) error { +func (p *Curve25519Protocol) Process(msg interfaces.CanProcess, s interfaces.State) error { if err := msg.Process(p, s); err != nil { - p.state = SessionStateError + p.state = types.SessionStateError return err } diff --git a/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go b/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go index a9c6bd0..fe8c3d8 100644 --- a/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go +++ b/pkg/middleware/encrypt/keyexchange/keyexchange_manager.go @@ -9,13 +9,13 @@ import ( "golang.org/x/crypto/hkdf" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/state" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" ) type Session struct { - protocol Protocol - state *state.State + protocol interfaces.Protocol + state interfaces.State createdAt time.Time completedAt time.Time } @@ -25,16 +25,16 @@ type Manager struct { sessions map[types.KeyExchangeSessionID]*Session } -func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error { +func (m *Manager) Init(s interfaces.State, options ...interfaces.ProtocolFactoryOption) error { sessionID := s.GenerateKeyExchangeSessionID() _, exists := m.sessions[sessionID] if exists { return encryptionerr.ErrExchangeInProgress } - factory, exists := m.registry[s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol] + factory, exists := m.registry[s.GetConfig().EncryptionProtocol.KeyExchangeProtocol] if !exists { - return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol) + return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.GetConfig().EncryptionProtocol.KeyExchangeProtocol) } p, err := factory(options...) @@ -55,13 +55,18 @@ func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error { return nil } -func (m *Manager) Process(s *state.State, msg MessageProcessor) error { - session, exists := m.sessions[s.KeyExchangeSessionID] +func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error { + session, exists := m.sessions[s.GetKeyExchangeSessionID()] if !exists { return encryptionerr.ErrSessionNotFound } - return session.protocol.Process(msg, s) + processor, ok := session.protocol.(interfaces.ProtocolProcessor) + if !ok { + return encryptionerr.ErrInvalidMessageType + } + + return processor.Process(msg, s) } // Derive generates encryption keys from shared secret diff --git a/pkg/middleware/encrypt/keyexchange/protocol.go b/pkg/middleware/encrypt/keyexchange/protocol.go index 6f3cc29..99e8a06 100644 --- a/pkg/middleware/encrypt/keyexchange/protocol.go +++ b/pkg/middleware/encrypt/keyexchange/protocol.go @@ -3,36 +3,12 @@ package keyexchange import ( "fmt" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/state" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" ) -type SessionState int - -const ( - SessionStateInitial SessionState = iota - SessionStateInProgress - SessionStateCompleted - SessionStateError -) - -type MessageProcessor interface { - Process(Protocol, *state.State) error -} - -type Protocol interface { - Init(s *state.State) error - GetKeys() (encKey types.Key, decKey types.Key, err error) - Process(MessageProcessor, *state.State) error - GetState() SessionState - IsComplete() bool -} - -type ProtocolFactoryOption func(Protocol) error - -func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption { - return func(protocol Protocol) error { +func WithKeySignature(keyProvider keyprovider.KeyProvider) interfaces.ProtocolFactoryOption { + return func(protocol interfaces.Protocol) error { curveProtocol, ok := protocol.(*Curve25519Protocol) if !ok { return fmt.Errorf("WithKeySignature only supports Curve25519Protocol") @@ -46,4 +22,4 @@ func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption } } -type ProtocolFactory func(options ...ProtocolFactoryOption) (Protocol, error) +type ProtocolFactory func(options ...interfaces.ProtocolFactoryOption) (interfaces.Protocol, error) diff --git a/pkg/middleware/encrypt/messages/init.go b/pkg/middleware/encrypt/messages/init.go deleted file mode 100644 index 6f17e81..0000000 --- a/pkg/middleware/encrypt/messages/init.go +++ /dev/null @@ -1,92 +0,0 @@ -package messages - -import ( - "fmt" - - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" - - "github.com/harshabose/socket-comm/pkg/interceptor" - "github.com/harshabose/socket-comm/pkg/message" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/state" - "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" -) - -// Protocol constants -const ( - InitProtocol message.Protocol = "curve25519.init" - ResponseProtocol message.Protocol = "curve25519.response" - ConfirmProtocol message.Protocol = "curve25519.confirm" -) - -type Init struct { - interceptor.BaseMessage - PublicKey types.PublicKey `json:"public_key"` - Signature []byte `json:"signature"` - SessionID types.SessionID `json:"session_id"` - Salt types.Salt `json:"salt"` -} - -func NewInit(sender message.Sender, receiver message.Receiver, key types.PublicKey, sign []byte, sessionID types.SessionID, salt types.Salt) *Init { - return &Init{ - BaseMessage: interceptor.NewBaseMessage(InitProtocol, sender, receiver), - PublicKey: key, - Signature: sign, - SessionID: sessionID, - Salt: salt, - } -} - -func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error { - return nil -} - -func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error { - s, err := encrypt.GetState(_interceptor, connection) - if err != nil { - return err - } - - return i.keyExchangeManager.Process(s, m) -} - -func (m *Init) Process(protocol keyexchange.Protocol, s *state.State) error { - p, ok := protocol.(*keyexchange.Curve25519Protocol) - if !ok { - return encryptionerr.ErrInvalidMessageType - } - - if p.GetState() != keyexchange.SessionStateInitial { - return encryptionerr.ErrInvalidSessionState - } - - sign := append(m.PublicKey[:], m.Salt[:]...) - if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) { - return encryptionerr.ErrInvalidSignature - } - - p.salt = m.Salt - shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:]) - if err != nil { - return fmt.Errorf("failed to compute shared secret: %w", err) - } - - encKey, decKey, err := keyexchange.Derive(shared, p.salt, "") // TODO: ADD INFO STRING - if err != nil { - return fmt.Errorf("key derivation failed: %w", err) - } - - p.encKey = encKey - p.decKey = decKey - p.sessionID = m.SessionID - - if err := s.Writer.Write(s.Connection, nil); err != nil { - return err - } // TODO: ADD RESPONSE MESSAGE - - p.state = SessionStateCompleted - return nil -} diff --git a/pkg/middleware/encrypt/state/state.go b/pkg/middleware/encrypt/state/state.go index 47fe782..cb79682 100644 --- a/pkg/middleware/encrypt/state/state.go +++ b/pkg/middleware/encrypt/state/state.go @@ -11,22 +11,23 @@ import ( "github.com/harshabose/socket-comm/pkg/message" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" ) type State struct { - InterceptorConfig config.Config // copy of interceptor config - PeerID message.Receiver - privKey types.PrivateKey // also in protocol curve25519protocol.go - salt types.Salt // also in protocol curve25519protocol.go - sessionID types.SessionID // encryption sessionID - encryptor encryptor.Encryptor - Connection interceptor.Connection - Writer interceptor.Writer - Reader interceptor.Reader + currentConfig config.Config // copy of interceptor config + peerID message.Receiver + privKey types.PrivateKey // also in protocol curve25519protocol.go + salt types.Salt // also in protocol curve25519protocol.go + encryptSessionID types.EncryptionSessionID // encryption encryptSessionID + encryptor interfaces.Encryptor + connection interceptor.Connection + writer interceptor.Writer + reader interceptor.Reader cancel context.CancelFunc ctx context.Context - KeyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking + keyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking mux sync.RWMutex } @@ -37,24 +38,56 @@ func NewState(ctx context.Context, cancel context.CancelFunc, config config.Conf } return &State{ - InterceptorConfig: config, - peerID: message.UnknownReceiver, - privKey: types.PrivateKey{}, - salt: types.Salt{}, - encryptor: newEncryptor, - Connection: connection, - Writer: writer, - Reader: reader, - cancel: cancel, - ctx: ctx, + currentConfig: config, + peerID: message.UnknownReceiver, + privKey: types.PrivateKey{}, + salt: types.Salt{}, + encryptor: newEncryptor, + connection: connection, + writer: writer, + reader: reader, + cancel: cancel, + ctx: ctx, }, nil } func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID { - if s.KeyExchangeSessionID != "" { - fmt.Println("KeyExchangeSessionID already exists; creating new") - } - s.KeyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString()) + s.mux.Lock() + defer s.mux.Unlock() - return s.KeyExchangeSessionID + if s.keyExchangeSessionID != "" { + fmt.Println("keyExchangeSessionID already exists; creating new") + } + s.keyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString()) + + return s.keyExchangeSessionID +} + +func (s *State) WriteMessage(msg interceptor.Message) error { + s.mux.Lock() + defer s.mux.Unlock() + + msg.SetReceiver(s.peerID) + return s.writer.Write(s.connection, msg) +} + +func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID { + s.mux.Lock() + defer s.mux.Unlock() + + return s.keyExchangeSessionID +} + +func (s *State) GetConfig() config.Config { + s.mux.Lock() + defer s.mux.Unlock() + + return s.currentConfig +} + +func (s *State) SetKeys(encKey, decKey types.Key) error { + s.mux.Lock() + defer s.mux.Unlock() + + return s.encryptor.SetKeys(encKey, decKey) } diff --git a/pkg/middleware/encrypt/state/state_manager.go b/pkg/middleware/encrypt/state/state_manager.go index 42fd1ef..43feb81 100644 --- a/pkg/middleware/encrypt/state/state_manager.go +++ b/pkg/middleware/encrypt/state/state_manager.go @@ -3,16 +3,18 @@ package state import ( "sync" + "github.com/harshabose/socket-comm/internal/util" "github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" + "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces" ) type Manager struct { - states map[interceptor.Connection]*State + states map[interceptor.Connection]interfaces.State mux sync.RWMutex } -func (m *Manager) GetState(connection interceptor.Connection) (*State, error) { +func (m *Manager) GetState(connection interceptor.Connection) (interfaces.State, error) { m.mux.RLock() defer m.mux.RUnlock() @@ -24,7 +26,7 @@ func (m *Manager) GetState(connection interceptor.Connection) (*State, error) { return state, nil } -func (m *Manager) SetState(connection interceptor.Connection, s *State) error { +func (m *Manager) SetState(connection interceptor.Connection, s interfaces.State) error { m.mux.Lock() defer m.mux.Unlock() @@ -36,31 +38,31 @@ func (m *Manager) SetState(connection interceptor.Connection, s *State) error { return nil } -// RemoveState removes a Connection's state -func (m *Manager) RemoveState(connection interceptor.Connection) (*State, error) { +// RemoveState removes a connection's state +func (m *Manager) RemoveState(connection interceptor.Connection) error { m.mux.Lock() defer m.mux.Unlock() - state, exists := m.states[connection] + _, exists := m.states[connection] if !exists { - return nil, encryptionerr.ErrConnectionNotFound + return encryptionerr.ErrConnectionNotFound } delete(m.states, connection) - return state, nil + return nil } // ForEach executes the provided function for each state in the manager -func (m *Manager) ForEach(fn func(connection interceptor.Connection, state *State) error) []error { +func (m *Manager) ForEach(fn func(connection interceptor.Connection, state interfaces.State) error) error { m.mux.RLock() defer m.mux.RUnlock() - var errs []error + var errs util.MultiError for conn, state := range m.states { if err := fn(conn, state); err != nil { - errs = append(errs, err) + errs.Add(err) } } - return errs + return errs.ErrorOrNil() } diff --git a/pkg/middleware/encrypt/types/types.go b/pkg/middleware/encrypt/types/common.go similarity index 97% rename from pkg/middleware/encrypt/types/types.go rename to pkg/middleware/encrypt/types/common.go index 6008db1..cf9a924 100644 --- a/pkg/middleware/encrypt/types/types.go +++ b/pkg/middleware/encrypt/types/common.go @@ -5,7 +5,7 @@ type ( PrivateKey [32]byte PublicKey [32]byte Salt [16]byte - SessionID [16]byte + EncryptionSessionID [16]byte Nonce [12]byte Key [32]byte KeyExchangeProtocol string diff --git a/pkg/middleware/encrypt/types/keyexchange.go b/pkg/middleware/encrypt/types/keyexchange.go new file mode 100644 index 0000000..a6002cc --- /dev/null +++ b/pkg/middleware/encrypt/types/keyexchange.go @@ -0,0 +1,11 @@ +package types + +type SessionState int + +const ( + SessionStateNotStart SessionState = iota + SessionStateInitial + SessionStateInProgress + SessionStateCompleted + SessionStateError +)