first commit

This commit is contained in:
harshabose
2025-04-26 14:36:25 +05:30
commit f6b6a7f70c
24 changed files with 1445 additions and 0 deletions

8
go.mod Normal file
View File

@@ -0,0 +1,8 @@
module github.com/harshabose/socket-comm
go 1.24
require (
github.com/google/uuid v1.6.0
golang.org/x/crypto v0.37.0
)

4
go.sum Normal file
View File

@@ -0,0 +1,4 @@
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=

88
internal/util/merr.go Normal file
View File

@@ -0,0 +1,88 @@
package util
import (
"errors"
"fmt"
"strings"
)
// MultiError is a collection of errors that implements the error interface
type MultiError struct {
errors []error
}
// NewMultiError creates a new empty MultiError
func NewMultiError() *MultiError {
return &MultiError{errors: []error{}}
}
// Add appends an error to the collection if it's not nil
func (multiErr *MultiError) Add(err error) {
if err != nil {
multiErr.errors = append(multiErr.errors, err)
}
}
// AddAll appends multiple errors to the collection, ignoring nil errors
func (multiErr *MultiError) AddAll(errs ...error) {
for _, err := range errs {
multiErr.Add(err)
}
}
// Len returns the number of errors in the collection
func (multiErr *MultiError) Len() int {
return len(multiErr.errors)
}
// Error implements the error interface
func (multiErr *MultiError) Error() string {
if multiErr.Len() == 0 {
return ""
}
if multiErr.Len() == 1 {
return multiErr.errors[0].Error()
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("%d errors occurred:\n", multiErr.Len()))
for i, err := range multiErr.errors {
sb.WriteString(fmt.Sprintf(" * %s\n", err.Error()))
if i < multiErr.Len()-1 {
sb.WriteString("\n")
}
}
return sb.String()
}
// ErrorOrNil returns nil if the collection is empty, otherwise returns the flattened MultiError
func (multiErr *MultiError) ErrorOrNil() error {
if multiErr.Len() == 0 {
return nil
}
return multiErr.Flatten()
}
// Errors returns all errors in the collection
func (multiErr *MultiError) Errors() []error {
return multiErr.errors
}
// Flatten returns a new MultiError with all nested MultiErrors flattened
func (multiErr *MultiError) Flatten() *MultiError {
flattened := NewMultiError()
for _, err := range multiErr.errors {
var merr *MultiError
if errors.As(err, &merr) {
flattened.AddAll(merr.Flatten().Errors()...)
} else {
flattened.Add(err)
}
}
return flattened
}

69
pkg/interceptor/chain.go Normal file
View File

@@ -0,0 +1,69 @@
package interceptor
import "github.com/harshabose/socket-comm/internal/util"
type Chain struct {
interceptors []Interceptor
}
func CreateChain(interceptors []Interceptor) *Chain {
return &Chain{interceptors: interceptors}
}
func (chain *Chain) BindSocketConnection(connection Connection, writer Writer, reader Reader) (Writer, Reader, error) {
var (
w Writer
r Reader
err error
)
for _, interceptor := range chain.interceptors {
if w, r, err = interceptor.BindSocketConnection(connection, chain.InterceptSocketWriter(writer), chain.InterceptSocketReader(reader)); err != nil {
return nil, nil, err
}
}
return w, r, nil
}
func (chain *Chain) Init(connection Connection) error {
for _, interceptor := range chain.interceptors {
if err := interceptor.Init(connection); err != nil {
return err
}
}
return nil
}
func (chain *Chain) InterceptSocketWriter(writer Writer) Writer {
for _, interceptor := range chain.interceptors {
writer = interceptor.InterceptSocketWriter(writer)
}
return writer
}
func (chain *Chain) InterceptSocketReader(reader Reader) Reader {
for _, interceptor := range chain.interceptors {
reader = interceptor.InterceptSocketReader(reader)
}
return reader
}
func (chain *Chain) UnBindSocketConnection(connection Connection) {
for _, interceptor := range chain.interceptors {
interceptor.UnBindSocketConnection(connection)
}
}
func (chain *Chain) Close() error {
var merr util.MultiError
for _, interceptor := range chain.interceptors {
merr.Add(interceptor.Close())
}
return merr.ErrorOrNil()
}

View File

