mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-08 17:10:03 +08:00
added aes256.go
This commit is contained in:
@@ -3,6 +3,7 @@ package interceptor
|
|||||||
import "github.com/harshabose/socket-comm/internal/util"
|
import "github.com/harshabose/socket-comm/internal/util"
|
||||||
|
|
||||||
type Chain struct {
|
type Chain struct {
|
||||||
|
NoOpInterceptor
|
||||||
interceptors []Interceptor
|
interceptors []Interceptor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -44,6 +44,12 @@ type Connection interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Interceptor interface {
|
type Interceptor interface {
|
||||||
|
ID() string
|
||||||
|
|
||||||
|
Ctx() context.Context
|
||||||
|
|
||||||
|
GetMessageRegistry() message.Registry
|
||||||
|
|
||||||
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
|
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
|
||||||
|
|
||||||
Init(Connection) error
|
Init(Connection) error
|
||||||
@@ -78,10 +84,18 @@ func (f WriterFunc) Write(conn Connection, message message.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NoOpInterceptor struct {
|
type NoOpInterceptor struct {
|
||||||
ID string
|
iD string
|
||||||
messageRegistry message.Registry
|
messageRegistry message.Registry
|
||||||
Mutex sync.RWMutex
|
Mutex sync.RWMutex
|
||||||
Ctx context.Context
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) Ctx() context.Context {
|
||||||
|
return interceptor.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (interceptor *NoOpInterceptor) ID() string {
|
||||||
|
return interceptor.iD
|
||||||
}
|
}
|
||||||
|
|
||||||
func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry {
|
func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry {
|
||||||
|
@@ -55,7 +55,6 @@ type BaseMessage struct {
|
|||||||
message.BaseMessage
|
message.BaseMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module
|
|
||||||
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
|
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
|
||||||
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
|
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -48,6 +48,10 @@ type Message interface {
|
|||||||
// GetProtocol returns the protocol identifier for this message
|
// GetProtocol returns the protocol identifier for this message
|
||||||
GetProtocol() Protocol
|
GetProtocol() Protocol
|
||||||
|
|
||||||
|
GetNextPayload() (Payload, error)
|
||||||
|
|
||||||
|
GetNextProtocol() Protocol
|
||||||
|
|
||||||
// GetNext retrieves the next message in the chain, if one exists
|
// GetNext retrieves the next message in the chain, if one exists
|
||||||
// Returns nil, nil if there is no next message
|
// Returns nil, nil if there is no next message
|
||||||
GetNext(Registry) (Message, error)
|
GetNext(Registry) (Message, error)
|
||||||
@@ -86,8 +90,8 @@ type BaseMessage struct {
|
|||||||
// CURRENT OTHER FIELDS...
|
// CURRENT OTHER FIELDS...
|
||||||
|
|
||||||
// NEXT MESSAGE PROCESSOR
|
// NEXT MESSAGE PROCESSOR
|
||||||
NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
|
NextPayload Payload `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
|
||||||
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
|
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProtocol returns this message's protocol identifier
|
// GetProtocol returns this message's protocol identifier
|
||||||
@@ -95,6 +99,22 @@ func (m *BaseMessage) GetProtocol() Protocol {
|
|||||||
return m.CurrentProtocol
|
return m.CurrentProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *BaseMessage) GetNextPayload() (Payload, error) {
|
||||||
|
if m.NextProtocol == NoneProtocol {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.NextPayload == nil {
|
||||||
|
return nil, ErrNoPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.NextPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *BaseMessage) GetNextProtocol() Protocol {
|
||||||
|
return m.NextProtocol
|
||||||
|
}
|
||||||
|
|
||||||
// GetNext retrieves the next message in the chain, if one exists.
|
// GetNext retrieves the next message in the chain, if one exists.
|
||||||
// Returns nil, nil if NextProtocol is NoneProtocol.
|
// Returns nil, nil if NextProtocol is NoneProtocol.
|
||||||
// Uses the provided Registry to create and unmarshal the next message.
|
// Uses the provided Registry to create and unmarshal the next message.
|
||||||
@@ -107,7 +127,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
|
|||||||
return nil, ErrNoPayload
|
return nil, ErrNoPayload
|
||||||
}
|
}
|
||||||
|
|
||||||
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
|
return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal serializes the message to JSON format
|
// Marshal serializes the message to JSON format
|
||||||
@@ -136,7 +156,7 @@ func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (Ba
|
|||||||
return BaseMessage{
|
return BaseMessage{
|
||||||
CurrentProtocol: msg.GetProtocol(),
|
CurrentProtocol: msg.GetProtocol(),
|
||||||
CurrentHeader: NewV1Header(UnknownSender, UnknownReceiver),
|
CurrentHeader: NewV1Header(UnknownSender, UnknownReceiver),
|
||||||
NextPayload: inner,
|
NextPayload: Payload(inner),
|
||||||
NextProtocol: nextProtocol,
|
NextProtocol: nextProtocol,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@@ -27,7 +27,7 @@ type Registry interface {
|
|||||||
|
|
||||||
// Unmarshal creates and deserializes a message for a protocol
|
// Unmarshal creates and deserializes a message for a protocol
|
||||||
// The provided data is parsed into the appropriate message type
|
// The provided data is parsed into the appropriate message type
|
||||||
Unmarshal(Protocol, json.RawMessage) (Message, error)
|
Unmarshal(Protocol, Payload) (Message, error)
|
||||||
|
|
||||||
// UnmarshalRaw deserializes a message when the protocol is unknown
|
// UnmarshalRaw deserializes a message when the protocol is unknown
|
||||||
// It first inspects the envelope to determine the protocol, then unmarshals accordingly
|
// It first inspects the envelope to determine the protocol, then unmarshals accordingly
|
||||||
@@ -103,7 +103,7 @@ func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) {
|
|||||||
// Unmarshal creates and deserializes a message for a protocol
|
// Unmarshal creates and deserializes a message for a protocol
|
||||||
// The provided data is parsed into the appropriate message type
|
// The provided data is parsed into the appropriate message type
|
||||||
// This method is thread-safe
|
// This method is thread-safe
|
||||||
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) {
|
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data Payload) (Message, error) {
|
||||||
msg, err := r.Create(protocol)
|
msg, err := r.Create(protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -120,7 +120,7 @@ func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Me
|
|||||||
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
|
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
|
||||||
// This method is particularly useful for handling incoming WebSocket messages
|
// This method is particularly useful for handling incoming WebSocket messages
|
||||||
// This method is thread-safe
|
// This method is thread-safe
|
||||||
func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) {
|
func (r *DefaultRegistry) UnmarshalRaw(data Payload) (Message, error) {
|
||||||
var envelope Envelope
|
var envelope Envelope
|
||||||
if err := json.Unmarshal(data, &envelope); err != nil {
|
if err := json.Unmarshal(data, &envelope); err != nil {
|
||||||
return nil, fmt.Errorf("failed to extract protocol: %w", err)
|
return nil, fmt.Errorf("failed to extract protocol: %w", err)
|
||||||
|
@@ -62,15 +62,32 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
|
|||||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protocol := msg.GetProtocol()
|
||||||
|
|
||||||
data, err := msg.Marshal()
|
data, err := msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:])
|
||||||
|
|
||||||
|
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) {
|
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
|
||||||
// TODO implement me
|
m, ok := msg.(*EncryptedMessage)
|
||||||
panic("implement me")
|
if !ok {
|
||||||
|
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := a.decryptor.Open(nil, m.Nonce[:], m.NextPayload, a.sessionID[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decryption failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.NextPayload = data
|
||||||
|
|
||||||
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) {
|
func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) {
|
||||||
|
@@ -1,14 +1,37 @@
|
|||||||
package encryptor
|
package encryptor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||||
"github.com/harshabose/socket-comm/pkg/message"
|
"github.com/harshabose/socket-comm/pkg/message"
|
||||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||||
|
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EncryptedMessage struct {
|
type EncryptedMessage struct {
|
||||||
interceptor.BaseMessage
|
interceptor.BaseMessage
|
||||||
|
Nonce types.Nonce
|
||||||
|
Timestamp time.Time
|
||||||
|
SessionID types.EncryptionSessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) {
|
||||||
|
msg := &EncryptedMessage{
|
||||||
|
Nonce: nonce,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
SessionID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.BaseMessage = bmsg
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
|
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
|
||||||
@@ -27,16 +50,17 @@ func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn intercep
|
|||||||
return encryptionerr.ErrInvalidInterceptor
|
return encryptionerr.ErrInvalidInterceptor
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := ss.Decrypt(m)
|
decmsg, err := ss.Decrypt(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: message.Registry is not implemented yet
|
msg, ok := decmsg.(*EncryptedMessage)
|
||||||
decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload)
|
if !ok {
|
||||||
if err != nil {
|
return encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.NextPayload = msg.NextPayload // JUST MAKING SURE
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -81,6 +81,11 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
iMessage, ok := encrypted.(interceptor.Message)
|
||||||
|
if !ok {
|
||||||
|
return writer.Write(conn)
|
||||||
|
}
|
||||||
|
|
||||||
return writer.Write(conn, encrypted)
|
return writer.Write(conn, encrypted)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@@ -63,14 +63,24 @@ func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
|
|||||||
return s.keyExchangeSessionID
|
return s.keyExchangeSessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *State) GetConnection() interceptor.Connection {
|
||||||
|
return s.connection
|
||||||
|
}
|
||||||
|
|
||||||
func (s *State) WriteMessage(msg interceptor.Message) error {
|
func (s *State) WriteMessage(msg interceptor.Message) error {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
// TODO: MANAGE CLIENT DISCOVERY
|
||||||
msg.SetReceiver(s.peerID)
|
|
||||||
return s.writer.Write(s.connection, msg)
|
return s.writer.Write(s.connection, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *State) ReadMessage(msg interceptor.Message) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
return s.reader.Read()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
|
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
Reference in New Issue
Block a user