mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-06 08:06:59 +08:00
first commit
This commit is contained in:
8
go.mod
Normal file
8
go.mod
Normal file
@@ -0,0 +1,8 @@
|
||||
module github.com/harshabose/socket-comm
|
||||
|
||||
go 1.24
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
)
|
4
go.sum
Normal file
4
go.sum
Normal file
@@ -0,0 +1,4 @@
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
88
internal/util/merr.go
Normal file
88
internal/util/merr.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MultiError is a collection of errors that implements the error interface
|
||||
type MultiError struct {
|
||||
errors []error
|
||||
}
|
||||
|
||||
// NewMultiError creates a new empty MultiError
|
||||
func NewMultiError() *MultiError {
|
||||
return &MultiError{errors: []error{}}
|
||||
}
|
||||
|
||||
// Add appends an error to the collection if it's not nil
|
||||
func (multiErr *MultiError) Add(err error) {
|
||||
if err != nil {
|
||||
multiErr.errors = append(multiErr.errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AddAll appends multiple errors to the collection, ignoring nil errors
|
||||
func (multiErr *MultiError) AddAll(errs ...error) {
|
||||
for _, err := range errs {
|
||||
multiErr.Add(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of errors in the collection
|
||||
func (multiErr *MultiError) Len() int {
|
||||
return len(multiErr.errors)
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (multiErr *MultiError) Error() string {
|
||||
if multiErr.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if multiErr.Len() == 1 {
|
||||
return multiErr.errors[0].Error()
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("%d errors occurred:\n", multiErr.Len()))
|
||||
|
||||
for i, err := range multiErr.errors {
|
||||
sb.WriteString(fmt.Sprintf(" * %s\n", err.Error()))
|
||||
if i < multiErr.Len()-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ErrorOrNil returns nil if the collection is empty, otherwise returns the flattened MultiError
|
||||
func (multiErr *MultiError) ErrorOrNil() error {
|
||||
if multiErr.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
return multiErr.Flatten()
|
||||
}
|
||||
|
||||
// Errors returns all errors in the collection
|
||||
func (multiErr *MultiError) Errors() []error {
|
||||
return multiErr.errors
|
||||
}
|
||||
|
||||
// Flatten returns a new MultiError with all nested MultiErrors flattened
|
||||
func (multiErr *MultiError) Flatten() *MultiError {
|
||||
flattened := NewMultiError()
|
||||
|
||||
for _, err := range multiErr.errors {
|
||||
var merr *MultiError
|
||||
if errors.As(err, &merr) {
|
||||
flattened.AddAll(merr.Flatten().Errors()...)
|
||||
} else {
|
||||
flattened.Add(err)
|
||||
}
|
||||
}
|
||||
|
||||
return flattened
|
||||
}
|
69
pkg/interceptor/chain.go
Normal file
69
pkg/interceptor/chain.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package interceptor
|
||||
|
||||
import "github.com/harshabose/socket-comm/internal/util"
|
||||
|
||||
type Chain struct {
|
||||
interceptors []Interceptor
|
||||
}
|
||||
|
||||
func CreateChain(interceptors []Interceptor) *Chain {
|
||||
return &Chain{interceptors: interceptors}
|
||||
}
|
||||
|
||||
func (chain *Chain) BindSocketConnection(connection Connection, writer Writer, reader Reader) (Writer, Reader, error) {
|
||||
var (
|
||||
w Writer
|
||||
r Reader
|
||||
err error
|
||||
)
|
||||
|
||||
for _, interceptor := range chain.interceptors {
|
||||
if w, r, err = interceptor.BindSocketConnection(connection, chain.InterceptSocketWriter(writer), chain.InterceptSocketReader(reader)); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return w, r, nil
|
||||
}
|
||||
|
||||
func (chain *Chain) Init(connection Connection) error {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
if err := interceptor.Init(connection); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (chain *Chain) InterceptSocketWriter(writer Writer) Writer {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
writer = interceptor.InterceptSocketWriter(writer)
|
||||
}
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
func (chain *Chain) InterceptSocketReader(reader Reader) Reader {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
reader = interceptor.InterceptSocketReader(reader)
|
||||
}
|
||||
|
||||
return reader
|
||||
}
|
||||
|
||||
func (chain *Chain) UnBindSocketConnection(connection Connection) {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
interceptor.UnBindSocketConnection(connection)
|
||||
}
|
||||
}
|
||||
|
||||
func (chain *Chain) Close() error {
|
||||
var merr util.MultiError
|
||||
|
||||
for _, interceptor := range chain.interceptors {
|
||||
merr.Add(interceptor.Close())
|
||||
}
|
||||
|
||||
return merr.ErrorOrNil()
|
||||
}
|
104
pkg/interceptor/interceptor.go
Normal file
104
pkg/interceptor/interceptor.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package interceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
factories []Factory
|
||||
}
|
||||
|
||||
func (registry *Registry) Register(factory Factory) {
|
||||
registry.factories = append(registry.factories, factory)
|
||||
}
|
||||
|
||||
func (registry *Registry) Build(ctx context.Context, id string) (Interceptor, error) {
|
||||
if len(registry.factories) == 0 {
|
||||
return &NoOpInterceptor{}, nil
|
||||
}
|
||||
|
||||
interceptors := make([]Interceptor, 0)
|
||||
for _, factory := range registry.factories {
|
||||
interceptor, err := factory.NewInterceptor(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
interceptors = append(interceptors, interceptor)
|
||||
}
|
||||
|
||||
return CreateChain(interceptors), nil
|
||||
}
|
||||
|
||||
type Factory interface {
|
||||
NewInterceptor(context.Context, string) (Interceptor, error)
|
||||
}
|
||||
|
||||
type Connection interface {
|
||||
Write(ctx context.Context, p []byte) error
|
||||
Read(ctx context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
type Interceptor interface {
|
||||
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
|
||||
|
||||
Init(Connection) error
|
||||
|
||||
InterceptSocketWriter(Writer) Writer
|
||||
|
||||
InterceptSocketReader(Reader) Reader
|
||||
|
||||
UnBindSocketConnection(Connection)
|
||||
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type Writer interface {
|
||||
Write(conn Connection, message Message) error
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
Read(conn Connection) (Message, error)
|
||||
}
|
||||
|
||||
type ReaderFunc func(conn Connection) (Message, error)
|
||||
|
||||
func (f ReaderFunc) Read(conn Connection) (Message, error) {
|
||||
return f(conn)
|
||||
}
|
||||
|
||||
type WriterFunc func(conn Connection, message Message) error
|
||||
|
||||
func (f WriterFunc) Write(conn Connection, message Message) error {
|
||||
return f(conn, message)
|
||||
}
|
||||
|
||||
type NoOpInterceptor struct {
|
||||
ID string
|
||||
Mutex sync.RWMutex
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) BindSocketConnection(_ Connection, _ Writer, _ Reader) (Writer, Reader, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) Init(_ Connection) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) InterceptSocketWriter(writer Writer) Writer {
|
||||
return writer
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) InterceptSocketReader(reader Reader) Reader {
|
||||
return reader
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) UnBindSocketConnection(_ Connection) {}
|
||||
|
||||
func (interceptor *NoOpInterceptor) Close() error {
|
||||
return nil
|
||||
}
|
96
pkg/interceptor/message.go
Normal file
96
pkg/interceptor/message.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Package interceptor provides a middleware system for processing WebSocket messages.
|
||||
// It builds on the message package to add processing capabilities to messages.
|
||||
package interceptor
|
||||
|
||||
import "github.com/harshabose/socket-comm/pkg/message"
|
||||
|
||||
// Message extends the base message.Message interface with processing capabilities.
|
||||
// This interface allows messages to be processed by interceptors in the communication chain.
|
||||
// Types implementing this interface can define custom behavior for how they interact
|
||||
// with specific interceptors.
|
||||
type Message interface {
|
||||
// Message Embed the base Message interface
|
||||
message.Message
|
||||
|
||||
// WriteProcess handles interceptor processing for outgoing messages.
|
||||
// This method is called when a message is being written to a connection.
|
||||
//
|
||||
// The implementation should handle any message-specific processing required
|
||||
// for the given interceptor type. For example, an encryption message would
|
||||
// encrypt data when this method is called.
|
||||
//
|
||||
// Parameters:
|
||||
// - interceptor: The interceptor that should process this message
|
||||
// - connection: The network connection associated with this message
|
||||
//
|
||||
// Returns an error if processing fails
|
||||
WriteProcess(Interceptor, Connection) error
|
||||
|
||||
// ReadProcess handles interceptor processing for incoming messages.
|
||||
// This method is called when a message is being read from a connection.
|
||||
//
|
||||
// The implementation should handle any message-specific processing required
|
||||
// for the given interceptor type. For example, an encryption message would
|
||||
// decrypt data when this method is called.
|
||||
//
|
||||
// Parameters:
|
||||
// - interceptor: The interceptor that should process this message
|
||||
// - connection: The network connection associated with this message
|
||||
//
|
||||
// Returns an error if processing fails
|
||||
ReadProcess(Interceptor, Connection) error
|
||||
}
|
||||
|
||||
// BaseMessage provides a default implementation of the Message interface.
|
||||
// It embeds message.BaseMessage to inherit its functionality and adds
|
||||
// a no-op Process method that can be overridden by specific message types.
|
||||
//
|
||||
// Custom interceptor message types should embed this struct and override
|
||||
// the Process method with their specific processing logic.
|
||||
type BaseMessage struct {
|
||||
// Embed the base message implementation
|
||||
message.BaseMessage
|
||||
}
|
||||
|
||||
// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module
|
||||
func NewBaseMessage(protocol message.Protocol, sender message.Sender, receiver message.Receiver) BaseMessage {
|
||||
return BaseMessage{
|
||||
BaseMessage: message.BaseMessage{
|
||||
CurrentProtocol: protocol,
|
||||
CurrentHeader: message.NewV1Header(sender, receiver),
|
||||
NextProtocol: message.NoneProtocol,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WriteProcess handles interceptor processing for outgoing messages.
|
||||
// This method is called when a message is being written to a connection.
|
||||
// It should be overridden by specific message types to implement
|
||||
// their custom outgoing message processing logic.
|
||||
//
|
||||
// Parameters:
|
||||
// - interceptor: The interceptor that should process this message
|
||||
// - connection: The network connection associated with this message
|
||||
//
|
||||
// Returns nil by default, indicating no processing was performed
|
||||
func (m *BaseMessage) WriteProcess(_ Interceptor, _ Connection) error {
|
||||
// Default implementation does nothing
|
||||
// Derived-types should override this method with specific processing logic
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadProcess handles interceptor processing for incoming messages.
|
||||
// This method is called when a message is being read from a connection.
|
||||
// It should be overridden by specific message types to implement
|
||||
// their custom incoming message processing logic.
|
||||
//
|
||||
// Parameters:
|
||||
// - interceptor: The interceptor that should process this message
|
||||
// - connection: The network connection associated with this message
|
||||
//
|
||||
// Returns nil by default, indicating no processing was performed
|
||||
func (m *BaseMessage) ReadProcess(_ Interceptor, _ Connection) error {
|
||||
// Default implementation does nothing
|
||||
// Derived-types should override this method with specific processing logic
|
||||
return nil
|
||||
}
|
19
pkg/message/errors.go
Normal file
19
pkg/message/errors.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Package message provides a type-safe, extensible message system for WebSocket communication.
|
||||
package message
|
||||
|
||||
import "errors"
|
||||
|
||||
// Error definitions for the message package
|
||||
var (
|
||||
// ErrNoProtocolMatch is returned when attempting to create or unmarshal
|
||||
// a message with a protocol that is not registered
|
||||
ErrNoProtocolMatch = errors.New("no protocol in the registry")
|
||||
|
||||
// ErrNoPayload is returned when a message has a non-none NextProtocol
|
||||
// but is missing the corresponding NextPayload data
|
||||
ErrNoPayload = errors.New("protocol is not none but payload is nil")
|
||||
|
||||
// ErrInvalidMessageData is returned when raw message data cannot be
|
||||
// properly identified or does not contain a valid protocol field
|
||||
ErrInvalidMessageData = errors.New("invalid message data")
|
||||
)
|
117
pkg/message/message.go
Normal file
117
pkg/message/message.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Package message provides a type-safe, extensible message system for WebSocket communication.
|
||||
// It implements a nested message structure that allows for message interception and transformation
|
||||
// through an interceptor chain pattern.
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// Type aliases for improved readability and type safety
|
||||
type (
|
||||
// Protocol identifies the message type or format
|
||||
Protocol string
|
||||
|
||||
// Payload contains the serialized data of the message
|
||||
Payload json.RawMessage
|
||||
|
||||
// Sender identifies the source of the message
|
||||
Sender string
|
||||
|
||||
// Receiver identifies the intended recipient of the message
|
||||
Receiver string
|
||||
|
||||
// Version specifies the message protocol version
|
||||
Version string
|
||||
)
|
||||
|
||||
// Protocol constants
|
||||
const (
|
||||
// NoneProtocol indicates no nested message exists
|
||||
NoneProtocol Protocol = "none"
|
||||
|
||||
// Version1 is the current message protocol version
|
||||
Version1 Version = "v1.0"
|
||||
|
||||
// UnknownReceiver receiverID initialising
|
||||
UnknownReceiver Receiver = "unknown"
|
||||
)
|
||||
|
||||
// Message defines the interface that all message types must implement.
|
||||
// It provides methods for protocol identification, serialization, and
|
||||
// message nesting/unwrapping.
|
||||
type Message interface {
|
||||
// GetProtocol returns the protocol identifier for this message
|
||||
GetProtocol() Protocol
|
||||
|
||||
// GetNext retrieves the next message in the chain, if one exists
|
||||
// Returns nil, nil if there is no next message
|
||||
GetNext(Registry) (Message, error)
|
||||
|
||||
// Marshal serializes the message to JSON format
|
||||
Marshal() ([]byte, error)
|
||||
|
||||
// Unmarshal deserializes the message from JSON format
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
|
||||
// Header contains common metadata for all messages
|
||||
type Header struct {
|
||||
Sender Sender `json:"sender"` // Sender identifies the message source
|
||||
Receiver Receiver `json:"receiver"` // Receiver identifies the intended recipient
|
||||
Version Version `json:"version"` // Version specifies the protocol version
|
||||
}
|
||||
|
||||
// NewV1Header creates a new header with Version1
|
||||
// This is a convenience constructor for common header creation
|
||||
func NewV1Header(sender Sender, receiver Receiver) Header {
|
||||
return Header{
|
||||
Sender: sender,
|
||||
Receiver: receiver,
|
||||
Version: Version1,
|
||||
}
|
||||
}
|
||||
|
||||
// BaseMessage provides a foundation for all message types.
|
||||
// It implements the Message interface and manages message nesting.
|
||||
// Custom message types should embed this struct to inherit its functionality.
|
||||
type BaseMessage struct {
|
||||
// CURRENT MESSAGE PROCESSOR
|
||||
CurrentProtocol Protocol `json:"protocol"` // CurrentProtocol identifies this message's type
|
||||
CurrentHeader Header `json:"header"` // CurrentHeader contains metadata for this message
|
||||
// CURRENT OTHER FIELDS...
|
||||
|
||||
// NEXT MESSAGE PROCESSOR
|
||||
NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
|
||||
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
|
||||
}
|
||||
|
||||
// GetProtocol returns this message's protocol identifier
|
||||
func (m *BaseMessage) GetProtocol() Protocol {
|
||||
return m.CurrentProtocol
|
||||
}
|
||||
|
||||
// GetNext retrieves the next message in the chain, if one exists.
|
||||
// Returns nil, nil if NextProtocol is NoneProtocol.
|
||||
// Uses the provided Registry to create and unmarshal the next message.
|
||||
func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
|
||||
if m.NextProtocol == NoneProtocol {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if m.NextPayload == nil {
|
||||
return nil, ErrNoPayload
|
||||
}
|
||||
|
||||
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
|
||||
}
|
||||
|
||||
// Marshal serializes the message to JSON format
|
||||
func (m *BaseMessage) Marshal() ([]byte, error) {
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
// Unmarshal deserializes the message from JSON format
|
||||
func (m *BaseMessage) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, m)
|
||||
}
|
134
pkg/message/registry.go
Normal file
134
pkg/message/registry.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Package message provides a type-safe, extensible message system for WebSocket communication.
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Envelope is a lightweight struct used to extract just the protocol
|
||||
// information from a raw message without full deserialization.
|
||||
// This enables protocol-based routing of incoming messages.
|
||||
type Envelope struct {
|
||||
Protocol Protocol `json:"protocol"`
|
||||
}
|
||||
|
||||
// Registry defines the interface for message type registration and instantiation.
|
||||
// It provides a centralized mechanism for creating, deserializing, and inspecting messages.
|
||||
type Registry interface {
|
||||
// Register adds a message factory for a specific protocol
|
||||
// Returns an error if the protocol is already registered
|
||||
Register(Protocol, Factory) error
|
||||
|
||||
// Create instantiates a new message for the given protocol
|
||||
// Returns an error if the protocol is not registered
|
||||
Create(Protocol) (Message, error)
|
||||
|
||||
// Unmarshal creates and deserializes a message for a protocol
|
||||
// The provided data is parsed into the appropriate message type
|
||||
Unmarshal(Protocol, json.RawMessage) (Message, error)
|
||||
|
||||
// UnmarshalRaw deserializes a message when the protocol is unknown
|
||||
// It first inspects the envelope to determine the protocol, then unmarshals accordingly
|
||||
UnmarshalRaw(json.RawMessage) (Message, error)
|
||||
}
|
||||
|
||||
// Factory defines the interface for creating new message instances.
|
||||
// Each message type should have a corresponding factory.
|
||||
type Factory interface {
|
||||
// Create instantiates a new instance of a message type
|
||||
Create() (Message, error)
|
||||
}
|
||||
|
||||
// FactoryFunc is a function type that implements the Factory interface.
|
||||
// It allows simple functions to be used as factories without creating a separate type.
|
||||
type FactoryFunc func() Message
|
||||
|
||||
// Create implements the Factory interface for FactoryFunc
|
||||
func (f FactoryFunc) Create() Message {
|
||||
return f()
|
||||
}
|
||||
|
||||
// DefaultRegistry provides a thread-safe implementation of the Registry interface.
|
||||
// It maintains a map of protocols to their corresponding factories.
|
||||
type DefaultRegistry struct {
|
||||
factories map[Protocol]Factory
|
||||
mux sync.RWMutex // Protects concurrent access to factories
|
||||
}
|
||||
|
||||
// NewRegistry creates a new message registry
|
||||
// The returned registry is ready to use but contains no registered message types
|
||||
func NewRegistry() *DefaultRegistry {
|
||||
return &DefaultRegistry{
|
||||
factories: make(map[Protocol]Factory),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a message factory for a protocol
|
||||
// Returns an error if the protocol is already registered
|
||||
// This method is thread-safe
|
||||
func (r *DefaultRegistry) Register(protocol Protocol, factory Factory) error {
|
||||
r.mux.Lock()
|
||||
defer r.mux.Unlock()
|
||||
|
||||
if _, exists := r.factories[protocol]; exists {
|
||||
return fmt.Errorf("protocol %s is already registered", protocol)
|
||||
}
|
||||
|
||||
r.factories[protocol] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create instantiates a new message for a protocol
|
||||
// Returns an error if the protocol is not registered
|
||||
// This method is thread-safe
|
||||
func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) {
|
||||
r.mux.RLock()
|
||||
defer r.mux.RUnlock()
|
||||
|
||||
factory, exists := r.factories[protocol]
|
||||
if !exists {
|
||||
return nil, ErrNoProtocolMatch
|
||||
}
|
||||
|
||||
msg, err := factory.Create()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Unmarshal creates and deserializes a message for a protocol
|
||||
// The provided data is parsed into the appropriate message type
|
||||
// This method is thread-safe
|
||||
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) {
|
||||
msg, err := r.Create(protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := msg.Unmarshal(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalRaw deserializes a message when the protocol is unknown
|
||||
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
|
||||
// This method is particularly useful for handling incoming WebSocket messages
|
||||
// This method is thread-safe
|
||||
func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) {
|
||||
var envelope Envelope
|
||||
if err := json.Unmarshal(data, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("failed to extract protocol: %w", err)
|
||||
}
|
||||
|
||||
if envelope.Protocol == "" {
|
||||
return nil, ErrInvalidMessageData
|
||||
}
|
||||
|
||||
return r.Unmarshal(envelope.Protocol, data)
|
||||
}
|
136
pkg/middleware/encrypt/config/config.go
Normal file
136
pkg/middleware/encrypt/config/config.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
EnvProvider Provider = "env"
|
||||
FileProvider Provider = "file"
|
||||
VaultProvider Provider = "vault"
|
||||
)
|
||||
|
||||
// KeyProviderConfig defines the source of cryptographic keys
|
||||
type KeyProviderConfig struct {
|
||||
Provider Provider `json:"provider"`
|
||||
VaultConfig *VaultKeyConfig `json:"vault_config,omitempty"`
|
||||
FileConfig *FileKeyConfig `json:"file_config,omitempty"`
|
||||
EnvConfig *EnvKeyConfig `json:"env_config,omitempty"`
|
||||
}
|
||||
|
||||
type VaultKeyConfig struct {
|
||||
Address string `json:"address"`
|
||||
Path string `json:"path"`
|
||||
TokenPath string `json:"token_path,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
SecretID string `json:"secret_id,omitempty"`
|
||||
KeyName string `json:"key_name"`
|
||||
}
|
||||
|
||||
type FileKeyConfig struct {
|
||||
KeyPath string `json:"key_path"`
|
||||
PrivateKeyPath string `json:"private_key_path"`
|
||||
PublicKeyPath string `json:"public_key_path"`
|
||||
Permissions string `json:"permissions,omitempty"`
|
||||
AutoCreate bool `json:"auto_create,omitempty"`
|
||||
}
|
||||
|
||||
type EnvKeyConfig struct {
|
||||
PrivateKeyVar string `json:"private_key_var"`
|
||||
PublicKeyVar string `json:"public_key_var"`
|
||||
}
|
||||
|
||||
// Config provides complete configuration for the encryption system
|
||||
type Config struct {
|
||||
// General settings
|
||||
IsServer bool `json:"is_server"`
|
||||
RequireEncryption bool `json:"require_encryption"`
|
||||
DisableEncryption bool `json:"disable_encryption,omitempty"`
|
||||
|
||||
// Timeout settings
|
||||
KeyExchangeTimeout time.Duration `json:"key_exchange_timeout"`
|
||||
SessionTimeout time.Duration `json:"session_timeout"`
|
||||
|
||||
// Key rotation
|
||||
EnableKeyRotation bool `json:"enable_key_rotation"`
|
||||
KeyRotationInterval time.Duration `json:"key_rotation_interval,omitempty"`
|
||||
|
||||
// EncryptionProtocol settings
|
||||
EncryptionProtocol types.EncryptionProtocol `json:"encryption_protocol"`
|
||||
// KeyExchangeProtocolOptions []keyexchange.ProtocolFactoryOption
|
||||
EncryptionFallbackProtocols []types.EncryptionProtocol `json:"encryption_fallback_protocols,omitempty"`
|
||||
|
||||
// Key management
|
||||
KeyProvider KeyProviderConfig `json:"key_provider"`
|
||||
|
||||
// Security settings
|
||||
ReplayProtection bool `json:"replay_protection"`
|
||||
NonceReplayWindow time.Duration `json:"nonce_replay_window,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultConfig provides sensible defaults for the encryption system
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
RequireEncryption: true,
|
||||
DisableEncryption: false,
|
||||
KeyExchangeTimeout: 30 * time.Second,
|
||||
SessionTimeout: 24 * time.Hour,
|
||||
EnableKeyRotation: true,
|
||||
KeyRotationInterval: 1 * time.Hour,
|
||||
EncryptionProtocol: types.ProtocolV2,
|
||||
// KeyExchangeProtocolOptions: []keyexchange.ProtocolFactoryOption{keyexchange.WithKeySignature()},
|
||||
EncryptionFallbackProtocols: []types.EncryptionProtocol{types.ProtocolV1},
|
||||
KeyProvider: KeyProviderConfig{
|
||||
Provider: EnvProvider,
|
||||
EnvConfig: &EnvKeyConfig{
|
||||
PrivateKeyVar: "SERVER_ENCRYPT_PRIV_KEY",
|
||||
PublicKeyVar: "SERVER_ENCRYPT_PUB_KEY",
|
||||
},
|
||||
},
|
||||
ReplayProtection: true,
|
||||
NonceReplayWindow: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig checks if a configuration is valid and complete
|
||||
func ValidateConfig(config Config) error {
|
||||
if config.KeyExchangeTimeout < 5*time.Second {
|
||||
return encryptionerr.ErrInvalidConfig
|
||||
}
|
||||
|
||||
if config.RequireEncryption && config.DisableEncryption {
|
||||
return encryptionerr.ErrInvalidConfig
|
||||
}
|
||||
|
||||
if config.EnableKeyRotation && config.KeyRotationInterval < time.Minute {
|
||||
return encryptionerr.ErrInvalidConfig
|
||||
}
|
||||
|
||||
// Validate key provider configuration
|
||||
switch config.KeyProvider.Provider {
|
||||
case VaultProvider:
|
||||
if config.KeyProvider.VaultConfig == nil {
|
||||
return encryptionerr.ErrInvalidProvider
|
||||
}
|
||||
|
||||
case FileProvider:
|
||||
if config.KeyProvider.FileConfig == nil {
|
||||
return encryptionerr.ErrInvalidProvider
|
||||
}
|
||||
|
||||
case EnvProvider:
|
||||
if config.KeyProvider.EnvConfig == nil {
|
||||
return encryptionerr.ErrInvalidProvider
|
||||
}
|
||||
|
||||
default:
|
||||
return encryptionerr.ErrInvalidProvider
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
35
pkg/middleware/encrypt/encryptionerr/errors.go
Normal file
35
pkg/middleware/encrypt/encryptionerr/errors.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package encryptionerr
|
||||
|
||||
import "errors"
|
||||
|
||||
// Common error definitions for the encryption interceptor
|
||||
var (
|
||||
// General errors
|
||||
ErrConnectionNotFound = errors.New("connection not registered")
|
||||
ErrConnectionExists = errors.New("connection already exists")
|
||||
ErrInvalidInterceptor = errors.New("inappropriate interceptor for the payload")
|
||||
|
||||
// Key exchange errors
|
||||
ErrInvalidMessageType = errors.New("message does not implement required key exchange MessageProcessor interface")
|
||||
ErrKeyExchangeTimeout = errors.New("key exchange timed out")
|
||||
ErrInvalidSignature = errors.New("signature verification failed")
|
||||
ErrProtocolNotFound = errors.New("key exchange protocol not found")
|
||||
ErrExchangeInProgress = errors.New("key exchange already in progress")
|
||||
ErrSessionNotFound = errors.New("key exchange session not found")
|
||||
ErrInvalidSessionState = errors.New("key exchange state is not valid")
|
||||
ErrExchangeNotComplete = errors.New("key exchange not completed")
|
||||
|
||||
// Encryption errors
|
||||
ErrEncryptionNotReady = errors.New("encryption not ready")
|
||||
ErrInvalidKey = errors.New("invalid encryption key")
|
||||
ErrInvalidNonce = errors.New("invalid nonce")
|
||||
ErrNonceReused = errors.New("nonce has been used before")
|
||||
|
||||
// Configuration errors
|
||||
ErrInvalidConfig = errors.New("invalid configuration")
|
||||
ErrInvalidProvider = errors.New("invalid key provider")
|
||||
|
||||
// Security errors
|
||||
ErrInvalidServerRequest = errors.New("invalid request to server")
|
||||
ErrSecurityViolation = errors.New("security violation detected")
|
||||
)
|
28
pkg/middleware/encrypt/encryptor/encryptor.go
Normal file
28
pkg/middleware/encrypt/encryptor/encryptor.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package encryptor
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
// Encryptor defines the interface for message encryption and decryption
|
||||
type Encryptor interface {
|
||||
// SetKeys configures the encryption and decryption keys
|
||||
SetKeys(encryptKey, decryptKey types.Key) error
|
||||
|
||||
// SetSessionID sets the session identifier for this encryption session
|
||||
SetSessionID(id types.SessionID)
|
||||
|
||||
// Encrypt encrypts a message between sender and receiver
|
||||
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)
|
||||
|
||||
// Decrypt decrypts an encrypted message in-place
|
||||
Decrypt(message message.Message) error
|
||||
|
||||
// Ready checks if the encryptor is properly initialized and ready to use
|
||||
Ready() bool
|
||||
|
||||
io.Closer
|
||||
}
|
5
pkg/middleware/encrypt/encryptor/factory.go
Normal file
5
pkg/middleware/encrypt/encryptor/factory.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package encryptor
|
||||
|
||||
func NewEncryptor(cipherSuite string) (Encryptor, error) {
|
||||
return nil, nil
|
||||
}
|
62
pkg/middleware/encrypt/interceptor.go
Normal file
62
pkg/middleware/encrypt/interceptor.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
)
|
||||
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
nonceValidator NonceValidator
|
||||
keyExchangeManager keyexchange.Manager
|
||||
keyProvider keyprovider.KeyProvider
|
||||
stateManager state.Manager
|
||||
config config.Config
|
||||
}
|
||||
|
||||
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (interceptor.Writer, interceptor.Reader, error) {
|
||||
ctx, cancel := context.WithCancel(i.Ctx)
|
||||
|
||||
newState, err := state.NewState(ctx, cancel, i.config, connection, writer, reader)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := i.stateManager.SetState(connection, newState); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return writer, reader, nil
|
||||
}
|
||||
|
||||
func (i *Interceptor) Init(connection interceptor.Connection) error {
|
||||
_state, err := i.stateManager.GetState(connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := i.keyExchangeManager.Init(_state, keyexchange.WithKeySignature(i.keyProvider)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
|
||||
}
|
||||
|
||||
func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) {
|
||||
|
||||
}
|
||||
|
||||
func (i *Interceptor) Close() error {
|
||||
|
||||
}
|
16
pkg/middleware/encrypt/interfaces.go
Normal file
16
pkg/middleware/encrypt/interfaces.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
// NonceValidator provides protection against replay attacks
|
||||
type NonceValidator interface {
|
||||
// Validate checks if a nonce is valid and hasn't been seen before
|
||||
Validate(nonce []byte, sessionID types.SessionID) error
|
||||
|
||||
// Cleanup removes expired nonces
|
||||
Cleanup(before time.Time)
|
||||
}
|
90
pkg/middleware/encrypt/keyexchange/curve25519protocol.go
Normal file
90
pkg/middleware/encrypt/keyexchange/curve25519protocol.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package keyexchange
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/messages"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type Curve25519Protocol struct {
|
||||
privKey types.PrivateKey // also in protocol curve25519protocol.go
|
||||
pubKey types.PublicKey
|
||||
peerPubKey types.PublicKey
|
||||
salt types.Salt // also in protocol curve25519protocol.go
|
||||
sessionID types.SessionID
|
||||
sharedSecret []byte
|
||||
encKey types.Key
|
||||
decKey types.Key
|
||||
state SessionState
|
||||
options Curve25519Options
|
||||
|
||||
// TODO: mutex is needed here for SessionState and Keys
|
||||
}
|
||||
|
||||
type Curve25519Options struct {
|
||||
SigningKey ed25519.PrivateKey
|
||||
VerificationKey ed25519.PublicKey
|
||||
RequireSignature bool
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) Init(s *state.State) error {
|
||||
if _, err := io.ReadFull(rand.Reader, p.privKey[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
curve25519.ScalarBaseMult((*[32]byte)(&p.pubKey), (*[32]byte)(&p.privKey))
|
||||
|
||||
if s.InterceptorConfig.IsServer && p.options.RequireSignature {
|
||||
if _, err := io.ReadFull(rand.Reader, p.salt[:]); err != nil {
|
||||
p.state = SessionStateError
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(rand.Reader, p.sessionID[:]); err != nil {
|
||||
p.state = SessionStateError
|
||||
return err
|
||||
}
|
||||
|
||||
sign := ed25519.Sign(p.options.SigningKey, append(p.pubKey[:], p.salt[:]...))
|
||||
|
||||
if err := s.Writer.Write(s.Connection, messages.NewInit("", s.PeerID, p.pubKey, sign, p.sessionID, p.salt)); err != nil {
|
||||
p.state = SessionStateError
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.state = SessionStateInProgress
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) GetKeys() (encKey types.Key, decKey types.Key, err error) {
|
||||
if p.state != SessionStateCompleted {
|
||||
return types.Key{}, types.Key{}, encryptionerr.ErrExchangeNotComplete
|
||||
}
|
||||
|
||||
return p.encKey, p.decKey, nil
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) GetState() SessionState {
|
||||
return p.state
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) IsComplete() bool {
|
||||
return p.state == SessionStateCompleted
|
||||
}
|
||||
|
||||
func (p *Curve25519Protocol) Process(msg MessageProcessor, s *state.State) error {
|
||||
if err := msg.Process(p, s); err != nil {
|
||||
p.state = SessionStateError
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
82
pkg/middleware/encrypt/keyexchange/keyexchange_manager.go
Normal file
82
pkg/middleware/encrypt/keyexchange/keyexchange_manager.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package keyexchange
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
protocol Protocol
|
||||
state *state.State
|
||||
createdAt time.Time
|
||||
completedAt time.Time
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
registry map[types.KeyExchangeProtocol]ProtocolFactory
|
||||
sessions map[types.KeyExchangeSessionID]*Session
|
||||
}
|
||||
|
||||
func (m *Manager) Init(s *state.State, options ...ProtocolFactoryOption) error {
|
||||
sessionID := s.GenerateKeyExchangeSessionID()
|
||||
_, exists := m.sessions[sessionID]
|
||||
if exists {
|
||||
return encryptionerr.ErrExchangeInProgress
|
||||
}
|
||||
|
||||
factory, exists := m.registry[s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s", encryptionerr.ErrProtocolNotFound, s.InterceptorConfig.EncryptionProtocol.KeyExchangeProtocol)
|
||||
}
|
||||
|
||||
p, err := factory(options...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p.Init(s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.sessions[sessionID] = &Session{
|
||||
protocol: p,
|
||||
state: s,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Process(s *state.State, msg MessageProcessor) error {
|
||||
session, exists := m.sessions[s.KeyExchangeSessionID]
|
||||
if !exists {
|
||||
return encryptionerr.ErrSessionNotFound
|
||||
}
|
||||
|
||||
return session.protocol.Process(msg, s)
|
||||
}
|
||||
|
||||
// Derive generates encryption keys from shared secret
|
||||
func Derive(shared []byte, salt types.Salt, info string) (types.Key, types.Key, error) {
|
||||
hkdfReader := hkdf.New(sha256.New, shared, salt[:], []byte(info))
|
||||
|
||||
key1 := types.Key{}
|
||||
if _, err := io.ReadFull(hkdfReader, key1[:]); err != nil {
|
||||
return types.Key{}, types.Key{}, err
|
||||
}
|
||||
|
||||
key2 := types.Key{}
|
||||
if _, err := io.ReadFull(hkdfReader, key2[:]); err != nil {
|
||||
return types.Key{}, types.Key{}, err
|
||||
}
|
||||
|
||||
return key1, key2, nil
|
||||
}
|
49
pkg/middleware/encrypt/keyexchange/protocol.go
Normal file
49
pkg/middleware/encrypt/keyexchange/protocol.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package keyexchange
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyprovider"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type SessionState int
|
||||
|
||||
const (
|
||||
SessionStateInitial SessionState = iota
|
||||
SessionStateInProgress
|
||||
SessionStateCompleted
|
||||
SessionStateError
|
||||
)
|
||||
|
||||
type MessageProcessor interface {
|
||||
Process(Protocol, *state.State) error
|
||||
}
|
||||
|
||||
type Protocol interface {
|
||||
Init(s *state.State) error
|
||||
GetKeys() (encKey types.Key, decKey types.Key, err error)
|
||||
Process(MessageProcessor, *state.State) error
|
||||
GetState() SessionState
|
||||
IsComplete() bool
|
||||
}
|
||||
|
||||
type ProtocolFactoryOption func(Protocol) error
|
||||
|
||||
func WithKeySignature(keyProvider keyprovider.KeyProvider) ProtocolFactoryOption {
|
||||
return func(protocol Protocol) error {
|
||||
curveProtocol, ok := protocol.(*Curve25519Protocol)
|
||||
if !ok {
|
||||
return fmt.Errorf("WithKeySignature only supports Curve25519Protocol")
|
||||
}
|
||||
|
||||
curveProtocol.options.SigningKey = keyProvider.GetSigningKey()
|
||||
curveProtocol.options.VerificationKey = keyProvider.GetVerificationKey()
|
||||
curveProtocol.options.RequireSignature = true
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type ProtocolFactory func(options ...ProtocolFactoryOption) (Protocol, error)
|
10
pkg/middleware/encrypt/keyprovider/keyprovider.go
Normal file
10
pkg/middleware/encrypt/keyprovider/keyprovider.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package keyprovider
|
||||
|
||||
import "crypto/ed25519"
|
||||
|
||||
// KeyProvider interface for secure access to cryptographic keys
|
||||
type KeyProvider interface {
|
||||
GetSigningKey() ed25519.PrivateKey
|
||||
GetVerificationKey() ed25519.PublicKey
|
||||
Close() error
|
||||
}
|
97
pkg/middleware/encrypt/messages/init.go
Normal file
97
pkg/middleware/encrypt/messages/init.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/keyexchange"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/state"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
// Protocol constants
|
||||
const (
|
||||
InitProtocol message.Protocol = "curve25519.init"
|
||||
ResponseProtocol message.Protocol = "curve25519.response"
|
||||
ConfirmProtocol message.Protocol = "curve25519.confirm"
|
||||
)
|
||||
|
||||
type Init struct {
|
||||
interceptor.BaseMessage
|
||||
PublicKey types.PublicKey `json:"public_key"`
|
||||
Signature []byte `json:"signature"`
|
||||
SessionID types.SessionID `json:"session_id"`
|
||||
Salt types.Salt `json:"salt"`
|
||||
}
|
||||
|
||||
func NewInit(sender message.Sender, receiver message.Receiver, key types.PublicKey, sign []byte, sessionID types.SessionID, salt types.Salt) *Init {
|
||||
return &Init{
|
||||
BaseMessage: interceptor.NewBaseMessage(InitProtocol, sender, receiver),
|
||||
PublicKey: key,
|
||||
Signature: sign,
|
||||
SessionID: sessionID,
|
||||
Salt: salt,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Init) WriteProcess(_ interceptor.Interceptor, _ interceptor.Connection) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Init) ReadProcess(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
i, ok := _interceptor.(*encrypt.Interceptor)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
s, err := i.stateManager.GetState(connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return i.keyExchangeManager.Process(s, m)
|
||||
}
|
||||
|
||||
func (m *Init) Process(protocol keyexchange.Protocol, s *state.State) error {
|
||||
p, ok := protocol.(*keyexchange.Curve25519Protocol)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
if p.GetState() != keyexchange.SessionStateInitial {
|
||||
return encryptionerr.ErrInvalidSessionState
|
||||
}
|
||||
|
||||
sign := append(m.PublicKey[:], m.Salt[:]...)
|
||||
if !ed25519.Verify(p.options.VerificationKey, sign, m.Signature) {
|
||||
return encryptionerr.ErrInvalidSignature
|
||||
}
|
||||
|
||||
p.salt = m.Salt
|
||||
shared, err := curve25519.X25519(p.privKey[:], m.PublicKey[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compute shared secret: %w", err)
|
||||
}
|
||||
|
||||
encKey, decKey, err := keyexchange.Derive(shared, p.salt, "") // TODO: ADD INFO STRING
|
||||
if err != nil {
|
||||
return fmt.Errorf("key derivation failed: %w", err)
|
||||
}
|
||||
|
||||
p.encKey = encKey
|
||||
p.decKey = decKey
|
||||
p.sessionID = m.SessionID
|
||||
|
||||
if err := s.Writer.Write(s.Connection, nil); err != nil {
|
||||
return err
|
||||
} // TODO: ADD RESPONSE MESSAGE
|
||||
|
||||
p.state = SessionStateCompleted
|
||||
return nil
|
||||
}
|
60
pkg/middleware/encrypt/state/state.go
Normal file
60
pkg/middleware/encrypt/state/state.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
InterceptorConfig config.Config // copy of interceptor config
|
||||
PeerID message.Receiver
|
||||
privKey types.PrivateKey // also in protocol curve25519protocol.go
|
||||
salt types.Salt // also in protocol curve25519protocol.go
|
||||
sessionID types.SessionID // encryption sessionID
|
||||
encryptor encryptor.Encryptor
|
||||
Connection interceptor.Connection
|
||||
Writer interceptor.Writer
|
||||
Reader interceptor.Reader
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
KeyExchangeSessionID types.KeyExchangeSessionID // Used for key exchange tracking
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func NewState(ctx context.Context, cancel context.CancelFunc, config config.Config, connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) (*State, error) {
|
||||
newEncryptor, err := encryptor.NewEncryptor(config.EncryptionProtocol.CipherSuite)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &State{
|
||||
InterceptorConfig: config,
|
||||
peerID: message.UnknownReceiver,
|
||||
privKey: types.PrivateKey{},
|
||||
salt: types.Salt{},
|
||||
encryptor: newEncryptor,
|
||||
Connection: connection,
|
||||
Writer: writer,
|
||||
Reader: reader,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
|
||||
if s.KeyExchangeSessionID != "" {
|
||||
fmt.Println("KeyExchangeSessionID already exists; creating new")
|
||||
}
|
||||
s.KeyExchangeSessionID = types.KeyExchangeSessionID(uuid.NewString())
|
||||
|
||||
return s.KeyExchangeSessionID
|
||||
}
|
66
pkg/middleware/encrypt/state/state_manager.go
Normal file
66
pkg/middleware/encrypt/state/state_manager.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
states map[interceptor.Connection]*State
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *Manager) GetState(connection interceptor.Connection) (*State, error) {
|
||||
m.mux.RLock()
|
||||
defer m.mux.RUnlock()
|
||||
|
||||
state, exists := m.states[connection]
|
||||
if !exists {
|
||||
return nil, encryptionerr.ErrConnectionNotFound
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (m *Manager) SetState(connection interceptor.Connection, s *State) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if _, exists := m.states[connection]; exists {
|
||||
return encryptionerr.ErrConnectionExists
|
||||
}
|
||||
|
||||
m.states[connection] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveState removes a Connection's state
|
||||
func (m *Manager) RemoveState(connection interceptor.Connection) (*State, error) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
state, exists := m.states[connection]
|
||||
if !exists {
|
||||
return nil, encryptionerr.ErrConnectionNotFound
|
||||
}
|
||||
|
||||
delete(m.states, connection)
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// ForEach executes the provided function for each state in the manager
|
||||
func (m *Manager) ForEach(fn func(connection interceptor.Connection, state *State) error) []error {
|
||||
m.mux.RLock()
|
||||
defer m.mux.RUnlock()
|
||||
|
||||
var errs []error
|
||||
for conn, state := range m.states {
|
||||
if err := fn(conn, state); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
51
pkg/middleware/encrypt/types/types.go
Normal file
51
pkg/middleware/encrypt/types/types.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package types
|
||||
|
||||
// Crypto-related type definitions for improved type safety
|
||||
type (
|
||||
PrivateKey [32]byte
|
||||
PublicKey [32]byte
|
||||
Salt [16]byte
|
||||
SessionID [16]byte
|
||||
Nonce [12]byte
|
||||
Key [32]byte
|
||||
KeyExchangeProtocol string
|
||||
KeyExchangeSessionID string
|
||||
)
|
||||
|
||||
// EncryptionProtocol defines the capabilities of a specific protocol version
|
||||
type EncryptionProtocol struct {
|
||||
Version uint8 `json:"version"`
|
||||
CipherSuite string `json:"cipher_suite"`
|
||||
KeyExchangeProtocol KeyExchangeProtocol `json:"key_exchange"`
|
||||
Authenticated bool `json:"authenticated"`
|
||||
SupportedExtensions map[string]string `json:"supported_extensions,omitempty"`
|
||||
}
|
||||
|
||||
// Protocol versions
|
||||
var (
|
||||
// ProtocolV1 is the original protocol
|
||||
ProtocolV1 = EncryptionProtocol{
|
||||
Version: 1,
|
||||
CipherSuite: "AES-256-GCM",
|
||||
KeyExchangeProtocol: "curve25519-ed25519",
|
||||
Authenticated: true,
|
||||
}
|
||||
|
||||
// ProtocolV2 adds support for key rotation
|
||||
ProtocolV2 = EncryptionProtocol{
|
||||
Version: 2,
|
||||
CipherSuite: "AES-256-GCM",
|
||||
KeyExchangeProtocol: "curve25519-ed25519",
|
||||
Authenticated: true,
|
||||
SupportedExtensions: map[string]string{
|
||||
"key_rotation": "supported",
|
||||
"forward_secrecy": "enabled",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// IsZero is a generic function to check if a value is the zero value for its type
|
||||
func IsZero[T comparable](value T) bool {
|
||||
var zero T
|
||||
return value == zero
|
||||
}
|
19
pkg/middleware/encrypt/util.go
Normal file
19
pkg/middleware/encrypt/util.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FormatDuration formats a duration in a human-readable way
|
||||
func FormatDuration(d time.Duration) string {
|
||||
if d < time.Second {
|
||||
return fmt.Sprintf("%d ms", d.Milliseconds())
|
||||
} else if d < time.Minute {
|
||||
return fmt.Sprintf("%.1f s", d.Seconds())
|
||||
} else if d < time.Hour {
|
||||
return fmt.Sprintf("%.1f min", d.Minutes())
|
||||
} else {
|
||||
return fmt.Sprintf("%.1f h", d.Hours())
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user