@@ -0,0 +1,104 @@
package interceptor
import (
"context"
"io"
"sync"
)
type Registry struct {
factories []Factory
}
func (registry *Registry) Register(factory Factory) {
registry.factories = append(registry.factories, factory)
}
func (registry *Registry) Build(ctx context.Context, id string) (Interceptor, error) {
if len(registry.factories) == 0 {
return &NoOpInterceptor{}, nil
}
interceptors := make([]Interceptor, 0)
for _, factory := range registry.factories {
interceptor, err := factory.NewInterceptor(ctx, id)
if err != nil {
return nil, err
}
interceptors = append(interceptors, interceptor)
}
return CreateChain(interceptors), nil
}
type Factory interface {
NewInterceptor(context.Context, string) (Interceptor, error)
}
type Connection interface {
Write(ctx context.Context, p []byte) error
Read(ctx context.Context) ([]byte, error)
}
type Interceptor interface {
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
Init(Connection) error
InterceptSocketWriter(Writer) Writer
InterceptSocketReader(Reader) Reader
UnBindSocketConnection(Connection)
io.Closer
}
type Writer interface {
Write(conn Connection, message Message) error
}
type Reader interface {
Read(conn Connection) (Message, error)
}
type ReaderFunc func(conn Connection) (Message, error)
func (f ReaderFunc) Read(conn Connection) (Message, error) {
return f(conn)
}
type WriterFunc func(conn Connection, message Message) error
func (f WriterFunc) Write(conn Connection, message Message) error {
return f(conn, message)
}
type NoOpInterceptor struct {
ID string
Mutex sync.RWMutex
Ctx context.Context
}
func (interceptor *NoOpInterceptor) BindSocketConnection(_ Connection, _ Writer, _ Reader) (Writer, Reader, error) {
return nil, nil, nil
}
func (interceptor *NoOpInterceptor) Init(_ Connection) error {
return nil
}
func (interceptor *NoOpInterceptor) InterceptSocketWriter(writer Writer) Writer {
return writer
}
func (interceptor *NoOpInterceptor) InterceptSocketReader(reader Reader) Reader {
return reader
}
func (interceptor *NoOpInterceptor) UnBindSocketConnection(_ Connection) {}
func (interceptor *NoOpInterceptor) Close() error {
return nil
}

View File

@@ -0,0 +1,96 @@
// Package interceptor provides a middleware system for processing WebSocket messages.
// It builds on the message package to add processing capabilities to messages.
package interceptor
import "github.com/harshabose/socket-comm/pkg/message"
// Message extends the base message.Message interface with processing capabilities.
// This interface allows messages to be processed by interceptors in the communication chain.
// Types implementing this interface can define custom behavior for how they interact
// with specific interceptors.
type Message interface {
// Message Embed the base Message interface
message.Message
// WriteProcess handles interceptor processing for outgoing messages.
// This method is called when a message is being written to a connection.
//
// The implementation should handle any message-specific processing required
// for the given interceptor type. For example, an encryption message would
// encrypt data when this method is called.
//
// Parameters:
// - interceptor: The interceptor that should process this message
// - connection: The network connection associated with this message
//
// Returns an error if processing fails
WriteProcess(Interceptor, Connection) error
// ReadProcess handles interceptor processing for incoming messages.
// This method is called when a message is being read from a connection.
//
// The implementation should handle any message-specific processing required
// for the given interceptor type. For example, an encryption message would
// decrypt data when this method is called.
//
// Parameters:
// - interceptor: The interceptor that should process this message
// - connection: The network connection associated with this message
//
// Returns an error if processing fails
ReadProcess(Interceptor, Connection) error
}
// BaseMessage provides a default implementation of the Message interface.
// It embeds message.BaseMessage to inherit its functionality and adds
// a no-op Process method that can be overridden by specific message types.
//
// Custom interceptor message types should embed this struct and override
// the Process method with their specific processing logic.
type BaseMessage struct {
// Embed the base message implementation
message.BaseMessage
}
// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module
func NewBaseMessage(protocol message.Protocol, sender message.Sender, receiver message.Receiver) BaseMessage {
return BaseMessage{
BaseMessage: message.BaseMessage{
CurrentProtocol: protocol,
CurrentHeader: message.NewV1Header(sender, receiver),
NextProtocol: message.NoneProtocol,
},
}
}
// WriteProcess handles interceptor processing for outgoing messages.
// This method is called when a message is being written to a connection.
// It should be overridden by specific message types to implement
// their custom outgoing message processing logic.
//
// Parameters:
// - interceptor: The interceptor that should process this message
// - connection: The network connection associated with this message
//
// Returns nil by default, indicating no processing was performed
func (m *BaseMessage) WriteProcess(_ Interceptor, _ Connection) error {
// Default implementation does nothing
// Derived-types should override this method with specific processing logic
return nil
}
// ReadProcess handles interceptor processing for incoming messages.
// This method is called when a message is being read from a connection.
// It should be overridden by specific message types to implement
// their custom incoming message processing logic.
//
// Parameters:
// - interceptor: The interceptor that should process this message
// - connection: The network connection associated with this message
//
// Returns nil by default, indicating no processing was performed
func (m *BaseMessage) ReadProcess(_ Interceptor, _ Connection) error {
// Default implementation does nothing
// Derived-types should override this method with specific processing logic
return nil
}

