diff --git a/pkg/interceptor/interceptor.go b/pkg/interceptor/interceptor.go index 08f94ae..d508efa 100644 --- a/pkg/interceptor/interceptor.go +++ b/pkg/interceptor/interceptor.go @@ -46,6 +46,7 @@ type Factory interface { type Connection interface { Write(ctx context.Context, p []byte) error Read(ctx context.Context) ([]byte, error) + io.Closer } type Interceptor interface { diff --git a/pkg/middleware/chat/factory.go b/pkg/middleware/chat/factory.go index 0811b39..ebccebf 100644 --- a/pkg/middleware/chat/factory.go +++ b/pkg/middleware/chat/factory.go @@ -59,4 +59,12 @@ func (f *InterceptorFactory) NewInterceptor(ctx context.Context, id string, regi writeProcessMessages: message.NewDefaultRegistry(), states: state.NewManager(), } + + for _, option := range f.options { + if err := option(i); err != nil { + return nil, err + } + } + + return i, nil } diff --git a/pkg/transport/socket/api.go b/pkg/transport/socket/api.go index 71ed6a8..67783d0 100644 --- a/pkg/transport/socket/api.go +++ b/pkg/transport/socket/api.go @@ -5,7 +5,6 @@ import ( "github.com/harshabose/socket-comm/pkg/interceptor" "github.com/harshabose/socket-comm/pkg/message" - "github.com/harshabose/socket-comm/pkg/transport/types" ) type API struct { @@ -43,13 +42,8 @@ func NewAPI(options ...APIOption) (*API, error) { // TODO: MAKE REGISTRIES TO NON POINTERS -func (a *API) NewSocket(ctx context.Context, id types.SocketID, options ...Option) (*Socket, error) { - s := &Socket{ - ID: id, - settings: NewDefaultSettings(), - messageRegistry: a.messagesRegistry, - ctx: ctx, - } +func (a *API) NewSocket(ctx context.Context, options ...Option) (*Socket, error) { + s := NewSocket(NewDefaultSettings(), a.messagesRegistry) interceptors, err := a.interceptorRegistry.Build(s.ctx, string(s.ID)) if err != nil { diff --git a/pkg/transport/socket/connection.go b/pkg/transport/socket/connection.go index cb7d43b..2ca79e1 100644 --- a/pkg/transport/socket/connection.go +++ b/pkg/transport/socket/connection.go @@ -3,6 +3,9 @@ package socket import ( "context" "errors" + "fmt" + "sync" + "time" "github.com/coder/websocket" ) @@ -10,35 +13,190 @@ import ( var ( ErrNotSupportedMessageType = errors.New("not supported message type") ErrConnectionClosed = errors.New("connection closed") - ErrServerClosed = errors.New("server closed") - ErrReaderWriterAlreadySet = errors.New("reader and writer already set") ) type adaptor struct { - id string - conn *websocket.Conn + connectionSettings + id string + conn *websocket.Conn + readQ Buffer[[]byte] + writeQ Buffer[[]byte] + ctx context.Context + cancel context.CancelFunc + closeOnce sync.Once + wg sync.WaitGroup + closeErr error + closeErrMu sync.Mutex } -func newAdaptor(id string, conn *websocket.Conn) *adaptor { +func newAdaptor(ctx context.Context, id string, conn *websocket.Conn, readTimeout time.Duration, writeTimeout time.Duration) *adaptor { + // Create a child context with cancellation + childCtx, cancel := context.WithCancel(ctx) + return &adaptor{ - id: id, - conn: conn, + connectionSettings: connectionSettings{ + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + }, + ctx: childCtx, + cancel: cancel, + id: id, + conn: conn, } } +// Write pushes the message of the type '[]byte' to the WriteQ, which will be later sent through the socket func (a *adaptor) Write(ctx context.Context, p []byte) error { - return a.conn.Write(ctx, websocket.MessageText, p) + if ctx == nil { + ctx = context.Background() + } + + // Create a copy of the message to prevent data races + msgCopy := make([]byte, len(p)) + copy(msgCopy, p) + + return a.writeQ.Push(ctx, msgCopy) } +// Read reads a message of the type '[]byte' from the ReadQ, which was read from the websocket. func (a *adaptor) Read(ctx context.Context) ([]byte, error) { - msgType, p, err := a.conn.Read(ctx) + if ctx == nil { + ctx = context.Background() + } + + data, err := a.readQ.Pop(ctx) if err != nil { return nil, err } - if msgType != websocket.MessageText { - return nil, ErrNotSupportedMessageType - } + // Return a copy to prevent data races + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) - return p, nil + return dataCopy, nil +} + +func (a *adaptor) StartReaderWriter() { + a.wg.Add(2) // One for the reader, one for the writer + go a.Writer() + go a.Reader() +} + +func (a *adaptor) Writer() { + defer a.wg.Done() + defer a.closeWithError(ErrConnectionClosed) + + for { + select { + case <-a.ctx.Done(): + return + default: + // Use a timeout context for the write operation + writeCtx, cancel := context.WithTimeout(a.ctx, a.WriteTimeout) + p, err := a.writeQ.Pop(writeCtx) + cancel() + + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + // Check if context was cancelled due to connection close + select { + case <-a.ctx.Done(): + return + default: + // Just a timeout, continue + continue + } + } + + fmt.Printf("Error while popping message from WriteQ; err: %s\n", err.Error()) + continue + } + + // Use a timeout context for the websocket write + writeCtx, cancel = context.WithTimeout(a.ctx, a.WriteTimeout) + err = a.conn.Write(writeCtx, websocket.MessageText, p) + cancel() + + if err != nil { + fmt.Printf("Error while writing message to socket; err: %s\n", err.Error()) + return + } + } + } +} + +func (a *adaptor) Reader() { + defer a.wg.Done() + defer a.closeWithError(ErrConnectionClosed) + + for { + select { + case <-a.ctx.Done(): + return + default: + // Use a timeout context for the read operation + readCtx, cancel := context.WithTimeout(a.ctx, a.ReadTimeout) + msgType, p, err := a.conn.Read(readCtx) + cancel() + + if err != nil { + fmt.Printf("Error while reading message from socket; err: %s\n", err.Error()) + return + } + + if msgType != websocket.MessageText { + fmt.Printf("Error while reading message from socket; err: %s\n", ErrNotSupportedMessageType.Error()) + continue + } + + // Use a background context since we want to buffer the message even if the operation takes time + err = a.readQ.Push(a.ctx, p) + if err != nil { + fmt.Printf("Error while pushing message to ReadQ; err: %s\n", err.Error()) + continue + } + } + } +} + +// Close initiates a graceful shutdown of the connection +func (a *adaptor) Close() error { + var err error + a.closeOnce.Do(func() { + // Cancel the context to signal all goroutines to stop + a.cancel() + + // Try to send a close message to the peer + closeErr := a.conn.Close(websocket.StatusNormalClosure, "connection closed") + if closeErr != nil { + err = closeErr + } + + // Close the buffers + a.readQ.Close() + a.writeQ.Close() + }) + return err +} + +// closeWithError stores the error that caused the connection to close +func (a *adaptor) closeWithError(err error) { + a.closeErrMu.Lock() + if a.closeErr == nil { + a.closeErr = err + } + a.closeErrMu.Unlock() + a.Close() +} + +// GetCloseError returns the error that caused the connection to close +func (a *adaptor) GetCloseError() error { + a.closeErrMu.Lock() + defer a.closeErrMu.Unlock() + return a.closeErr +} + +// WaitUntilClose blocks until all reader and writer goroutines exit +func (a *adaptor) WaitUntilClose() { + a.wg.Wait() } diff --git a/pkg/transport/socket/read_buffer.go b/pkg/transport/socket/read_buffer.go index 11b7037..d359abb 100644 --- a/pkg/transport/socket/read_buffer.go +++ b/pkg/transport/socket/read_buffer.go @@ -2,158 +2,39 @@ package socket import ( "context" - "errors" - "fmt" - "sync" ) -var ( - ErrorElementUnallocated = errors.New("encountered nil in the buffer. this should not happen. check usage") - ErrorChannelBufferClose = errors.New("channel buffer has be closed. cannot perform this operation") -) +// TODO: LIMIT AND SELF KILL BUFFER. +// ELEMENT IS RESPONSIBLE TO KILL AND FREE ITS MEMORY. IT SHOULD ALSO DELETE ITSELF FROM THE PARENT +// HARD LIMIT ON NUMBER OF BUFFERS -type Pool[T any] interface { - Get() T - Put(T) - Release() +type Buffered[T any] struct { + element T + ctx context.Context } type Buffer[T any] interface { + Pop(context.Context) (T, error) Push(context.Context, T) error - Pop(ctx context.Context) (T, error) - Size() int + Close() } -type ChannelBuffer[T any] struct { - pool Pool[T] - bufferChannel chan T - inputBuffer chan T - closed bool - mux sync.RWMutex - ctx context.Context +type LimitKillBuffer[T any] struct { + buffer chan Buffered[T] } -func CreateChannelBuffer[T any](ctx context.Context, size int, pool Pool[T]) *ChannelBuffer[T] { - buffer := &ChannelBuffer[T]{ - pool: pool, - bufferChannel: make(chan T, size), - inputBuffer: make(chan T), - closed: false, - ctx: ctx, - } - go buffer.loop() - return buffer +func (b *LimitKillBuffer[T]) Pop(ctx context.Context) (T, error) { + } -func (buffer *ChannelBuffer[T]) Push(ctx context.Context, element T) error { - buffer.mux.RLock() - defer buffer.mux.RUnlock() +func (b *LimitKillBuffer[T]) Push(ctx context.Context, element T) error { - if buffer.closed { - return errors.New("buffer closed") - } - select { - case buffer.inputBuffer <- element: - // WARN: LACKS CHECKS FOR CLOSED CHANNEL - return nil - case <-ctx.Done(): - return ctx.Err() - } } -func (buffer *ChannelBuffer[T]) Pop(ctx context.Context) (T, error) { - buffer.mux.RLock() - defer buffer.mux.RUnlock() +func (b *LimitKillBuffer[T]) Close() { - if buffer.closed { - var t T - return t, errors.New("buffer closed") - } - select { - case <-ctx.Done(): - var t T - return t, ctx.Err() - case data, ok := <-buffer.bufferChannel: - if !ok { - var t T - return t, ErrorChannelBufferClose - } - if data == nil { - var t T - return t, ErrorElementUnallocated - } - return data, nil - } } -func (buffer *ChannelBuffer[T]) Generate() T { - return buffer.pool.Get() -} +func (b *LimitKillBuffer[T]) manager() { -func (buffer *ChannelBuffer[T]) PutBack(element T) { - if buffer.pool != nil { - buffer.pool.Put(element) - } -} - -func (buffer *ChannelBuffer[T]) GetChannel() chan T { - return buffer.bufferChannel -} - -func (buffer *ChannelBuffer[T]) Size() int { - return len(buffer.bufferChannel) -} - -func (buffer *ChannelBuffer[T]) loop() { - defer buffer.close() -loop: - for { - select { - case <-buffer.ctx.Done(): - return - case element, ok := <-buffer.inputBuffer: - if !ok || element == nil { - continue loop - } - select { - case buffer.bufferChannel <- element: // SUCCESSFULLY BUFFERED - continue loop - default: - select { - case oldElement := <-buffer.bufferChannel: - buffer.PutBack(oldElement) - select { - case buffer.bufferChannel <- element: - continue loop - default: - fmt.Println("unexpected buffer state. skipping the element..") - buffer.PutBack(element) - } - } - } - } - } -} - -func (buffer *ChannelBuffer[T]) close() { - buffer.mux.Lock() - buffer.closed = true - buffer.mux.Unlock() - -loop: - for { - select { - case element := <-buffer.bufferChannel: - if buffer.pool != nil { - buffer.pool.Put(element) - } - default: - close(buffer.bufferChannel) - close(buffer.inputBuffer) - break loop - } - } - if buffer.pool != nil { - buffer.pool.Release() - } } diff --git a/pkg/transport/socket/settings.go b/pkg/transport/socket/settings.go index 446b70f..01f4a74 100644 --- a/pkg/transport/socket/settings.go +++ b/pkg/transport/socket/settings.go @@ -1,18 +1,26 @@ package socket import ( + "crypto/tls" + "errors" + "fmt" "time" ) +var ErrSettingsInvalid = errors.New("server settings invalid") + type connectionSettings struct { - ReadTimout time.Duration - WriteTimout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration } type Settings struct { connectionSettings - IdleTimout time.Duration - ShutdownTimout time.Duration + Address string + Port uint16 + ReadHeaderTimeout time.Duration + IdleTimout time.Duration + ShutdownTimout time.Duration // NOTE: TLS SETTINGS ARE OPTIONAL TLSCertFile string @@ -20,19 +28,58 @@ type Settings struct { MaxConnections int ConnectionTimeout time.Duration + + PopMessageTimeout time.Duration + PushMessageTimout time.Duration } -func NewDefaultSettings() *Settings { - return &Settings{ +func (s Settings) Validate() error { + // TODO: IMPLEMENT THIS + return ErrSettingsInvalid +} + +func NewDefaultSettings() Settings { + return Settings{ connectionSettings: connectionSettings{ - ReadTimout: 30 * time.Second, - WriteTimout: 30 * time.Second, + ReadTimeout: time.Second, + WriteTimeout: time.Second, }, + ReadHeaderTimeout: 5 * time.Second, + ConnectionTimeout: 10 * time.Second, IdleTimout: 120 * time.Second, ShutdownTimout: 30 * time.Second, TLSCertFile: "", TLSKeyFile: "", MaxConnections: 1000, - ConnectionTimeout: 10 * time.Second, + PopMessageTimeout: 30 * time.Second, + PushMessageTimout: 30 * time.Second, } } + +func GetTLSV1(tlsCertPath, tlsKeyFile string) (*tls.Config, error) { + var tlsConfig *tls.Config + if tlsCertPath != "" && tlsKeyFile != "" { + cert, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyFile) + if err != nil { + return nil, fmt.Errorf("loading TLS certificates: %w", err) + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP256}, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + } + + return tlsConfig, nil + } + + return nil, errors.New("invalid cert or/and file path") +} diff --git a/pkg/transport/socket/socket.go b/pkg/transport/socket/socket.go index 608c266..8c03044 100644 --- a/pkg/transport/socket/socket.go +++ b/pkg/transport/socket/socket.go @@ -2,7 +2,9 @@ package socket import ( "context" + "errors" "fmt" + "net" "net/http" "sync" "time" @@ -19,55 +21,100 @@ const DEBUG = true type Option func(*Socket) error +// Metrics holds server statistics +type Metrics struct { + ActiveConnections int + TotalConnections int + FailedConnections int + mux sync.RWMutex +} + type Socket struct { ID types.SocketID `json:"id"` server *http.Server router *http.ServeMux - handlerFunc http.HandlerFunc - settings *Settings + settings Settings interceptor interceptor.Interceptor + connections map[string]interceptor.Connection + metrics *Metrics messageRegistry message.Registry - mux sync.RWMutex + cancel context.CancelFunc ctx context.Context + mux sync.Mutex +} + +func NewSocket(ctx context.Context, settings Settings, registry message.Registry) *Socket { + ctx2, cancel := context.WithCancel(ctx) + return &Socket{ + ID: types.SocketID(uuid.NewString()), + settings: settings, + messageRegistry: registry, + connections: make(map[string]interceptor.Connection), + metrics: &Metrics{}, + cancel: cancel, + ctx: ctx2, + } } func (s *Socket) GetID() types.SocketID { return s.ID } +func (s *Socket) Ctx(l net.Listener) context.Context { + return s.ctx +} + func (s *Socket) Init() error { - s.mux.Lock() - defer s.mux.Unlock() + if err := s.settings.Validate(); err != nil { + return err + } + + tlsConfig, err := GetTLSV1(s.settings.TLSCertFile, s.settings.TLSKeyFile) + if err != nil { + return err + } s.router = http.NewServeMux() s.server = &http.Server{ - ReadTimeout: s.settings.ReadTimout, - WriteTimeout: s.settings.WriteTimout, - IdleTimeout: s.settings.IdleTimout, + Addr: fmt.Sprintf("%s:%d", s.settings.Address, s.settings.Port), + ReadTimeout: s.settings.PopMessageTimeout, + WriteTimeout: s.settings.PushMessageTimout, + ReadHeaderTimeout: s.settings.ReadHeaderTimeout, + IdleTimeout: s.settings.IdleTimout, + TLSConfig: tlsConfig, + Handler: s.router, + BaseContext: s.Ctx, // TODO: MAYBE ADD MORE } - s.handlerFunc = s.handler + // Set up routes + s.router.HandleFunc("/ws", s.handleWebSocket) + s.router.HandleFunc("/health", s.handleHealth) + s.router.HandleFunc("/metrics", s.handleMetrics) return nil } func (s *Socket) Serve() error { - defer s.Close() - for { select { case <-s.ctx.Done(): return nil // TODO: add error - default: - if DEBUG { + if s.server.TLSConfig != nil { if err := s.server.ListenAndServe(); err != nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } fmt.Println("error while serving; err: ", err.Error()) fmt.Println("retrying...") time.Sleep(1 * time.Second) } + continue } if err := s.server.ListenAndServeTLS(s.settings.TLSCertFile, s.settings.TLSKeyFile); err != nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } fmt.Println("error while serving; err: ", err.Error()) fmt.Println("retrying...") time.Sleep(1 * time.Second) @@ -76,12 +123,8 @@ func (s *Socket) Serve() error { } } -func (s *Socket) Close() { - -} - func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (message.Message, error) { - ctx, cancel := context.WithTimeout(s.ctx, s.settings.ReadTimout) + ctx, cancel := context.WithTimeout(s.ctx, s.settings.PopMessageTimeout) defer cancel() data, err := connection.Read(ctx) @@ -93,7 +136,7 @@ func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (m } func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, msg message.Message) error { - ctx, cancel := context.WithTimeout(s.ctx, s.settings.WriteTimout) + ctx, cancel := context.WithTimeout(s.ctx, s.settings.PushMessageTimout) defer cancel() data, err := msg.Marshal() @@ -104,38 +147,119 @@ func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, m return connection.Write(ctx, data) } -func (s *Socket) handler(writer http.ResponseWriter, request *http.Request) { - conn, err := websocket.Accept(writer, request, nil) - if err != nil { - fmt.Println("error while handling client; removing client...") - } - connection := newAdaptor(uuid.NewString(), conn) +// registerConnection adds a new connection +func (s *Socket) registerConnection(id string, conn interceptor.Connection) { + s.mux.Lock() + defer s.mux.Unlock() - w, r, err := s.interceptor.BindSocketConnection(connection, s, s) - if err != nil { + s.connections[id] = conn + + s.metrics.mux.Lock() + s.metrics.ActiveConnections++ + s.metrics.TotalConnections++ + s.metrics.mux.Unlock() +} + +// unregisterConnection removes a connection +func (s *Socket) unregisterConnection(id string) { + s.mux.Lock() + defer s.mux.Unlock() + + if _, exists := s.connections[id]; exists { + delete(s.connections, id) + + s.metrics.mux.Lock() + s.metrics.ActiveConnections-- + s.metrics.mux.Unlock() + } +} + +// closeAllConnections closes all active connections +func (s *Socket) closeAllConnections() { + s.mux.Lock() + defer s.mux.Unlock() + + for _, conn := range s.connections { + if err := conn.Close(); err != nil { + fmt.Println("error while closing a connection; err:", err.Error()) + } + } +} + +func (s *Socket) handleWebSocket(writer http.ResponseWriter, request *http.Request) { + s.metrics.mux.RLock() + currentConnections := s.metrics.ActiveConnections + s.metrics.mux.RUnlock() + + if currentConnections >= s.settings.MaxConnections { + fmt.Println("Connection limit reached", "active", currentConnections) + http.Error(writer, "Service Unavailable", http.StatusServiceUnavailable) return } - for { - select { - case <-request.Context().Done(): - return - case <-s.ctx.Done(): - return - default: - msg, err := r.Read(s.ctx, connection) - if err != nil { - return - } + conn, err := websocket.Accept(writer, request, nil) + if err != nil { + s.metrics.mux.Lock() + s.metrics.FailedConnections++ + s.metrics.mux.Unlock() - fmt.Println(msg) - } + fmt.Println("error while handling client; removing client...") + return } + iD := uuid.NewString() + connection := newAdaptor(request.Context(), iD, conn, s.settings.PopMessageTimeout, s.settings.PushMessageTimout) + + s.registerConnection(iD, connection) + defer s.unregisterConnection(iD) + + connection.StartReaderWriter() + + if _, _, err := s.interceptor.BindSocketConnection(connection, s, s); err != nil { + fmt.Println(fmt.Errorf("error while binding socket to interceptors; err: %s", err.Error())) + fmt.Println("dropping client...") + return + } defer s.interceptor.UnBindSocketConnection(connection) if err := s.interceptor.Init(connection); err != nil { fmt.Println("error while connection init; dropping client") + fmt.Println("dropping client...") return } + + connection.WaitUntilClose() +} + +func (s *Socket) ShutDown(ctx context.Context) error { + ctx2, cancel := context.WithTimeout(ctx, s.settings.ShutdownTimout) + defer cancel() + + s.cancel() + if err := s.server.Shutdown(ctx2); err != nil { + return fmt.Errorf("server shutdown error: %w", err) + } + + s.closeAllConnections() + return nil +} + +// handleHealth provides a health check endpoint +func (s *Socket) handleHealth(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"ok"}`)) +} + +// handleMetrics exposes server metrics +func (s *Socket) handleMetrics(w http.ResponseWriter, _ *http.Request) { + s.metrics.mux.RLock() + defer s.metrics.mux.RUnlock() + + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{ + "active_connections": %d, + "total_connections": %d, + "failed_connections": %d + }`, s.metrics.ActiveConnections, s.metrics.TotalConnections, + s.metrics.FailedConnections) }