general commit

This commit is contained in:
harshabose
2025-05-23 01:34:52 +05:30
parent 2bae581d37
commit 3bc71a1627
7 changed files with 417 additions and 204 deletions

View File

@@ -46,6 +46,7 @@ type Factory interface {
type Connection interface {
Write(ctx context.Context, p []byte) error
Read(ctx context.Context) ([]byte, error)
io.Closer
}
type Interceptor interface {

View File

@@ -59,4 +59,12 @@ func (f *InterceptorFactory) NewInterceptor(ctx context.Context, id string, regi
writeProcessMessages: message.NewDefaultRegistry(),
states: state.NewManager(),
}
for _, option := range f.options {
if err := option(i); err != nil {
return nil, err
}
}
return i, nil
}

View File

@@ -5,7 +5,6 @@ import (
"github.com/harshabose/socket-comm/pkg/interceptor"
"github.com/harshabose/socket-comm/pkg/message"
"github.com/harshabose/socket-comm/pkg/transport/types"
)
type API struct {
@@ -43,13 +42,8 @@ func NewAPI(options ...APIOption) (*API, error) {
// TODO: MAKE REGISTRIES TO NON POINTERS
func (a *API) NewSocket(ctx context.Context, id types.SocketID, options ...Option) (*Socket, error) {
s := &Socket{
ID: id,
settings: NewDefaultSettings(),
messageRegistry: a.messagesRegistry,
ctx: ctx,
}
func (a *API) NewSocket(ctx context.Context, options ...Option) (*Socket, error) {
s := NewSocket(NewDefaultSettings(), a.messagesRegistry)
interceptors, err := a.interceptorRegistry.Build(s.ctx, string(s.ID))
if err != nil {

View File

@@ -3,6 +3,9 @@ package socket
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/coder/websocket"
)
@@ -10,35 +13,190 @@ import (
var (
ErrNotSupportedMessageType = errors.New("not supported message type")
ErrConnectionClosed = errors.New("connection closed")
ErrServerClosed = errors.New("server closed")
ErrReaderWriterAlreadySet = errors.New("reader and writer already set")
)
type adaptor struct {
id string
conn *websocket.Conn
connectionSettings
id string
conn *websocket.Conn
readQ Buffer[[]byte]
writeQ Buffer[[]byte]
ctx context.Context
cancel context.CancelFunc
closeOnce sync.Once
wg sync.WaitGroup
closeErr error
closeErrMu sync.Mutex
}
func newAdaptor(id string, conn *websocket.Conn) *adaptor {
func newAdaptor(ctx context.Context, id string, conn *websocket.Conn, readTimeout time.Duration, writeTimeout time.Duration) *adaptor {
// Create a child context with cancellation
childCtx, cancel := context.WithCancel(ctx)
return &adaptor{
id: id,
conn: conn,
connectionSettings: connectionSettings{
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
},
ctx: childCtx,
cancel: cancel,
id: id,
conn: conn,
}
}
// Write pushes the message of the type '[]byte' to the WriteQ, which will be later sent through the socket
func (a *adaptor) Write(ctx context.Context, p []byte) error {
return a.conn.Write(ctx, websocket.MessageText, p)
if ctx == nil {
ctx = context.Background()
}
// Create a copy of the message to prevent data races
msgCopy := make([]byte, len(p))
copy(msgCopy, p)
return a.writeQ.Push(ctx, msgCopy)
}
// Read reads a message of the type '[]byte' from the ReadQ, which was read from the websocket.
func (a *adaptor) Read(ctx context.Context) ([]byte, error) {
msgType, p, err := a.conn.Read(ctx)
if ctx == nil {
ctx = context.Background()
}
data, err := a.readQ.Pop(ctx)
if err != nil {
return nil, err
}
if msgType != websocket.MessageText {
return nil, ErrNotSupportedMessageType
}
// Return a copy to prevent data races
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
return p, nil
return dataCopy, nil
}
func (a *adaptor) StartReaderWriter() {
a.wg.Add(2) // One for the reader, one for the writer
go a.Writer()
go a.Reader()
}
func (a *adaptor) Writer() {
defer a.wg.Done()
defer a.closeWithError(ErrConnectionClosed)
for {
select {
case <-a.ctx.Done():
return
default:
// Use a timeout context for the write operation
writeCtx, cancel := context.WithTimeout(a.ctx, a.WriteTimeout)
p, err := a.writeQ.Pop(writeCtx)
cancel()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
// Check if context was cancelled due to connection close
select {
case <-a.ctx.Done():
return
default:
// Just a timeout, continue
continue
}
}
fmt.Printf("Error while popping message from WriteQ; err: %s\n", err.Error())
continue
}
// Use a timeout context for the websocket write
writeCtx, cancel = context.WithTimeout(a.ctx, a.WriteTimeout)
err = a.conn.Write(writeCtx, websocket.MessageText, p)
cancel()
if err != nil {
fmt.Printf("Error while writing message to socket; err: %s\n", err.Error())
return
}
}
}
}
func (a *adaptor) Reader() {
defer a.wg.Done()
defer a.closeWithError(ErrConnectionClosed)
for {
select {
case <-a.ctx.Done():
return
default:
// Use a timeout context for the read operation
readCtx, cancel := context.WithTimeout(a.ctx, a.ReadTimeout)
msgType, p, err := a.conn.Read(readCtx)
cancel()
if err != nil {
fmt.Printf("Error while reading message from socket; err: %s\n", err.Error())
return
}
if msgType != websocket.MessageText {
fmt.Printf("Error while reading message from socket; err: %s\n", ErrNotSupportedMessageType.Error())
continue
}
// Use a background context since we want to buffer the message even if the operation takes time
err = a.readQ.Push(a.ctx, p)
if err != nil {
fmt.Printf("Error while pushing message to ReadQ; err: %s\n", err.Error())
continue
}
}
}
}
// Close initiates a graceful shutdown of the connection
func (a *adaptor) Close() error {
var err error
a.closeOnce.Do(func() {
// Cancel the context to signal all goroutines to stop
a.cancel()
// Try to send a close message to the peer
closeErr := a.conn.Close(websocket.StatusNormalClosure, "connection closed")
if closeErr != nil {
err = closeErr
}
// Close the buffers
a.readQ.Close()
a.writeQ.Close()
})
return err
}
// closeWithError stores the error that caused the connection to close
func (a *adaptor) closeWithError(err error) {
a.closeErrMu.Lock()
if a.closeErr == nil {
a.closeErr = err
}
a.closeErrMu.Unlock()
a.Close()
}
// GetCloseError returns the error that caused the connection to close
func (a *adaptor) GetCloseError() error {
a.closeErrMu.Lock()
defer a.closeErrMu.Unlock()
return a.closeErr
}
// WaitUntilClose blocks until all reader and writer goroutines exit
func (a *adaptor) WaitUntilClose() {
a.wg.Wait()
}

View File

@@ -2,158 +2,39 @@ package socket
import (
"context"
"errors"
"fmt"
"sync"
)
var (
ErrorElementUnallocated = errors.New("encountered nil in the buffer. this should not happen. check usage")
ErrorChannelBufferClose = errors.New("channel buffer has be closed. cannot perform this operation")
)
// TODO: LIMIT AND SELF KILL BUFFER.
// ELEMENT IS RESPONSIBLE TO KILL AND FREE ITS MEMORY. IT SHOULD ALSO DELETE ITSELF FROM THE PARENT
// HARD LIMIT ON NUMBER OF BUFFERS
type Pool[T any] interface {
Get() T
Put(T)
Release()
type Buffered[T any] struct {
element T
ctx context.Context
}
type Buffer[T any] interface {
Pop(context.Context) (T, error)
Push(context.Context, T) error
Pop(ctx context.Context) (T, error)
Size() int
Close()
}
type ChannelBuffer[T any] struct {
pool Pool[T]
bufferChannel chan T
inputBuffer chan T
closed bool
mux sync.RWMutex
ctx context.Context
type LimitKillBuffer[T any] struct {
buffer chan Buffered[T]
}
func CreateChannelBuffer[T any](ctx context.Context, size int, pool Pool[T]) *ChannelBuffer[T] {
buffer := &ChannelBuffer[T]{
pool: pool,
bufferChannel: make(chan T, size),
inputBuffer: make(chan T),
closed: false,
ctx: ctx,
}
go buffer.loop()
return buffer
func (b *LimitKillBuffer[T]) Pop(ctx context.Context) (T, error) {
}
func (buffer *ChannelBuffer[T]) Push(ctx context.Context, element T) error {
buffer.mux.RLock()
defer buffer.mux.RUnlock()
func (b *LimitKillBuffer[T]) Push(ctx context.Context, element T) error {
if buffer.closed {
return errors.New("buffer closed")
}
select {
case buffer.inputBuffer <- element:
// WARN: LACKS CHECKS FOR CLOSED CHANNEL
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (buffer *ChannelBuffer[T]) Pop(ctx context.Context) (T, error) {
buffer.mux.RLock()
defer buffer.mux.RUnlock()
func (b *LimitKillBuffer[T]) Close() {
if buffer.closed {
var t T
return t, errors.New("buffer closed")
}
select {
case <-ctx.Done():
var t T
return t, ctx.Err()
case data, ok := <-buffer.bufferChannel:
if !ok {
var t T
return t, ErrorChannelBufferClose
}
if data == nil {
var t T
return t, ErrorElementUnallocated
}
return data, nil
}
}
func (buffer *ChannelBuffer[T]) Generate() T {
return buffer.pool.Get()
}
func (b *LimitKillBuffer[T]) manager() {
func (buffer *ChannelBuffer[T]) PutBack(element T) {
if buffer.pool != nil {
buffer.pool.Put(element)
}
}
func (buffer *ChannelBuffer[T]) GetChannel() chan T {
return buffer.bufferChannel
}
func (buffer *ChannelBuffer[T]) Size() int {
return len(buffer.bufferChannel)
}
func (buffer *ChannelBuffer[T]) loop() {
defer buffer.close()
loop:
for {
select {
case <-buffer.ctx.Done():
return
case element, ok := <-buffer.inputBuffer:
if !ok || element == nil {
continue loop
}
select {
case buffer.bufferChannel <- element: // SUCCESSFULLY BUFFERED
continue loop
default:
select {
case oldElement := <-buffer.bufferChannel:
buffer.PutBack(oldElement)
select {
case buffer.bufferChannel <- element:
continue loop
default:
fmt.Println("unexpected buffer state. skipping the element..")
buffer.PutBack(element)
}
}
}
}
}
}
func (buffer *ChannelBuffer[T]) close() {
buffer.mux.Lock()
buffer.closed = true
buffer.mux.Unlock()
loop:
for {
select {
case element := <-buffer.bufferChannel:
if buffer.pool != nil {
buffer.pool.Put(element)
}
default:
close(buffer.bufferChannel)
close(buffer.inputBuffer)
break loop
}
}
if buffer.pool != nil {
buffer.pool.Release()
}
}

View File

@@ -1,18 +1,26 @@
package socket
import (
"crypto/tls"
"errors"
"fmt"
"time"
)
var ErrSettingsInvalid = errors.New("server settings invalid")
type connectionSettings struct {
ReadTimout time.Duration
WriteTimout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
}
type Settings struct {
connectionSettings
IdleTimout time.Duration
ShutdownTimout time.Duration
Address string
Port uint16
ReadHeaderTimeout time.Duration
IdleTimout time.Duration
ShutdownTimout time.Duration
// NOTE: TLS SETTINGS ARE OPTIONAL
TLSCertFile string
@@ -20,19 +28,58 @@ type Settings struct {
MaxConnections int
ConnectionTimeout time.Duration
PopMessageTimeout time.Duration
PushMessageTimout time.Duration
}
func NewDefaultSettings() *Settings {
return &Settings{
func (s Settings) Validate() error {
// TODO: IMPLEMENT THIS
return ErrSettingsInvalid
}
func NewDefaultSettings() Settings {
return Settings{
connectionSettings: connectionSettings{
ReadTimout: 30 * time.Second,
WriteTimout: 30 * time.Second,
ReadTimeout: time.Second,
WriteTimeout: time.Second,
},
ReadHeaderTimeout: 5 * time.Second,
ConnectionTimeout: 10 * time.Second,
IdleTimout: 120 * time.Second,
ShutdownTimout: 30 * time.Second,
TLSCertFile: "",
TLSKeyFile: "",
MaxConnections: 1000,
ConnectionTimeout: 10 * time.Second,
PopMessageTimeout: 30 * time.Second,
PushMessageTimout: 30 * time.Second,
}
}
func GetTLSV1(tlsCertPath, tlsKeyFile string) (*tls.Config, error) {
var tlsConfig *tls.Config
if tlsCertPath != "" && tlsKeyFile != "" {
cert, err := tls.LoadX509KeyPair(tlsCertPath, tlsKeyFile)
if err != nil {
return nil, fmt.Errorf("loading TLS certificates: %w", err)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP256},
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
}
return tlsConfig, nil
}
return nil, errors.New("invalid cert or/and file path")
}

View File

@@ -2,7 +2,9 @@ package socket
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
@@ -19,55 +21,100 @@ const DEBUG = true
type Option func(*Socket) error
// Metrics holds server statistics
type Metrics struct {
ActiveConnections int
TotalConnections int
FailedConnections int
mux sync.RWMutex
}
type Socket struct {
ID types.SocketID `json:"id"`
server *http.Server
router *http.ServeMux
handlerFunc http.HandlerFunc
settings *Settings
settings Settings
interceptor interceptor.Interceptor
connections map[string]interceptor.Connection
metrics *Metrics
messageRegistry message.Registry
mux sync.RWMutex
cancel context.CancelFunc
ctx context.Context
mux sync.Mutex
}
func NewSocket(ctx context.Context, settings Settings, registry message.Registry) *Socket {
ctx2, cancel := context.WithCancel(ctx)
return &Socket{
ID: types.SocketID(uuid.NewString()),
settings: settings,
messageRegistry: registry,
connections: make(map[string]interceptor.Connection),
metrics: &Metrics{},
cancel: cancel,
ctx: ctx2,
}
}
func (s *Socket) GetID() types.SocketID {
return s.ID
}
func (s *Socket) Ctx(l net.Listener) context.Context {
return s.ctx
}
func (s *Socket) Init() error {
s.mux.Lock()
defer s.mux.Unlock()
if err := s.settings.Validate(); err != nil {
return err
}
tlsConfig, err := GetTLSV1(s.settings.TLSCertFile, s.settings.TLSKeyFile)
if err != nil {
return err
}
s.router = http.NewServeMux()
s.server = &http.Server{
ReadTimeout: s.settings.ReadTimout,
WriteTimeout: s.settings.WriteTimout,
IdleTimeout: s.settings.IdleTimout,
Addr: fmt.Sprintf("%s:%d", s.settings.Address, s.settings.Port),
ReadTimeout: s.settings.PopMessageTimeout,
WriteTimeout: s.settings.PushMessageTimout,
ReadHeaderTimeout: s.settings.ReadHeaderTimeout,
IdleTimeout: s.settings.IdleTimout,
TLSConfig: tlsConfig,
Handler: s.router,
BaseContext: s.Ctx,
// TODO: MAYBE ADD MORE
}
s.handlerFunc = s.handler
// Set up routes
s.router.HandleFunc("/ws", s.handleWebSocket)
s.router.HandleFunc("/health", s.handleHealth)
s.router.HandleFunc("/metrics", s.handleMetrics)
return nil
}
func (s *Socket) Serve() error {
defer s.Close()
for {
select {
case <-s.ctx.Done():
return nil // TODO: add error
default:
if DEBUG {
if s.server.TLSConfig != nil {
if err := s.server.ListenAndServe(); err != nil {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
fmt.Println("error while serving; err: ", err.Error())
fmt.Println("retrying...")
time.Sleep(1 * time.Second)
}
continue
}
if err := s.server.ListenAndServeTLS(s.settings.TLSCertFile, s.settings.TLSKeyFile); err != nil {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
fmt.Println("error while serving; err: ", err.Error())
fmt.Println("retrying...")
time.Sleep(1 * time.Second)
@@ -76,12 +123,8 @@ func (s *Socket) Serve() error {
}
}
func (s *Socket) Close() {
}
func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (message.Message, error) {
ctx, cancel := context.WithTimeout(s.ctx, s.settings.ReadTimout)
ctx, cancel := context.WithTimeout(s.ctx, s.settings.PopMessageTimeout)
defer cancel()
data, err := connection.Read(ctx)
@@ -93,7 +136,7 @@ func (s *Socket) Read(ctx context.Context, connection interceptor.Connection) (m
}
func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, msg message.Message) error {
ctx, cancel := context.WithTimeout(s.ctx, s.settings.WriteTimout)
ctx, cancel := context.WithTimeout(s.ctx, s.settings.PushMessageTimout)
defer cancel()
data, err := msg.Marshal()
@@ -104,38 +147,119 @@ func (s *Socket) Write(ctx context.Context, connection interceptor.Connection, m
return connection.Write(ctx, data)
}
func (s *Socket) handler(writer http.ResponseWriter, request *http.Request) {
conn, err := websocket.Accept(writer, request, nil)
if err != nil {
fmt.Println("error while handling client; removing client...")
}
connection := newAdaptor(uuid.NewString(), conn)
// registerConnection adds a new connection
func (s *Socket) registerConnection(id string, conn interceptor.Connection) {
s.mux.Lock()
defer s.mux.Unlock()
w, r, err := s.interceptor.BindSocketConnection(connection, s, s)
if err != nil {
s.connections[id] = conn
s.metrics.mux.Lock()
s.metrics.ActiveConnections++
s.metrics.TotalConnections++
s.metrics.mux.Unlock()
}
// unregisterConnection removes a connection
func (s *Socket) unregisterConnection(id string) {
s.mux.Lock()
defer s.mux.Unlock()
if _, exists := s.connections[id]; exists {
delete(s.connections, id)
s.metrics.mux.Lock()
s.metrics.ActiveConnections--
s.metrics.mux.Unlock()
}
}
// closeAllConnections closes all active connections
func (s *Socket) closeAllConnections() {
s.mux.Lock()
defer s.mux.Unlock()
for _, conn := range s.connections {
if err := conn.Close(); err != nil {
fmt.Println("error while closing a connection; err:", err.Error())
}
}
}
func (s *Socket) handleWebSocket(writer http.ResponseWriter, request *http.Request) {
s.metrics.mux.RLock()
currentConnections := s.metrics.ActiveConnections
s.metrics.mux.RUnlock()
if currentConnections >= s.settings.MaxConnections {
fmt.Println("Connection limit reached", "active", currentConnections)
http.Error(writer, "Service Unavailable", http.StatusServiceUnavailable)
return
}
for {
select {
case <-request.Context().Done():
return
case <-s.ctx.Done():
return
default:
msg, err := r.Read(s.ctx, connection)
if err != nil {
return
}
conn, err := websocket.Accept(writer, request, nil)
if err != nil {
s.metrics.mux.Lock()
s.metrics.FailedConnections++
s.metrics.mux.Unlock()
fmt.Println(msg)
}
fmt.Println("error while handling client; removing client...")
return
}
iD := uuid.NewString()
connection := newAdaptor(request.Context(), iD, conn, s.settings.PopMessageTimeout, s.settings.PushMessageTimout)
s.registerConnection(iD, connection)
defer s.unregisterConnection(iD)
connection.StartReaderWriter()
if _, _, err := s.interceptor.BindSocketConnection(connection, s, s); err != nil {
fmt.Println(fmt.Errorf("error while binding socket to interceptors; err: %s", err.Error()))
fmt.Println("dropping client...")
return
}
defer s.interceptor.UnBindSocketConnection(connection)
if err := s.interceptor.Init(connection); err != nil {
fmt.Println("error while connection init; dropping client")
fmt.Println("dropping client...")
return
}
connection.WaitUntilClose()
}
func (s *Socket) ShutDown(ctx context.Context) error {
ctx2, cancel := context.WithTimeout(ctx, s.settings.ShutdownTimout)
defer cancel()
s.cancel()
if err := s.server.Shutdown(ctx2); err != nil {
return fmt.Errorf("server shutdown error: %w", err)
}
s.closeAllConnections()
return nil
}
// handleHealth provides a health check endpoint
func (s *Socket) handleHealth(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"ok"}`))
}
// handleMetrics exposes server metrics
func (s *Socket) handleMetrics(w http.ResponseWriter, _ *http.Request) {
s.metrics.mux.RLock()
defer s.metrics.mux.RUnlock()
w.Header().Set("Content-Type", "application/json")
_, _ = fmt.Fprintf(w, `{
"active_connections": %d,
"total_connections": %d,
"failed_connections": %d
}`, s.metrics.ActiveConnections, s.metrics.TotalConnections,
s.metrics.FailedConnections)
}