19
pkg/message/errors.go Normal file
View File

@@ -0,0 +1,19 @@
// Package message provides a type-safe, extensible message system for WebSocket communication.
package message
import "errors"
// Error definitions for the message package
var (
// ErrNoProtocolMatch is returned when attempting to create or unmarshal
// a message with a protocol that is not registered
ErrNoProtocolMatch = errors.New("no protocol in the registry")
// ErrNoPayload is returned when a message has a non-none NextProtocol
// but is missing the corresponding NextPayload data
ErrNoPayload = errors.New("protocol is not none but payload is nil")
// ErrInvalidMessageData is returned when raw message data cannot be
// properly identified or does not contain a valid protocol field
ErrInvalidMessageData = errors.New("invalid message data")
)

117
pkg/message/message.go Normal file
View File

@@ -0,0 +1,117 @@
// Package message provides a type-safe, extensible message system for WebSocket communication.
// It implements a nested message structure that allows for message interception and transformation
// through an interceptor chain pattern.
package message
import (
"encoding/json"
)
// Type aliases for improved readability and type safety
type (
// Protocol identifies the message type or format
Protocol string
// Payload contains the serialized data of the message
Payload json.RawMessage
// Sender identifies the source of the message
Sender string
// Receiver identifies the intended recipient of the message
Receiver string
// Version specifies the message protocol version
Version string
)
// Protocol constants
const (
// NoneProtocol indicates no nested message exists
NoneProtocol Protocol = "none"
// Version1 is the current message protocol version
Version1 Version = "v1.0"
// UnknownReceiver receiverID initialising
UnknownReceiver Receiver = "unknown"
)
// Message defines the interface that all message types must implement.
// It provides methods for protocol identification, serialization, and
// message nesting/unwrapping.
type Message interface {
// GetProtocol returns the protocol identifier for this message
GetProtocol() Protocol
// GetNext retrieves the next message in the chain, if one exists
// Returns nil, nil if there is no next message
GetNext(Registry) (Message, error)
// Marshal serializes the message to JSON format
Marshal() ([]byte, error)
// Unmarshal deserializes the message from JSON format
Unmarshal([]byte) error
}
// Header contains common metadata for all messages
type Header struct {
Sender Sender `json:"sender"` // Sender identifies the message source
Receiver Receiver `json:"receiver"` // Receiver identifies the intended recipient
Version Version `json:"version"` // Version specifies the protocol version
}
// NewV1Header creates a new header with Version1
// This is a convenience constructor for common header creation
func NewV1Header(sender Sender, receiver Receiver) Header {
return Header{
Sender: sender,
Receiver: receiver,
Version: Version1,
}
}
// BaseMessage provides a foundation for all message types.
// It implements the Message interface and manages message nesting.
// Custom message types should embed this struct to inherit its functionality.
type BaseMessage struct {
// CURRENT MESSAGE PROCESSOR
CurrentProtocol Protocol `json:"protocol"` // CurrentProtocol identifies this message's type
CurrentHeader Header `json:"header"` // CurrentHeader contains metadata for this message
// CURRENT OTHER FIELDS...
// NEXT MESSAGE PROCESSOR
NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
}
// GetProtocol returns this message's protocol identifier
func (m *BaseMessage) GetProtocol() Protocol {
return m.CurrentProtocol
}
// GetNext retrieves the next message in the chain, if one exists.
// Returns nil, nil if NextProtocol is NoneProtocol.
// Uses the provided Registry to create and unmarshal the next message.
func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
if m.NextProtocol == NoneProtocol {
return nil, nil
}
if m.NextPayload == nil {
return nil, ErrNoPayload
}
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
}
// Marshal serializes the message to JSON format
func (m *BaseMessage) Marshal() ([]byte, error) {
return json.Marshal(m)
}
// Unmarshal deserializes the message from JSON format
func (m *BaseMessage) Unmarshal(data []byte) error {
return json.Unmarshal(data, m)
}

134
pkg/message/registry.go Normal file
View File

