mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-11 02:20:06 +08:00
updated encryption interceptor to meet good programming principle: added modular key exchange mechanism, updated interface design, restructured interfaces to support interface segregation
This commit is contained in:
@@ -39,6 +39,9 @@ type Message interface {
|
||||
//
|
||||
// Returns an error if processing fails
|
||||
ReadProcess(Interceptor, Connection) error
|
||||
|
||||
SetReceiver(message.Receiver)
|
||||
SetSender(message.Sender)
|
||||
}
|
||||
|
||||
// BaseMessage provides a default implementation of the Message interface.
|
||||
@@ -94,3 +97,11 @@ func (m *BaseMessage) ReadProcess(_ Interceptor, _ Connection) error {
|
||||
// Derived-types should override this method with specific processing logic
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BaseMessage) SetSender(sender message.Sender) {
|
||||
m.CurrentHeader.Sender = sender
|
||||
}
|
||||
|
||||
func (m *BaseMessage) SetReceiver(receiver message.Receiver) {
|
||||
m.CurrentHeader.Receiver = receiver
|
||||
}
|
||||
|
@@ -2,21 +2,24 @@ package encrypt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
nonceValidator NonceValidator
|
||||
keyExchangeManager keyexchange.Manager
|
||||
keyExchangeManager interfaces.KeyExchangeManager
|
||||
keyProvider keyprovider.KeyProvider
|
||||
stateManager state.Manager
|
||||
stateManager interfaces.StateManager
|
||||
config config.Config
|
||||
}
|
||||
|
||||
@@ -36,14 +39,21 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
|
||||
}
|
||||
|
||||
func (i *Interceptor) Init(connection interceptor.Connection) error {
|
||||
_state, err := i.stateManager.GetState(connection)
|
||||
s, err := i.stateManager.GetState(connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := i.keyExchangeManager.Init(_state, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
|
||||
if err := i.keyExchangeManager.Init(s, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(i.Ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
|
||||
|
||||
return i.Process(waiter, s)
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
@@ -62,16 +72,15 @@ func (i *Interceptor) Close() error {
|
||||
|
||||
}
|
||||
|
||||
func GetState(_i interceptor.Interceptor, connection interceptor.Connection) (*state.State, error) {
|
||||
i, ok := _i.(*Interceptor)
|
||||
if !ok {
|
||||
return nil, encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
s, err := i.stateManager.GetState(connection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) {
|
||||
return i.stateManager.GetState(connection)
|
||||
}
|
||||
|
||||
func (i *Interceptor) Process(msg interfaces.CanProcess, state interfaces.State) error {
|
||||
processor, ok := i.keyExchangeManager.(interfaces.ProtocolProcessor)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
return processor.Process(msg, state)
|
||||
}
|
||||
|
@@ -9,7 +9,7 @@ import (
|
||||
// NonceValidator provides protection against replay attacks
|
||||
type NonceValidator interface {
|
||||
// Validate checks if a nonce is valid and hasn't been seen before
|
||||
Validate(nonce []byte, sessionID types.SessionID) error
|
||||
Validate(nonce []byte, sessionID types.EncryptionSessionID) error
|
||||
|
||||
// Cleanup removes expired nonces
|
||||
Cleanup(before time.Time)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
package encryptor
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"io"
|
||||
@@ -7,13 +7,21 @@ import (
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type KeySetter interface {
|
||||
// SetKeys configures the encryption and decryption keys
|
||||
SetKeys(encryptorKey, decryptorKey types.Key) error
|
||||
}
|
||||
|
||||
type KeyGetter interface {
|
||||
GetKeys() (encKey, decKey types.Key, err error)
|
||||
}
|
||||
|
||||
// Encryptor defines the interface for message encryption and decryption
|
||||
type Encryptor interface {
|
||||
// SetKeys configures the encryption and decryption keys
|
||||
SetKeys(encryptKey, decryptKey types.Key) error
|
||||
KeySetter
|
||||
|
||||
// SetSessionID sets the session identifier for this encryption session
|
||||
SetSessionID(id types.SessionID)
|
||||
SetSessionID(id types.EncryptionSessionID)
|
||||
|
||||
// Encrypt encrypts a message between sender and receiver
|
||||
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)
|
26
pkg/middleware/encrypt/interfaces/keyexchange.go
Normal file
26
pkg/middleware/encrypt/interfaces/keyexchange.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type ProtocolProcessor interface {
|
||||
Process(CanProcess, State) error
|
||||
}
|
||||
|
||||
type KeyExchangeManager interface {
|
||||
Init(state State, options ...ProtocolFactoryOption) error
|
||||
}
|
||||
|
||||
type CanGetSessionState interface {
|
||||
GetState() types.SessionState
|
||||
}
|
||||
|
||||
type Protocol interface {
|
||||
Init(s State) error
|
||||
IsComplete() bool
|
||||
}
|
||||
|
||||
type CanProcess interface {
|
||||
Process(Protocol, State) error
|
||||
}
|
35
pkg/middleware/encrypt/interfaces/state.go
Normal file
35
pkg/middleware/encrypt/interfaces/state.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type ProtocolFactoryOption func(Protocol) error
|
||||
|
||||
type State interface {
|
||||
GenerateKeyExchangeSessionID() types.KeyExchangeSessionID
|
||||
GetKeyExchangeSessionID() types.KeyExchangeSessionID
|
||||
WriteMessage(interceptor.Message) error
|
||||
GetConfig() config.Config
|
||||
}
|
||||
|
||||
type CanGetState interface {
|
||||
GetState(interceptor.Connection) (State, error)
|
||||
}
|
||||
|
||||
type CanSetState interface {
|
||||
SetState(interceptor.Connection, State) error
|
||||
}
|
||||
|
||||
type CanRemoveState interface {
|
||||
RemoveState(interceptor.Connection) error
|
||||
}
|
||||
|
||||
type StateManager interface {
|
||||
CanGetState
|
||||
CanSetState
|
||||
CanRemoveState
|
||||
ForEach(func(interceptor.Connection, State) error) error
|
||||
}
|
190
pkg/middleware/encrypt/keyexchange/curve25519messages.go
Normal file
190
pkg/middleware/encrypt/keyexchange/curve25519messages.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package keyexchange
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type Init struct {
|
||||
// TODO: MANAGE STATE USING KEY EXCHANGE SESSION ID
|
||||
interceptor.BaseMessage
|
||||
PublicKey types.PublicKey `json:"public_key"`
|
||||
Signature []byte `json:"signature"`
|
||||
SessionID types.EncryptionSessionID `json:"session_id"`
|
||||
Salt types.Salt `json:"salt"`
|
||||
}
|
||||
|
||||
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
ss, ok := _interceptor.(interfaces.CanGetState)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
s, err := ss.GetState(connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pp, ok := _interceptor.(interfaces.ProtocolProcessor)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
return pp.Process(m, s)
|
||||
}
|
||||
|
||||
func (m *Init) Process(protocol interfaces.Protocol, s interfaces.State) error {
|
||||
p, ok := protocol.(*Curve25519Protocol)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
if p.state != types.SessionStateInitial {
|
||||
return encryptionerr.ErrInvalidSessionState
|
||||
}
|
||||
|
||||
sign := append(m.PublicKey[:], m.Salt[:]...)
|
||||
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
|
||||
return encryptionerr.ErrInvalidSignature
|
||||
}
|
||||
|
||||
p.salt = m.Salt
|
||||
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compute shared secret: %w", err)
|
||||
}
|
||||
|
||||
encKey, decKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING
|
||||
if err != nil {
|
||||
return fmt.Errorf("key derivation failed: %w", err)
|
||||
}
|
||||
|
||||
p.encKey = encKey
|
||||
p.decKey = decKey
|
||||
p.sessionID = m.SessionID
|
||||
|
||||
if err := s.WriteMessage(nil); err != nil {
|
||||
|
||||
} // TODO: ADD RESPONSE MESSAGE
|
||||
|
||||
p.state = types.SessionStateInProgress
|
||||
return nil
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
interceptor.BaseMessage
|
||||
PublicKey types.PublicKey `json:"public_key"`
|
||||
}
|
||||
|
||||
func (m *Response) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Response) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
ss, ok := _i.(interfaces.CanGetState)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
s, err := ss.GetState(connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pp, ok := _i.(interfaces.ProtocolProcessor)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
return pp.Process(m, s)
|
||||
}
|
||||
|
||||
func (m *Response) Process(protocol interfaces.Protocol, s interfaces.State) error {
|
||||
p, ok := protocol.(*Curve25519Protocol)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
if p.state != types.SessionStateInitial {
|
||||
return encryptionerr.ErrInvalidSessionState
|
||||
}
|
||||
|
||||
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compute shared secret: %w", err)
|
||||
}
|
||||
|
||||
decKey, encKey, err := Derive(shared, p.salt, "") // TODO: ADD INFO STRING
|
||||
if err != nil {
|
||||
return fmt.Errorf("key derivation failed: %w", err)
|
||||
}
|
||||
|
||||
p.encKey = encKey
|
||||
p.decKey = decKey
|
||||
|
||||
if err := s.WriteMessage(nil); err != nil {
|
||||
return err
|
||||
} // TODO: Send Done message
|
||||
|
||||
p.state = types.SessionStateInProgress
|
||||
return nil
|
||||
}
|
||||
|
||||
type Done struct {
|
||||
interceptor.BaseMessage
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
func (m *Done) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Done) ReadProcess(_i interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
ss, ok := _i.(interfaces.CanGetState)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
s, err := ss.GetState(connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pp, ok := _i.(interfaces.ProtocolProcessor)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
return pp.Process(m, s)
|
||||
}
|
||||
|
||||
func (m *Done) Process(protocol interfaces.Protocol, _s interfaces.State) error {
|
||||
ss, ok := _s.(interfaces.KeySetter)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
p, ok := protocol.(*Curve25519Protocol)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
if err := ss.SetKeys(p.encKey, p.decKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.state = types.SessionStateCompleted
|
||||
return nil
|
||||
}
|
47
pkg/middleware/encrypt/keyexchange/curve25519process.go
Normal file
47
pkg/middleware/encrypt/keyexchange/curve25519process.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package keyexchange
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type SessionStateTargetWaiter struct {
|
||||
target types.SessionState
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewSessionStateTargetWaiter(ctx context.Context, target types.SessionState) SessionStateTargetWaiter {
|
||||
return SessionStateTargetWaiter{
|
||||
target: target,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (w SessionStateTargetWaiter) Process(protocol interfaces.Protocol, _ interfaces.State) error {
|
||||
p, ok := protocol.(interfaces.CanGetSessionState)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
// TODO: Implement BLOCKing (ideally using sync.Cond, or simple ticker)
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if p.GetState() == w.target {
|
||||
return nil
|
||||
}
|
||||
case <-w.ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for state %v: %w", w.target, w.ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: WRITE MORE KEY EXCHANGE PROCESSES HERE
|
@@ -8,8 +8,7 @@ import (
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/messages"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
@@ -18,11 +17,11 @@ type Curve25519Protocol struct {
|
||||
pubKey types.PublicKey
|
||||
peerPubKey types.PublicKey
|
||||
salt types.Salt // also in protocol curve25519protocol.go
|
||||
sessionID types.SessionID
|
||||
sessionID types.EncryptionSessionID
|
||||
sharedSecret []byte
|
||||
encKey types.Key
|
||||
decKey types.Key
|
||||
state SessionState
|
||||
state types.SessionState
|
||||
options Curve25519Options
|
||||
|
||||
// TODO: mutex is needed here for SessionState and Keys
|
||||
@@ -34,55 +33,55 @@ type Curve25519Options struct {
|
||||
RequireSignature bool
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) Init(s *state.State) error {
|
||||
func (p *Curve25519Protocol) Init(s interfaces.State) error {
|
||||
if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey))
|
||||
|
||||
if s.InterceptorConfig.IsServer && p.options.RequireSignature {
|
||||
if s.GetConfig().IsServer && p.options.RequireSignature {
|
||||
if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil {
|
||||
p.state = SessionStateError
|
||||
p.state = types.SessionStateError
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil {
|
||||
p.state = SessionStateError
|
||||
p.state = types.SessionStateError
|
||||
return err
|
||||
}
|
||||
|
||||
sign := ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
|
||||
_ = ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
|
||||
|
||||
if err := s.Writer.Write(s.Connection, messages.NewInit("", s.PeerID, p.pubKey, sign, p.sessionID, p.salt)); err != nil {
|
||||
p.state = SessionStateError
|
||||
if err := s.WriteMessage(nil); err != nil { // TODO: SEND INIT MESSAGE
|
||||
p.state = types.SessionStateError
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.state = SessionStateInProgress
|
||||
p.state = types.SessionStateInitial
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) {
|
||||
if p.state != SessionStateCompleted {
|
||||
if p.state != types.SessionStateCompleted {
|
||||
return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete
|
||||
}
|
||||
|
||||
return p.encKey, p.decKey, nil
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) GetState() SessionState {
|
||||
func (p *Curve25519Protocol) GetState() types.SessionState {
|
||||
return p.state
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) IsComplete() bool {
|
||||
return p.state == SessionStateCompleted
|
||||
return p.state == types.SessionStateCompleted
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) Process(msg MessageProcessor, s *state.State) error {
|
||||
func (p *Curve25519Protocol) Process(msg interfaces.CanProcess, s interfaces.State) error {
|
||||
if err := msg.Process(p, s); err != nil {
|
||||
p.state = SessionStateError
|
||||
p.state = types.SessionStateError
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -9,13 +9,13 @@ import (
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
protocol Protocol
|
||||
state *state.State
|
||||
protocol interfaces.Protocol
|
||||
state interfaces.State
|
||||
createdAt time.Time
|
||||
completedAt time.Time
|
||||
}
|
||||
@@ -25,16 +25,16 @@ type Manager struct {
|
||||
sessions map[types.KeyExchangeSessionID]*Session
|
||||
}
|
||||
|
||||
func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
|
||||
func (m *Manager) Init(s interfaces.State, options ...interfaces.ProtocolFactoryOption) error {
|
||||
sessionID := s.GenerateKeyExchangeSessionID()
|
||||
_, exists := m.sessions[sessionID]
|
||||
if exists {
|
||||
return encryptionerr.ErrExchangeInProgress
|
||||
}
|
||||
|
||||
factory, exists := m.registry[s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol]
|
||||
factory, exists := m.registry[s.GetConfig().EncryptionProtocol.KeyExchangeProtocol]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol)
|
||||
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.GetConfig().EncryptionProtocol.KeyExchangeProtocol)
|
||||
}
|
||||
|
||||
p, err := factory(options...)
|
||||
@@ -55,13 +55,18 @@ func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Process(s *state.State, msg MessageProcessor) error {
|
||||
session, exists := m.sessions[s.KeyExchangeSessionID]
|
||||
func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error {
|
||||
session, exists := m.sessions[s.GetKeyExchangeSessionID()]
|
||||
if !exists {
|
||||
return encryptionerr.ErrSessionNotFound
|
||||
}
|
||||
|
||||
return session.protocol.Process(msg, s)
|
||||
processor, ok := session.protocol.(interfaces.ProtocolProcessor)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
return processor.Process(msg, s)
|
||||
}
|
||||
|
||||
// Derive generates encryption keys from shared secret
|
||||
|
@@ -3,36 +3,12 @@ package keyexchange
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type SessionState int
|
||||
|
||||
const (
|
||||
SessionStateInitial SessionState = iota
|
||||
SessionStateInProgress
|
||||
SessionStateCompleted
|
||||
SessionStateError
|
||||
)
|
||||
|
||||
type MessageProcessor interface {
|
||||
Process(Protocol, *state.State) error
|
||||
}
|
||||
|
||||
type Protocol interface {
|
||||
Init(s *state.State) error
|
||||
GetKeys() (encKey types.Key, decKey types.Key, err error)
|
||||
Process(MessageProcessor, *state.State) error
|
||||
GetState() SessionState
|
||||
IsComplete() bool
|
||||
}
|
||||
|
||||
type ProtocolFactoryOption func(Protocol) error
|
||||
|
||||
func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption {
|
||||
return func(protocol Protocol) error {
|
||||
func WithKeySignature(keyProvider keyprovider.KeyProvider) interfaces.ProtocolFactoryOption {
|
||||
return func(protocol interfaces.Protocol) error {
|
||||
curveProtocol, ok := protocol.(*Curve25519Protocol)
|
||||
if !ok {
|
||||
return fmt.Errorf("WithKeySignature only supports Curve25519Protocol")
|
||||
@@ -46,4 +22,4 @@ func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption
|
||||
}
|
||||
}
|
||||
|
||||
type ProtocolFactory func(options ...ProtocolFactoryOption) (Protocol, error)
|
||||
type ProtocolFactory func(options ...interfaces.ProtocolFactoryOption) (interfaces.Protocol, error)
|
||||
|
@@ -1,92 +0,0 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
// Protocol constants
|
||||
const (
|
||||
InitProtocol message.Protocol = "curve25519.init"
|
||||
ResponseProtocol message.Protocol = "curve25519.response"
|
||||
ConfirmProtocol message.Protocol = "curve25519.confirm"
|
||||
)
|
||||
|
||||
type Init struct {
|
||||
interceptor.BaseMessage
|
||||
PublicKey types.PublicKey `json:"public_key"`
|
||||
Signature []byte `json:"signature"`
|
||||
SessionID types.SessionID `json:"session_id"`
|
||||
Salt types.Salt `json:"salt"`
|
||||
}
|
||||
|
||||
func NewInit(sender message.Sender, receiver message.Receiver, key types.PublicKey, sign []byte, sessionID types.SessionID, salt types.Salt) *Init {
|
||||
return &Init{
|
||||
BaseMessage: interceptor.NewBaseMessage(InitProtocol, sender, receiver),
|
||||
PublicKey: key,
|
||||
Signature: sign,
|
||||
SessionID: sessionID,
|
||||
Salt: salt,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
s, err := encrypt.GetState(_interceptor, connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return i.keyExchangeManager.Process(s, m)
|
||||
}
|
||||
|
||||
func (m *Init) Process(protocol keyexchange.Protocol, s *state.State) error {
|
||||
p, ok := protocol.(*keyexchange.Curve25519Protocol)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
if p.GetState() != keyexchange.SessionStateInitial {
|
||||
return encryptionerr.ErrInvalidSessionState
|
||||
}
|
||||
|
||||
sign := append(m.PublicKey[:], m.Salt[:]...)
|
||||
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
|
||||
return encryptionerr.ErrInvalidSignature
|
||||
}
|
||||
|
||||
p.salt = m.Salt
|
||||
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compute shared secret: %w", err)
|
||||
}
|
||||
|
||||
encKey, decKey, err := keyexchange.Derive(shared, p.salt, "") // TODO: ADD INFO STRING
|
||||
if err != nil {
|
||||
return fmt.Errorf("key derivation failed: %w", err)
|
||||
}
|
||||
|
||||
p.encKey = encKey
|
||||
p.decKey = decKey
|
||||
p.sessionID = m.SessionID
|
||||
|
||||
if err := s.Writer.Write(s.Connection, nil); err != nil {
|
||||
return err
|
||||
} // TODO: ADD RESPONSE MESSAGE
|
||||
|
||||
p.state = SessionStateCompleted
|
||||
return nil
|
||||
}
|
@@ -11,22 +11,23 @@ import (
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
InterceptorConfig config.Config // copy of interceptor config
|
||||
PeerID message.Receiver
|
||||
privKey types.PrivateKey // also in protocol curve25519protocol.go
|
||||
salt types.Salt // also in protocol curve25519protocol.go
|
||||
sessionID types.SessionID // encryption sessionID
|
||||
encryptor encryptor.Encryptor
|
||||
Connection interceptor.Connection
|
||||
Writer interceptor.Writer
|
||||
Reader interceptor.Reader
|
||||
currentConfig config.Config // copy of interceptor config
|
||||
peerID message.Receiver
|
||||
privKey types.PrivateKey // also in protocol curve25519protocol.go
|
||||
salt types.Salt // also in protocol curve25519protocol.go
|
||||
encryptSessionID types.EncryptionSessionID // encryption encryptSessionID
|
||||
encryptor interfaces.Encryptor
|
||||
connection interceptor.Connection
|
||||
writer interceptor.Writer
|
||||
reader interceptor.Reader
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
KeyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
|
||||
keyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -37,24 +38,56 @@ func NewState(ctx context.Context, cancel context.CancelFunc, config config.Conf
|
||||
}
|
||||
|
||||
return &State{
|
||||
InterceptorConfig: config,
|
||||
peerID: message.UnknownReceiver,
|
||||
privKey: types.PrivateKey{},
|
||||
salt: types.Salt{},
|
||||
encryptor: newEncryptor,
|
||||
Connection: connection,
|
||||
Writer: writer,
|
||||
Reader: reader,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
currentConfig: config,
|
||||
peerID: message.UnknownReceiver,
|
||||
privKey: types.PrivateKey{},
|
||||
salt: types.Salt{},
|
||||
encryptor: newEncryptor,
|
||||
connection: connection,
|
||||
writer: writer,
|
||||
reader: reader,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
|
||||
if s.KeyExchangeSessionID != "" {
|
||||
fmt.Println("KeyExchangeSessionID already exists; creating new")
|
||||
}
|
||||
s.KeyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return s.KeyExchangeSessionID
|
||||
if s.keyExchangeSessionID != "" {
|
||||
fmt.Println("keyExchangeSessionID already exists; creating new")
|
||||
}
|
||||
s.keyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
|
||||
|
||||
return s.keyExchangeSessionID
|
||||
}
|
||||
|
||||
func (s *State) WriteMessage(msg interceptor.Message) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
msg.SetReceiver(s.peerID)
|
||||
return s.writer.Write(s.connection, msg)
|
||||
}
|
||||
|
||||
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return s.keyExchangeSessionID
|
||||
}
|
||||
|
||||
func (s *State) GetConfig() config.Config {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return s.currentConfig
|
||||
}
|
||||
|
||||
func (s *State) SetKeys(encKey, decKey types.Key) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return s.encryptor.SetKeys(encKey, decKey)
|
||||
}
|
||||
|
@@ -3,16 +3,18 @@ package state
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/harshabose/socket-comm/internal/util"
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
states map[interceptor.Connection]*State
|
||||
states map[interceptor.Connection]interfaces.State
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
|
||||
func (m *Manager) GetState(connection interceptor.Connection) (interfaces.State, error) {
|
||||
m.mux.RLock()
|
||||
defer m.mux.RUnlock()
|
||||
|
||||
@@ -24,7 +26,7 @@ func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
|
||||
func (m *Manager) SetState(connection interceptor.Connection, s interfaces.State) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
@@ -36,31 +38,31 @@ func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveState removes a Connection's state
|
||||
func (m *Manager) RemoveState(connection interceptor.Connection) (*State, error) {
|
||||
// RemoveState removes a connection's state
|
||||
func (m *Manager) RemoveState(connection interceptor.Connection) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
state, exists := m.states[connection]
|
||||
_, exists := m.states[connection]
|
||||
if !exists {
|
||||
return nil, encryptionerr.ErrConnectionNotFound
|
||||
return encryptionerr.ErrConnectionNotFound
|
||||
}
|
||||
|
||||
delete(m.states, connection)
|
||||
return state, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForEach executes the provided function for each state in the manager
|
||||
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state *State) error) []error {
|
||||
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state interfaces.State) error) error {
|
||||
m.mux.RLock()
|
||||
defer m.mux.RUnlock()
|
||||
|
||||
var errs []error
|
||||
var errs util.MultiError
|
||||
for conn, state := range m.states {
|
||||
if err := fn(conn, state); err != nil {
|
||||
errs = append(errs, err)
|
||||
errs.Add(err)
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
return errs.ErrorOrNil()
|
||||
}
|
||||
|
@@ -5,7 +5,7 @@ type (
|
||||
PrivateKey [32]byte
|
||||
PublicKey [32]byte
|
||||
Salt [16]byte
|
||||
SessionID [16]byte
|
||||
EncryptionSessionID [16]byte
|
||||
Nonce [12]byte
|
||||
Key [32]byte
|
||||
KeyExchangeProtocol string
|
11
pkg/middleware/encrypt/types/keyexchange.go
Normal file
11
pkg/middleware/encrypt/types/keyexchange.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package types
|
||||
|
||||
type SessionState int
|
||||
|
||||
const (
|
||||
SessionStateNotStart SessionState = iota
|
||||
SessionStateInitial
|
||||
SessionStateInProgress
|
||||
SessionStateCompleted
|
||||
SessionStateError
|
||||
)
|
Reference in New Issue
Block a user