updated encryption interceptor to meet good programming principle: added modular key exchange mechanism, updated interface design, restructured interfaces to support interface segregation

This commit is contained in:
harshabose
2025-05-01 01:09:43 +05:30
parent 8f381b77f4
commit aed3b8c928
16 changed files with 465 additions and 205 deletions

View File

@@ -39,6 +39,9 @@ type Message interface {
// //
// Returns an error if processing fails // Returns an error if processing fails
ReadProcess(Interceptor, Connection) error ReadProcess(Interceptor, Connection) error
SetReceiver(message.Receiver)
SetSender(message.Sender)
} }
// BaseMessage provides a default implementation of the Message interface. // BaseMessage provides a default implementation of the Message interface.
@@ -94,3 +97,11 @@ func (m *BaseMessage) ReadProcess(_ Interceptor, _ Connection) error {
// Derived-types should override this method with specific processing logic // Derived-types should override this method with specific processing logic
return nil return nil
} }
func (m *BaseMessage) SetSender(sender message.Sender) {
m.CurrentHeader.Sender = sender
}
func (m *BaseMessage) SetReceiver(receiver message.Receiver) {
m.CurrentHeader.Receiver = receiver
}

View File

@@ -2,21 +2,24 @@ package encrypt
import ( import (
"context" "context"
"time"
"github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
) )
type Interceptor struct { type Interceptor struct {
interceptor.NoOpInterceptor interceptor.NoOpInterceptor
nonceValidator NonceValidator nonceValidator NonceValidator
keyExchangeManager keyexchange.Manager keyExchangeManager interfaces.KeyExchangeManager
keyProvider keyprovider.KeyProvider keyProvider keyprovider.KeyProvider
stateManager state.Manager stateManager interfaces.StateManager
config config.Config config config.Config
} }
@@ -36,14 +39,21 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
} }
func (i *Interceptor) Init(connection interceptor.Connection) error { func (i *Interceptor) Init(connection interceptor.Connection) error {
_state, err := i.stateManager.GetState(connection) s, err := i.stateManager.GetState(connection)
if err != nil { if err != nil {
return err return err
} }
if err := i.keyExchangeManager.Init(_state, keyexchange.WithKeySignature(i.keyProvider)); err != nil { if err := i.keyExchangeManager.Init(s, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
return err return err
} }
ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second)
defer cancel()
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
return i.Process(waiter, s)
} }
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
@@ -62,16 +72,15 @@ func (i *Interceptor) Close() error {
} }
func GetState(_i interceptor.Interceptor, connection interceptor.Connection) (*state.State, error) { func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) {
i, ok := _i.(*Interceptor) return i.stateManager.GetState(connection)
if !ok { }
return nil, encryptionerr.ErrInvalidInterceptor
} func (i *Interceptor) Process(msg interfaces.CanProcess, state interfaces.State) error {
processor, ok := i.keyExchangeManager.(interfaces.ProtocolProcessor)
s, err := i.stateManager.GetState(connection) if !ok {
if err != nil { return encryptionerr.ErrInvalidMessageType
return nil, err }
}
return processor.Process(msg, state)
return s, nil
} }

View File