@@ -0,0 +1,134 @@
// Package message provides a type-safe, extensible message system for WebSocket communication.
package message
import (
"encoding/json"
"fmt"
"sync"
)
// Envelope is a lightweight struct used to extract just the protocol
// information from a raw message without full deserialization.
// This enables protocol-based routing of incoming messages.
type Envelope struct {
Protocol Protocol `json:"protocol"`
}
// Registry defines the interface for message type registration and instantiation.
// It provides a centralized mechanism for creating, deserializing, and inspecting messages.
type Registry interface {
// Register adds a message factory for a specific protocol
// Returns an error if the protocol is already registered
Register(Protocol, Factory) error
// Create instantiates a new message for the given protocol
// Returns an error if the protocol is not registered
Create(Protocol) (Message, error)
// Unmarshal creates and deserializes a message for a protocol
// The provided data is parsed into the appropriate message type
Unmarshal(Protocol, json.RawMessage) (Message, error)
// UnmarshalRaw deserializes a message when the protocol is unknown
// It first inspects the envelope to determine the protocol, then unmarshals accordingly
UnmarshalRaw(json.RawMessage) (Message, error)
}
// Factory defines the interface for creating new message instances.
// Each message type should have a corresponding factory.
type Factory interface {
// Create instantiates a new instance of a message type
Create() (Message, error)
}
// FactoryFunc is a function type that implements the Factory interface.
// It allows simple functions to be used as factories without creating a separate type.
type FactoryFunc func() Message
// Create implements the Factory interface for FactoryFunc
func (f FactoryFunc) Create() Message {
return f()
}
// DefaultRegistry provides a thread-safe implementation of the Registry interface.
// It maintains a map of protocols to their corresponding factories.
type DefaultRegistry struct {
factories map[Protocol]Factory
mux sync.RWMutex // Protects concurrent access to factories
}
// NewRegistry creates a new message registry
// The returned registry is ready to use but contains no registered message types
func NewRegistry() *DefaultRegistry {
return &DefaultRegistry{
factories: make(map[Protocol]Factory),
}
}
// Register adds a message factory for a protocol
// Returns an error if the protocol is already registered
// This method is thread-safe
func (r *DefaultRegistry) Register(protocol Protocol, factory Factory) error {
r.mux.Lock()
defer r.mux.Unlock()
if _, exists := r.factories[protocol]; exists {
return fmt.Errorf("protocol %s is already registered", protocol)
}
r.factories[protocol] = factory
return nil
}
// Create instantiates a new message for a protocol
// Returns an error if the protocol is not registered
// This method is thread-safe
func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) {
r.mux.RLock()
defer r.mux.RUnlock()
factory, exists := r.factories[protocol]
if !exists {
return nil, ErrNoProtocolMatch
}
msg, err := factory.Create()
if err != nil {
return nil, err
}
return msg, nil
}
// Unmarshal creates and deserializes a message for a protocol
// The provided data is parsed into the appropriate message type
// This method is thread-safe
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) {
msg, err := r.Create(protocol)
if err != nil {
return nil, err
}
if err := msg.Unmarshal(data); err != nil {
return nil, err
}
return msg, nil
}
// UnmarshalRaw deserializes a message when the protocol is unknown
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
// This method is particularly useful for handling incoming WebSocket messages
// This method is thread-safe
func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) {
var envelope Envelope
if err := json.Unmarshal(data, &envelope); err != nil {
return nil, fmt.Errorf("failed to extract protocol: %w", err)
}
if envelope.Protocol == "" {
return nil, ErrInvalidMessageData
}
return r.Unmarshal(envelope.Protocol, data)
}

View File

