mirror of
https://github.com/harshabose/socket-comm.git
synced 2025-10-04 23:32:42 +08:00
general commit
This commit is contained in:
@@ -46,6 +46,7 @@ type Factory interface {
|
||||
type Connection interface {
|
||||
Write(ctx context.Context, p []byte) error
|
||||
Read(ctx context.Context) ([]byte, error)
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type Interceptor interface {
|
||||
|
@@ -59,4 +59,12 @@ func (f *InterceptorFactory) NewInterceptor(ctx context.Context, id string, regi
|
||||
writeProcessMessages: message.NewDefaultRegistry(),
|
||||
states: state.NewManager(),
|
||||
}
|
||||
|
||||
for _, option := range f.options {
|
||||
if err := option(i); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/harshabose/socket-comm/pkg/interceptor"
|
||||
"github.com/harshabose/socket-comm/pkg/message"
|
||||
"github.com/harshabose/socket-comm/pkg/transport/types"
|
||||
)
|
||||
|
||||
type API struct {
|
||||
@@ -43,13 +42,8 @@ func NewAPI(options ...APIOption) (*API, error) {
|
||||
|
||||
// TODO: MAKE REGISTRIES TO NON POINTERS
|
||||
|
||||
func (a *API) NewSocket(ctx context.Context, id types.SocketID, options ...Option) (*Socket, error) {
|
||||
s := &Socket{
|
||||
ID: id,
|
||||
settings: NewDefaultSettings(),
|
||||
messageRegistry: a.messagesRegistry,
|
||||
ctx: ctx,
|
||||
}
|
||||
func (a *API) NewSocket(ctx context.Context, options ...Option) (*Socket, error) {
|
||||
s := NewSocket(NewDefaultSettings(), a.messagesRegistry)
|
||||
|
||||
interceptors, err := a.interceptorRegistry.Build(s.ctx, string(s.ID))
|
||||
if err != nil {
|
||||
|
@@ -3,6 +3,9 @@ package socket
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
@@ -10,35 +13,190 @@ import (
|
||||
var (
|
||||
ErrNotSupportedMessageType = errors.New("not supported message type")
|
||||
ErrConnectionClosed = errors.New("connection closed")
|
||||
ErrServerClosed = errors.New("server closed")
|
||||
ErrReaderWriterAlreadySet = errors.New("reader and writer already set")
|
||||
)
|
||||
|
||||
type adaptor struct {
|
||||
connectionSettings
|
||||
id string
|
||||
conn *websocket.Conn
|
||||
readQ Buffer[[]byte]
|
||||
writeQ Buffer[[]byte]
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
closeErr error
|
||||
closeErrMu sync.Mutex
|
||||
}
|
||||
|
||||
func newAdaptor(id string, conn *websocket.Conn) *adaptor {
|
||||
func newAdaptor(ctx context.Context, id string, conn *websocket.Conn, readTimeout time.Duration, writeTimeout time.Duration) *adaptor {
|
||||
// Create a child context with cancellation
|
||||
childCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &adaptor{
|
||||
connectionSettings: connectionSettings{
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
},
|
||||
ctx: childCtx,
|
||||
cancel: cancel,
|
||||
id: id,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Write pushes the message of the type '[]byte' to the WriteQ, which will be later sent through the socket
|
||||
func (a *adaptor) Write(ctx context.Context, p []byte) error {
|
||||
return a.conn.Write(ctx, websocket.MessageText, p)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Create a copy of the message to prevent data races
|
||||
msgCopy := make([]byte, len(p))
|
||||
copy(msgCopy, p)
|
||||
|
||||
return a.writeQ.Push(ctx, msgCopy)
|
||||
}
|
||||
|
||||
// Read reads a message of the type '[]byte' from the ReadQ, which was read from the websocket.
|
||||
func (a *adaptor) Read(ctx context.Context) ([]byte, error) {
|
||||
msgType, p, err := a.conn.Read(ctx)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
data, err := a.readQ.Pop(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msgType != websocket.MessageText {
|
||||
return nil, ErrNotSupportedMessageType
|
||||
// Return a copy to prevent data races
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
return dataCopy, nil
|
||||
}
|
||||
|
||||
return p, nil
|
||||
func (a *adaptor) StartReaderWriter() {
|
||||
a.wg.Add(2) // One for the reader, one for the writer
|
||||
go a.Writer()
|
||||
go a.Reader()
|
||||
}
|
||||
|
||||
func (a *adaptor) Writer() {
|
||||
defer a.wg.Done()
|
||||
defer a.closeWithError(ErrConnectionClosed)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Use a timeout context for the write operation
|
||||
writeCtx, cancel := context.WithTimeout(a.ctx, a.WriteTimeout)
|
||||
p, err := a.writeQ.Pop(writeCtx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
// Check if context was cancelled due to connection close
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Just a timeout, continue
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Error while popping message from WriteQ; err: %s\n", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// Use a timeout context for the websocket write
|
||||
writeCtx, cancel = context.WithTimeout(a.ctx, a.WriteTimeout)
|
||||
err = a.conn.Write(writeCtx, websocket.MessageText, p)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Error while writing message to socket; err: %s\n", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *adaptor) Reader() {
|
||||
defer a.wg.Done()
|
||||
defer a.closeWithError(ErrConnectionClosed)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Use a timeout context for the read operation
|
||||
readCtx, cancel := context.WithTimeout(a.ctx, a.ReadTimeout)
|
||||
msgType, p, err := a.conn.Read(readCtx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Error while reading message from socket; err: %s\n", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if msgType != websocket.MessageText {
|
||||
fmt.Printf("Error while reading message from socket; err: %s\n", ErrNotSupportedMessageType.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// Use a background context since we want to buffer the message even if the operation takes time
|
||||
err = a.readQ.Push(a.ctx, p)
|
||||
if err != nil {
|
||||
fmt.Printf("Error while pushing message to ReadQ; err: %s\n", err.Error())
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close initiates a graceful shutdown of the connection
|
||||
func (a *adaptor) Close() error {
|
||||
var err error
|
||||
a.closeOnce.Do(func() {
|
||||
// Cancel the context to signal all goroutines to stop
|
||||
a.cancel()
|
||||
|
||||
// Try to send a close message to the peer
|
||||
closeErr := a.conn.Close(websocket.StatusNormalClosure, "connection closed")
|
||||
if closeErr != nil {
|
||||
err = closeErr
|
||||
}
|
||||
|
||||
// Close the buffers
|
||||
a.readQ.Close()
|
||||
a.writeQ.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// closeWithError stores the error that caused the connection to close
|
||||
func (a *adaptor) closeWithError(err error) {
|
||||
a.closeErrMu.Lock()
|
||||
if a.closeErr == nil {
|
||||
a.closeErr = err
|
||||
}
|
||||
a.closeErrMu.Unlock()
|
||||
a.Close()
|
||||
}
|
||||
|
||||
// GetCloseError returns the error that caused the connection to close
|
||||
func (a *adaptor) GetCloseError() error {
|
||||
a.closeErrMu.Lock()
|
||||
defer a.closeErrMu.Unlock()
|
||||
return a.closeErr
|
||||
}
|
||||
|
||||
// WaitUntilClose blocks until all reader and writer goroutines exit
|
||||
func (a *adaptor) WaitUntilClose() {
|
||||
a.wg.Wait()
|
||||
}
|
||||
|
@@ -2,158 +2,39 @@ package socket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorElementUnallocated = errors.New("encountered nil in the buffer. this should not happen. check usage")
|
||||
ErrorChannelBufferClose = errors.New("channel buffer has be closed. cannot perform this operation")
|
||||
)
|
||||
// TODO: LIMIT AND SELF KILL BUFFER.
|
||||
// ELEMENT IS RESPONSIBLE TO KILL AND FREE ITS MEMORY. IT SHOULD ALSO DELETE ITSELF FROM THE PARENT
|
||||
// HARD LIMIT ON NUMBER OF BUFFERS
|
||||
|
||||
type Pool[T any] interface {
|
||||
Get() T
|
||||
Put(T)
|
||||
Release()
|
||||
}
|
||||
|
||||
type Buffer[T any] interface {
|
||||
Push(context.Context, T) error
|
||||
Pop(ctx context.Context) (T, error)
|
||||
Size() int
|
||||
}
|
||||
|
||||
type ChannelBuffer[T any] struct {
|
||||
pool Pool[T]
|
||||
bufferChannel chan T
|
||||
inputBuffer chan T
|
||||
closed bool
|
||||
mux sync.RWMutex
|
||||
type Buffered[T any] struct {
|
||||
element T
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func CreateChannelBuffer[T any](ctx context.Context, size int, pool Pool[T]) *ChannelBuffer[T] {
|
||||
buffer := &ChannelBuffer[T]{
|
||||
pool: pool,
|
||||
bufferChannel: make(chan T, size),
|
||||
inputBuffer: make(chan T),
|
||||
closed: false,
|
||||
ctx: ctx,
|
||||
}
|
||||
go buffer.loop()
|
||||
return buffer
|
||||
type Buffer[T any] interface {
|
||||
Pop(context.Context) (T, error)
|
||||
Push(context.Context, T) error
|
||||
Close()
|
||||
}
|
||||
|
||||
func (buffer *ChannelBuffer[T]) Push(ctx context.Context, element T) error {
|
||||
buffer.mux.RLock()
|
||||
defer buffer.mux.RUnlock()
|
||||
|
||||
if buffer.closed {
|
||||
return errors.New("buffer closed")
|
||||
}
|
||||
select {
|
||||
case buffer.inputBuffer <- element:
|
||||
// WARN: LACKS CHECKS FOR CLOSED CHANNEL
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
type LimitKillBuffer[T any] struct {
|
||||
buffer chan Buffered[T]
|
||||
}
|
||||
|
||||
func (buffer *ChannelBuffer[T]) Pop(ctx context.Context) (T, error) {
|
||||
buffer.mux.RLock()
|
||||
defer buffer.mux.RUnlock()
|
||||
func (b *LimitKillBuffer[T]) Pop(ctx context.Context) (T, error) {
|
||||
|
||||
if buffer.closed {
|
||||
var t T
|
||||
return t, errors.New("buffer closed")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
var t T
|
||||
return t, ctx.Err()
|
||||
case data, ok := <-buffer.bufferChannel:
|
||||
if !ok {
|
||||
var t T
|
||||
return t, ErrorChannelBufferClose
|
||||
}
|
||||
if data == nil {
|
||||
var t T
|
||||
return t, ErrorElementUnallocated
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (buffer *ChannelBuffer[T]) Generate() T {
|
||||
return buffer.pool.Get()
|
||||
func (b *LimitKillBuffer[T]) Push(ctx context.Context, element T) error {
|
||||
|
||||
}
|
||||
|
||||
func (buffer *ChannelBuffer[T]) PutBack(element T) {
|
||||
if buffer.pool != nil {
|
||||
buffer.pool.Put(element)
|
||||
}
|
||||
func (b *LimitKillBuffer[T]) Close() {
|
||||
|
||||
}
|
||||
|
||||
func (buffer *ChannelBuffer[T]) GetChannel() chan T {
|
||||
return buffer.bufferChannel
|
||||
}
|
||||
func (b *LimitKillBuffer[T]) manager() {
|
||||
|
||||
func (buffer *ChannelBuffer[T]) Size() int {
|
||||
return len(buffer.bufferChannel)
|
||||
}
|
||||
|
||||
func (buffer *ChannelBuffer[T]) loop() {
|
||||
defer buffer.close()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-buffer.ctx.Done():
|
||||
return
|
||||
case element, ok := <-buffer.inputBuffer:
|
||||
if !ok || element == nil {
|
||||
continue loop
|
||||
}
|
||||
select {
|
||||
case buffer.bufferChannel <- element: // SUCCESSFULLY BUFFERED
|
||||
continue loop
|
||||
default:
|
||||
select {
|
||||
case oldElement := <-buffer.bufferChannel:
|
||||
buffer.PutBack(oldElement)
|
||||
select {
|
||||
case buffer.bufferChannel <- element:
|
||||
continue loop
|
||||
default:
|
||||
fmt.Println("unexpected buffer state. skipping the element..")
|
||||
buffer.PutBack(element)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (buffer *ChannelBuffer[T]) close() {
|
||||
buffer.mux.Lock()
|
||||
buffer.closed = true
|
||||
buffer.mux.Unlock()
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case element := <-buffer.bufferChannel:
|
||||
if buffer.pool != nil {
|
||||
buffer.pool.Put(element)
|
||||
}
|
||||
default:
|
||||
close(buffer.bufferChannel)
|
||||
close(buffer.inputBuffer)
|
||||
break loop
|
||||
}
|
||||
}
|
||||
if buffer.pool != nil {
|
||||
buffer.pool.Release()
|
||||
}
|
||||
}
|
||||
|
@@ -1,16 +1,24 @@
|
||||
package socket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrSettingsInvalid = errors.New("server settings invalid")
|
||||
|
||||
type connectionSettings struct {
|
||||
ReadTimout time.Duration
|
||||
WriteTimout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
}
|
||||
|
||||
type Settings struct {
|
||||
connectionSettings
|
||||
Address string
|
||||
Port uint16
|
||||
ReadHeaderTimeout time.Duration
|
||||
IdleTimout time.Duration
|
||||
ShutdownTimout time.Duration
|
||||
|
||||
@@ -20,19 +28,58 @@ type Settings struct {
|
||||
|
||||
MaxConnections int
|
||||
ConnectionTimeout time.Duration
|
||||
|
||||
PopMessageTimeout time.Duration
|
||||
PushMessageTimout time.Duration
|
||||
}
|
||||
|
||||
func NewDefaultSettings() *Settings {
|
||||
return &Settings{
|
||||
func (s Settings) Validate() error {
|
||||
// TODO: IMPLEMENT THIS
|
||||
return ErrSettingsInvalid
|
||||
}
|
||||
|
||||
func NewDefaultSettings() Settings {
|
||||
return Settings{
|
||||
connectionSettings: connectionSettings{
|
||||
ReadTimout: 30 * time.Second,
|
||||
WriteTimout: 30 * time.Second,
|
||||
ReadTimeout: time.Second,
|
||||
WriteTimeout: time.Second,
|
||||
},
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
ConnectionTimeout: 10 * time.Second,
|
||||
IdleTimout: 120 * time.Second,
|
||||
ShutdownTimout: 30 * time.Second,
|
||||
TLSCertFile: "",
|
||||
TLSKeyFile: "",
|
||||
MaxConnections: 1000,
|
||||
ConnectionTimeout: 10 * time.Second,
|
||||
PopMessageTimeout: 30 * time.Second,
|
||||
PushMessageTimout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func GetTLSV1(tlsCertPath, tlsKeyFile string) (*tls.Config, error) {
|
||||
var tlsConfig *tls.Config
|
||||
if tlsCertPath != "" && tlsKeyFile != "" {
|
||||
cert, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading TLS certificates: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP256},
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid cert or/and file path")
|
||||
}
|
||||
|
@@ -2,7 +2,9 @@ package socket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -19,55 +21,100 @@ const DEBUG = true
|
||||
|
||||
type Option func(*Socket) error
|
||||
|
||||
// Metrics holds server statistics
|
||||
type Metrics struct {
|
||||
ActiveConnections int
|
||||
TotalConnections int
|
||||
FailedConnections int
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
type Socket struct {
|
||||
ID types.SocketID `json:"id"`
|
||||
server *http.Server
|
||||
router *http.ServeMux
|
||||
handlerFunc http.HandlerFunc
|
||||
settings *Settings
|
||||
settings Settings
|
||||
interceptor interceptor.Interceptor
|
||||
connections map[string]interceptor.Connection
|
||||
metrics *Metrics
|
||||
messageRegistry message.Registry
|
||||
mux sync.RWMutex
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func NewSocket(ctx context.Context, settings Settings, registry message.Registry) *Socket {
|
||||
ctx2, cancel := context.WithCancel(ctx)
|
||||
return &Socket{
|
||||
ID: types.SocketID(uuid.NewString()),
|
||||
settings: settings,
|
||||
messageRegistry: registry,
|
||||
connections: make(map[string]interceptor.Connection),
|
||||
metrics: &Metrics{},
|
||||
cancel: cancel,
|
||||
ctx: ctx2,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) GetID() types.SocketID {
|
||||
return s.ID
|
||||
}
|
||||
|
||||
func (s *Socket) Ctx(l net.Listener) context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *Socket) Init() error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
if err := s.settings.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tlsConfig, err := GetTLSV1(s.settings.TLSCertFile, s.settings.TLSKeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.router = http.NewServeMux()
|
||||
s.server = &http.Server{
|
||||
ReadTimeout: s.settings.ReadTimout,
|
||||
WriteTimeout: s.settings.WriteTimout,
|
||||
Addr: fmt.Sprintf("%s:%d", s.settings.Address, s.settings.Port),
|
||||
ReadTimeout: s.settings.PopMessageTimeout,
|
||||
WriteTimeout: s.settings.PushMessageTimout,
|
||||
ReadHeaderTimeout: s.settings.ReadHeaderTimeout,
|
||||
IdleTimeout: s.settings.IdleTimout,
|
||||
TLSConfig: tlsConfig,
|
||||
Handler: s.router,
|
||||
BaseContext: s.Ctx,
|
||||
// TODO: MAYBE ADD MORE
|
||||
}
|
||||
s.handlerFunc = s.handler
|
||||
|
||||
// Set up routes
|
||||
s.router.HandleFunc("/ws", s.handleWebSocket)
|
||||
s.router.HandleFunc("/health", s.handleHealth)
|
||||
s.router.HandleFunc("/metrics", s.handleMetrics)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Socket) Serve() error {
|
||||
defer s.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return nil // TODO: add error
|
||||
|
||||
default:
|
||||
if DEBUG {
|
||||
if s.server.TLSConfig != nil {
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
fmt.Println("error while serving; err: ", err.Error())
|
||||
fmt.Println("retrying...")
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := s.server.ListenAndServeTLS(s.settings.TLSCertFile, s.settings.TLSKeyFile); err != nil {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
fmt.Println("error while serving; err: ", err.Error())
|
||||
fmt.Println("retrying...")
|
||||
time.Sleep(1 * time.Second)
|
||||
@@ -76,12 +123,8 @@ func (s *Socket) Serve() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) Close() {
|
||||
|
||||
}
|
||||
|
||||
func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (message.Message, error) {
|
||||
ctx, cancel := context.WithTimeout(s.ctx, s.settings.ReadTimout)
|
||||
ctx, cancel := context.WithTimeout(s.ctx, s.settings.PopMessageTimeout)
|
||||
defer cancel()
|
||||
|
||||
data, err := connection.Read(ctx)
|
||||
@@ -93,7 +136,7 @@ func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (m
|
||||
}
|
||||
|
||||
func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, msg message.Message) error {
|
||||
ctx, cancel := context.WithTimeout(s.ctx, s.settings.WriteTimout)
|
||||
ctx, cancel := context.WithTimeout(s.ctx, s.settings.PushMessageTimout)
|
||||
defer cancel()
|
||||
|
||||
data, err := msg.Marshal()
|
||||
@@ -104,38 +147,119 @@ func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, m
|
||||
return connection.Write(ctx, data)
|
||||
}
|
||||
|
||||
func (s *Socket) handler(writer http.ResponseWriter, request *http.Request) {
|
||||
// registerConnection adds a new connection
|
||||
func (s *Socket) registerConnection(id string, conn interceptor.Connection) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.connections[id] = conn
|
||||
|
||||
s.metrics.mux.Lock()
|
||||
s.metrics.ActiveConnections++
|
||||
s.metrics.TotalConnections++
|
||||
s.metrics.mux.Unlock()
|
||||
}
|
||||
|
||||
// unregisterConnection removes a connection
|
||||
func (s *Socket) unregisterConnection(id string) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
if _, exists := s.connections[id]; exists {
|
||||
delete(s.connections, id)
|
||||
|
||||
s.metrics.mux.Lock()
|
||||
s.metrics.ActiveConnections--
|
||||
s.metrics.mux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// closeAllConnections closes all active connections
|
||||
func (s *Socket) closeAllConnections() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for _, conn := range s.connections {
|
||||
if err := conn.Close(); err != nil {
|
||||
fmt.Println("error while closing a connection; err:", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) handleWebSocket(writer http.ResponseWriter, request *http.Request) {
|
||||
s.metrics.mux.RLock()
|
||||
currentConnections := s.metrics.ActiveConnections
|
||||
s.metrics.mux.RUnlock()
|
||||
|
||||
if currentConnections >= s.settings.MaxConnections {
|
||||
fmt.Println("Connection limit reached", "active", currentConnections)
|
||||
http.Error(writer, "Service Unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(writer, request, nil)
|
||||
if err != nil {
|
||||
s.metrics.mux.Lock()
|
||||
s.metrics.FailedConnections++
|
||||
s.metrics.mux.Unlock()
|
||||
|
||||
fmt.Println("error while handling client; removing client...")
|
||||
}
|
||||
connection := newAdaptor(uuid.NewString(), conn)
|
||||
|
||||
w, r, err := s.interceptor.BindSocketConnection(connection, s, s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-request.Context().Done():
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
msg, err := r.Read(s.ctx, connection)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
iD := uuid.NewString()
|
||||
connection := newAdaptor(request.Context(), iD, conn, s.settings.PopMessageTimeout, s.settings.PushMessageTimout)
|
||||
|
||||
fmt.Println(msg)
|
||||
}
|
||||
}
|
||||
s.registerConnection(iD, connection)
|
||||
defer s.unregisterConnection(iD)
|
||||
|
||||
connection.StartReaderWriter()
|
||||
|
||||
if _, _, err := s.interceptor.BindSocketConnection(connection, s, s); err != nil {
|
||||
fmt.Println(fmt.Errorf("error while binding socket to interceptors; err: %s", err.Error()))
|
||||
fmt.Println("dropping client...")
|
||||
return
|
||||
}
|
||||
defer s.interceptor.UnBindSocketConnection(connection)
|
||||
|
||||
if err := s.interceptor.Init(connection); err != nil {
|
||||
fmt.Println("error while connection init; dropping client")
|
||||
fmt.Println("dropping client...")
|
||||
return
|
||||
}
|
||||
|
||||
connection.WaitUntilClose()
|
||||
}
|
||||
|
||||
func (s *Socket) ShutDown(ctx context.Context) error {
|
||||
ctx2, cancel := context.WithTimeout(ctx, s.settings.ShutdownTimout)
|
||||
defer cancel()
|
||||
|
||||
s.cancel()
|
||||
if err := s.server.Shutdown(ctx2); err != nil {
|
||||
return fmt.Errorf("server shutdown error: %w", err)
|
||||
}
|
||||
|
||||
s.closeAllConnections()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleHealth provides a health check endpoint
|
||||
func (s *Socket) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||
}
|
||||
|
||||
// handleMetrics exposes server metrics
|
||||
func (s *Socket) handleMetrics(w http.ResponseWriter, _ *http.Request) {
|
||||
s.metrics.mux.RLock()
|
||||
defer s.metrics.mux.RUnlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = fmt.Fprintf(w, `{
|
||||
"active_connections": %d,
|
||||
"total_connections": %d,
|
||||
"failed_connections": %d
|
||||
}`, s.metrics.ActiveConnections, s.metrics.TotalConnections,
|
||||
s.metrics.FailedConnections)
|
||||
}
|
||||
|
Reference in New Issue
Block a user