mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-07 00:22:49 +08:00
first commit
This commit is contained in:
8
go.mod
Normal file
8
go.mod
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
module github.com/harshabose/socket-comm
|
||||||
|
|
||||||
|
go 1.24
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
golang.org/x/crypto v0.37.0
|
||||||
|
)
|
4
go.sum
Normal file
4
go.sum
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||||
|
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
88
internal/util/merr.go
Normal file
88
internal/util/merr.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MultiError is a collection of errors that implements the error interface
|
||||||
|
type MultiError struct {
|
||||||
|
errors []error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMultiError creates a new empty MultiError
|
||||||
|
func NewMultiError() *MultiError {
|
||||||
|
return &MultiError{errors: []error{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add appends an error to the collection if it's not nil
|
||||||
|
func (multiErr *MultiError) Add(err error) {
|
||||||
|
if err != nil {
|
||||||
|
multiErr.errors = append(multiErr.errors, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAll appends multiple errors to the collection, ignoring nil errors
|
||||||
|
func (multiErr *MultiError) AddAll(errs ...error) {
|
||||||
|
for _, err := range errs {
|
||||||
|
multiErr.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of errors in the collection
|
||||||
|
func (multiErr *MultiError) Len() int {
|
||||||
|
return len(multiErr.errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface
|
||||||
|
func (multiErr *MultiError) Error() string {
|
||||||
|
if multiErr.Len() == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if multiErr.Len() == 1 {
|
||||||
|
return multiErr.errors[0].Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString(fmt.Sprintf("%d errors occurred:\n", multiErr.Len()))
|
||||||
|
|
||||||
|
for i, err := range multiErr.errors {
|
||||||
|
sb.WriteString(fmt.Sprintf(" * %s\n", err.Error()))
|
||||||
|
if i < multiErr.Len()-1 {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorOrNil returns nil if the collection is empty, otherwise returns the flattened MultiError
|
||||||
|
func (multiErr *MultiError) ErrorOrNil() error {
|
||||||
|
if multiErr.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return multiErr.Flatten()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errors returns all errors in the collection
|
||||||
|
func (multiErr *MultiError) Errors() []error {
|
||||||
|
return multiErr.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten returns a new MultiError with all nested MultiErrors flattened
|
||||||
|
func (multiErr *MultiError) Flatten() *MultiError {
|
||||||
|
flattened := NewMultiError()
|
||||||
|
|
||||||
|
for _, err := range multiErr.errors {
|
||||||
|
var merr *MultiError
|
||||||
|
if errors.As(err, &merr) {
|
||||||
|
flattened.AddAll(merr.Flatten().Errors()...)
|
||||||
|
} else {
|
||||||
|
flattened.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return flattened
|
||||||
|
}
|
69
pkg/interceptor/chain.go
Normal file
69
pkg/interceptor/chain.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package interceptor
|
||||||
|
|
||||||
|
import "github.com/harshabose/socket-comm/internal/util"
|
||||||
|
|
||||||
|
type Chain struct {
|
||||||
|
interceptors []Interceptor
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateChain(interceptors []Interceptor) *Chain {
|
||||||
|
return &Chain{interceptors: interceptors}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chain *Chain) BindSocketConnection(connection Connection, writer Writer, reader Reader) (Writer, Reader, error) {
|
||||||
|
var (
|
||||||
|
w Writer
|
||||||
|
r Reader
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, interceptor := range chain.interceptors {
|
||||||
|
if w, r, err = interceptor.BindSocketConnection(connection, chain.InterceptSocketWriter(writer), chain.InterceptSocketReader(reader)); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return w, r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chain *Chain) Init(connection Connection) error {
|
||||||
|
for _, interceptor := range chain.interceptors {
|
||||||
|
if err := interceptor.Init(connection); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chain *Chain) InterceptSocketWriter(writer Writer) Writer {
|
||||||
|
for _, interceptor := range chain.interceptors {
|
||||||
|
writer = interceptor.InterceptSocketWriter(writer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chain *Chain) InterceptSocketReader(reader Reader) Reader {
|
||||||
|
for _, interceptor := range chain.interceptors {
|
||||||
|
reader = interceptor.InterceptSocketReader(reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
return reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chain *Chain) UnBindSocketConnection(connection Connection) {
|
||||||
|
for _, interceptor := range chain.interceptors {
|
||||||
|
interceptor.UnBindSocketConnection(connection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chain *Chain) Close() error {
|
||||||
|
var merr util.MultiError
|
||||||
|
|
||||||
|
for _, interceptor := range chain.interceptors {
|
||||||
|
merr.Add(interceptor.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
return merr.ErrorOrNil()
|
||||||
|
}
|
104
pkg/interceptor/interceptor.go
Normal file
104
pkg/interceptor/interceptor.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package interceptor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Registry struct {
|
||||||
|
factories []Factory
|
||||||
|
}
|
||||||
|
|
||||||
|
func (registry *Registry) Register(factory Factory) {
|
||||||
|
registry.factories = append(registry.factories, factory)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (registry *Registry) Build(ctx context.Context, id string) (Interceptor, error) {
|
||||||
|
if len(registry.factories) == 0 {
|
||||||
|
return &NoOpInterceptor{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
interceptors := make([]Interceptor, 0)
|
||||||
|
for _, factory := range registry.factories {
|
||||||
|
interceptor, err := factory.NewInterceptor(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
interceptors = append(interceptors, interceptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
return CreateChain(interceptors), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Factory interface {
|
||||||
|
NewInterceptor(context.Context, string) (Interceptor, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Connection interface {
|
||||||
|
Write(ctx context.Context, p []byte) error
|
||||||
|
Read(ctx context.Context) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Interceptor interface {
|
||||||
|
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
|
||||||
|
|
||||||
|
Init(Connection) error
|
||||||
|
|
||||||
|
InterceptSocketWriter(Writer) Writer
|
||||||
|
|
||||||
|
InterceptSocketReader(Reader) Reader
|
||||||
|
|
||||||
|
UnBindSocketConnection(Connection)
|
||||||
|
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
type Writer interface {
|
||||||
|
Write(conn Connection, message Message) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Reader interface {
|
||||||
|
Read(conn Connection) (Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReaderFunc func(conn Connection) (Message, error)
|
||||||
|
|
||||||
|
func (f ReaderFunc) Read(conn Connection) (Message, error) {
|
||||||
|
return f(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WriterFunc func(conn Connection, message Message) error
|
||||||
|
|
||||||
|
func (f WriterFunc) Write(conn Connection, message Message) error {
|
||||||
|
return f(conn, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
type NoOpInterceptor struct {
|
||||||
|
ID string
|
||||||
|
Mutex sync.RWMutex
|
||||||
|
Ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) BindSocketConnection(_ Connection, _ Writer, _ Reader) (Writer, Reader, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) Init(_ Connection) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) InterceptSocketWriter(writer Writer) Writer {
|
||||||
|
return writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) InterceptSocketReader(reader Reader) Reader {
|
||||||
|
return reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) UnBindSocketConnection(_ Connection) {}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
96
pkg/interceptor/message.go
Normal file
96
pkg/interceptor/message.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
// Package interceptor provides a middleware system for processing WebSocket messages.
|
||||||
|
// It builds on the message package to add processing capabilities to messages.
|
||||||
|
package interceptor
|
||||||
|
|
||||||
|
import "github.com/harshabose/socket-comm/pkg/message"
|
||||||
|
|
||||||
|
// Message extends the base message.Message interface with processing capabilities.
|
||||||
|
// This interface allows messages to be processed by interceptors in the communication chain.
|
||||||
|
// Types implementing this interface can define custom behavior for how they interact
|
||||||
|
// with specific interceptors.
|
||||||
|
type Message interface {
|
||||||
|
// Message Embed the base Message interface
|
||||||
|
message.Message
|
||||||
|
|
||||||
|
// WriteProcess handles interceptor processing for outgoing messages.
|
||||||
|
// This method is called when a message is being written to a connection.
|
||||||
|
//
|
||||||
|
// The implementation should handle any message-specific processing required
|
||||||
|
// for the given interceptor type. For example, an encryption message would
|
||||||
|
// encrypt data when this method is called.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - interceptor: The interceptor that should process this message
|
||||||
|
// - connection: The network connection associated with this message
|
||||||
|
//
|
||||||
|
// Returns an error if processing fails
|
||||||
|
WriteProcess(Interceptor, Connection) error
|
||||||
|
|
||||||
|
// ReadProcess handles interceptor processing for incoming messages.
|
||||||
|
// This method is called when a message is being read from a connection.
|
||||||
|
//
|
||||||
|
// The implementation should handle any message-specific processing required
|
||||||
|
// for the given interceptor type. For example, an encryption message would
|
||||||
|
// decrypt data when this method is called.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - interceptor: The interceptor that should process this message
|
||||||
|
// - connection: The network connection associated with this message
|
||||||
|
//
|
||||||
|
// Returns an error if processing fails
|
||||||
|
ReadProcess(Interceptor, Connection) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseMessage provides a default implementation of the Message interface.
|
||||||
|
// It embeds message.BaseMessage to inherit its functionality and adds
|
||||||
|
// a no-op Process method that can be overridden by specific message types.
|
||||||
|
//
|
||||||
|
// Custom interceptor message types should embed this struct and override
|
||||||
|
// the Process method with their specific processing logic.
|
||||||
|
type BaseMessage struct {
|
||||||
|
// Embed the base message implementation
|
||||||
|
message.BaseMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module
|
||||||
|
func NewBaseMessage(protocol message.Protocol, sender message.Sender, receiver message.Receiver) BaseMessage {
|
||||||
|
return BaseMessage{
|
||||||
|
BaseMessage: message.BaseMessage{
|
||||||
|
CurrentProtocol: protocol,
|
||||||
|
CurrentHeader: message.NewV1Header(sender, receiver),
|
||||||
|
NextProtocol: message.NoneProtocol,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteProcess handles interceptor processing for outgoing messages.
|
||||||
|
// This method is called when a message is being written to a connection.
|
||||||
|
// It should be overridden by specific message types to implement
|
||||||
|
// their custom outgoing message processing logic.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - interceptor: The interceptor that should process this message
|
||||||
|
// - connection: The network connection associated with this message
|
||||||
|
//
|
||||||
|
// Returns nil by default, indicating no processing was performed
|
||||||
|
func (m *BaseMessage) WriteProcess(_ Interceptor, _ Connection) error {
|
||||||
|
// Default implementation does nothing
|
||||||
|
// Derived-types should override this method with specific processing logic
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadProcess handles interceptor processing for incoming messages.
|
||||||
|
// This method is called when a message is being read from a connection.
|
||||||
|
// It should be overridden by specific message types to implement
|
||||||
|
// their custom incoming message processing logic.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - interceptor: The interceptor that should process this message
|
||||||
|
// - connection: The network connection associated with this message
|
||||||
|
//
|
||||||
|
// Returns nil by default, indicating no processing was performed
|
||||||
|
func (m *BaseMessage) ReadProcess(_ Interceptor, _ Connection) error {
|
||||||
|
// Default implementation does nothing
|
||||||
|
// Derived-types should override this method with specific processing logic
|
||||||
|
return nil
|
||||||
|
}
|
19
pkg/message/errors.go
Normal file
19
pkg/message/errors.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
// Package message provides a type-safe, extensible message system for WebSocket communication.
|
||||||
|
package message
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// Error definitions for the message package
|
||||||
|
var (
|
||||||
|
// ErrNoProtocolMatch is returned when attempting to create or unmarshal
|
||||||
|
// a message with a protocol that is not registered
|
||||||
|
ErrNoProtocolMatch = errors.New("no protocol in the registry")
|
||||||
|
|
||||||
|
// ErrNoPayload is returned when a message has a non-none NextProtocol
|
||||||
|
// but is missing the corresponding NextPayload data
|
||||||
|
ErrNoPayload = errors.New("protocol is not none but payload is nil")
|
||||||
|
|
||||||
|
// ErrInvalidMessageData is returned when raw message data cannot be
|
||||||
|
// properly identified or does not contain a valid protocol field
|
||||||
|
ErrInvalidMessageData = errors.New("invalid message data")
|
||||||
|
)
|
117
pkg/message/message.go
Normal file
117
pkg/message/message.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
// Package message provides a type-safe, extensible message system for WebSocket communication.
|
||||||
|
// It implements a nested message structure that allows for message interception and transformation
|
||||||
|
// through an interceptor chain pattern.
|
||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type aliases for improved readability and type safety
|
||||||
|
type (
|
||||||
|
// Protocol identifies the message type or format
|
||||||
|
Protocol string
|
||||||
|
|
||||||
|
// Payload contains the serialized data of the message
|
||||||
|
Payload json.RawMessage
|
||||||
|
|
||||||
|
// Sender identifies the source of the message
|
||||||
|
Sender string
|
||||||
|
|
||||||
|
// Receiver identifies the intended recipient of the message
|
||||||
|
Receiver string
|
||||||
|
|
||||||
|
// Version specifies the message protocol version
|
||||||
|
Version string
|
||||||
|
)
|
||||||
|
|
||||||
|
// Protocol constants
|
||||||
|
const (
|
||||||
|
// NoneProtocol indicates no nested message exists
|
||||||
|
NoneProtocol Protocol = "none"
|
||||||
|
|
||||||
|
// Version1 is the current message protocol version
|
||||||
|
Version1 Version = "v1.0"
|
||||||
|
|
||||||
|
// UnknownReceiver receiverID initialising
|
||||||
|
UnknownReceiver Receiver = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message defines the interface that all message types must implement.
|
||||||
|
// It provides methods for protocol identification, serialization, and
|
||||||
|
// message nesting/unwrapping.
|
||||||
|
type Message interface {
|
||||||
|
// GetProtocol returns the protocol identifier for this message
|
||||||
|
GetProtocol() Protocol
|
||||||
|
|
||||||
|
// GetNext retrieves the next message in the chain, if one exists
|
||||||
|
// Returns nil, nil if there is no next message
|
||||||
|
GetNext(Registry) (Message, error)
|
||||||
|
|
||||||
|
// Marshal serializes the message to JSON format
|
||||||
|
Marshal() ([]byte, error)
|
||||||
|
|
||||||
|
// Unmarshal deserializes the message from JSON format
|
||||||
|
Unmarshal([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header contains common metadata for all messages
|
||||||
|
type Header struct {
|
||||||
|
Sender Sender `json:"sender"` // Sender identifies the message source
|
||||||
|
Receiver Receiver `json:"receiver"` // Receiver identifies the intended recipient
|
||||||
|
Version Version `json:"version"` // Version specifies the protocol version
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewV1Header creates a new header with Version1
|
||||||
|
// This is a convenience constructor for common header creation
|
||||||
|
func NewV1Header(sender Sender, receiver Receiver) Header {
|
||||||
|
return Header{
|
||||||
|
Sender: sender,
|
||||||
|
Receiver: receiver,
|
||||||
|
Version: Version1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseMessage provides a foundation for all message types.
|
||||||
|
// It implements the Message interface and manages message nesting.
|
||||||
|
// Custom message types should embed this struct to inherit its functionality.
|
||||||
|
type BaseMessage struct {
|
||||||
|
// CURRENT MESSAGE PROCESSOR
|
||||||
|
CurrentProtocol Protocol `json:"protocol"` // CurrentProtocol identifies this message's type
|
||||||
|
CurrentHeader Header `json:"header"` // CurrentHeader contains metadata for this message
|
||||||
|
// CURRENT OTHER FIELDS...
|
||||||
|
|
||||||
|
// NEXT MESSAGE PROCESSOR
|
||||||
|
NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
|
||||||
|
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProtocol returns this message's protocol identifier
|
||||||
|
func (m *BaseMessage) GetProtocol() Protocol {
|
||||||
|
return m.CurrentProtocol
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNext retrieves the next message in the chain, if one exists.
|
||||||
|
// Returns nil, nil if NextProtocol is NoneProtocol.
|
||||||
|
// Uses the provided Registry to create and unmarshal the next message.
|
||||||
|
func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
|
||||||
|
if m.NextProtocol == NoneProtocol {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.NextPayload == nil {
|
||||||
|
return nil, ErrNoPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal serializes the message to JSON format
|
||||||
|
func (m *BaseMessage) Marshal() ([]byte, error) {
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal deserializes the message from JSON format
|
||||||
|
func (m *BaseMessage) Unmarshal(data []byte) error {
|
||||||
|
return json.Unmarshal(data, m)
|
||||||
|
}
|
134
pkg/message/registry.go
Normal file
134
pkg/message/registry.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
// Package message provides a type-safe, extensible message system for WebSocket communication.
|
||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Envelope is a lightweight struct used to extract just the protocol
|
||||||
|
// information from a raw message without full deserialization.
|
||||||
|
// This enables protocol-based routing of incoming messages.
|
||||||
|
type Envelope struct {
|
||||||
|
Protocol Protocol `json:"protocol"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registry defines the interface for message type registration and instantiation.
|
||||||
|
// It provides a centralized mechanism for creating, deserializing, and inspecting messages.
|
||||||
|
type Registry interface {
|
||||||
|
// Register adds a message factory for a specific protocol
|
||||||
|
// Returns an error if the protocol is already registered
|
||||||
|
Register(Protocol, Factory) error
|
||||||
|
|
||||||
|
// Create instantiates a new message for the given protocol
|
||||||
|
// Returns an error if the protocol is not registered
|
||||||
|
Create(Protocol) (Message, error)
|
||||||
|
|
||||||
|
// Unmarshal creates and deserializes a message for a protocol
|
||||||
|
// The provided data is parsed into the appropriate message type
|
||||||
|
Unmarshal(Protocol, json.RawMessage) (Message, error)
|
||||||
|
|
||||||
|
// UnmarshalRaw deserializes a message when the protocol is unknown
|
||||||
|
// It first inspects the envelope to determine the protocol, then unmarshals accordingly
|
||||||
|
UnmarshalRaw(json.RawMessage) (Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Factory defines the interface for creating new message instances.
|
||||||
|
// Each message type should have a corresponding factory.
|
||||||
|
type Factory interface {
|
||||||
|
// Create instantiates a new instance of a message type
|
||||||
|
Create() (Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactoryFunc is a function type that implements the Factory interface.
|
||||||
|
// It allows simple functions to be used as factories without creating a separate type.
|
||||||
|
type FactoryFunc func() Message
|
||||||
|
|
||||||
|
// Create implements the Factory interface for FactoryFunc
|
||||||
|
func (f FactoryFunc) Create() Message {
|
||||||
|
return f()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRegistry provides a thread-safe implementation of the Registry interface.
|
||||||
|
// It maintains a map of protocols to their corresponding factories.
|
||||||
|
type DefaultRegistry struct {
|
||||||
|
factories map[Protocol]Factory
|
||||||
|
mux sync.RWMutex // Protects concurrent access to factories
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry creates a new message registry
|
||||||
|
// The returned registry is ready to use but contains no registered message types
|
||||||
|
func NewRegistry() *DefaultRegistry {
|
||||||
|
return &DefaultRegistry{
|
||||||
|
factories: make(map[Protocol]Factory),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a message factory for a protocol
|
||||||
|
// Returns an error if the protocol is already registered
|
||||||
|
// This method is thread-safe
|
||||||
|
func (r *DefaultRegistry) Register(protocol Protocol, factory Factory) error {
|
||||||
|
r.mux.Lock()
|
||||||
|
defer r.mux.Unlock()
|
||||||
|
|
||||||
|
if _, exists := r.factories[protocol]; exists {
|
||||||
|
return fmt.Errorf("protocol %s is already registered", protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.factories[protocol] = factory
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create instantiates a new message for a protocol
|
||||||
|
// Returns an error if the protocol is not registered
|
||||||
|
// This method is thread-safe
|
||||||
|
func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) {
|
||||||
|
r.mux.RLock()
|
||||||
|
defer r.mux.RUnlock()
|
||||||
|
|
||||||
|
factory, exists := r.factories[protocol]
|
||||||
|
if !exists {
|
||||||
|
return nil, ErrNoProtocolMatch
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := factory.Create()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal creates and deserializes a message for a protocol
|
||||||
|
// The provided data is parsed into the appropriate message type
|
||||||
|
// This method is thread-safe
|
||||||
|
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) {
|
||||||
|
msg, err := r.Create(protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := msg.Unmarshal(data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalRaw deserializes a message when the protocol is unknown
|
||||||
|
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
|
||||||
|
// This method is particularly useful for handling incoming WebSocket messages
|
||||||
|
// This method is thread-safe
|
||||||
|
func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) {
|
||||||
|
var envelope Envelope
|
||||||
|
if err := json.Unmarshal(data, &envelope); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to extract protocol: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if envelope.Protocol == "" {
|
||||||
|
return nil, ErrInvalidMessageData
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Unmarshal(envelope.Protocol, data)
|
||||||
|
}
|
136
pkg/middleware/encrypt/config/config.go
Normal file
136
pkg/middleware/encrypt/config/config.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Provider string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EnvProvider Provider = "env"
|
||||||
|
FileProvider Provider = "file"
|
||||||
|
VaultProvider Provider = "vault"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyProviderConfig defines the source of cryptographic keys
|
||||||
|
type KeyProviderConfig struct {
|
||||||
|
Provider Provider `json:"provider"`
|
||||||
|
VaultConfig *VaultKeyConfig `json:"vault_config,omitempty"`
|
||||||
|
FileConfig *FileKeyConfig `json:"file_config,omitempty"`
|
||||||
|
EnvConfig *EnvKeyConfig `json:"env_config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VaultKeyConfig struct {
|
||||||
|
Address string `json:"address"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
TokenPath string `json:"token_path,omitempty"`
|
||||||
|
RoleID string `json:"role_id,omitempty"`
|
||||||
|
SecretID string `json:"secret_id,omitempty"`
|
||||||
|
KeyName string `json:"key_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileKeyConfig struct {
|
||||||
|
KeyPath string `json:"key_path"`
|
||||||
|
PrivateKeyPath string `json:"private_key_path"`
|
||||||
|
PublicKeyPath string `json:"public_key_path"`
|
||||||
|
Permissions string `json:"permissions,omitempty"`
|
||||||
|
AutoCreate bool `json:"auto_create,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnvKeyConfig struct {
|
||||||
|
PrivateKeyVar string `json:"private_key_var"`
|
||||||
|
PublicKeyVar string `json:"public_key_var"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config provides complete configuration for the encryption system
|
||||||
|
type Config struct {
|
||||||
|
// General settings
|
||||||
|
IsServer bool `json:"is_server"`
|
||||||
|
RequireEncryption bool `json:"require_encryption"`
|
||||||
|
DisableEncryption bool `json:"disable_encryption,omitempty"`
|
||||||
|
|
||||||
|
// Timeout settings
|
||||||
|
KeyExchangeTimeout time.Duration `json:"key_exchange_timeout"`
|
||||||
|
SessionTimeout time.Duration `json:"session_timeout"`
|
||||||
|
|
||||||
|
// Key rotation
|
||||||
|
EnableKeyRotation bool `json:"enable_key_rotation"`
|
||||||
|
KeyRotationInterval time.Duration `json:"key_rotation_interval,omitempty"`
|
||||||
|
|
||||||
|
// EncryptionProtocol settings
|
||||||
|
EncryptionProtocol types.EncryptionProtocol `json:"encryption_protocol"`
|
||||||
|
// KeyExchangeProtocolOptions []keyexchange.ProtocolFactoryOption
|
||||||
|
EncryptionFallbackProtocols []types.EncryptionProtocol `json:"encryption_fallback_protocols,omitempty"`
|
||||||
|
|
||||||
|
// Key management
|
||||||
|
KeyProvider KeyProviderConfig `json:"key_provider"`
|
||||||
|
|
||||||
|
// Security settings
|
||||||
|
ReplayProtection bool `json:"replay_protection"`
|
||||||
|
NonceReplayWindow time.Duration `json:"nonce_replay_window,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig provides sensible defaults for the encryption system
|
||||||
|
func DefaultConfig() Config {
|
||||||
|
return Config{
|
||||||
|
RequireEncryption: true,
|
||||||
|
DisableEncryption: false,
|
||||||
|
KeyExchangeTimeout: 30 * time.Second,
|
||||||
|
SessionTimeout: 24 * time.Hour,
|
||||||
|
EnableKeyRotation: true,
|
||||||
|
KeyRotationInterval: 1 * time.Hour,
|
||||||
|
EncryptionProtocol: types.ProtocolV2,
|
||||||
|
// KeyExchangeProtocolOptions: []keyexchange.ProtocolFactoryOption{keyexchange.WithKeySignature()},
|
||||||
|
EncryptionFallbackProtocols: []types.EncryptionProtocol{types.ProtocolV1},
|
||||||
|
KeyProvider: KeyProviderConfig{
|
||||||
|
Provider: EnvProvider,
|
||||||
|
EnvConfig: &EnvKeyConfig{
|
||||||
|
PrivateKeyVar: "SERVER_ENCRYPT_PRIV_KEY",
|
||||||
|
PublicKeyVar: "SERVER_ENCRYPT_PUB_KEY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ReplayProtection: true,
|
||||||
|
NonceReplayWindow: 5 * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConfig checks if a configuration is valid and complete
|
||||||
|
func ValidateConfig(config Config) error {
|
||||||
|
if config.KeyExchangeTimeout < 5*time.Second {
|
||||||
|
return encryptionerr.ErrInvalidConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.RequireEncryption && config.DisableEncryption {
|
||||||
|
return encryptionerr.ErrInvalidConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnableKeyRotation && config.KeyRotationInterval < time.Minute {
|
||||||
|
return encryptionerr.ErrInvalidConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key provider configuration
|
||||||
|
switch config.KeyProvider.Provider {
|
||||||
|
case VaultProvider:
|
||||||
|
if config.KeyProvider.VaultConfig == nil {
|
||||||
|
return encryptionerr.ErrInvalidProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
case FileProvider:
|
||||||
|
if config.KeyProvider.FileConfig == nil {
|
||||||
|
return encryptionerr.ErrInvalidProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
case EnvProvider:
|
||||||
|
if config.KeyProvider.EnvConfig == nil {
|
||||||
|
return encryptionerr.ErrInvalidProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return encryptionerr.ErrInvalidProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
35
pkg/middleware/encrypt/encryptionerr/errors.go
Normal file
35
pkg/middleware/encrypt/encryptionerr/errors.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package encryptionerr
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// Common error definitions for the encryption interceptor
|
||||||
|
var (
|
||||||
|
// General errors
|
||||||
|
ErrConnectionNotFound = errors.New("connection not registered")
|
||||||
|
ErrConnectionExists = errors.New("connection already exists")
|
||||||
|
ErrInvalidInterceptor = errors.New("inappropriate interceptor for the payload")
|
||||||
|
|
||||||
|
// Key exchange errors
|
||||||
|
ErrInvalidMessageType = errors.New("message does not implement required key exchange MessageProcessor interface")
|
||||||
|
ErrKeyExchangeTimeout = errors.New("key exchange timed out")
|
||||||
|
ErrInvalidSignature = errors.New("signature verification failed")
|
||||||
|
ErrProtocolNotFound = errors.New("key exchange protocol not found")
|
||||||
|
ErrExchangeInProgress = errors.New("key exchange already in progress")
|
||||||
|
ErrSessionNotFound = errors.New("key exchange session not found")
|
||||||
|
ErrInvalidSessionState = errors.New("key exchange state is not valid")
|
||||||
|
ErrExchangeNotComplete = errors.New("key exchange not completed")
|
||||||
|
|
||||||
|
// Encryption errors
|
||||||
|
ErrEncryptionNotReady = errors.New("encryption not ready")
|
||||||
|
ErrInvalidKey = errors.New("invalid encryption key")
|
||||||
|
ErrInvalidNonce = errors.New("invalid nonce")
|
||||||
|
ErrNonceReused = errors.New("nonce has been used before")
|
||||||
|
|
||||||
|
// Configuration errors
|
||||||
|
ErrInvalidConfig = errors.New("invalid configuration")
|
||||||
|
ErrInvalidProvider = errors.New("invalid key provider")
|
||||||
|
|
||||||
|
// Security errors
|
||||||
|
ErrInvalidServerRequest = errors.New("invalid request to server")
|
||||||
|
ErrSecurityViolation = errors.New("security violation detected")
|
||||||
|
)
|
28
pkg/middleware/encrypt/encryptor/encryptor.go
Normal file
28
pkg/middleware/encrypt/encryptor/encryptor.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package encryptor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/message"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encryptor defines the interface for message encryption and decryption
|
||||||
|
type Encryptor interface {
|
||||||
|
// SetKeys configures the encryption and decryption keys
|
||||||
|
SetKeys(encryptKey, decryptKey types.Key) error
|
||||||
|
|
||||||
|
// SetSessionID sets the session identifier for this encryption session
|
||||||
|
SetSessionID(id types.SessionID)
|
||||||
|
|
||||||
|
// Encrypt encrypts a message between sender and receiver
|
||||||
|
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)
|
||||||
|
|
||||||
|
// Decrypt decrypts an encrypted message in-place
|
||||||
|
Decrypt(message message.Message) error
|
||||||
|
|
||||||
|
// Ready checks if the encryptor is properly initialized and ready to use
|
||||||
|
Ready() bool
|
||||||
|
|
||||||
|
io.Closer
|
||||||
|
}
|
5
pkg/middleware/encrypt/encryptor/factory.go
Normal file
5
pkg/middleware/encrypt/encryptor/factory.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package encryptor
|
||||||
|
|
||||||
|
func NewEncryptor(cipherSuite string) (Encryptor, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
62
pkg/middleware/encrypt/interceptor.go
Normal file
62
pkg/middleware/encrypt/interceptor.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package encrypt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Interceptor struct {
|
||||||
|
interceptor.NoOpInterceptor
|
||||||
|
nonceValidator NonceValidator
|
||||||
|
keyExchangeManager keyexchange.Manager
|
||||||
|
keyProvider keyprovider.KeyProvider
|
||||||
|
stateManager state.Manager
|
||||||
|
config config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
|
||||||
|
ctx, cancel := context.WithCancel(i.Ctx)
|
||||||
|
|
||||||
|
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := i.stateManager.SetState(connection, newState); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return writer, reader, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interceptor) Init(connection interceptor.Connection) error {
|
||||||
|
_state, err := i.stateManager.GetState(connection)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := i.keyExchangeManager.Init(_state, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interceptor) Close() error {
|
||||||
|
|
||||||
|
}
|
16
pkg/middleware/encrypt/interfaces.go
Normal file
16
pkg/middleware/encrypt/interfaces.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package encrypt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NonceValidator provides protection against replay attacks
|
||||||
|
type NonceValidator interface {
|
||||||
|
// Validate checks if a nonce is valid and hasn't been seen before
|
||||||
|
Validate(nonce []byte, sessionID types.SessionID) error
|
||||||
|
|
||||||
|
// Cleanup removes expired nonces
|
||||||
|
Cleanup(before time.Time)
|
||||||
|
}
|
90
pkg/middleware/encrypt/keyexchange/curve25519protocol.go
Normal file
90
pkg/middleware/encrypt/keyexchange/curve25519protocol.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package keyexchange
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/crypto/ed25519"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/messages"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Curve25519Protocol struct {
|
||||||
|
privKey types.PrivateKey // also in protocol curve25519protocol.go
|
||||||
|
pubKey types.PublicKey
|
||||||
|
peerPubKey types.PublicKey
|
||||||
|
salt types.Salt // also in protocol curve25519protocol.go
|
||||||
|
sessionID types.SessionID
|
||||||
|
sharedSecret []byte
|
||||||
|
encKey types.Key
|
||||||
|
decKey types.Key
|
||||||
|
state SessionState
|
||||||
|
options Curve25519Options
|
||||||
|
|
||||||
|
// TODO: mutex is needed here for SessionState and Keys
|
||||||
|
}
|
||||||
|
|
||||||
|
type Curve25519Options struct {
|
||||||
|
SigningKey ed25519.PrivateKey
|
||||||
|
VerificationKey ed25519.PublicKey
|
||||||
|
RequireSignature bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Curve25519Protocol) Init(s *state.State) error {
|
||||||
|
if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey))
|
||||||
|
|
||||||
|
if s.InterceptorConfig.IsServer && p.options.RequireSignature {
|
||||||
|
if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil {
|
||||||
|
p.state = SessionStateError
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil {
|
||||||
|
p.state = SessionStateError
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sign := ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
|
||||||
|
|
||||||
|
if err := s.Writer.Write(s.Connection, messages.NewInit("", s.PeerID, p.pubKey, sign, p.sessionID, p.salt)); err != nil {
|
||||||
|
p.state = SessionStateError
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.state = SessionStateInProgress
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) {
|
||||||
|
if p.state != SessionStateCompleted {
|
||||||
|
return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.encKey, p.decKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Curve25519Protocol) GetState() SessionState {
|
||||||
|
return p.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Curve25519Protocol) IsComplete() bool {
|
||||||
|
return p.state == SessionStateCompleted
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Curve25519Protocol) Process(msg MessageProcessor, s *state.State) error {
|
||||||
|
if err := msg.Process(p, s); err != nil {
|
||||||
|
p.state = SessionStateError
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
82
pkg/middleware/encrypt/keyexchange/keyexchange_manager.go
Normal file
82
pkg/middleware/encrypt/keyexchange/keyexchange_manager.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package keyexchange
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
protocol Protocol
|
||||||
|
state *state.State
|
||||||
|
createdAt time.Time
|
||||||
|
completedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
registry map[types.KeyExchangeProtocol]ProtocolFactory
|
||||||
|
sessions map[types.KeyExchangeSessionID]*Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
|
||||||
|
sessionID := s.GenerateKeyExchangeSessionID()
|
||||||
|
_, exists := m.sessions[sessionID]
|
||||||
|
if exists {
|
||||||
|
return encryptionerr.ErrExchangeInProgress
|
||||||
|
}
|
||||||
|
|
||||||
|
factory, exists := m.registry[s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := factory(options...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.Init(s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.sessions[sessionID] = &Session{
|
||||||
|
protocol: p,
|
||||||
|
state: s,
|
||||||
|
createdAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Process(s *state.State, msg MessageProcessor) error {
|
||||||
|
session, exists := m.sessions[s.KeyExchangeSessionID]
|
||||||
|
if !exists {
|
||||||
|
return encryptionerr.ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return session.protocol.Process(msg, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive generates encryption keys from shared secret
|
||||||
|
func Derive(shared []byte, salt types.Salt, info string) (types.Key, types.Key, error) {
|
||||||
|
hkdfReader := hkdf.New(sha256.New, shared, salt[:], []byte(info))
|
||||||
|
|
||||||
|
key1 := types.Key{}
|
||||||
|
if _, err := io.ReadFull(hkdfReader, key1[:]); err != nil {
|
||||||
|
return types.Key{}, types.Key{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
key2 := types.Key{}
|
||||||
|
if _, err := io.ReadFull(hkdfReader, key2[:]); err != nil {
|
||||||
|
return types.Key{}, types.Key{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return key1, key2, nil
|
||||||
|
}
|
49
pkg/middleware/encrypt/keyexchange/protocol.go
Normal file
49
pkg/middleware/encrypt/keyexchange/protocol.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package keyexchange
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
SessionStateInitial SessionState = iota
|
||||||
|
SessionStateInProgress
|
||||||
|
SessionStateCompleted
|
||||||
|
SessionStateError
|
||||||
|
)
|
||||||
|
|
||||||
|
type MessageProcessor interface {
|
||||||
|
Process(Protocol, *state.State) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Protocol interface {
|
||||||
|
Init(s *state.State) error
|
||||||
|
GetKeys() (encKey types.Key, decKey types.Key, err error)
|
||||||
|
Process(MessageProcessor, *state.State) error
|
||||||
|
GetState() SessionState
|
||||||
|
IsComplete() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProtocolFactoryOption func(Protocol) error
|
||||||
|
|
||||||
|
func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption {
|
||||||
|
return func(protocol Protocol) error {
|
||||||
|
curveProtocol, ok := protocol.(*Curve25519Protocol)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("WithKeySignature only supports Curve25519Protocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
curveProtocol.options.SigningKey = keyProvider.GetSigningKey()
|
||||||
|
curveProtocol.options.VerificationKey = keyProvider.GetVerificationKey()
|
||||||
|
curveProtocol.options.RequireSignature = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProtocolFactory func(options ...ProtocolFactoryOption) (Protocol, error)
|
10
pkg/middleware/encrypt/keyprovider/keyprovider.go
Normal file
10
pkg/middleware/encrypt/keyprovider/keyprovider.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package keyprovider
|
||||||
|
|
||||||
|
import "crypto/ed25519"
|
||||||
|
|
||||||
|
// KeyProvider interface for secure access to cryptographic keys
|
||||||
|
type KeyProvider interface {
|
||||||
|
GetSigningKey() ed25519.PrivateKey
|
||||||
|
GetVerificationKey() ed25519.PublicKey
|
||||||
|
Close() error
|
||||||
|
}
|
97
pkg/middleware/encrypt/messages/init.go
Normal file
97
pkg/middleware/encrypt/messages/init.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/crypto/ed25519"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/message"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Protocol constants
|
||||||
|
const (
|
||||||
|
InitProtocol message.Protocol = "curve25519.init"
|
||||||
|
ResponseProtocol message.Protocol = "curve25519.response"
|
||||||
|
ConfirmProtocol message.Protocol = "curve25519.confirm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Init struct {
|
||||||
|
interceptor.BaseMessage
|
||||||
|
PublicKey types.PublicKey `json:"public_key"`
|
||||||
|
Signature []byte `json:"signature"`
|
||||||
|
SessionID types.SessionID `json:"session_id"`
|
||||||
|
Salt types.Salt `json:"salt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInit(sender message.Sender, receiver message.Receiver, key types.PublicKey, sign []byte, sessionID types.SessionID, salt types.Salt) *Init {
|
||||||
|
return &Init{
|
||||||
|
BaseMessage: interceptor.NewBaseMessage(InitProtocol, sender, receiver),
|
||||||
|
PublicKey: key,
|
||||||
|
Signature: sign,
|
||||||
|
SessionID: sessionID,
|
||||||
|
Salt: salt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
|
||||||
|
i, ok := _interceptor.(*encrypt.Interceptor)
|
||||||
|
if !ok {
|
||||||
|
return encryptionerr.ErrInvalidInterceptor
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := i.stateManager.GetState(connection)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return i.keyExchangeManager.Process(s, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Init) Process(protocol keyexchange.Protocol, s *state.State) error {
|
||||||
|
p, ok := protocol.(*keyexchange.Curve25519Protocol)
|
||||||
|
if !ok {
|
||||||
|
return encryptionerr.ErrInvalidMessageType
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.GetState() != keyexchange.SessionStateInitial {
|
||||||
|
return encryptionerr.ErrInvalidSessionState
|
||||||
|
}
|
||||||
|
|
||||||
|
sign := append(m.PublicKey[:], m.Salt[:]...)
|
||||||
|
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
|
||||||
|
return encryptionerr.ErrInvalidSignature
|
||||||
|
}
|
||||||
|
|
||||||
|
p.salt = m.Salt
|
||||||
|
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to compute shared secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encKey, decKey, err := keyexchange.Derive(shared, p.salt, "") // TODO: ADD INFO STRING
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("key derivation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.encKey = encKey
|
||||||
|
p.decKey = decKey
|
||||||
|
p.sessionID = m.SessionID
|
||||||
|
|
||||||
|
if err := s.Writer.Write(s.Connection, nil); err != nil {
|
||||||
|
return err
|
||||||
|
} // TODO: ADD RESPONSE MESSAGE
|
||||||
|
|
||||||
|
p.state = SessionStateCompleted
|
||||||
|
return nil
|
||||||
|
}
|
60
pkg/middleware/encrypt/state/state.go
Normal file
60
pkg/middleware/encrypt/state/state.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/message"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
InterceptorConfig config.Config // copy of interceptor config
|
||||||
|
PeerID message.Receiver
|
||||||
|
privKey types.PrivateKey // also in protocol curve25519protocol.go
|
||||||
|
salt types.Salt // also in protocol curve25519protocol.go
|
||||||
|
sessionID types.SessionID // encryption sessionID
|
||||||
|
encryptor encryptor.Encryptor
|
||||||
|
Connection interceptor.Connection
|
||||||
|
Writer interceptor.Writer
|
||||||
|
Reader interceptor.Reader
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ctx context.Context
|
||||||
|
KeyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
|
||||||
|
mux sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewState(ctx context.Context, cancel context.CancelFunc, config config.Config, connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (*State, error) {
|
||||||
|
newEncryptor, err := encryptor.NewEncryptor(config.EncryptionProtocol.CipherSuite)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &State{
|
||||||
|
InterceptorConfig: config,
|
||||||
|
peerID: message.UnknownReceiver,
|
||||||
|
privKey: types.PrivateKey{},
|
||||||
|
salt: types.Salt{},
|
||||||
|
encryptor: newEncryptor,
|
||||||
|
Connection: connection,
|
||||||
|
Writer: writer,
|
||||||
|
Reader: reader,
|
||||||
|
cancel: cancel,
|
||||||
|
ctx: ctx,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
|
||||||
|
if s.KeyExchangeSessionID != "" {
|
||||||
|
fmt.Println("KeyExchangeSessionID already exists; creating new")
|
||||||
|
}
|
||||||
|
s.KeyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
|
||||||
|
|
||||||
|
return s.KeyExchangeSessionID
|
||||||
|
}
|
66
pkg/middleware/encrypt/state/state_manager.go
Normal file
66
pkg/middleware/encrypt/state/state_manager.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
states map[interceptor.Connection]*State
|
||||||
|
mux sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
|
||||||
|
m.mux.RLock()
|
||||||
|
defer m.mux.RUnlock()
|
||||||
|
|
||||||
|
state, exists := m.states[connection]
|
||||||
|
if !exists {
|
||||||
|
return nil, encryptionerr.ErrConnectionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.states[connection]; exists {
|
||||||
|
return encryptionerr.ErrConnectionExists
|
||||||
|
}
|
||||||
|
|
||||||
|
m.states[connection] = s
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveState removes a Connection's state
|
||||||
|
func (m *Manager) RemoveState(connection interceptor.Connection) (*State, error) {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
state, exists := m.states[connection]
|
||||||
|
if !exists {
|
||||||
|
return nil, encryptionerr.ErrConnectionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.states, connection)
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForEach executes the provided function for each state in the manager
|
||||||
|
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state *State) error) []error {
|
||||||
|
m.mux.RLock()
|
||||||
|
defer m.mux.RUnlock()
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for conn, state := range m.states {
|
||||||
|
if err := fn(conn, state); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
51
pkg/middleware/encrypt/types/types.go
Normal file
51
pkg/middleware/encrypt/types/types.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
// Crypto-related type definitions for improved type safety
|
||||||
|
type (
|
||||||
|
PrivateKey [32]byte
|
||||||
|
PublicKey [32]byte
|
||||||
|
Salt [16]byte
|
||||||
|
SessionID [16]byte
|
||||||
|
Nonce [12]byte
|
||||||
|
Key [32]byte
|
||||||
|
KeyExchangeProtocol string
|
||||||
|
KeyExchangeSessionID string
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncryptionProtocol defines the capabilities of a specific protocol version
|
||||||
|
type EncryptionProtocol struct {
|
||||||
|
Version uint8 `json:"version"`
|
||||||
|
CipherSuite string `json:"cipher_suite"`
|
||||||
|
KeyExchangeProtocol KeyExchangeProtocol `json:"key_exchange"`
|
||||||
|
Authenticated bool `json:"authenticated"`
|
||||||
|
SupportedExtensions map[string]string `json:"supported_extensions,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protocol versions
|
||||||
|
var (
|
||||||
|
// ProtocolV1 is the original protocol
|
||||||
|
ProtocolV1 = EncryptionProtocol{
|
||||||
|
Version: 1,
|
||||||
|
CipherSuite: "AES-256-GCM",
|
||||||
|
KeyExchangeProtocol: "curve25519-ed25519",
|
||||||
|
Authenticated: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtocolV2 adds support for key rotation
|
||||||
|
ProtocolV2 = EncryptionProtocol{
|
||||||
|
Version: 2,
|
||||||
|
CipherSuite: "AES-256-GCM",
|
||||||
|
KeyExchangeProtocol: "curve25519-ed25519",
|
||||||
|
Authenticated: true,
|
||||||
|
SupportedExtensions: map[string]string{
|
||||||
|
"key_rotation": "supported",
|
||||||
|
"forward_secrecy": "enabled",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsZero is a generic function to check if a value is the zero value for its type
|
||||||
|
func IsZero[T comparable](value T) bool {
|
||||||
|
var zero T
|
||||||
|
return value == zero
|
||||||
|
}
|
19
pkg/middleware/encrypt/util.go
Normal file
19
pkg/middleware/encrypt/util.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package encrypt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FormatDuration formats a duration in a human-readable way
|
||||||
|
func FormatDuration(d time.Duration) string {
|
||||||
|
if d < time.Second {
|
||||||
|
return fmt.Sprintf("%d ms", d.Milliseconds())
|
||||||
|
} else if d < time.Minute {
|
||||||
|
return fmt.Sprintf("%.1f s", d.Seconds())
|
||||||
|
} else if d < time.Hour {
|
||||||
|
return fmt.Sprintf("%.1f min", d.Minutes())
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%.1f h", d.Hours())
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user