@@ -0,0 +1,136 @@
package config
import (
"time"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type Provider string
const (
EnvProvider Provider = "env"
FileProvider Provider = "file"
VaultProvider Provider = "vault"
)
// KeyProviderConfig defines the source of cryptographic keys
type KeyProviderConfig struct {
Provider Provider `json:"provider"`
VaultConfig *VaultKeyConfig `json:"vault_config,omitempty"`
FileConfig *FileKeyConfig `json:"file_config,omitempty"`
EnvConfig *EnvKeyConfig `json:"env_config,omitempty"`
}
type VaultKeyConfig struct {
Address string `json:"address"`
Path string `json:"path"`
TokenPath string `json:"token_path,omitempty"`
RoleID string `json:"role_id,omitempty"`
SecretID string `json:"secret_id,omitempty"`
KeyName string `json:"key_name"`
}
type FileKeyConfig struct {
KeyPath string `json:"key_path"`
PrivateKeyPath string `json:"private_key_path"`
PublicKeyPath string `json:"public_key_path"`
Permissions string `json:"permissions,omitempty"`
AutoCreate bool `json:"auto_create,omitempty"`
}
type EnvKeyConfig struct {
PrivateKeyVar string `json:"private_key_var"`
PublicKeyVar string `json:"public_key_var"`
}
// Config provides complete configuration for the encryption system
type Config struct {
// General settings
IsServer bool `json:"is_server"`
RequireEncryption bool `json:"require_encryption"`
DisableEncryption bool `json:"disable_encryption,omitempty"`
// Timeout settings
KeyExchangeTimeout time.Duration `json:"key_exchange_timeout"`
SessionTimeout time.Duration `json:"session_timeout"`
// Key rotation
EnableKeyRotation bool `json:"enable_key_rotation"`
KeyRotationInterval time.Duration `json:"key_rotation_interval,omitempty"`
// EncryptionProtocol settings
EncryptionProtocol types.EncryptionProtocol `json:"encryption_protocol"`
// KeyExchangeProtocolOptions []keyexchange.ProtocolFactoryOption
EncryptionFallbackProtocols []types.EncryptionProtocol `json:"encryption_fallback_protocols,omitempty"`
// Key management
KeyProvider KeyProviderConfig `json:"key_provider"`
// Security settings
ReplayProtection bool `json:"replay_protection"`
NonceReplayWindow time.Duration `json:"nonce_replay_window,omitempty"`
}
// DefaultConfig provides sensible defaults for the encryption system
func DefaultConfig() Config {
return Config{
RequireEncryption: true,
DisableEncryption: false,
KeyExchangeTimeout: 30 * time.Second,
SessionTimeout: 24 * time.Hour,
EnableKeyRotation: true,
KeyRotationInterval: 1 * time.Hour,
EncryptionProtocol: types.ProtocolV2,
// KeyExchangeProtocolOptions: []keyexchange.ProtocolFactoryOption{keyexchange.WithKeySignature()},
EncryptionFallbackProtocols: []types.EncryptionProtocol{types.ProtocolV1},
KeyProvider: KeyProviderConfig{
Provider: EnvProvider,
EnvConfig: &EnvKeyConfig{
PrivateKeyVar: "SERVER_ENCRYPT_PRIV_KEY",
PublicKeyVar: "SERVER_ENCRYPT_PUB_KEY",
},
},
ReplayProtection: true,
NonceReplayWindow: 5 * time.Minute,
}
}
// ValidateConfig checks if a configuration is valid and complete
func ValidateConfig(config Config) error {
if config.KeyExchangeTimeout < 5*time.Second {
return encryptionerr.ErrInvalidConfig
}
if config.RequireEncryption && config.DisableEncryption {
return encryptionerr.ErrInvalidConfig
}
if config.EnableKeyRotation && config.KeyRotationInterval < time.Minute {
return encryptionerr.ErrInvalidConfig
}
// Validate key provider configuration
switch config.KeyProvider.Provider {
case VaultProvider:
if config.KeyProvider.VaultConfig == nil {
return encryptionerr.ErrInvalidProvider
}
case FileProvider:
if config.KeyProvider.FileConfig == nil {
return encryptionerr.ErrInvalidProvider
}
case EnvProvider:
if config.KeyProvider.EnvConfig == nil {
return encryptionerr.ErrInvalidProvider
}
default:
return encryptionerr.ErrInvalidProvider
}
return nil
}

View File

@@ -0,0 +1,35 @@
package encryptionerr
import "errors"
// Common error definitions for the encryption interceptor
var (
// General errors
ErrConnectionNotFound = errors.New("connection not registered")
ErrConnectionExists = errors.New("connection already exists")
ErrInvalidInterceptor = errors.New("inappropriate interceptor for the payload")
// Key exchange errors
ErrInvalidMessageType = errors.New("message does not implement required key exchange MessageProcessor interface")
ErrKeyExchangeTimeout = errors.New("key exchange timed out")
ErrInvalidSignature = errors.New("signature verification failed")
ErrProtocolNotFound = errors.New("key exchange protocol not found")
ErrExchangeInProgress = errors.New("key exchange already in progress")
ErrSessionNotFound = errors.New("key exchange session not found")
ErrInvalidSessionState = errors.New("key exchange state is not valid")
ErrExchangeNotComplete = errors.New("key exchange not completed")
// Encryption errors
ErrEncryptionNotReady = errors.New("encryption not ready")
ErrInvalidKey = errors.New("invalid encryption key")
ErrInvalidNonce = errors.New("invalid nonce")
ErrNonceReused = errors.New("nonce has been used before")
// Configuration errors
ErrInvalidConfig = errors.New("invalid configuration")
ErrInvalidProvider = errors.New("invalid key provider")
// Security errors
ErrInvalidServerRequest = errors.New("invalid request to server")
ErrSecurityViolation = errors.New("security violation detected")
)

View File

@@ -0,0 +1,28 @@
package encryptor
import (
"io"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
// Encryptor defines the interface for message encryption and decryption
type Encryptor interface {
// SetKeys configures the encryption and decryption keys
SetKeys(encryptKey, decryptKey types.Key) error
// SetSessionID sets the session identifier for this encryption session
SetSessionID(id types.SessionID)
// Encrypt encrypts a message between sender and receiver
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)
// Decrypt decrypts an encrypted message in-place
Decrypt(message message.Message) error
// Ready checks if the encryptor is properly initialized and ready to use
Ready() bool
io.Closer
}

View File

@@ -0,0 +1,5 @@
package encryptor
func NewEncryptor(cipherSuite string) (Encryptor, error) {
return nil, nil
}

View File

@@ -0,0 +1,62 @@
package encrypt
import (
"context"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
)
type Interceptor struct {
interceptor.NoOpInterceptor
nonceValidator NonceValidator
keyExchangeManager keyexchange.Manager
keyProvider keyprovider.KeyProvider
stateManager state.Manager
config config.Config
}
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
ctx, cancel := context.WithCancel(i.Ctx)
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
if err != nil {
return nil, nil, err
}
if err := i.stateManager.SetState(connection, newState); err != nil {
return nil, nil, err
}
return writer, reader, nil
}
func (i *Interceptor) Init(connection interceptor.Connection) error {
_state, err := i.stateManager.GetState(connection)
if err != nil {
return err
}
if err := i.keyExchangeManager.Init(_state, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
return err
}
}
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
}
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
}
func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) {
}
func (i *Interceptor) Close() error {
}

