mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-05 07:36:54 +08:00
general commit
This commit is contained in:
@@ -46,6 +46,7 @@ type Factory interface {
|
|||||||
type Connection interface {
|
type Connection interface {
|
||||||
Write(ctx context.Context, p []byte) error
|
Write(ctx context.Context, p []byte) error
|
||||||
Read(ctx context.Context) ([]byte, error)
|
Read(ctx context.Context) ([]byte, error)
|
||||||
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interceptor interface {
|
type Interceptor interface {
|
||||||
|
@@ -59,4 +59,12 @@ func (f *InterceptorFactory) NewInterceptor(ctx context.Context, id string, regi
|
|||||||
writeProcessMessages: message.NewDefaultRegistry(),
|
writeProcessMessages: message.NewDefaultRegistry(),
|
||||||
states: state.NewManager(),
|
states: state.NewManager(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, option := range f.options {
|
||||||
|
if err := option(i); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return i, nil
|
||||||
}
|
}
|
||||||
|
@@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||||
"github.com/harshabose/socket-comm/pkg/message"
|
"github.com/harshabose/socket-comm/pkg/message"
|
||||||
"github.com/harshabose/socket-comm/pkg/transport/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type API struct {
|
type API struct {
|
||||||
@@ -43,13 +42,8 @@ func NewAPI(options ...APIOption) (*API, error) {
|
|||||||
|
|
||||||
// TODO: MAKE REGISTRIES TO NON POINTERS
|
// TODO: MAKE REGISTRIES TO NON POINTERS
|
||||||
|
|
||||||
func (a *API) NewSocket(ctx context.Context, id types.SocketID, options ...Option) (*Socket, error) {
|
func (a *API) NewSocket(ctx context.Context, options ...Option) (*Socket, error) {
|
||||||
s := &Socket{
|
s := NewSocket(NewDefaultSettings(), a.messagesRegistry)
|
||||||
ID: id,
|
|
||||||
settings: NewDefaultSettings(),
|
|
||||||
messageRegistry: a.messagesRegistry,
|
|
||||||
ctx: ctx,
|
|
||||||
}
|
|
||||||
|
|
||||||
interceptors, err := a.interceptorRegistry.Build(s.ctx, string(s.ID))
|
interceptors, err := a.interceptorRegistry.Build(s.ctx, string(s.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -3,6 +3,9 @@ package socket
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
)
|
)
|
||||||
@@ -10,35 +13,190 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrNotSupportedMessageType = errors.New("not supported message type")
|
ErrNotSupportedMessageType = errors.New("not supported message type")
|
||||||
ErrConnectionClosed = errors.New("connection closed")
|
ErrConnectionClosed = errors.New("connection closed")
|
||||||
ErrServerClosed = errors.New("server closed")
|
|
||||||
ErrReaderWriterAlreadySet = errors.New("reader and writer already set")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type adaptor struct {
|
type adaptor struct {
|
||||||
|
connectionSettings
|
||||||
id string
|
id string
|
||||||
conn *websocket.Conn
|
conn *websocket.Conn
|
||||||
|
readQ Buffer[[]byte]
|
||||||
|
writeQ Buffer[[]byte]
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
closeErr error
|
||||||
|
closeErrMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAdaptor(id string, conn *websocket.Conn) *adaptor {
|
func newAdaptor(ctx context.Context, id string, conn *websocket.Conn, readTimeout time.Duration, writeTimeout time.Duration) *adaptor {
|
||||||
|
// Create a child context with cancellation
|
||||||
|
childCtx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &adaptor{
|
return &adaptor{
|
||||||
|
connectionSettings: connectionSettings{
|
||||||
|
ReadTimeout: readTimeout,
|
||||||
|
WriteTimeout: writeTimeout,
|
||||||
|
},
|
||||||
|
ctx: childCtx,
|
||||||
|
cancel: cancel,
|
||||||
id: id,
|
id: id,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write pushes the message of the type '[]byte' to the WriteQ, which will be later sent through the socket
|
||||||
func (a *adaptor) Write(ctx context.Context, p []byte) error {
|
func (a *adaptor) Write(ctx context.Context, p []byte) error {
|
||||||
return a.conn.Write(ctx, websocket.MessageText, p)
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a copy of the message to prevent data races
|
||||||
|
msgCopy := make([]byte, len(p))
|
||||||
|
copy(msgCopy, p)
|
||||||
|
|
||||||
|
return a.writeQ.Push(ctx, msgCopy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read reads a message of the type '[]byte' from the ReadQ, which was read from the websocket.
|
||||||
func (a *adaptor) Read(ctx context.Context) ([]byte, error) {
|
func (a *adaptor) Read(ctx context.Context) ([]byte, error) {
|
||||||
msgType, p, err := a.conn.Read(ctx)
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := a.readQ.Pop(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if msgType != websocket.MessageText {
|
// Return a copy to prevent data races
|
||||||
return nil, ErrNotSupportedMessageType
|
dataCopy := make([]byte, len(data))
|
||||||
|
copy(dataCopy, data)
|
||||||
|
|
||||||
|
return dataCopy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *adaptor) StartReaderWriter() {
|
||||||
|
a.wg.Add(2) // One for the reader, one for the writer
|
||||||
|
go a.Writer()
|
||||||
|
go a.Reader()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *adaptor) Writer() {
|
||||||
|
defer a.wg.Done()
|
||||||
|
defer a.closeWithError(ErrConnectionClosed)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-a.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Use a timeout context for the write operation
|
||||||
|
writeCtx, cancel := context.WithTimeout(a.ctx, a.WriteTimeout)
|
||||||
|
p, err := a.writeQ.Pop(writeCtx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||||
|
// Check if context was cancelled due to connection close
|
||||||
|
select {
|
||||||
|
case <-a.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Just a timeout, continue
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return p, nil
|
fmt.Printf("Error while popping message from WriteQ; err: %s\n", err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a timeout context for the websocket write
|
||||||
|
writeCtx, cancel = context.WithTimeout(a.ctx, a.WriteTimeout)
|
||||||
|
err = a.conn.Write(writeCtx, websocket.MessageText, p)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error while writing message to socket; err: %s\n", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *adaptor) Reader() {
|
||||||
|
defer a.wg.Done()
|
||||||
|
defer a.closeWithError(ErrConnectionClosed)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-a.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Use a timeout context for the read operation
|
||||||
|
readCtx, cancel := context.WithTimeout(a.ctx, a.ReadTimeout)
|
||||||
|
msgType, p, err := a.conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error while reading message from socket; err: %s\n", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != websocket.MessageText {
|
||||||
|
fmt.Printf("Error while reading message from socket; err: %s\n", ErrNotSupportedMessageType.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a background context since we want to buffer the message even if the operation takes time
|
||||||
|
err = a.readQ.Push(a.ctx, p)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error while pushing message to ReadQ; err: %s\n", err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close initiates a graceful shutdown of the connection
|
||||||
|
func (a *adaptor) Close() error {
|
||||||
|
var err error
|
||||||
|
a.closeOnce.Do(func() {
|
||||||
|
// Cancel the context to signal all goroutines to stop
|
||||||
|
a.cancel()
|
||||||
|
|
||||||
|
// Try to send a close message to the peer
|
||||||
|
closeErr := a.conn.Close(websocket.StatusNormalClosure, "connection closed")
|
||||||
|
if closeErr != nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the buffers
|
||||||
|
a.readQ.Close()
|
||||||
|
a.writeQ.Close()
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeWithError stores the error that caused the connection to close
|
||||||
|
func (a *adaptor) closeWithError(err error) {
|
||||||
|
a.closeErrMu.Lock()
|
||||||
|
if a.closeErr == nil {
|
||||||
|
a.closeErr = err
|
||||||
|
}
|
||||||
|
a.closeErrMu.Unlock()
|
||||||
|
a.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCloseError returns the error that caused the connection to close
|
||||||
|
func (a *adaptor) GetCloseError() error {
|
||||||
|
a.closeErrMu.Lock()
|
||||||
|
defer a.closeErrMu.Unlock()
|
||||||
|
return a.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitUntilClose blocks until all reader and writer goroutines exit
|
||||||
|
func (a *adaptor) WaitUntilClose() {
|
||||||
|
a.wg.Wait()
|
||||||
}
|
}
|
||||||
|
@@ -2,158 +2,39 @@ package socket
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// TODO: LIMIT AND SELF KILL BUFFER.
|
||||||
ErrorElementUnallocated = errors.New("encountered nil in the buffer. this should not happen. check usage")
|
// ELEMENT IS RESPONSIBLE TO KILL AND FREE ITS MEMORY. IT SHOULD ALSO DELETE ITSELF FROM THE PARENT
|
||||||
ErrorChannelBufferClose = errors.New("channel buffer has be closed. cannot perform this operation")
|
// HARD LIMIT ON NUMBER OF BUFFERS
|
||||||
)
|
|
||||||
|
|
||||||
type Pool[T any] interface {
|
type Buffered[T any] struct {
|
||||||
Get() T
|
element T
|
||||||
Put(T)
|
|
||||||
Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Buffer[T any] interface {
|
|
||||||
Push(context.Context, T) error
|
|
||||||
Pop(ctx context.Context) (T, error)
|
|
||||||
Size() int
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChannelBuffer[T any] struct {
|
|
||||||
pool Pool[T]
|
|
||||||
bufferChannel chan T
|
|
||||||
inputBuffer chan T
|
|
||||||
closed bool
|
|
||||||
mux sync.RWMutex
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateChannelBuffer[T any](ctx context.Context, size int, pool Pool[T]) *ChannelBuffer[T] {
|
type Buffer[T any] interface {
|
||||||
buffer := &ChannelBuffer[T]{
|
Pop(context.Context) (T, error)
|
||||||
pool: pool,
|
Push(context.Context, T) error
|
||||||
bufferChannel: make(chan T, size),
|
Close()
|
||||||
inputBuffer: make(chan T),
|
|
||||||
closed: false,
|
|
||||||
ctx: ctx,
|
|
||||||
}
|
|
||||||
go buffer.loop()
|
|
||||||
return buffer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) Push(ctx context.Context, element T) error {
|
type LimitKillBuffer[T any] struct {
|
||||||
buffer.mux.RLock()
|
buffer chan Buffered[T]
|
||||||
defer buffer.mux.RUnlock()
|
|
||||||
|
|
||||||
if buffer.closed {
|
|
||||||
return errors.New("buffer closed")
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case buffer.inputBuffer <- element:
|
|
||||||
// WARN: LACKS CHECKS FOR CLOSED CHANNEL
|
|
||||||
return nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) Pop(ctx context.Context) (T, error) {
|
func (b *LimitKillBuffer[T]) Pop(ctx context.Context) (T, error) {
|
||||||
buffer.mux.RLock()
|
|
||||||
defer buffer.mux.RUnlock()
|
|
||||||
|
|
||||||
if buffer.closed {
|
|
||||||
var t T
|
|
||||||
return t, errors.New("buffer closed")
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
var t T
|
|
||||||
return t, ctx.Err()
|
|
||||||
case data, ok := <-buffer.bufferChannel:
|
|
||||||
if !ok {
|
|
||||||
var t T
|
|
||||||
return t, ErrorChannelBufferClose
|
|
||||||
}
|
|
||||||
if data == nil {
|
|
||||||
var t T
|
|
||||||
return t, ErrorElementUnallocated
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) Generate() T {
|
func (b *LimitKillBuffer[T]) Push(ctx context.Context, element T) error {
|
||||||
return buffer.pool.Get()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) PutBack(element T) {
|
func (b *LimitKillBuffer[T]) Close() {
|
||||||
if buffer.pool != nil {
|
|
||||||
buffer.pool.Put(element)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) GetChannel() chan T {
|
func (b *LimitKillBuffer[T]) manager() {
|
||||||
return buffer.bufferChannel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) Size() int {
|
|
||||||
return len(buffer.bufferChannel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) loop() {
|
|
||||||
defer buffer.close()
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-buffer.ctx.Done():
|
|
||||||
return
|
|
||||||
case element, ok := <-buffer.inputBuffer:
|
|
||||||
if !ok || element == nil {
|
|
||||||
continue loop
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case buffer.bufferChannel <- element: // SUCCESSFULLY BUFFERED
|
|
||||||
continue loop
|
|
||||||
default:
|
|
||||||
select {
|
|
||||||
case oldElement := <-buffer.bufferChannel:
|
|
||||||
buffer.PutBack(oldElement)
|
|
||||||
select {
|
|
||||||
case buffer.bufferChannel <- element:
|
|
||||||
continue loop
|
|
||||||
default:
|
|
||||||
fmt.Println("unexpected buffer state. skipping the element..")
|
|
||||||
buffer.PutBack(element)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (buffer *ChannelBuffer[T]) close() {
|
|
||||||
buffer.mux.Lock()
|
|
||||||
buffer.closed = true
|
|
||||||
buffer.mux.Unlock()
|
|
||||||
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case element := <-buffer.bufferChannel:
|
|
||||||
if buffer.pool != nil {
|
|
||||||
buffer.pool.Put(element)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
close(buffer.bufferChannel)
|
|
||||||
close(buffer.inputBuffer)
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if buffer.pool != nil {
|
|
||||||
buffer.pool.Release()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@@ -1,16 +1,24 @@
|
|||||||
package socket
|
package socket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrSettingsInvalid = errors.New("server settings invalid")
|
||||||
|
|
||||||
type connectionSettings struct {
|
type connectionSettings struct {
|
||||||
ReadTimout time.Duration
|
ReadTimeout time.Duration
|
||||||
WriteTimout time.Duration
|
WriteTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
connectionSettings
|
connectionSettings
|
||||||
|
Address string
|
||||||
|
Port uint16
|
||||||
|
ReadHeaderTimeout time.Duration
|
||||||
IdleTimout time.Duration
|
IdleTimout time.Duration
|
||||||
ShutdownTimout time.Duration
|
ShutdownTimout time.Duration
|
||||||
|
|
||||||
@@ -20,19 +28,58 @@ type Settings struct {
|
|||||||
|
|
||||||
MaxConnections int
|
MaxConnections int
|
||||||
ConnectionTimeout time.Duration
|
ConnectionTimeout time.Duration
|
||||||
|
|
||||||
|
PopMessageTimeout time.Duration
|
||||||
|
PushMessageTimout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultSettings() *Settings {
|
func (s Settings) Validate() error {
|
||||||
return &Settings{
|
// TODO: IMPLEMENT THIS
|
||||||
|
return ErrSettingsInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDefaultSettings() Settings {
|
||||||
|
return Settings{
|
||||||
connectionSettings: connectionSettings{
|
connectionSettings: connectionSettings{
|
||||||
ReadTimout: 30 * time.Second,
|
ReadTimeout: time.Second,
|
||||||
WriteTimout: 30 * time.Second,
|
WriteTimeout: time.Second,
|
||||||
},
|
},
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
ConnectionTimeout: 10 * time.Second,
|
||||||
IdleTimout: 120 * time.Second,
|
IdleTimout: 120 * time.Second,
|
||||||
ShutdownTimout: 30 * time.Second,
|
ShutdownTimout: 30 * time.Second,
|
||||||
TLSCertFile: "",
|
TLSCertFile: "",
|
||||||
TLSKeyFile: "",
|
TLSKeyFile: "",
|
||||||
MaxConnections: 1000,
|
MaxConnections: 1000,
|
||||||
ConnectionTimeout: 10 * time.Second,
|
PopMessageTimeout: 30 * time.Second,
|
||||||
|
PushMessageTimout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTLSV1(tlsCertPath, tlsKeyFile string) (*tls.Config, error) {
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
if tlsCertPath != "" && tlsKeyFile != "" {
|
||||||
|
cert, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading TLS certificates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP256},
|
||||||
|
CipherSuites: []uint16{
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("invalid cert or/and file path")
|
||||||
|
}
|
||||||
|
@@ -2,7 +2,9 @@ package socket
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -19,55 +21,100 @@ const DEBUG = true
|
|||||||
|
|
||||||
type Option func(*Socket) error
|
type Option func(*Socket) error
|
||||||
|
|
||||||
|
// Metrics holds server statistics
|
||||||
|
type Metrics struct {
|
||||||
|
ActiveConnections int
|
||||||
|
TotalConnections int
|
||||||
|
FailedConnections int
|
||||||
|
mux sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
type Socket struct {
|
type Socket struct {
|
||||||
ID types.SocketID `json:"id"`
|
ID types.SocketID `json:"id"`
|
||||||
server *http.Server
|
server *http.Server
|
||||||
router *http.ServeMux
|
router *http.ServeMux
|
||||||
handlerFunc http.HandlerFunc
|
settings Settings
|
||||||
settings *Settings
|
|
||||||
interceptor interceptor.Interceptor
|
interceptor interceptor.Interceptor
|
||||||
|
connections map[string]interceptor.Connection
|
||||||
|
metrics *Metrics
|
||||||
messageRegistry message.Registry
|
messageRegistry message.Registry
|
||||||
mux sync.RWMutex
|
cancel context.CancelFunc
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
mux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSocket(ctx context.Context, settings Settings, registry message.Registry) *Socket {
|
||||||
|
ctx2, cancel := context.WithCancel(ctx)
|
||||||
|
return &Socket{
|
||||||
|
ID: types.SocketID(uuid.NewString()),
|
||||||
|
settings: settings,
|
||||||
|
messageRegistry: registry,
|
||||||
|
connections: make(map[string]interceptor.Connection),
|
||||||
|
metrics: &Metrics{},
|
||||||
|
cancel: cancel,
|
||||||
|
ctx: ctx2,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Socket) GetID() types.SocketID {
|
func (s *Socket) GetID() types.SocketID {
|
||||||
return s.ID
|
return s.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Socket) Ctx(l net.Listener) context.Context {
|
||||||
|
return s.ctx
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Socket) Init() error {
|
func (s *Socket) Init() error {
|
||||||
s.mux.Lock()
|
if err := s.settings.Validate(); err != nil {
|
||||||
defer s.mux.Unlock()
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err := GetTLSV1(s.settings.TLSCertFile, s.settings.TLSKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
s.router = http.NewServeMux()
|
s.router = http.NewServeMux()
|
||||||
s.server = &http.Server{
|
s.server = &http.Server{
|
||||||
ReadTimeout: s.settings.ReadTimout,
|
Addr: fmt.Sprintf("%s:%d", s.settings.Address, s.settings.Port),
|
||||||
WriteTimeout: s.settings.WriteTimout,
|
ReadTimeout: s.settings.PopMessageTimeout,
|
||||||
|
WriteTimeout: s.settings.PushMessageTimout,
|
||||||
|
ReadHeaderTimeout: s.settings.ReadHeaderTimeout,
|
||||||
IdleTimeout: s.settings.IdleTimout,
|
IdleTimeout: s.settings.IdleTimout,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
Handler: s.router,
|
||||||
|
BaseContext: s.Ctx,
|
||||||
// TODO: MAYBE ADD MORE
|
// TODO: MAYBE ADD MORE
|
||||||
}
|
}
|
||||||
s.handlerFunc = s.handler
|
|
||||||
|
|
||||||
|
// Set up routes
|
||||||
|
s.router.HandleFunc("/ws", s.handleWebSocket)
|
||||||
|
s.router.HandleFunc("/health", s.handleHealth)
|
||||||
|
s.router.HandleFunc("/metrics", s.handleMetrics)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Socket) Serve() error {
|
func (s *Socket) Serve() error {
|
||||||
defer s.Close()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-s.ctx.Done():
|
case <-s.ctx.Done():
|
||||||
return nil // TODO: add error
|
return nil // TODO: add error
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if DEBUG {
|
if s.server.TLSConfig != nil {
|
||||||
if err := s.server.ListenAndServe(); err != nil {
|
if err := s.server.ListenAndServe(); err != nil {
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
fmt.Println("error while serving; err: ", err.Error())
|
fmt.Println("error while serving; err: ", err.Error())
|
||||||
fmt.Println("retrying...")
|
fmt.Println("retrying...")
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if err := s.server.ListenAndServeTLS(s.settings.TLSCertFile, s.settings.TLSKeyFile); err != nil {
|
if err := s.server.ListenAndServeTLS(s.settings.TLSCertFile, s.settings.TLSKeyFile); err != nil {
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
fmt.Println("error while serving; err: ", err.Error())
|
fmt.Println("error while serving; err: ", err.Error())
|
||||||
fmt.Println("retrying...")
|
fmt.Println("retrying...")
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
@@ -76,12 +123,8 @@ func (s *Socket) Serve() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Socket) Close() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (message.Message, error) {
|
func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (message.Message, error) {
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, s.settings.ReadTimout)
|
ctx, cancel := context.WithTimeout(s.ctx, s.settings.PopMessageTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
data, err := connection.Read(ctx)
|
data, err := connection.Read(ctx)
|
||||||
@@ -93,7 +136,7 @@ func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, msg message.Message) error {
|
func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, msg message.Message) error {
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, s.settings.WriteTimout)
|
ctx, cancel := context.WithTimeout(s.ctx, s.settings.PushMessageTimout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
data, err := msg.Marshal()
|
data, err := msg.Marshal()
|
||||||
@@ -104,38 +147,119 @@ func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, m
|
|||||||
return connection.Write(ctx, data)
|
return connection.Write(ctx, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Socket) handler(writer http.ResponseWriter, request *http.Request) {
|
// registerConnection adds a new connection
|
||||||
|
func (s *Socket) registerConnection(id string, conn interceptor.Connection) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.connections[id] = conn
|
||||||
|
|
||||||
|
s.metrics.mux.Lock()
|
||||||
|
s.metrics.ActiveConnections++
|
||||||
|
s.metrics.TotalConnections++
|
||||||
|
s.metrics.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// unregisterConnection removes a connection
|
||||||
|
func (s *Socket) unregisterConnection(id string) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
if _, exists := s.connections[id]; exists {
|
||||||
|
delete(s.connections, id)
|
||||||
|
|
||||||
|
s.metrics.mux.Lock()
|
||||||
|
s.metrics.ActiveConnections--
|
||||||
|
s.metrics.mux.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeAllConnections closes all active connections
|
||||||
|
func (s *Socket) closeAllConnections() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
for _, conn := range s.connections {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
fmt.Println("error while closing a connection; err:", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) handleWebSocket(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
s.metrics.mux.RLock()
|
||||||
|
currentConnections := s.metrics.ActiveConnections
|
||||||
|
s.metrics.mux.RUnlock()
|
||||||
|
|
||||||
|
if currentConnections >= s.settings.MaxConnections {
|
||||||
|
fmt.Println("Connection limit reached", "active", currentConnections)
|
||||||
|
http.Error(writer, "Service Unavailable", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := websocket.Accept(writer, request, nil)
|
conn, err := websocket.Accept(writer, request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.metrics.mux.Lock()
|
||||||
|
s.metrics.FailedConnections++
|
||||||
|
s.metrics.mux.Unlock()
|
||||||
|
|
||||||
fmt.Println("error while handling client; removing client...")
|
fmt.Println("error while handling client; removing client...")
|
||||||
}
|
|
||||||
connection := newAdaptor(uuid.NewString(), conn)
|
|
||||||
|
|
||||||
w, r, err := s.interceptor.BindSocketConnection(connection, s, s)
|
|
||||||
if err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
iD := uuid.NewString()
|
||||||
select {
|
connection := newAdaptor(request.Context(), iD, conn, s.settings.PopMessageTimeout, s.settings.PushMessageTimout)
|
||||||
case <-request.Context().Done():
|
|
||||||
return
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
msg, err := r.Read(s.ctx, connection)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println(msg)
|
s.registerConnection(iD, connection)
|
||||||
}
|
defer s.unregisterConnection(iD)
|
||||||
}
|
|
||||||
|
|
||||||
|
connection.StartReaderWriter()
|
||||||
|
|
||||||
|
if _, _, err := s.interceptor.BindSocketConnection(connection, s, s); err != nil {
|
||||||
|
fmt.Println(fmt.Errorf("error while binding socket to interceptors; err: %s", err.Error()))
|
||||||
|
fmt.Println("dropping client...")
|
||||||
|
return
|
||||||
|
}
|
||||||
defer s.interceptor.UnBindSocketConnection(connection)
|
defer s.interceptor.UnBindSocketConnection(connection)
|
||||||
|
|
||||||
if err := s.interceptor.Init(connection); err != nil {
|
if err := s.interceptor.Init(connection); err != nil {
|
||||||
fmt.Println("error while connection init; dropping client")
|
fmt.Println("error while connection init; dropping client")
|
||||||
|
fmt.Println("dropping client...")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
connection.WaitUntilClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) ShutDown(ctx context.Context) error {
|
||||||
|
ctx2, cancel := context.WithTimeout(ctx, s.settings.ShutdownTimout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s.cancel()
|
||||||
|
if err := s.server.Shutdown(ctx2); err != nil {
|
||||||
|
return fmt.Errorf("server shutdown error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.closeAllConnections()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHealth provides a health check endpoint
|
||||||
|
func (s *Socket) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMetrics exposes server metrics
|
||||||
|
func (s *Socket) handleMetrics(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
s.metrics.mux.RLock()
|
||||||
|
defer s.metrics.mux.RUnlock()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = fmt.Fprintf(w, `{
|
||||||
|
"active_connections": %d,
|
||||||
|
"total_connections": %d,
|
||||||
|
"failed_connections": %d
|
||||||
|
}`, s.metrics.ActiveConnections, s.metrics.TotalConnections,
|
||||||
|
s.metrics.FailedConnections)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user