@@ -9,7 +9,7 @@ import (
// NonceValidator provides protection against replay attacks // NonceValidator provides protection against replay attacks
type NonceValidator interface { type NonceValidator interface {
// Validate checks if a nonce is valid and hasn't been seen before // Validate checks if a nonce is valid and hasn't been seen before
Validate(nonce []byte, sessionID types.SessionID) error Validate(nonce []byte, sessionID types.EncryptionSessionID) error
// Cleanup removes expired nonces // Cleanup removes expired nonces
Cleanup(before time.Time) Cleanup(before time.Time)

View File

@@ -1,4 +1,4 @@
package encryptor package interfaces
import ( import (
"io" "io"
@@ -7,13 +7,21 @@ import (
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
) )
type KeySetter interface {
// SetKeys configures the encryption and decryption keys
SetKeys(encryptorKey, decryptorKey types.Key) error
}
type KeyGetter interface {
GetKeys() (encKey, decKey types.Key, err error)
}
// Encryptor defines the interface for message encryption and decryption // Encryptor defines the interface for message encryption and decryption
type Encryptor interface { type Encryptor interface {
// SetKeys configures the encryption and decryption keys KeySetter
SetKeys(encryptKey, decryptKey types.Key) error
// SetSessionID sets the session identifier for this encryption session // SetSessionID sets the session identifier for this encryption session
SetSessionID(id types.SessionID) SetSessionID(id types.EncryptionSessionID)
// Encrypt encrypts a message between sender and receiver // Encrypt encrypts a message between sender and receiver
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error) Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)

View File

@@ -0,0 +1,26 @@
package interfaces
import (
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type ProtocolProcessor interface {
Process(CanProcess, State) error
}
type KeyExchangeManager interface {
Init(state State, options ...ProtocolFactoryOption) error
}
type CanGetSessionState interface {
GetState() types.SessionState
}
type Protocol interface {
Init(s State) error
IsComplete() bool
}
type CanProcess interface {
Process(Protocol, State) error
}

View File

@@ -0,0 +1,35 @@
package interfaces
import (
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type ProtocolFactoryOption func(Protocol) error
type State interface {
GenerateKeyExchangeSessionID() types.KeyExchangeSessionID
GetKeyExchangeSessionID() types.KeyExchangeSessionID
WriteMessage(interceptor.Message) error
GetConfig() config.Config
}
type CanGetState interface {
GetState(interceptor.Connection) (State, error)
}
type CanSetState interface {
SetState(interceptor.Connection, State) error
}
type CanRemoveState interface {
RemoveState(interceptor.Connection) error
}
type StateManager interface {
CanGetState
CanSetState
CanRemoveState
ForEach(func(interceptor.Connection, State) error) error
}

View File

@@ -0,0 +1,190 @@
package keyexchange
import (
"fmt"
"time"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type Init struct {
// TODO: MANAGE STATE USING KEY EXCHANGE SESSION ID
interceptor.BaseMessage
PublicKey types.PublicKey `json:"public_key"`
Signature []byte `json:"signature"`
SessionID types.EncryptionSessionID `json:"session_id"`
Salt types.Salt `json:"salt"`
}
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
ss, ok := _interceptor.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := ss.GetState(connection)
if err != nil {
return err
}
pp, ok := _interceptor.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
return pp.Process(m, s)
}
func (m *Init) Process(protocol interfaces.Protocol, s interfaces.State) error {
p, ok := protocol.(*Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if p.state != types.SessionStateInitial {
return encryptionerr.ErrInvalidSessionState
}
sign := append(m.PublicKey[:], m.Salt[:]...)
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
return encryptionerr.ErrInvalidSignature
}
p.salt = m.Salt
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
if err != nil {
return fmt.Errorf("failed to compute shared secret: %w", err)
}
encKey, decKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING
if err != nil {
return fmt.Errorf("key derivation failed: %w", err)
}
p.encKey = encKey
p.decKey = decKey
p.sessionID = m.SessionID
if err := s.WriteMessage(nil); err != nil {
} // TODO: ADD RESPONSE MESSAGE
p.state = types.SessionStateInProgress
return nil
}
type Response struct {
interceptor.BaseMessage
PublicKey types.PublicKey `json:"public_key"`
}
func (m *Response) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Response) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
ss, ok := _i.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := ss.GetState(connection)
if err != nil {
return err
}
pp, ok := _i.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
return pp.Process(m, s)
}
func (m *Response) Process(protocol interfaces.Protocol, s interfaces.State) error {
p, ok := protocol.(*Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if p.state != types.SessionStateInitial {
return encryptionerr.ErrInvalidSessionState
}
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
if err != nil {
return fmt.Errorf("failed to compute shared secret: %w", err)
}
decKey, encKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING
if err != nil {
return fmt.Errorf("key derivation failed: %w", err)
}
p.encKey = encKey
p.decKey = decKey
if err := s.WriteMessage(nil); err != nil {
return err
} // TODO: Send Done message
p.state = types.SessionStateInProgress
return nil
}
type Done struct {
interceptor.BaseMessage
Timestamp time.Time `json:"timestamp"`
}
func (m *Done) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Done) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
ss, ok := _i.(interfaces.CanGetState)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := ss.GetState(connection)
if err != nil {
return err
}
pp, ok := _i.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
return pp.Process(m, s)
}
func (m *Done) Process(protocol interfaces.Protocol, _s interfaces.State) error {
ss, ok := _s.(interfaces.KeySetter)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
p, ok := protocol.(*Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if err := ss.SetKeys(p.encKey, p.decKey); err != nil {
return err
}
p.state = types.SessionStateCompleted
return nil
}

View File

@@ -0,0 +1,47 @@
package keyexchange
import (
"context"
"fmt"
"time"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type SessionStateTargetWaiter struct {
target types.SessionState
ctx context.Context
}
func NewSessionStateTargetWaiter(ctx context.Context, target types.SessionState) SessionStateTargetWaiter {
return SessionStateTargetWaiter{
target: target,
ctx: ctx,
}
}
func (w SessionStateTargetWaiter) Process(protocol interfaces.Protocol, _ interfaces.State) error {
p, ok := protocol.(interfaces.CanGetSessionState)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
// TODO: Implement BLOCKing (ideally using sync.Cond, or simple ticker)
select {
case <-ticker.C:
if p.GetState() == w.target {
return nil
}
case <-w.ctx.Done():
return fmt.Errorf("timeout waiting for state %v: %w", w.target, w.ctx.Err())
}
}
}
// TODO: WRITE MORE KEY EXCHANGE PROCESSES HERE

View File

@@ -8,8 +8,7 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/messages" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
) )
@@ -18,11 +17,11 @@ type Curve25519Protocol struct {
pubKey types.PublicKey pubKey types.PublicKey
peerPubKey types.PublicKey peerPubKey types.PublicKey
salt types.Salt // also in protocol curve25519protocol.go salt types.Salt // also in protocol curve25519protocol.go
sessionID types.SessionID sessionID types.EncryptionSessionID
sharedSecret []byte sharedSecret []byte
encKey types.Key encKey types.Key
decKey types.Key decKey types.Key
state SessionState state types.SessionState
options Curve25519Options options Curve25519Options
// TODO: mutex is needed here for SessionState and Keys // TODO: mutex is needed here for SessionState and Keys
@@ -34,55 +33,55 @@ type Curve25519Options struct {
RequireSignature bool RequireSignature bool
} }
func (p *Curve25519Protocol) Init(s *state.State) error { func (p *Curve25519Protocol) Init(s interfaces.State) error {
if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil { if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil {
return err return err
} }
curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey)) curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey))
if s.InterceptorConfig.IsServer && p.options.RequireSignature { if s.GetConfig().IsServer && p.options.RequireSignature {
if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil { if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil {
p.state = SessionStateError p.state = types.SessionStateError
return err return err
} }
if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil { if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil {
p.state = SessionStateError p.state = types.SessionStateError
return err return err
} }
sign := ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...)) _ = ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
if err := s.Writer.Write(s.Connection, messages.NewInit("", s.PeerID, p.pubKey, sign, p.sessionID, p.salt)); err != nil { if err := s.WriteMessage(nil); err != nil { // TODO: SEND INIT MESSAGE
p.state = SessionStateError p.state = types.SessionStateError
return err return err
} }
} }
p.state = SessionStateInProgress p.state = types.SessionStateInitial
return nil return nil
} }
func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) { func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) {
if p.state != SessionStateCompleted { if p.state != types.SessionStateCompleted {
return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete
} }
return p.encKey, p.decKey, nil return p.encKey, p.decKey, nil
} }
func (p *Curve25519Protocol) GetState() SessionState { func (p *Curve25519Protocol) GetState() types.SessionState {
return p.state return p.state
} }
func (p *Curve25519Protocol) IsComplete() bool { func (p *Curve25519Protocol) IsComplete() bool {
return p.state == SessionStateCompleted return p.state == types.SessionStateCompleted
} }
func (p *Curve25519Protocol) Process(msg MessageProcessor, s *state.State) error { func (p *Curve25519Protocol) Process(msg interfaces.CanProcess, s interfaces.State) error {
if err := msg.Process(p, s); err != nil { if err := msg.Process(p, s); err != nil {
p.state = SessionStateError p.state = types.SessionStateError
return err return err
} }

View File

@@ -9,13 +9,13 @@ import (
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
) )
type Session struct { type Session struct {
protocol Protocol protocol interfaces.Protocol
state *state.State state interfaces.State
createdAt time.Time createdAt time.Time
completedAt time.Time completedAt time.Time
} }
@@ -25,16 +25,16 @@ type Manager struct {
sessions map[types.KeyExchangeSessionID]*Session sessions map[types.KeyExchangeSessionID]*Session
} }
func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error { func (m *Manager) Init(s interfaces.State, options ...interfaces.ProtocolFactoryOption) error {
sessionID := s.GenerateKeyExchangeSessionID() sessionID := s.GenerateKeyExchangeSessionID()
_, exists := m.sessions[sessionID] _, exists := m.sessions[sessionID]
if exists { if exists {
return encryptionerr.ErrExchangeInProgress return encryptionerr.ErrExchangeInProgress
} }
factory, exists := m.registry[s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol] factory, exists := m.registry[s.GetConfig().EncryptionProtocol.KeyExchangeProtocol]
if !exists { if !exists {
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol) return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.GetConfig().EncryptionProtocol.KeyExchangeProtocol)
} }
p, err := factory(options...) p, err := factory(options...)
@@ -55,13 +55,18 @@ func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
return nil return nil
} }
func (m *Manager) Process(s *state.State, msg MessageProcessor) error { func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error {
session, exists := m.sessions[s.KeyExchangeSessionID] session, exists := m.sessions[s.GetKeyExchangeSessionID()]
if !exists { if !exists {
return encryptionerr.ErrSessionNotFound return encryptionerr.ErrSessionNotFound
} }
return session.protocol.Process(msg, s) processor, ok := session.protocol.(interfaces.ProtocolProcessor)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
return processor.Process(msg, s)
} }
// Derive generates encryption keys from shared secret // Derive generates encryption keys from shared secret