View File

@@ -0,0 +1,16 @@
package encrypt
import (
"time"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
// NonceValidator provides protection against replay attacks
type NonceValidator interface {
// Validate checks if a nonce is valid and hasn't been seen before
Validate(nonce []byte, sessionID types.SessionID) error
// Cleanup removes expired nonces
Cleanup(before time.Time)
}

View File

@@ -0,0 +1,90 @@
package keyexchange
import (
"crypto/rand"
"io"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/messages"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type Curve25519Protocol struct {
privKey types.PrivateKey // also in protocol curve25519protocol.go
pubKey types.PublicKey
peerPubKey types.PublicKey
salt types.Salt // also in protocol curve25519protocol.go
sessionID types.SessionID
sharedSecret []byte
encKey types.Key
decKey types.Key
state SessionState
options Curve25519Options
// TODO: mutex is needed here for SessionState and Keys
}
type Curve25519Options struct {
SigningKey ed25519.PrivateKey
VerificationKey ed25519.PublicKey
RequireSignature bool
}
func (p *Curve25519Protocol) Init(s *state.State) error {
if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil {
return err
}
curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey))
if s.InterceptorConfig.IsServer && p.options.RequireSignature {
if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil {
p.state = SessionStateError
return err
}
if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil {
p.state = SessionStateError
return err
}
sign := ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
if err := s.Writer.Write(s.Connection, messages.NewInit("", s.PeerID, p.pubKey, sign, p.sessionID, p.salt)); err != nil {
p.state = SessionStateError
return err
}
}
p.state = SessionStateInProgress
return nil
}
func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) {
if p.state != SessionStateCompleted {
return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete
}
return p.encKey, p.decKey, nil
}
func (p *Curve25519Protocol) GetState() SessionState {
return p.state
}
func (p *Curve25519Protocol) IsComplete() bool {
return p.state == SessionStateCompleted
}
func (p *Curve25519Protocol) Process(msg MessageProcessor, s *state.State) error {
if err := msg.Process(p, s); err != nil {
p.state = SessionStateError
return err
}
return nil
}

View File

