mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-08 17:10:03 +08:00
added aes256.go
This commit is contained in:
@@ -3,6 +3,7 @@ package interceptor
|
||||
import "github.com/harshabose/socket-comm/internal/util"
|
||||
|
||||
type Chain struct {
|
||||
NoOpInterceptor
|
||||
interceptors []Interceptor
|
||||
}
|
||||
|
||||
|
@@ -44,6 +44,12 @@ type Connection interface {
|
||||
}
|
||||
|
||||
type Interceptor interface {
|
||||
ID() string
|
||||
|
||||
Ctx() context.Context
|
||||
|
||||
GetMessageRegistry() message.Registry
|
||||
|
||||
BindSocketConnection(Connection, Writer, Reader) (Writer, Reader, error)
|
||||
|
||||
Init(Connection) error
|
||||
@@ -78,10 +84,18 @@ func (f WriterFunc) Write(conn Connection, message message.Message) error {
|
||||
}
|
||||
|
||||
type NoOpInterceptor struct {
|
||||
ID string
|
||||
iD string
|
||||
messageRegistry message.Registry
|
||||
Mutex sync.RWMutex
|
||||
Ctx context.Context
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) Ctx() context.Context {
|
||||
return interceptor.ctx
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) ID() string {
|
||||
return interceptor.iD
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry {
|
||||
|
@@ -55,7 +55,6 @@ type BaseMessage struct {
|
||||
message.BaseMessage
|
||||
}
|
||||
|
||||
// NewBaseMessage creates a properly initialized interceptor BaseMessage for the key exchange module
|
||||
func NewBaseMessage(nextProtocol message.Protocol, nextPayload message.Message, msg Message) (BaseMessage, error) {
|
||||
bmsg, err := message.NewBaseMessage(nextProtocol, nextPayload, msg)
|
||||
if err != nil {
|
||||
|
@@ -48,6 +48,10 @@ type Message interface {
|
||||
// GetProtocol returns the protocol identifier for this message
|
||||
GetProtocol() Protocol
|
||||
|
||||
GetNextPayload() (Payload, error)
|
||||
|
||||
GetNextProtocol() Protocol
|
||||
|
||||
// GetNext retrieves the next message in the chain, if one exists
|
||||
// Returns nil, nil if there is no next message
|
||||
GetNext(Registry) (Message, error)
|
||||
@@ -86,8 +90,8 @@ type BaseMessage struct {
|
||||
// CURRENT OTHER FIELDS...
|
||||
|
||||
// NEXT MESSAGE PROCESSOR
|
||||
NextPayload json.RawMessage `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
|
||||
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
|
||||
NextPayload Payload `json:"next,omitempty"` // NextPayload contains the serialized next message in the chain
|
||||
NextProtocol Protocol `json:"next_protocol"` // NextProtocol identifies the type of the next message. NoneProtocol indicates end of chain
|
||||
}
|
||||
|
||||
// GetProtocol returns this message's protocol identifier
|
||||
@@ -95,6 +99,22 @@ func (m *BaseMessage) GetProtocol() Protocol {
|
||||
return m.CurrentProtocol
|
||||
}
|
||||
|
||||
func (m *BaseMessage) GetNextPayload() (Payload, error) {
|
||||
if m.NextProtocol == NoneProtocol {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if m.NextPayload == nil {
|
||||
return nil, ErrNoPayload
|
||||
}
|
||||
|
||||
return m.NextPayload, nil
|
||||
}
|
||||
|
||||
func (m *BaseMessage) GetNextProtocol() Protocol {
|
||||
return m.NextProtocol
|
||||
}
|
||||
|
||||
// GetNext retrieves the next message in the chain, if one exists.
|
||||
// Returns nil, nil if NextProtocol is NoneProtocol.
|
||||
// Uses the provided Registry to create and unmarshal the next message.
|
||||
@@ -107,7 +127,7 @@ func (m *BaseMessage) GetNext(registry Registry) (Message, error) {
|
||||
return nil, ErrNoPayload
|
||||
}
|
||||
|
||||
return registry.Unmarshal(m.NextProtocol, m.NextPayload)
|
||||
return registry.Unmarshal(m.NextProtocol, json.RawMessage(m.NextPayload))
|
||||
}
|
||||
|
||||
// Marshal serializes the message to JSON format
|
||||
@@ -136,7 +156,7 @@ func NewBaseMessage(nextProtocol Protocol, nextPayload Message, msg Message) (Ba
|
||||
return BaseMessage{
|
||||
CurrentProtocol: msg.GetProtocol(),
|
||||
CurrentHeader: NewV1Header(UnknownSender, UnknownReceiver),
|
||||
NextPayload: inner,
|
||||
NextPayload: Payload(inner),
|
||||
NextProtocol: nextProtocol,
|
||||
}, nil
|
||||
}
|
||||
|
@@ -27,7 +27,7 @@ type Registry interface {
|
||||
|
||||
// Unmarshal creates and deserializes a message for a protocol
|
||||
// The provided data is parsed into the appropriate message type
|
||||
Unmarshal(Protocol, json.RawMessage) (Message, error)
|
||||
Unmarshal(Protocol, Payload) (Message, error)
|
||||
|
||||
// UnmarshalRaw deserializes a message when the protocol is unknown
|
||||
// It first inspects the envelope to determine the protocol, then unmarshals accordingly
|
||||
@@ -103,7 +103,7 @@ func (r *DefaultRegistry) Create(protocol Protocol) (Message, error) {
|
||||
// Unmarshal creates and deserializes a message for a protocol
|
||||
// The provided data is parsed into the appropriate message type
|
||||
// This method is thread-safe
|
||||
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Message, error) {
|
||||
func (r *DefaultRegistry) Unmarshal(protocol Protocol, data Payload) (Message, error) {
|
||||
msg, err := r.Create(protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -120,7 +120,7 @@ func (r *DefaultRegistry) Unmarshal(protocol Protocol, data json.RawMessage) (Me
|
||||
// It first extracts just the protocol from the data, then creates and unmarshals the appropriate message type
|
||||
// This method is particularly useful for handling incoming WebSocket messages
|
||||
// This method is thread-safe
|
||||
func (r *DefaultRegistry) UnmarshalRaw(data json.RawMessage) (Message, error) {
|
||||
func (r *DefaultRegistry) UnmarshalRaw(data Payload) (Message, error) {
|
||||
var envelope Envelope
|
||||
if err := json.Unmarshal(data, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("failed to extract protocol: %w", err)
|
||||
|
@@ -62,15 +62,32 @@ func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error)
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
protocol := msg.GetProtocol()
|
||||
|
||||
data, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
encryptedData := a.encryptor.Seal(nil, nonce[:], data, a.sessionID[:])
|
||||
|
||||
return NewEncryptedMessage(encryptedData, protocol, nonce, a.sessionID)
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
func (a *AES256Encryptor) Decrypt(msg message.Message) (message.Message, error) {
|
||||
m, ok := msg.(*EncryptedMessage)
|
||||
if !ok {
|
||||
return nil, encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||
}
|
||||
|
||||
data, err := a.decryptor.Open(nil, m.Nonce[:], m.NextPayload, a.sessionID[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
m.NextPayload = data
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) {
|
||||
|
@@ -1,14 +1,37 @@
|
||||
package encryptor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type EncryptedMessage struct {
|
||||
interceptor.BaseMessage
|
||||
Nonce types.Nonce
|
||||
Timestamp time.Time
|
||||
SessionID types.EncryptionSessionID
|
||||
}
|
||||
|
||||
func NewEncryptedMessage(encrypted message.Payload, protocol message.Protocol, nonce types.Nonce, id types.EncryptionSessionID) (*EncryptedMessage, error) {
|
||||
msg := &EncryptedMessage{
|
||||
Nonce: nonce,
|
||||
Timestamp: time.Now(),
|
||||
SessionID: id,
|
||||
}
|
||||
|
||||
bmsg, err := interceptor.NewBaseMessage(protocol, encrypted, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg.BaseMessage = bmsg
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
|
||||
@@ -27,16 +50,17 @@ func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn intercep
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
msg, err := ss.Decrypt(m)
|
||||
decmsg, err := ss.Decrypt(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: message.Registry is not implemented yet
|
||||
decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
msg, ok := decmsg.(*EncryptedMessage)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor // JUST TO BE SURE
|
||||
}
|
||||
|
||||
m.NextPayload = msg.NextPayload // JUST MAKING SURE
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -81,6 +81,11 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
|
||||
return err
|
||||
}
|
||||
|
||||
iMessage, ok := encrypted.(interceptor.Message)
|
||||
if !ok {
|
||||
return writer.Write(conn)
|
||||
}
|
||||
|
||||
return writer.Write(conn, encrypted)
|
||||
})
|
||||
}
|
||||
|
@@ -63,14 +63,24 @@ func (s *State) GenerateKeyExchangeSessionID() types.KeyExchangeSessionID {
|
||||
return s.keyExchangeSessionID
|
||||
}
|
||||
|
||||
func (s *State) GetConnection() interceptor.Connection {
|
||||
return s.connection
|
||||
}
|
||||
|
||||
func (s *State) WriteMessage(msg interceptor.Message) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
msg.SetReceiver(s.peerID)
|
||||
// TODO: MANAGE CLIENT DISCOVERY
|
||||
return s.writer.Write(s.connection, msg)
|
||||
}
|
||||
|
||||
func (s *State) ReadMessage(msg interceptor.Message) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return s.reader.Read()
|
||||
}
|
||||
|
||||
func (s *State) GetKeyExchangeSessionID() types.KeyExchangeSessionID {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
Reference in New Issue
Block a user