mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-05 23:56:54 +08:00
added NewBaseMessage function in message.go
This commit is contained in:
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
@@ -56,31 +58,36 @@ type Interceptor interface {
|
||||
}
|
||||
|
||||
type Writer interface {
|
||||
Write(conn Connection, message Message) error
|
||||
Write(conn Connection, message message.Message) error
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
Read(conn Connection) (Message, error)
|
||||
Read(conn Connection) (message.Message, error)
|
||||
}
|
||||
|
||||
type ReaderFunc func(conn Connection) (Message, error)
|
||||
type ReaderFunc func(conn Connection) (message.Message, error)
|
||||
|
||||
func (f ReaderFunc) Read(conn Connection) (Message, error) {
|
||||
func (f ReaderFunc) Read(conn Connection) (message.Message, error) {
|
||||
return f(conn)
|
||||
}
|
||||
|
||||
type WriterFunc func(conn Connection, message Message) error
|
||||
type WriterFunc func(conn Connection, message message.Message) error
|
||||
|
||||
func (f WriterFunc) Write(conn Connection, message Message) error {
|
||||
func (f WriterFunc) Write(conn Connection, message message.Message) error {
|
||||
return f(conn, message)
|
||||
}
|
||||
|
||||
type NoOpInterceptor struct {
|
||||
ID string
|
||||
messageRegistry message.Registry
|
||||
Mutex sync.RWMutex
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) GetMessageRegistry() message.Registry {
|
||||
return interceptor.messageRegistry
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) BindSocketConnection(_ Connection, _ Writer, _ Reader) (Writer, Reader, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
93
pkg/middleware/encrypt/encryptor/aes256.go
Normal file
93
pkg/middleware/encrypt/encryptor/aes256.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package encryptor
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type AES256Encryptor struct {
|
||||
encryptor cipher.AEAD
|
||||
decryptor cipher.AEAD
|
||||
sessionID types.EncryptionSessionID
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) SetKeys(encryptorKey, decryptorKey types.Key) error {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
|
||||
// Setup encryption AEAD
|
||||
encBlock, err := aes.NewCipher(encryptorKey[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
|
||||
}
|
||||
|
||||
encGCM, err := cipher.NewGCM(encBlock)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
|
||||
}
|
||||
|
||||
// Setup decryption AEAD
|
||||
decBlock, err := aes.NewCipher(decryptorKey[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
|
||||
}
|
||||
|
||||
decGCM, err := cipher.NewGCM(decBlock)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", encryptionerr.ErrInvalidKey, err)
|
||||
}
|
||||
|
||||
a.encryptor = encGCM
|
||||
a.decryptor = decGCM
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) Encrypt(msg message.Message) (message.Message, error) {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
|
||||
nonce := types.Nonce{}
|
||||
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
data, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) Decrypt(message message.Message) (message.Message, error) {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) SetSessionID(id types.EncryptionSessionID) {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
|
||||
a.sessionID = id
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) Ready() bool {
|
||||
a.mux.RLock()
|
||||
defer a.mux.RUnlock()
|
||||
|
||||
return a.encryptor != nil && a.decryptor != nil
|
||||
}
|
||||
|
||||
func (a *AES256Encryptor) Close() error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
42
pkg/middleware/encrypt/encryptor/encryptor_messages.go
Normal file
42
pkg/middleware/encrypt/encryptor/encryptor_messages.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package encryptor
|
||||
|
||||
import (
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
)
|
||||
|
||||
type EncryptedMessage struct {
|
||||
interceptor.BaseMessage
|
||||
}
|
||||
|
||||
func (m *EncryptedMessage) ReadProcess(_i interceptor.Interceptor, conn interceptor.Connection) error {
|
||||
i, ok := _i.(interfaces.CanGetState)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
s, err := i.GetState(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ss, ok := s.(interfaces.CanDecrypt)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
msg, err := ss.Decrypt(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: message.Registry is not implemented yet
|
||||
decrytpedMsg, err := message.Registry().Unmarshal(m.NextProtocol, m.NextPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
64
pkg/middleware/encrypt/encryptor/skipper.go
Normal file
64
pkg/middleware/encrypt/encryptor/skipper.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package encryptor
|
||||
|
||||
import (
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/types"
|
||||
)
|
||||
|
||||
type SkipEncryptionChecker func(message.Message) bool
|
||||
|
||||
// NewProtocolSkipChecker creates a checker that skips encryption for specific protocols
|
||||
func NewProtocolSkipChecker(protocolsToSkip ...message.Protocol) SkipEncryptionChecker {
|
||||
skipMap := make(map[message.Protocol]bool)
|
||||
for _, protocol := range protocolsToSkip {
|
||||
skipMap[protocol] = true
|
||||
}
|
||||
|
||||
return func(msg message.Message) bool {
|
||||
return skipMap[msg.GetProtocol()]
|
||||
}
|
||||
}
|
||||
|
||||
type SkipperEncryptor struct {
|
||||
wrapped interfaces.Encryptor
|
||||
skip SkipEncryptionChecker
|
||||
}
|
||||
|
||||
func NewSkipperEncryptor(wrapped interfaces.Encryptor, skip SkipEncryptionChecker) *SkipperEncryptor {
|
||||
return &SkipperEncryptor{
|
||||
wrapped: wrapped,
|
||||
skip: skip,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *SkipperEncryptor) SetSessionID(id types.EncryptionSessionID) {
|
||||
e.wrapped.SetSessionID(id)
|
||||
}
|
||||
|
||||
func (e *SkipperEncryptor) SetKeys(encryptorKey, decryptorKey types.Key) error {
|
||||
return e.wrapped.SetKeys(encryptorKey, decryptorKey)
|
||||
}
|
||||
|
||||
func (e *SkipperEncryptor) Encrypt(msg message.Message) (message.Message, error) {
|
||||
if e.skip(msg) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
return e.wrapped.Encrypt(msg)
|
||||
}
|
||||
|
||||
func (e *SkipperEncryptor) Decrypt(msg message.Message) (message.Message, error) {
|
||||
if !e.skip(msg) {
|
||||
return e.wrapped.Decrypt(msg)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (e *SkipperEncryptor) Ready() bool {
|
||||
return e.wrapped.Ready()
|
||||
}
|
||||
|
||||
func (e *SkipperEncryptor) Close() error {
|
||||
return e.wrapped.Close()
|
||||
}
|
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/config"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/encryptionerr"
|
||||
"github.com/harshabose/socket-comm/pkg/middleware/encrypt/interfaces"
|
||||
@@ -53,23 +54,71 @@ func (i *Interceptor) Init(connection interceptor.Connection) error {
|
||||
|
||||
waiter := keyexchange.NewSessionStateTargetWaiter(ctx, types.SessionStateCompleted)
|
||||
|
||||
return i.Process(waiter, s)
|
||||
if err := i.Process(waiter, s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return i.keyExchangeManager.Finalise(s)
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(conn interceptor.Connection, msg message.Message) error {
|
||||
s, err := i.GetState(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ss, ok := s.(interfaces.CanEncrypt)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
encrypted, err := ss.Encrypt(msg)
|
||||
if err != nil {
|
||||
if !s.GetConfig().RequireEncryption {
|
||||
return writer.Write(conn, msg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return writer.Write(conn, encrypted)
|
||||
})
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(conn interceptor.Connection) (message.Message, error) {
|
||||
msg, err := reader.Read(conn)
|
||||
if err != nil {
|
||||
return msg, err
|
||||
}
|
||||
|
||||
s, err := i.GetState(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss, ok := s.(interfaces.CanDecrypt)
|
||||
if !ok {
|
||||
return msg, err
|
||||
}
|
||||
|
||||
m, err := ss.Decrypt(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) {
|
||||
|
||||
// TODO: Implement full closing
|
||||
}
|
||||
|
||||
func (i *Interceptor) Close() error {
|
||||
|
||||
// TODO: Use UnBindSocketConnection to close all
|
||||
// TODO: Close interceptor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Interceptor) GetState(connection interceptor.Connection) (interfaces.State, error) {
|
||||
|
@@ -16,19 +16,25 @@ type KeyGetter interface {
|
||||
GetKeys() (encKey, decKey types.Key, err error)
|
||||
}
|
||||
|
||||
type CanEncrypt interface {
|
||||
// Encrypt encrypts a message between sender and receiver
|
||||
Encrypt(message message.Message) (message.Message, error)
|
||||
}
|
||||
|
||||
type CanDecrypt interface {
|
||||
// Decrypt decrypts an encrypted message in-place
|
||||
Decrypt(message message.Message) (message.Message, error)
|
||||
}
|
||||
|
||||
// Encryptor defines the interface for message encryption and decryption
|
||||
type Encryptor interface {
|
||||
KeySetter
|
||||
CanEncrypt
|
||||
CanDecrypt
|
||||
|
||||
// SetSessionID sets the session identifier for this encryption session
|
||||
SetSessionID(id types.EncryptionSessionID)
|
||||
|
||||
// Encrypt encrypts a message between sender and receiver
|
||||
Encrypt(senderID, receiverID string, message message.Message) (message.Message, error)
|
||||
|
||||
// Decrypt decrypts an encrypted message in-place
|
||||
Decrypt(message message.Message) error
|
||||
|
||||
// Ready checks if the encryptor is properly initialized and ready to use
|
||||
Ready() bool
|
||||
|
||||
|
@@ -10,6 +10,7 @@ type ProtocolProcessor interface {
|
||||
|
||||
type KeyExchangeManager interface {
|
||||
Init(state State, options ...ProtocolFactoryOption) error
|
||||
Finalise(state State) error
|
||||
}
|
||||
|
||||
type CanGetSessionState interface {
|
||||
|
@@ -225,21 +225,50 @@ func (m *Done) ReadProcess(_i interceptor.Interceptor, connection interceptor.Co
|
||||
return pp.Process(m, s)
|
||||
}
|
||||
|
||||
func (m *Done) Process(protocol interfaces.Protocol, _s interfaces.State) error {
|
||||
ss, ok := _s.(interfaces.KeySetter)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
func (m *Done) Process(protocol interfaces.Protocol, s interfaces.State) error {
|
||||
p, ok := protocol.(*Curve25519Protocol)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
if err := ss.SetKeys(p.encKey, p.decKey); err != nil {
|
||||
msg, err := NewDoneResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.WriteMessage(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.state = types.SessionStateCompleted
|
||||
return nil
|
||||
}
|
||||
|
||||
type DoneResponse struct {
|
||||
Done
|
||||
}
|
||||
|
||||
func NewDoneResponse() (*DoneResponse, error) {
|
||||
msg := &DoneResponse{
|
||||
Done: Done{
|
||||
Timestamp: time.Now(),
|
||||
},
|
||||
}
|
||||
bmsg, err := interceptor.NewBaseMessage(message.NoneProtocol, nil, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.BaseMessage = bmsg
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (m *DoneResponse) Process(protocol interfaces.Protocol, s interfaces.State) error {
|
||||
p, ok := protocol.(*Curve25519Protocol)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidMessageType
|
||||
}
|
||||
|
||||
p.state = types.SessionStateCompleted
|
||||
return nil
|
||||
}
|
||||
|
@@ -55,6 +55,31 @@ func (m *Manager) Init(s interfaces.State, options ...interfaces.ProtocolFactory
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Finalise(s interfaces.State) error {
|
||||
sessionID := s.GetKeyExchangeSessionID()
|
||||
session, exists := m.sessions[sessionID]
|
||||
if !exists {
|
||||
return encryptionerr.ErrExchangeNotComplete
|
||||
}
|
||||
|
||||
ss, ok := s.(interfaces.KeySetter)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
p, ok := (session.protocol).(interfaces.KeyGetter)
|
||||
if !ok {
|
||||
return encryptionerr.ErrInvalidInterceptor
|
||||
}
|
||||
|
||||
encKey, decKey, err := p.GetKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ss.SetKeys(encKey, decKey)
|
||||
}
|
||||
|
||||
func (m *Manager) Process(msg interfaces.CanProcess, s interfaces.State) error {
|
||||
session, exists := m.sessions[s.GetKeyExchangeSessionID()]
|
||||
if !exists {
|
||||
|
@@ -91,3 +91,7 @@ func (s *State) SetKeys(encKey, decKey types.Key) error {
|
||||
|
||||
return s.encryptor.SetKeys(encKey, decKey)
|
||||
}
|
||||
|
||||
func (s *State) Decrypt(msg message.Message) (message.Message, error) {
|
||||
return s.encryptor.Decrypt(msg)
|
||||
}
|
||||
|
Reference in New Issue
Block a user