@@ -0,0 +1,82 @@
package keyexchange
import (
"crypto/sha256"
"fmt"
"io"
"time"
"golang.org/x/crypto/hkdf"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type Session struct {
protocol Protocol
state *state.State
createdAt time.Time
completedAt time.Time
}
type Manager struct {
registry map[types.KeyExchangeProtocol]ProtocolFactory
sessions map[types.KeyExchangeSessionID]*Session
}
func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
sessionID := s.GenerateKeyExchangeSessionID()
_, exists := m.sessions[sessionID]
if exists {
return encryptionerr.ErrExchangeInProgress
}
factory, exists := m.registry[s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol]
if !exists {
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol)
}
p, err := factory(options...)
if err != nil {
return err
}
if err := p.Init(s); err != nil {
return err
}
m.sessions[sessionID] = &Session{
protocol: p,
state: s,
createdAt: time.Now(),
}
return nil
}
func (m *Manager) Process(s *state.State, msg MessageProcessor) error {
session, exists := m.sessions[s.KeyExchangeSessionID]
if !exists {
return encryptionerr.ErrSessionNotFound
}
return session.protocol.Process(msg, s)
}
// Derive generates encryption keys from shared secret
func Derive(shared []byte, salt types.Salt, info string) (types.Key, types.Key, error) {
hkdfReader := hkdf.New(sha256.New, shared, salt[:], []byte(info))
key1 := types.Key{}
if _, err := io.ReadFull(hkdfReader, key1[:]); err != nil {
return types.Key{}, types.Key{}, err
}
key2 := types.Key{}
if _, err := io.ReadFull(hkdfReader, key2[:]); err != nil {
return types.Key{}, types.Key{}, err
}
return key1, key2, nil
}

View File

@@ -0,0 +1,49 @@
package keyexchange
import (
"fmt"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type SessionState int
const (
SessionStateInitial SessionState = iota
SessionStateInProgress
SessionStateCompleted
SessionStateError
)
type MessageProcessor interface {
Process(Protocol, *state.State) error
}
type Protocol interface {
Init(s *state.State) error
GetKeys() (encKey types.Key, decKey types.Key, err error)
Process(MessageProcessor, *state.State) error
GetState() SessionState
IsComplete() bool
}
type ProtocolFactoryOption func(Protocol) error
func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption {
return func(protocol Protocol) error {
curveProtocol, ok := protocol.(*Curve25519Protocol)
if !ok {
return fmt.Errorf("WithKeySignature only supports Curve25519Protocol")
}
curveProtocol.options.SigningKey = keyProvider.GetSigningKey()
curveProtocol.options.VerificationKey = keyProvider.GetVerificationKey()
curveProtocol.options.RequireSignature = true
return nil
}
}
type ProtocolFactory func(options ...ProtocolFactoryOption) (Protocol, error)

View File

@@ -0,0 +1,10 @@
package keyprovider
import "crypto/ed25519"
// KeyProvider interface for secure access to cryptographic keys
type KeyProvider interface {
GetSigningKey() ed25519.PrivateKey
GetVerificationKey() ed25519.PublicKey
Close() error
}

View File

@@ -0,0 +1,97 @@
package messages
import (
"fmt"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
// Protocol constants
const (
InitProtocol message.Protocol = "curve25519.init"
ResponseProtocol message.Protocol = "curve25519.response"
ConfirmProtocol message.Protocol = "curve25519.confirm"
)
type Init struct {
interceptor.BaseMessage
PublicKey types.PublicKey `json:"public_key"`
Signature []byte `json:"signature"`
SessionID types.SessionID `json:"session_id"`
Salt types.Salt `json:"salt"`
}
func NewInit(sender message.Sender, receiver message.Receiver, key types.PublicKey, sign []byte, sessionID types.SessionID, salt types.Salt) *Init {
return &Init{
BaseMessage: interceptor.NewBaseMessage(InitProtocol, sender, receiver),
PublicKey: key,
Signature: sign,
SessionID: sessionID,
Salt: salt,
}
}
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
i, ok := _interceptor.(*encrypt.Interceptor)
if !ok {
return encryptionerr.ErrInvalidInterceptor
}
s, err := i.stateManager.GetState(connection)
if err != nil {
return err
}
return i.keyExchangeManager.Process(s, m)
}
func (m *Init) Process(protocol keyexchange.Protocol, s *state.State) error {
p, ok := protocol.(*keyexchange.Curve25519Protocol)
if !ok {
return encryptionerr.ErrInvalidMessageType
}
if p.GetState() != keyexchange.SessionStateInitial {
return encryptionerr.ErrInvalidSessionState
}
sign := append(m.PublicKey[:], m.Salt[:]...)
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
return encryptionerr.ErrInvalidSignature
}
p.salt = m.Salt
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
if err != nil {
return fmt.Errorf("failed to compute shared secret: %w", err)
}
encKey, decKey, err := keyexchange.Derive(shared, p.salt, "") // TODO: ADD INFO STRING
if err != nil {
return fmt.Errorf("key derivation failed: %w", err)
}
p.encKey = encKey
p.decKey = decKey
p.sessionID = m.SessionID
if err := s.Writer.Write(s.Connection, nil); err != nil {
return err
} // TODO: ADD RESPONSE MESSAGE
p.state = SessionStateCompleted
return nil
}

View File