View File

@@ -3,36 +3,12 @@ package keyexchange
import ( import (
"fmt" "fmt"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
) )
type SessionState int func WithKeySignature(keyProvider keyprovider.KeyProvider) interfaces.ProtocolFactoryOption {
return func(protocol interfaces.Protocol) error {
const (
SessionStateInitial SessionState = iota
SessionStateInProgress
SessionStateCompleted
SessionStateError
)
type MessageProcessor interface {
Process(Protocol, *state.State) error
}
type Protocol interface {
Init(s *state.State) error
GetKeys() (encKey types.Key, decKey types.Key, err error)
Process(MessageProcessor, *state.State) error
GetState() SessionState
IsComplete() bool
}
type ProtocolFactoryOption func(Protocol) error
func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption {
return func(protocol Protocol) error {
curveProtocol, ok := protocol.(*Curve25519Protocol) curveProtocol, ok := protocol.(*Curve25519Protocol)
if !ok { if !ok {
return fmt.Errorf("WithKeySignature only supports Curve25519Protocol") return fmt.Errorf("WithKeySignature only supports Curve25519Protocol")
@@ -46,4 +22,4 @@ func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption
} }
} }
type ProtocolFactory func(options ...ProtocolFactoryOption) (Protocol, error) type ProtocolFactory func(options ...interfaces.ProtocolFactoryOption) (interfaces.Protocol, error)

View File

@@ -1,92 +0,0 @@
package messages
import (
"fmt"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
// Protocol constants
const (
InitProtocol message.Protocol = "curve25519.init"
ResponseProtocol message.Protocol = "curve25519.response"
ConfirmProtocol message.Protocol = "curve25519.confirm"
)
type Init struct {
interceptor.BaseMessage
PublicKey types.PublicKey `json:"public_key"`
Signature []byte `json:"signature"`
SessionID types.SessionID `json:"session_id"`
Salt types.Salt `json:"salt"`
}
func NewInit(sender message.Sender, receiver message.Receiver, key types.PublicKey, sign []byte, sessionID types.SessionID, salt types.Salt) *Init {
return &Init{
BaseMessage: interceptor.NewBaseMessage(InitProtocol, sender, receiver),
PublicKey: key,
Signature: sign,
SessionID: sessionID,
Salt: salt,
}
}
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
s, err := encrypt.GetState(_interceptor, connection)
if err != nil {
return err
}
return i.keyExchangeManager.Process(s, m)
}
func (m *Init) Process(protocol keyexchange.Protocol, s *state.State) error {
p, ok := protocol.(*keyexchange.Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if p.GetState() != keyexchange.SessionStateInitial {
return encryptionerr.ErrInvalidSessionState
}
sign := append(m.PublicKey[:], m.Salt[:]...)
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
return encryptionerr.ErrInvalidSignature
}
p.salt = m.Salt
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
if err != nil {
return fmt.Errorf("failed to compute shared secret: %w", err)
}
encKey, decKey, err := keyexchange.Derive(shared, p.salt, "") // TODO: ADD INFO STRING
if err != nil {
return fmt.Errorf("key derivation failed: %w", err)
}
p.encKey = encKey
p.decKey = decKey
p.sessionID = m.SessionID
if err := s.Writer.Write(s.Connection, nil); err != nil {
return err
} // TODO: ADD RESPONSE MESSAGE
p.state = SessionStateCompleted
return nil
}

View File

@@ -11,22 +11,23 @@ import (
"github.com/harshabose/socket-comm/pkg/message" "github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
) )
type State struct { type State struct {
InterceptorConfig config.Config // copy of interceptor config currentConfig config.Config // copy of interceptor config
PeerID message.Receiver peerID message.Receiver
privKey types.PrivateKey // also in protocol curve25519protocol.go privKey types.PrivateKey // also in protocol curve25519protocol.go
salt types.Salt // also in protocol curve25519protocol.go salt types.Salt // also in protocol curve25519protocol.go
sessionID types.SessionID // encryption sessionID encryptSessionID types.EncryptionSessionID // encryption encryptSessionID
encryptor encryptor.Encryptor encryptor interfaces.Encryptor
Connection interceptor.Connection connection interceptor.Connection
Writer interceptor.Writer writer interceptor.Writer
Reader interceptor.Reader reader interceptor.Reader
cancel context.CancelFunc cancel context.CancelFunc
ctx context.Context ctx context.Context
KeyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking keyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
mux sync.RWMutex mux sync.RWMutex
} }
@@ -37,24 +38,56 @@ func NewState(ctx context.Context, cancel context.CancelFunc, config config.Conf
} }
return &State{ return &State{
InterceptorConfig: config, currentConfig: config,
peerID: message.UnknownReceiver, peerID: message.UnknownReceiver,
privKey: types.PrivateKey{}, privKey: types.PrivateKey{},
salt: types.Salt{}, salt: types.Salt{},
encryptor: newEncryptor, encryptor: newEncryptor,
Connection: connection, connection: connection,
Writer: writer, writer: writer,
Reader: reader, reader: reader,
cancel: cancel, cancel: cancel,
ctx: ctx, ctx: ctx,
}, nil }, nil
} }
func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID { func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
if s.KeyExchangeSessionID != "" { s.mux.Lock()
fmt.Println("KeyExchangeSessionID already exists; creating new") defer s.mux.Unlock()
}
s.KeyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
return s.KeyExchangeSessionID if s.keyExchangeSessionID != "" {
fmt.Println("keyExchangeSessionID already exists; creating new")
}
s.keyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
return s.keyExchangeSessionID
}
func (s *State) WriteMessage(msg interceptor.Message) error {
s.mux.Lock()
defer s.mux.Unlock()
msg.SetReceiver(s.peerID)
return s.writer.Write(s.connection, msg)
}
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
s.mux.Lock()
defer s.mux.Unlock()
return s.keyExchangeSessionID
}
func (s *State) GetConfig() config.Config {
s.mux.Lock()
defer s.mux.Unlock()
return s.currentConfig
}
func (s *State) SetKeys(encKey, decKey types.Key) error {
s.mux.Lock()
defer s.mux.Unlock()
return s.encryptor.SetKeys(encKey, decKey)
} }