@@ -0,0 +1,60 @@
package state
import (
"context"
"fmt"
"sync"
"github.com/google/uuid"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
)
type State struct {
InterceptorConfig config.Config // copy of interceptor config
PeerID message.Receiver
privKey types.PrivateKey // also in protocol curve25519protocol.go
salt types.Salt // also in protocol curve25519protocol.go
sessionID types.SessionID // encryption sessionID
encryptor encryptor.Encryptor
Connection interceptor.Connection
Writer interceptor.Writer
Reader interceptor.Reader
cancel context.CancelFunc
ctx context.Context
KeyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
mux sync.RWMutex
}
func NewState(ctx context.Context, cancel context.CancelFunc, config config.Config, connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (*State, error) {
newEncryptor, err := encryptor.NewEncryptor(config.EncryptionProtocol.CipherSuite)
if err != nil {
return nil, err
}
return &State{
InterceptorConfig: config,
peerID: message.UnknownReceiver,
privKey: types.PrivateKey{},
salt: types.Salt{},
encryptor: newEncryptor,
Connection: connection,
Writer: writer,
Reader: reader,
cancel: cancel,
ctx: ctx,
}, nil
}
func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
if s.KeyExchangeSessionID != "" {
fmt.Println("KeyExchangeSessionID already exists; creating new")
}
s.KeyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
return s.KeyExchangeSessionID
}

View File

@@ -0,0 +1,66 @@
package state
import (
"sync"
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
)
type Manager struct {
states map[interceptor.Connection]*State
mux sync.RWMutex
}
func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
m.mux.RLock()
defer m.mux.RUnlock()
state, exists := m.states[connection]
if !exists {
return nil, encryptionerr.ErrConnectionNotFound
}
return state, nil
}
func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
m.mux.Lock()
defer m.mux.Unlock()
if _, exists := m.states[connection]; exists {
return encryptionerr.ErrConnectionExists
}
m.states[connection] = s
return nil
}
// RemoveState removes a Connection's state
func (m *Manager) RemoveState(connection interceptor.Connection) (*State, error) {
m.mux.Lock()
defer m.mux.Unlock()
state, exists := m.states[connection]
if !exists {
return nil, encryptionerr.ErrConnectionNotFound
}
delete(m.states, connection)
return state, nil
}
// ForEach executes the provided function for each state in the manager
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state *State) error) []error {
m.mux.RLock()
defer m.mux.RUnlock()
var errs []error
for conn, state := range m.states {
if err := fn(conn, state); err != nil {
errs = append(errs, err)
}
}
return errs
}

View File

@@ -0,0 +1,51 @@
package types
// Crypto-related type definitions for improved type safety
type (
PrivateKey [32]byte
PublicKey [32]byte
Salt [16]byte
SessionID [16]byte
Nonce [12]byte
Key [32]byte
KeyExchangeProtocol string
KeyExchangeSessionID string
)
// EncryptionProtocol defines the capabilities of a specific protocol version
type EncryptionProtocol struct {
Version uint8 `json:"version"`
CipherSuite string `json:"cipher_suite"`
KeyExchangeProtocol KeyExchangeProtocol `json:"key_exchange"`
Authenticated bool `json:"authenticated"`
SupportedExtensions map[string]string `json:"supported_extensions,omitempty"`
}
// Protocol versions
var (
// ProtocolV1 is the original protocol
ProtocolV1 = EncryptionProtocol{
Version: 1,
CipherSuite: "AES-256-GCM",
KeyExchangeProtocol: "curve25519-ed25519",
Authenticated: true,
}
// ProtocolV2 adds support for key rotation
ProtocolV2 = EncryptionProtocol{
Version: 2,
CipherSuite: "AES-256-GCM",
KeyExchangeProtocol: "curve25519-ed25519",
Authenticated: true,
SupportedExtensions: map[string]string{
"key_rotation": "supported",
"forward_secrecy": "enabled",
},
}
)
// IsZero is a generic function to check if a value is the zero value for its type
func IsZero[T comparable](value T) bool {
var zero T
return value == zero
}

View File

@@ -0,0 +1,19 @@
package encrypt
import (
"fmt"
"time"
)
// FormatDuration formats a duration in a human-readable way
func FormatDuration(d time.Duration) string {
if d < time.Second {
return fmt.Sprintf("%d ms", d.Milliseconds())
} else if d < time.Minute {
return fmt.Sprintf("%.1f s", d.Seconds())
} else if d < time.Hour {
return fmt.Sprintf("%.1f min", d.Minutes())
} else {
return fmt.Sprintf("%.1f h", d.Hours())
}
}