View File

@@ -3,16 +3,18 @@ package state
import ( import (
"sync" "sync"
"github.com/harshabose/socket-comm/internal/util"
"github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr" "github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
) )
type Manager struct { type Manager struct {
states map[interceptor.Connection]*State states map[interceptor.Connection]interfaces.State
mux sync.RWMutex mux sync.RWMutex
} }
func (m *Manager) GetState(connection interceptor.Connection) (*State, error) { func (m *Manager) GetState(connection interceptor.Connection) (interfaces.State, error) {
m.mux.RLock() m.mux.RLock()
defer m.mux.RUnlock() defer m.mux.RUnlock()
@@ -24,7 +26,7 @@ func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
return state, nil return state, nil
} }
func (m *Manager) SetState(connection interceptor.Connection, s *State) error { func (m *Manager) SetState(connection interceptor.Connection, s interfaces.State) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@@ -36,31 +38,31 @@ func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
return nil return nil
} }
// RemoveState removes a Connection's state // RemoveState removes a connection's state
func (m *Manager) RemoveState(connection interceptor.Connection) (*State, error) { func (m *Manager) RemoveState(connection interceptor.Connection) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
state, exists := m.states[connection] _, exists := m.states[connection]
if !exists { if !exists {
return nil, encryptionerr.ErrConnectionNotFound return encryptionerr.ErrConnectionNotFound
} }
delete(m.states, connection) delete(m.states, connection)
return state, nil return nil
} }
// ForEach executes the provided function for each state in the manager // ForEach executes the provided function for each state in the manager
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state *State) error) []error { func (m *Manager) ForEach(fn func(connection interceptor.Connection, state interfaces.State) error) error {
m.mux.RLock() m.mux.RLock()
defer m.mux.RUnlock() defer m.mux.RUnlock()
var errs []error var errs util.MultiError
for conn, state := range m.states { for conn, state := range m.states {
if err := fn(conn, state); err != nil { if err := fn(conn, state); err != nil {
errs = append(errs, err) errs.Add(err)
} }
} }
return errs return errs.ErrorOrNil()
} }

View File

@@ -5,7 +5,7 @@ type (
PrivateKey [32]byte PrivateKey [32]byte
PublicKey [32]byte PublicKey [32]byte
Salt [16]byte Salt [16]byte
SessionID [16]byte EncryptionSessionID [16]byte
Nonce [12]byte Nonce [12]byte
Key [32]byte Key [32]byte
KeyExchangeProtocol string KeyExchangeProtocol string

View File

@@ -0,0 +1,11 @@
package types
type SessionState int
const (
SessionStateNotStart SessionState = iota
SessionStateInitial
SessionStateInProgress
SessionStateCompleted
SessionStateError
)