feat: implement websocket and UI

This commit is contained in:
Oarkflow
2024-10-22 21:12:55 +05:45
parent 05a673bcd6
commit f337e00b88
21 changed files with 1880 additions and 33 deletions

View File

@@ -3,6 +3,7 @@ package dag
import (
"encoding/json"
"fmt"
"github.com/oarkflow/mq/sio"
"io"
"net/http"
"os"
@@ -21,8 +22,20 @@ type Request struct {
Recurring bool `json:"recurring"`
}
func (tm *DAG) SetupWS() *sio.Server {
ws := sio.New(sio.Config{
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: true,
})
WsEvents(ws)
tm.Notifier = ws
return ws
}
func (tm *DAG) Handlers() {
metrics.HandleHTTP()
http.Handle("/", http.FileServer(http.Dir("webroot")))
http.Handle("/notify", tm.SetupWS())
http.HandleFunc("POST /request", tm.Request)
http.HandleFunc("POST /publish", tm.Publish)
http.HandleFunc("POST /schedule", tm.Schedule)

View File

@@ -3,6 +3,7 @@ package dag
import (
"context"
"fmt"
"github.com/oarkflow/mq/sio"
"log"
"net/http"
"sync"
@@ -54,28 +55,34 @@ type (
)
type DAG struct {
nodes map[string]*Node
server *mq.Broker
consumer *mq.Consumer
taskContext map[string]*TaskManager
conditions map[FromNode]map[When]Then
pool *mq.Pool
taskCleanupCh chan string
name string
key string
startNode string
consumerTopic string
opts []mq.Option
mu sync.RWMutex
paused bool
Error error
report string
nodes map[string]*Node
server *mq.Broker
consumer *mq.Consumer
taskContext map[string]*TaskManager
conditions map[FromNode]map[When]Then
pool *mq.Pool
taskCleanupCh chan string
name string
key string
startNode string
consumerTopic string
opts []mq.Option
mu sync.RWMutex
reportNodeResultCallback func(mq.Result)
Notifier *sio.Server
paused bool
Error error
report string
}
func (tm *DAG) SetKey(key string) {
tm.key = key
}
func (tm *DAG) ReportNodeResult(callback func(mq.Result)) {
tm.reportNodeResultCallback = callback
}
func (tm *DAG) GetType() string {
return tm.key
}
@@ -210,7 +217,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
}(con)
}
}
log.Printf("DAG - HTTP_SERVER ~> started on %s", addr)
log.Printf("DAG - HTTP_SERVER ~> started on http://localhost%s", addr)
tm.Handlers()
config := tm.server.TLSConfig()
if config.UseTLS {
@@ -518,3 +525,7 @@ func (tm *DAG) doConsumer(ctx context.Context, id string, action consts.CMD) {
log.Printf("[WARNING] - Consumer %s not found", id)
}
}
func (tm *DAG) SetNotifyResponse(callback mq.Callback) {
tm.server.SetNotifyHandler(callback)
}

View File

@@ -188,6 +188,9 @@ func (tm *TaskManager) appendResult(result mq.Result, final bool) {
}
tm.nodeResults[result.Topic] = result
tm.mutex.Unlock()
if tm.dag.reportNodeResultCallback != nil {
tm.dag.reportNodeResultCallback(result)
}
}
func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json.RawMessage) {
@@ -202,9 +205,8 @@ func (tm *TaskManager) processNode(ctx context.Context, node *Node, payload json
var result mq.Result
if tm.dag.server.SyncMode() {
defer func() {
tm.mutex.Lock()
tm.nodeResults[node.Key] = result
tm.mutex.Unlock()
result.Topic = node.Key
tm.appendResult(result, false)
tm.handleCallback(ctx, result)
}()
}

View File

@@ -168,7 +168,7 @@ func (tm *DAG) ExportDOT() string {
for _, nodeKey := range sortedNodes {
node := tm.nodes[nodeKey]
nodeColor := "lightblue"
sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s"];`, node.Key, node.Name, nodeColor))
sb.WriteString(fmt.Sprintf(` "%s" [label=" %s", fillcolor="%s", id="node_%s"];`, node.Key, node.Name, nodeColor, node.Key))
sb.WriteString("\n")
}
for _, nodeKey := range sortedNodes {

32
dag/websocket.go Normal file
View File

@@ -0,0 +1,32 @@
package dag
import (
"encoding/json"
"github.com/oarkflow/mq/sio"
)
func WsEvents(s *sio.Server) {
s.On("join", join)
s.On("message", message)
}
func join(s *sio.Socket, data []byte) {
//just one room at a time for the simple example
currentRooms := s.GetRooms()
for _, room := range currentRooms {
s.Leave(room)
}
s.Join(string(data))
s.Emit("joinedRoom", string(data))
}
type msg struct {
Room string
Message string
}
func message(s *sio.Socket, data []byte) {
var m msg
json.Unmarshal(data, &m)
s.ToRoom(m.Room, "message", m.Message)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/oarkflow/mq/examples/tasks"
"github.com/oarkflow/mq/services"
@@ -13,7 +12,6 @@ import (
)
func main() {
// Sync()
aSync()
}
@@ -43,14 +41,20 @@ func sendData(f *dag.DAG) {
fmt.Println(string(result.Payload))
}
func Sync() {
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse))
setup(f)
sendData(f)
}
func aSync() {
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithNotifyResponse(tasks.NotifyResponse))
f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete())
f.SetNotifyResponse(func(ctx context.Context, result mq.Result) error {
if f.Notifier != nil {
f.Notifier.ToRoom("global", "final-message", result)
}
return nil
})
f.ReportNodeResult(func(result mq.Result) {
if f.Notifier != nil {
f.Notifier.ToRoom("global", "message", result)
}
})
setup(f)
err := f.Validate()
if err != nil {

View File

@@ -0,0 +1,25 @@
#container{
width: 400px;
margin: 30px auto 0 auto;
}
#messages{
height: 200px;
overflow-y: scroll;
}
.control-item{
margin-top: 15px;
}
#message-send-btn, #clear-messages-btn{
width: 100%;
}
#message-list{
padding-left: 25px;
}
.message-item{
margin-bottom: 8px;
}

View File

320
examples/webroot/index.html Normal file
View File

@@ -0,0 +1,320 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Task Status Dashboard</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
/* Custom scrollbar styles */
.scrollbar {
scrollbar-width: thin;
scrollbar-color: #D1D5DB #E5E7EB;
}
.scrollbar::-webkit-scrollbar {
width: 8px;
}
.scrollbar::-webkit-scrollbar-track {
background: #E5E7EB;
}
.scrollbar::-webkit-scrollbar-thumb {
background-color: #D1D5DB;
border-radius: 9999px;
}
.svg-container {
max-width: 100%;
max-height: 80vh;
overflow: auto;
position: relative;
border: 1px solid #ccc;
}
#dag-diagram {
width: 100%;
height: auto;
}
.popover {
position: absolute;
z-index: 10;
display: none;
background-color: white;
border: 1px solid #E5E7EB;
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);
padding: 10px;
width: 200px;
border-radius: 8px;
}
.popover.visible {
display: block;
}
</style>
</head>
<body class="bg-gray-100 flex h-screen p-4">
<div class="w-screen mx-auto bg-white shadow-lg rounded-lg overflow-hidden">
<!-- Task Table Section -->
<div class="p-4">
<h1 class="text-xl font-semibold text-gray-700 mb-4">Task Status</h1>
<!-- Scrollable Table -->
<div class="max-h-64 overflow-y-auto scrollbar relative">
<table class="min-w-full table-auto border-collapse border border-gray-300">
<thead class="bg-gray-100">
<tr>
<th class="px-4 py-2 border border-gray-300">Task ID</th>
<th class="px-4 py-2 border border-gray-300">Created At</th>
<th class="px-4 py-2 border border-gray-300">Processed At</th>
<th class="px-4 py-2 border border-gray-300">Latency</th>
<th class="px-4 py-2 border border-gray-300">Status</th>
<th class="px-4 py-2 border border-gray-300">Result</th>
</tr>
</thead>
<tbody id="taskTableBody">
<!-- Dynamic rows will be appended here -->
</tbody>
</table>
</div>
<!-- JSON Popover -->
<div id="jsonPopover" class="popover w-1/2"></div>
<div id="svgPopover" class="popover"></div>
</div>
<!-- SVG Viewer Section -->
<div class="p-4">
<h2 class="text-lg font-semibold text-gray-700 mb-2">SVG Viewer</h2>
<div class="border border-gray-300 p-4 bg-gray-50">
<div id="svg-container" class="svg-container">
<svg id="dag-diagram" class="svg-image" xmlns="http://www.w3.org/2000/svg"></svg>
</div>
</div>
</div>
</div>
<script src="/js/socket.js"></script>
<script>
(function(SS) {
'use strict';
// New function to load the SVG dynamically
function loadSVG(url) {
fetch(url)
.then(response => response.text())
.then(svgContent => {
const container = document.getElementById('svg-container');
container.innerHTML = svgContent; // Inject the SVG into the container
attachSVGNodeEvents();
})
.catch(err => console.error('Failed to load SVG:', err));
}
// Call loadSVG when the page loads
window.onload = function() {
loadSVG('http://localhost:8083/ui'); // Adjust this to the correct path of your SVG
};
const tasks = {}; // Local storage for tasks
// Attach click event listeners to all SVG nodes after loading the SVG
function attachSVGNodeEvents() {
const svgNodes = document.querySelectorAll('g.node'); // Adjust selector as per your SVG structure
svgNodes.forEach(node => {
node.addEventListener('click', handleSVGNodeClick);
});
}
// Function to handle the click on an SVG node and show the popover
function handleSVGNodeClick(event) {
const nodeId = event.currentTarget.id; // Get the node ID (e.g., 'node_store:data')
const nodeData = findNodeDataById(nodeId); // Fetch data related to the node (status, result, etc.)
if (nodeData) {
showSVGPopover(event, nodeData);
}
}
// Function to show the popover next to the clicked SVG node
function showSVGPopover(event, nodeData) {
const popover = document.getElementById('svgPopover');
popover.classList.add('visible');
popover.innerHTML = `
<div class="text-sm text-gray-700">
<strong>Result:</strong> ${JSON.stringify(nodeData.payload)}<br>
<strong>Status:</strong> ${nodeData.status}<br>
<strong>Error:</strong> ${nodeData.error || 'N/A'}
</div>
`;
const rect = event.currentTarget.getBoundingClientRect();
popover.style.top = `${rect.top + window.scrollY + rect.height + 10}px`;
popover.style.left = `${rect.left + window.scrollX + rect.width / 2}px`;
}
// Function to find node data (status, result, error) by the node's ID
function findNodeDataById(nodeId) {
for (const taskId in tasks) {
const task = tasks[taskId];
const node = task.nodes.find(n => `node_${n.topic}` === nodeId); // Ensure the ID format matches your SVG
if (node) {
return node;
}
}
return null; // Return null if no matching node is found
}
const addOrUpdateTask = (message, isFinal = false) => {
const taskTableBody = document.getElementById('taskTableBody');
const taskId = message.task_id;
const rowId = `row-${taskId}`;
// Check if task already exists
let existingRow = document.getElementById(rowId);
if (!existingRow) {
// Create new row if it doesn't exist
const row = document.createElement('tr');
row.id = rowId;
taskTableBody.insertBefore(row, taskTableBody.firstChild);
existingRow = row;
}
// Update task data
tasks[taskId] = tasks[taskId] || { nodes: [], final: null };
if (isFinal) {
tasks[taskId].final = message;
} else {
tasks[taskId].nodes.push(message);
}
// Determine the latest task status and status color
const latestStatus = isFinal ? message.status : message.status;
const statusColor = latestStatus === 'success' ? 'bg-green-100 text-green-700' :
latestStatus === 'fail' ? 'bg-red-100 text-red-700' :
'bg-yellow-100 text-yellow-700';
// Update row content
existingRow.innerHTML = `
<td class="px-4 py-2 border border-gray-300">${taskId}</td>
<td class="px-4 py-2 border border-gray-300">${new Date(message.created_at).toLocaleString()}</td>
<td class="px-4 py-2 border border-gray-300">${new Date(message.processed_at).toLocaleString()}</td>
<td class="px-4 py-2 border border-gray-300">${message.latency}</td>
<td class="px-4 py-2 border border-gray-300 ${statusColor}">${latestStatus}</td>
<td class="px-4 py-2 border border-gray-300">
<button class="view-btn text-blue-600 hover:underline" data-task-id='${taskId}'>
View
</button>
</td>
`;
attachViewButtonEvent();
}
let ss = new SS('ws://' + window.location.host + '/notify');
ss.onConnect(function() {
ss.emit('join', "global");
});
ss.onDisconnect(function() {
alert('chat disconnected');
});
ss.on('joinedRoom', function(room) {
console.log('Joined room: ' + room);
});
// Handle "message" for per-node status
ss.on('message', function(msg) {
addOrUpdateTask(msg, false);
});
// Handle "final-message" for final task status
ss.on('final-message', function(msg) {
addOrUpdateTask(msg, true);
});
// Attach event listeners to "View" buttons
function attachViewButtonEvent() {
const buttons = document.querySelectorAll('.view-btn');
buttons.forEach(button => {
button.removeEventListener('click', handleViewButtonClick); // Remove previous listeners
button.addEventListener('click', handleViewButtonClick); // Attach new listener
});
}
// Show formatted data in popover and update SVG
function handleViewButtonClick(event) {
const taskId = event.target.getAttribute('data-task-id');
const task = tasks[taskId];
showPopover(event.target, task);
updateSVGNodes(task);
}
// Function to show popover with node statuses and final status if available
function showPopover(target, task) {
const popover = document.getElementById('jsonPopover');
popover.classList.add('visible');
const rect = target.getBoundingClientRect();
popover.style.top = `${rect.top + window.scrollY + rect.height + 10}px`;
popover.style.left = `${rect.left + window.scrollX - 600}px`;
let popoverContent = '';
// If final status exists, display it
if (task.final) {
popoverContent += `<div class="text-sm text-gray-700">
<strong>Final Status:</strong> ${task.final.status}<br>
<strong>Result:</strong> ${JSON.stringify(task.final.payload)}
</div>`;
}
// Build table for task node statuses
popoverContent += `<table class="min-w-full table-auto border-collapse border border-gray-300 mb-2">
<thead class="bg-gray-100">
<tr>
<th class="px-2 py-1 border border-gray-300">Node</th>
<th class="px-2 py-1 border border-gray-300">Result</th>
<th class="px-2 py-1 border border-gray-300">Status</th>
<th class="px-2 py-1 border border-gray-300">Error</th>
</tr>
</thead>
<tbody>`;
const statusColor = (status) => status === 'success' ? 'bg-green-100 text-green-700' :
status === 'fail' ? 'bg-red-100 text-red-700' :
'bg-yellow-100 text-yellow-700';
task.nodes.forEach((node) => {
popoverContent += `
<tr>
<td class="px-2 py-1 border border-gray-300">${node.topic}</td>
<td class="px-2 py-1 border border-gray-300">${JSON.stringify(node.payload)}</td>
<td class="px-2 py-1 border border-gray-300 ${statusColor(node.status)}">${node.status}</td>
<td class="px-2 py-1 border border-gray-300">${node.error || 'N/A'}</td>
</tr>`;
});
popoverContent += '</tbody></table>';
popover.innerHTML = popoverContent;
document.addEventListener('click', function handleClickOutside(event) {
if (!popover.contains(event.target) && event.target !== target) {
popover.classList.remove('visible');
document.removeEventListener('click', handleClickOutside);
}
});
}
// Function to update the SVG nodes based on task statuses
function updateSVGNodes(task) {
task.nodes.forEach((node) => {
const svgNode = document.querySelector(`#node_${node.topic.replace(':', '\\:')}`);
if (svgNode) {
const status = node.status;
const fillColor = status === 'success' ? 'green' :
status === 'fail' ? 'red' : 'yellow';
const path = svgNode.querySelector('path');
if (path) {
path.setAttribute('fill', fillColor);
}
}
});
}
})(window.SS);
</script>
</body>
</html>

View File

@@ -0,0 +1,232 @@
if(typeof window === 'undefined'){
var window = {};
}
if(typeof module === 'undefined'){
var module = {};
}
(function(window, module){ 'use strict';
var SS = function(url, opts){
opts = opts || {};
var self = this,
events = {},
reconnectOpts = {enabled: true, replayOnConnect: true, intervalMS: 5000},
reconnecting = false,
connectedOnce = false,
headerStartCharCode = 1,
headerStartChar = String.fromCharCode(headerStartCharCode),
dataStartCharCode = 2,
dataStartChar = String.fromCharCode(dataStartCharCode),
subProtocol = 'sac-sock',
ws = new WebSocket(url, subProtocol);
//blomp blomp-a noop noop a-noop noop noop
self.noop = function(){ };
//we really only support reconnect options for now
if(typeof opts.reconnectOpts == 'object'){
for(var i in opts.reconnectOpts){
if(!opts.reconnectOpts.hasOwnProperty(i)) continue;
reconnectOpts[i] = opts.reconnectOpts[i];
}
}
//sorry, only supporting arraybuffer at this time
//maybe if there is demand for it, I'll add Blob support
ws.binaryType = 'arraybuffer';
//Parses all incoming messages and dispatches their payload to the appropriate eventName if one has been registered. Messages received for unregistered events will be ignored.
ws.onmessage = function(e){
var msg = e.data,
headers = {},
eventName = '',
data = '',
chr = null,
i, msgLen;
if(typeof msg === 'string'){
var dataStarted = false,
headerStarted = false;
for(i = 0, msgLen = msg.length; i < msgLen; i++){
chr = msg[i];
if(!dataStarted && !headerStarted && chr !== dataStartChar && chr !== headerStartChar){
eventName += chr;
}else if(!headerStarted && chr === headerStartChar){
headerStarted = true;
}else if(headerStarted && !dataStarted && chr !== dataStartChar){
headers[chr] = true;
}else if(!dataStarted && chr === dataStartChar){
dataStarted = true;
}else{
data += chr;
}
}
}else if(msg && msg instanceof ArrayBuffer && msg.byteLength !== undefined){
var dv = new DataView(msg),
headersStarted = false;
for(i = 0, msgLen = dv.byteLength; i < msgLen; i++){
chr = dv.getUint8(i);
if(chr !== dataStartCharCode && chr !== headerStartCharCode && !headersStarted){
eventName += String.fromCharCode(chr);
}else if(chr === headerStartCharCode && !headersStarted){
headersStarted = true;
}else if(headersStarted && chr !== dataStartCharCode){
headers[String.fromCharCode(chr)] = true;
}else if(chr === dataStartCharCode){
data = dv.buffer.slice(i+1);
break;
}
}
}
if(eventName.length === 0) return; //no event to dispatch
if(typeof events[eventName] === 'undefined') return;
events[eventName].call(self, (headers.J) ? JSON.parse(data) : data);
};
/**
* startReconnect is an internal function for reconnecting after an unexpected disconnect
*
* @function startReconnect
*
*/
function startReconnect(){
setTimeout(function(){
console.log('attempting reconnect');
var newWS = new WebSocket(url, subProtocol);
newWS.onmessage = ws.onmessage;
newWS.onclose = ws.onclose;
newWS.binaryType = ws.binaryType;
//we need to run the initially set onConnect function on first successful connect,
//even if replayOnConnect is disabled. The server might not be available on first
//connection attempt.
if(reconnectOpts.replayOnConnect || !connectedOnce){
newWS.onopen = ws.onopen;
}
ws = newWS;
if(!reconnectOpts.replayOnConnect && connectedOnce){
self.onConnect(self.noop);
}
}, reconnectOpts.intervalMS);
}
/**
* onConnect registers a callback to be run when the websocket connection is open.
*
* @method onConnect
* @param {Function} callback(event) - The callback that will be executed when the websocket connection opens.
*
*/
self.onConnect = function(callback){
ws.onopen = function(){
connectedOnce = true;
var args = arguments;
callback.apply(self, args);
if(reconnecting){
reconnecting = false;
}
};
};
self.onConnect(self.noop);
/**
* onDisconnect registers a callback to be run when the websocket connection is closed.
*
* @method onDisconnect
* @param {Function} callback(event) - The callback that will be executed when the websocket connection is closed.
*/
self.onDisconnect = function(callback){
ws.onclose = function(){
var args = arguments;
if(!reconnecting && connectedOnce){
callback.apply(self, args);
}
if(reconnectOpts.enabled){
reconnecting = true;
startReconnect();
}
};
};
self.onDisconnect(self.noop);
/**
* on registers an event to be called when the client receives an emit from the server for
* the given eventName.
*
* @method on
* @param {String} eventName - The name of the event being registerd
* @param {Function} callback(payload) - The callback that will be ran whenever the client receives an emit from the server for the given eventName. The payload passed into callback may be of type String, Object, or ArrayBuffer
*
*/
self.on = function(eventName, callback){
events[eventName] = callback;
};
/**
* off unregisters an emit event
*
* @method off
* @param {String} eventName - The name of event being unregistered
*/
self.off = function(eventName){
if(events[eventName]){
delete events[eventName];
}
};
/**
* emit dispatches an event to the server
*
* @method emit
* @param {String} eventName - The event to dispatch
* @param {String|Object|ArrayBuffer} data - The data to be sent to the server. If data is a string then it will be sent as a normal string to the server. If data is an object it will be converted to JSON before being sent to the server. If data is an ArrayBuffer then it will be sent to the server as a uint8 binary payload.
*/
self.emit = function(eventName, data){
var rs = ws.readyState;
if(rs === 0){
console.warn("websocket is not open yet");
return;
}else if(rs === 2 || rs === 3){
console.error("websocket is closed");
return;
}
var msg = '';
if(data instanceof ArrayBuffer){
var ab = new ArrayBuffer(data.byteLength+eventName.length+1),
newBuf = new DataView(ab),
oldBuf = new DataView(data),
i = 0;
for(var evtLen = eventName.length; i < evtLen; i++){
newBuf.setUint8(i, eventName.charCodeAt(i));
}
newBuf.setUint8(i, dataStartCharCode);
i++;
for(var x = 0, xLen = oldBuf.byteLength; x < xLen; x++, i++){
newBuf.setUint8(i, oldBuf.getUint8(x));
}
msg = ab;
}else if(typeof data === 'object'){
msg = eventName+dataStartChar+JSON.stringify(data);
}else{
msg = eventName+dataStartChar+data;
}
ws.send(msg);
};
/**
* close will close the websocket connection, calling the "onDisconnect" event if one has been registered.
*
* @method close
*/
self.close = function(){
reconnectOpts.enabled = false; //don't reconnect if close is called
return ws.close(1000);
};
};
window.SS = SS;
module.exports = SS;
})(window, module);

2
go.mod
View File

@@ -3,11 +3,13 @@ module github.com/oarkflow/mq
go 1.23
require (
github.com/gorilla/websocket v1.5.3
github.com/oarkflow/date v0.0.4
github.com/oarkflow/dipper v0.0.6
github.com/oarkflow/errors v0.0.6
github.com/oarkflow/expr v0.0.11
github.com/oarkflow/json v0.0.13
github.com/oarkflow/sio v0.0.7
github.com/oarkflow/xid v1.2.5
github.com/prometheus/client_golang v1.20.5
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c

4
go.sum
View File

@@ -4,6 +4,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
@@ -20,6 +22,8 @@ github.com/oarkflow/expr v0.0.11 h1:H6h+dIUlU+xDlijMXKQCh7TdE6MGVoFPpZU7q/dziRI=
github.com/oarkflow/expr v0.0.11/go.mod h1:WgMZqP44h7SBwKyuGZwC15vj46lHtI0/QpKdEZpRVE4=
github.com/oarkflow/json v0.0.13 h1:/ZKW924/v4U1ht34WY7rj/GC/qW9+10IiV5+MR2vO0A=
github.com/oarkflow/json v0.0.13/go.mod h1:S5BZA4/rM87+MY8mFrga3jISzxCL9RtLE6xHSk63VxI=
github.com/oarkflow/sio v0.0.7 h1:Bymu2GO9UccbgwWGcwqinhtFd9Pn+i24fkV6MFHyVx4=
github.com/oarkflow/sio v0.0.7/go.mod h1:tsmF+CUUdI0oksCGlLP+YT8aMF/dKXKAw1YXqQZtvtQ=
github.com/oarkflow/xid v1.2.5 h1:6RcNJm9+oZ/B647gkME9trCzhpxGQaSdNoD56Vmkeho=
github.com/oarkflow/xid v1.2.5/go.mod h1:jG4YBh+swbjlWApGWDBYnsJEa7hi3CCpmuqhB3RAxVo=
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=

View File

@@ -0,0 +1,114 @@
package bpool
import (
"io"
"unsafe"
)
// ByteBuffer provides byte buffer, which can be used for minimizing
// memory allocations.
//
// ByteBuffer may be used with functions appending data to the given []byte
// slice. See example code for details.
//
// Use Get for obtaining an empty byte buffer.
type ByteBuffer struct {
// B is a byte buffer to use in append-like workloads.
// See example code for details.
B []byte
}
// Len returns the size of the byte buffer.
func (b *ByteBuffer) Len() int {
return len(b.B)
}
// ReadFrom implements io.ReaderFrom.
//
// The function appends all the data read from r to b.
func (b *ByteBuffer) ReadFrom(r io.Reader) (int64, error) {
p := b.B
nStart := int64(len(p))
nMax := int64(cap(p))
n := nStart
if nMax == 0 {
nMax = 64
p = make([]byte, nMax)
} else {
p = p[:nMax]
}
for {
if n == nMax {
nMax *= 2
bNew := make([]byte, nMax)
copy(bNew, p)
p = bNew
}
nn, err := r.Read(p[n:])
n += int64(nn)
if err != nil {
b.B = p[:n]
n -= nStart
if err == io.EOF {
return n, nil
}
return n, err
}
}
}
// WriteTo implements io.WriterTo.
func (b *ByteBuffer) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(b.B)
return int64(n), err
}
// Bytes returns b.B, i.e. all the bytes accumulated in the buffer.
//
// The purpose of this function is bytes.Buffer compatibility.
func (b *ByteBuffer) Bytes() []byte {
return b.B
}
// Write implements io.Writer - it appends p to ByteBuffer.B
func (b *ByteBuffer) Write(p []byte) (int, error) {
b.B = append(b.B, p...)
return len(p), nil
}
// WriteByte appends the byte c to the buffer.
//
// The purpose of this function is bytes.Buffer compatibility.
//
// The function always returns nil.
func (b *ByteBuffer) WriteByte(c byte) error {
b.B = append(b.B, c)
return nil
}
// WriteString appends s to ByteBuffer.B.
func (b *ByteBuffer) WriteString(s string) (int, error) {
b.B = append(b.B, s...)
return len(s), nil
}
// Set sets ByteBuffer.B to p.
func (b *ByteBuffer) Set(p []byte) {
b.B = append(b.B[:0], p...)
}
// SetString sets ByteBuffer.B to s.
func (b *ByteBuffer) SetString(s string) {
b.B = append(b.B[:0], s...)
}
// String returns string representation of ByteBuffer.B.
func (b *ByteBuffer) String() string {
return *(*string)(unsafe.Pointer(&b.B))
}
// Reset makes ByteBuffer.B empty.
func (b *ByteBuffer) Reset() {
b.B = b.B[:0]
}

224
internal/bpool/pool.go Normal file
View File

@@ -0,0 +1,224 @@
package bpool
import (
"sort"
"sync"
"sync/atomic"
)
const (
defaultMinBitSize = 6 // 2**6=64 is a CPU cache line size
steps = 20
defaultMinSize = 1 << defaultMinBitSize
defaultMaxSize = 1 << (defaultMinBitSize + steps - 1)
calibrateCallsThreshold = 42000
maxPercentile = 0.95
)
// Pool represents byte buffer pool.
//
// Distinct pools may be used for distinct types of byte buffers.
// Properly determined byte buffer types with their own pools may help reducing
// memory waste.
type Pool struct {
calls [steps]uint64
calibrating uint64
defaultSize uint64
maxSize uint64
minBitSize uint64
minSize uint64
pool sync.Pool
}
var defaultPool Pool
// Get returns an empty byte buffer from the pool.
//
// Got byte buffer may be returned to the pool via Put call.
// This reduces the number of memory allocations required for byte buffer
// management.
func Get() *ByteBuffer { return defaultPool.Get() }
// Get returns new byte buffer with zero length.
//
// The byte buffer may be returned to the pool via Put after the use
// in order to minimize GC overhead.
func (p *Pool) Get() *ByteBuffer {
v := p.pool.Get()
if v != nil {
b := v.(*ByteBuffer)
b.Reset()
return b
}
return &ByteBuffer{
B: make([]byte, 0, atomic.LoadUint64(&p.defaultSize)),
}
}
// GetLen returns a buufer with its
// []byte slice of the exact len as specified
//
// The byte buffer may be returned to the pool via Put after the use
// in order to minimize GC overhead.
func GetLen(s int) *ByteBuffer { return defaultPool.GetLen(s) }
// GetLen return a buufer with its
// []byte slice of the exact len as specified
//
// The byte buffer may be returned to the pool via Put after the use
// in order to minimize GC overhead.
func (p *Pool) GetLen(s int) *ByteBuffer {
v := p.pool.Get()
if v == nil {
size := int(p.minSize << uint(index(p.minBitSize, s)))
if size < s {
size = s
}
return &ByteBuffer{
B: make([]byte, s, size),
}
}
b := v.(*ByteBuffer)
if cap(b.B) >= s {
b.B = b.B[:s]
return b
}
// The size is smaller, return it to the pool and create another one
p.pool.Put(b)
size := int(p.minSize << uint(index(p.minBitSize, s)))
if size < s {
size = s
}
return &ByteBuffer{
B: make([]byte, s, size),
}
}
// Put returns byte buffer to the pool.
//
// ByteBuffer.B mustn't be touched after returning it to the pool.
// Otherwise data races will occur.
func Put(b *ByteBuffer) { defaultPool.Put(b) }
// Put releases byte buffer obtained via Get to the pool.
//
// The buffer mustn't be accessed after returning to the pool.
func (p *Pool) Put(b *ByteBuffer) {
if p.minBitSize == 0 {
p.initBins()
}
idx := index(p.minBitSize, len(b.B))
if atomic.AddUint64(&p.calls[idx], 1) > calibrateCallsThreshold {
p.calibrate()
}
maxSize := int(atomic.LoadUint64(&p.maxSize))
if maxSize == 0 || cap(b.B) <= maxSize {
p.pool.Put(b)
}
}
func (p *Pool) calibrate() {
if !atomic.CompareAndSwapUint64(&p.calibrating, 0, 1) {
return
}
if p.minBitSize == 0 {
p.initBins()
}
a := make(callSizes, 0, steps)
var callsSum uint64
for i := uint64(0); i < steps; i++ {
calls := atomic.SwapUint64(&p.calls[i], 0)
callsSum += calls
a = append(a, callSize{
calls: calls,
size: p.minSize << i,
})
}
if p.minBitSize+steps < 32 && a[steps-1].calls > a[0].calls {
// Increase the first bin's size
p.resizeBins(p.minBitSize + 1)
} else if p.minBitSize > defaultMinBitSize &&
a[0].calls > 0 &&
a[steps-2].calls == 0 &&
a[steps-1].calls == 0 {
// Decrease the size of first bin's size
p.resizeBins(p.minBitSize - 1)
}
sort.Sort(a)
defaultSize := a[0].size
maxSize := defaultSize
maxSum := uint64(float64(callsSum) * maxPercentile)
callsSum = 0
for i := 0; i < steps; i++ {
if callsSum > maxSum {
break
}
callsSum += a[i].calls
size := a[i].size
if size > maxSize {
maxSize = size
}
}
atomic.StoreUint64(&p.defaultSize, defaultSize)
atomic.StoreUint64(&p.maxSize, maxSize)
atomic.StoreUint64(&p.calibrating, 0)
}
func (p *Pool) resizeBins(minBitSize uint64) {
atomic.StoreUint64(&p.minBitSize, minBitSize)
atomic.StoreUint64(&p.minSize, 1<<minBitSize)
}
func (p *Pool) initBins() {
atomic.StoreUint64(&p.minBitSize, defaultMinBitSize)
atomic.StoreUint64(&p.minSize, 1<<defaultMinBitSize)
}
type callSize struct {
calls uint64
size uint64
}
type callSizes []callSize
func (ci callSizes) Len() int {
return len(ci)
}
func (ci callSizes) Less(i, j int) bool {
return ci[i].calls > ci[j].calls
}
func (ci callSizes) Swap(i, j int) {
ci[i], ci[j] = ci[j], ci[i]
}
func index(minBitSize uint64, n int) int {
n--
n >>= minBitSize
idx := 0
for n > 0 {
n >>= 1
idx++
}
if idx >= steps {
idx = steps - 1
}
return idx
}

View File

@@ -24,7 +24,6 @@ type Result struct {
Payload json.RawMessage `json:"payload"`
}
// MarshalJSON customizes the marshaling of Result
func (r Result) MarshalJSON() ([]byte, error) {
type Alias Result
aux := &struct {
@@ -39,7 +38,6 @@ func (r Result) MarshalJSON() ([]byte, error) {
return json.Marshal(aux)
}
// UnmarshalJSON customizes the unmarshalling of Result
func (r *Result) UnmarshalJSON(data []byte) error {
type Alias Result
aux := &struct {
@@ -92,7 +90,6 @@ func (r Result) WithData(status string, data []byte) Result {
return Result{
Status: status,
Payload: data,
Error: nil,
Ctx: r.Ctx,
}
}

10
sio/adapter.go Normal file
View File

@@ -0,0 +1,10 @@
package sio
type Adapter interface {
Init()
Shutdown() error
BroadcastToBackend(*BroadcastMsg)
RoomcastToBackend(*RoomMsg)
BroadcastFromBackend(b chan<- *BroadcastMsg)
RoomcastFromBackend(r chan<- *RoomMsg)
}

210
sio/hub.go Normal file
View File

@@ -0,0 +1,210 @@
package sio
import (
"slices"
"sync"
)
type hub struct {
sockets map[string]*Socket
rooms map[string]*room
shutdownCh chan bool
socketList chan []*Socket
addCh chan *Socket
delCh chan *Socket
joinRoomCh chan *joinRequest
leaveRoomCh chan *leaveRequest
roomMsgCh chan *RoomMsg
broomcastCh chan *RoomMsg // for passing data from the backend
broadcastCh chan *BroadcastMsg
bbroadcastCh chan *BroadcastMsg
multihomeEnabled bool
multihomeBackend Adapter
l sync.RWMutex
}
type room struct {
name string
sockets map[string]*Socket
l sync.RWMutex
}
type joinRequest struct {
roomName string
socket *Socket
}
type leaveRequest struct {
roomName string
socket *Socket
}
// RoomMsg represents an event to be dispatched to a room of sockets
type RoomMsg struct {
RoomName string
Except []string
EventName string
Data any
}
// BroadcastMsg represents an event to be dispatched to all Sockets on the Server
type BroadcastMsg struct {
EventName string
Except []string
Data any
}
func (h *hub) addSocket(s *Socket) {
h.addCh <- s
}
func (h *hub) removeSocket(s *Socket) {
h.delCh <- s
}
func (h *hub) joinRoom(j *joinRequest) {
h.joinRoomCh <- j
}
func (h *hub) leaveRoom(l *leaveRequest) {
h.leaveRoomCh <- l
}
func (h *hub) toRoom(msg *RoomMsg) {
h.roomMsgCh <- msg
}
func (h *hub) broadcast(b *BroadcastMsg) {
h.broadcastCh <- b
}
func (h *hub) setMultihomeBackend(b Adapter) {
if h.multihomeEnabled {
return // can't have two backends... yet
}
h.multihomeBackend = b
h.multihomeEnabled = true
h.multihomeBackend.Init()
go h.multihomeBackend.BroadcastFromBackend(h.bbroadcastCh)
go h.multihomeBackend.RoomcastFromBackend(h.broomcastCh)
}
func (h *hub) listen() {
for {
select {
case c := <-h.addCh:
h.l.Lock()
h.sockets[c.ID()] = c
h.l.Unlock()
case c := <-h.delCh:
delete(h.sockets, c.ID())
case c := <-h.joinRoomCh:
if _, exists := h.rooms[c.roomName]; !exists { // make the room if it doesn't exist
h.rooms[c.roomName] = &room{name: c.roomName, sockets: make(map[string]*Socket)}
}
h.rooms[c.roomName].l.Lock()
h.rooms[c.roomName].sockets[c.socket.ID()] = c.socket
h.rooms[c.roomName].l.Unlock()
case c := <-h.leaveRoomCh:
if room, exists := h.rooms[c.roomName]; exists {
room.l.Lock()
delete(room.sockets, c.socket.ID())
room.l.Unlock()
if len(room.sockets) == 0 { // room is empty, delete it
delete(h.rooms, c.roomName)
}
}
case c := <-h.roomMsgCh:
if room, exists := h.rooms[c.RoomName]; exists {
room.l.Lock()
for _, s := range room.sockets {
if len(c.Except) > 0 {
if !slices.Contains(c.Except, s.ID()) {
s.Emit(c.EventName, c.Data)
}
} else {
s.Emit(c.EventName, c.Data)
}
}
room.l.Unlock()
}
if h.multihomeEnabled { // the room may exist on the other end
go h.multihomeBackend.RoomcastToBackend(c)
}
case c := <-h.broomcastCh:
if room, exists := h.rooms[c.RoomName]; exists {
room.l.Lock()
for _, s := range room.sockets {
if len(c.Except) > 0 {
if !slices.Contains(c.Except, s.ID()) {
s.Emit(c.EventName, c.Data)
}
} else {
s.Emit(c.EventName, c.Data)
}
}
room.l.Unlock()
}
case c := <-h.broadcastCh:
h.l.Lock()
for _, s := range h.sockets {
if len(c.Except) > 0 {
if !slices.Contains(c.Except, s.ID()) {
s.Emit(c.EventName, c.Data)
}
} else {
s.Emit(c.EventName, c.Data)
}
}
h.l.Unlock()
if h.multihomeEnabled {
go h.multihomeBackend.BroadcastToBackend(c)
}
case c := <-h.bbroadcastCh:
h.l.Lock()
for _, s := range h.sockets {
if len(c.Except) > 0 {
if !slices.Contains(c.Except, s.ID()) {
s.Emit(c.EventName, c.Data)
}
} else {
s.Emit(c.EventName, c.Data)
}
}
h.l.Unlock()
case _ = <-h.shutdownCh:
var socketList []*Socket
h.l.Lock()
for _, s := range h.sockets {
socketList = append(socketList, s)
}
h.l.Unlock()
h.socketList <- socketList
}
}
}
func newHub() *hub {
h := &hub{
shutdownCh: make(chan bool),
socketList: make(chan []*Socket),
sockets: make(map[string]*Socket),
rooms: make(map[string]*room),
addCh: make(chan *Socket),
delCh: make(chan *Socket),
joinRoomCh: make(chan *joinRequest),
leaveRoomCh: make(chan *leaveRequest),
roomMsgCh: make(chan *RoomMsg),
broomcastCh: make(chan *RoomMsg),
broadcastCh: make(chan *BroadcastMsg),
bbroadcastCh: make(chan *BroadcastMsg),
multihomeEnabled: false,
}
go h.listen()
return h
}

28
sio/rng.go Normal file
View File

@@ -0,0 +1,28 @@
package sio
import (
"math/rand"
"sync"
"time"
)
// RNG is a random number generator that is safe for concurrent use by multiple go routines
type RNG struct {
r *rand.Rand
mu *sync.Mutex
}
// Read reads len(b) random bytes into b and always returns a nil error
func (r *RNG) Read(b []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.r.Read(b)
}
// NewRNG creates a new random number generator
func NewRNG() *RNG {
return &RNG{
r: rand.New(rand.NewSource(time.Now().UnixNano())),
mu: &sync.Mutex{},
}
}

341
sio/server.go Normal file
View File

@@ -0,0 +1,341 @@
package sio
import (
"context"
"github.com/gorilla/websocket"
"io"
"log/slog"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
)
const ( // ASCII chars
startOfHeaderByte uint8 = 1 // SOH
startOfDataByte = 2 // STX
// SubProtocol is the official sacrificial-socket sub protocol
SubProtocol string = "sac-sock"
)
type event struct {
eventName string
eventHandler func(*Socket, []byte)
}
// Config specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
//
// It is safe to call Config's methods concurrently.
type Config struct {
HandshakeTimeout time.Duration
ReadBufferSize, WriteBufferSize int
WriteBufferPool websocket.BufferPool
Subprotocols []string
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
CheckOrigin func(r *http.Request) bool
EnableCompression bool
}
// Server manages the coordination between
// sockets, rooms, events and the socket hub
// add my own custom field
type Server struct {
hub *hub
events map[string]*event
onConnectFunc func(*Socket) error
onDisconnectFunc func(*Socket) error
onError func(*Socket, error)
l *sync.RWMutex
upgrader *websocket.Upgrader
}
// New creates a new instance of Server
func New(cfg ...Config) *Server {
var config Config
upgrader := DefaultUpgrader()
if len(cfg) > 0 {
config = cfg[0]
}
if config.CheckOrigin != nil {
upgrader.CheckOrigin = config.CheckOrigin
}
if config.HandshakeTimeout != 0 {
upgrader.HandshakeTimeout = config.HandshakeTimeout
}
if config.ReadBufferSize != 0 {
upgrader.ReadBufferSize = config.ReadBufferSize
}
if config.WriteBufferSize != 0 {
upgrader.WriteBufferSize = config.WriteBufferSize
}
if len(config.Subprotocols) > 0 {
upgrader.Subprotocols = config.Subprotocols
}
if config.Error != nil {
upgrader.Error = config.Error
}
upgrader.EnableCompression = config.EnableCompression
s := &Server{
hub: newHub(),
events: make(map[string]*event),
l: &sync.RWMutex{},
upgrader: upgrader,
}
return s
}
func (serv *Server) ShutdownWithSignal() {
c := make(chan bool)
serv.EnableSignalShutdown(c)
go func() {
<-c
os.Exit(0)
}()
}
// EnableSignalShutdown listens for linux syscalls SIGHUP, SIGINT, SIGTERM, SIGQUIT, SIGKILL and
// calls the Server.Shutdown() to perform a clean shutdown. true will be passed into complete
// after the Shutdown proccess is finished
func (serv *Server) EnableSignalShutdown(complete chan<- bool) {
c := make(chan os.Signal, 1)
signal.Notify(c,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
syscall.SIGKILL)
go func() {
<-c
complete <- serv.Shutdown()
}()
}
func (serv *Server) Lock() {
serv.l.Lock()
}
func (serv *Server) Unlock() {
serv.l.Unlock()
}
func (serv *Server) RoomSocketList(id string) map[string]*Socket {
sockets := make(map[string]*Socket)
if room, exists := serv.hub.rooms[id]; exists {
room.l.Lock()
for id, socket := range room.sockets {
sockets[id] = socket
}
room.l.Unlock()
}
return sockets
}
func (serv *Server) SocketList() map[string]*Socket {
sockets := make(map[string]*Socket)
serv.l.Lock()
for id, socket := range serv.hub.sockets {
sockets[id] = socket
}
serv.l.Unlock()
return sockets
}
// Shutdown closes all active sockets and triggers the Shutdown()
// method on any Adapter that is currently set.
func (serv *Server) Shutdown() bool {
slog.Info("shutting down")
// complete := serv.hub.shutdown()
serv.hub.shutdownCh <- true
socketList := <-serv.hub.socketList
for _, s := range socketList {
s.Close()
}
if serv.hub.multihomeEnabled {
slog.Info("shutting down multihome backend")
serv.hub.multihomeBackend.Shutdown()
slog.Info("backend shutdown")
}
slog.Info("shutdown")
return true
}
// EventHandler is an interface for registering events using SockerServer.OnEvent
type EventHandler interface {
HandleEvent(*Socket, []byte)
EventName() string
}
// On registers event functions to be called on individual Socket connections
// when the server's socket receives an Emit from the client's socket.
//
// Any event functions registered with On, must be safe for concurrent use by multiple
// go routines
func (serv *Server) On(eventName string, handleFunc func(*Socket, []byte)) {
serv.events[eventName] = &event{eventName, handleFunc} // you think you can handle the func?
}
// OnEvent has the same functionality as On, but accepts
// an EventHandler interface instead of a handler function.
func (serv *Server) OnEvent(h EventHandler) {
serv.On(h.EventName(), h.HandleEvent)
}
// OnConnect registers an event function to be called whenever a new Socket connection
// is created
func (serv *Server) OnConnect(handleFunc func(*Socket) error) {
serv.onConnectFunc = handleFunc
}
// OnError registers an event function to be called whenever a new Socket connection
// is created
func (serv *Server) OnError(handleFunc func(*Socket, error)) {
serv.onError = handleFunc
}
// OnDisconnect registers an event function to be called as soon as a Socket connection
// is closed
func (serv *Server) OnDisconnect(handleFunc func(*Socket) error) {
serv.onDisconnectFunc = handleFunc
}
// WebHandler returns a http.Handler to be passed into http.Handle
//
// Depricated: The Server struct now satisfies the http.Handler interface, use that instead
func (serv *Server) WebHandler() http.Handler {
return serv
}
// ServeHTTP will upgrade a http request to a websocket using the sac-sock subprotocol
func (serv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws, err := serv.upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error(err.Error())
return
}
request := r.Clone(context.Background())
serv.loop(ws, request)
}
// DefaultUpgrader returns a websocket upgrader suitable for creating sacrificial-socket websockets.
func DefaultUpgrader() *websocket.Upgrader {
return &websocket.Upgrader{
Subprotocols: []string{SubProtocol},
CheckOrigin: func(r *http.Request) bool {
return true
},
}
}
// SetUpgrader sets the websocket.Upgrader used by the Server.
func (serv *Server) SetUpgrader(u *websocket.Upgrader) {
serv.upgrader = u
}
// SetMultihomeBackend registers an Adapter interface and calls its Init() method
func (serv *Server) SetMultihomeBackend(b Adapter) {
serv.hub.setMultihomeBackend(b)
}
// ToRoom dispatches an event to all Sockets in the specified room.
func (serv *Server) ToRoom(roomName, eventName string, data any) {
serv.hub.toRoom(&RoomMsg{RoomName: roomName, EventName: eventName, Data: data})
}
// ToRoomExcept dispatches an event to all Sockets in the specified room.
func (serv *Server) ToRoomExcept(roomName string, except []string, eventName string, data any) {
serv.hub.toRoom(&RoomMsg{RoomName: roomName, EventName: eventName, Data: data, Except: except})
}
// Broadcast dispatches an event to all Sockets on the Server.
func (serv *Server) Broadcast(eventName string, data any) {
serv.hub.broadcast(&BroadcastMsg{EventName: eventName, Data: data})
}
// BroadcastExcept dispatches an event to all Sockets on the Server.
func (serv *Server) BroadcastExcept(except []string, eventName string, data any) {
serv.hub.broadcast(&BroadcastMsg{EventName: eventName, Except: except, Data: data})
}
// ToSocket dispatches an event to the specified socket ID.
func (serv *Server) ToSocket(socketID, eventName string, data any) {
serv.ToRoom("__socket_id:"+socketID, eventName, data)
}
// loop handles all the coordination between new sockets
// reading frames and dispatching events
func (serv *Server) loop(ws *websocket.Conn, r *http.Request) {
s := newSocket(serv, ws, r)
slog.Info("connected", "id", s.ID())
defer s.Close()
s.Join("__socket_id:" + s.ID())
serv.l.RLock()
e := serv.onConnectFunc
serv.l.RUnlock()
if e != nil {
err := e(s)
if err != nil && serv.onError != nil {
serv.onError(s, err)
}
}
for {
msg, err := s.receive()
if ignorableError(err) {
return
}
if err != nil {
slog.Error(err.Error())
return
}
eventName := ""
contentIdx := 0
for idx, chr := range msg {
if chr == startOfDataByte {
eventName = string(msg[:idx])
contentIdx = idx + 1
break
}
}
if eventName == "" {
slog.Warn("no event to dispatch")
continue
}
serv.l.RLock()
e, exists := serv.events[eventName]
serv.l.RUnlock()
if exists {
go e.eventHandler(s, msg[contentIdx:])
}
}
}
func ignorableError(err error) bool {
// not an error
if err == nil {
return false
}
return err == io.EOF ||
websocket.IsCloseError(err, 1000) ||
websocket.IsCloseError(err, 1001) ||
strings.HasSuffix(err.Error(), "use of closed network connection")
}

274
sio/socket.go Normal file
View File

@@ -0,0 +1,274 @@
package sio
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/oarkflow/mq/internal/bpool"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"log/slog"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
var (
socketRNG = NewRNG()
)
// Socket represents a websocket connection
type Socket struct {
l *sync.RWMutex
id string
ws *websocket.Conn
closed bool
serv *Server
roomsl *sync.RWMutex
request *http.Request
context storage.IMap[string, any]
rooms map[string]bool
pingTicker *time.Ticker
tickerDone chan bool
pingInterval time.Duration
}
const (
idLen int = 24
typeJSON string = "J"
typeBin = "B"
typeStr = "S"
)
func newSocket(serv *Server, ws *websocket.Conn, r *http.Request) *Socket {
s := &Socket{
l: &sync.RWMutex{},
id: newSocketID(),
ws: ws,
closed: false,
serv: serv,
roomsl: &sync.RWMutex{},
rooms: make(map[string]bool),
context: memory.New[string, any](),
request: r,
pingTicker: time.NewTicker(5 * time.Second),
tickerDone: make(chan bool),
}
serv.hub.addSocket(s)
go s.Ping()
return s
}
func newSocketID() string {
idBuf := make([]byte, idLen)
socketRNG.Read(idBuf)
return base64.StdEncoding.EncodeToString(idBuf)
}
func (s *Socket) receive() ([]byte, error) {
_, data, err := s.ws.ReadMessage()
return data, err
}
func (s *Socket) send(msgType int, data []byte) error {
s.l.Lock()
defer s.l.Unlock()
return s.ws.WriteMessage(msgType, data)
}
func (s *Socket) Ping() error {
for {
select {
case <-s.tickerDone:
return nil
case <-s.pingTicker.C:
buf := bpool.Get()
defer bpool.Put(buf)
buf.WriteString(fmt.Sprintf("%d", websocket.PongMessage))
s.ws.WriteMessage(websocket.TextMessage, buf.Bytes())
}
}
}
// InRoom returns true if s is currently a member of roomName
func (s *Socket) InRoom(roomName string) bool {
s.roomsl.RLock()
defer s.roomsl.RUnlock()
inRoom := s.rooms[roomName]
return inRoom
}
// Request get request
func (s *Socket) Request() *http.Request {
return s.request
}
// Set get request
func (s *Socket) Set(key string, val any) {
s.context.Set(key, val)
}
// Get gets value
func (s *Socket) Get(key string) (any, bool) {
return s.context.Get(key)
}
// Context gets value
func (s *Socket) Context() storage.IMap[string, any] {
return s.context
}
// GetRooms returns a list of rooms that s is a member of
func (s *Socket) GetRooms() []string {
s.roomsl.RLock()
defer s.roomsl.RUnlock()
var roomList []string
for room := range s.rooms {
roomList = append(roomList, room)
}
return roomList
}
// Join adds s to the specified room. If the room does
// not exist, it will be created
func (s *Socket) Join(roomName string) {
s.roomsl.Lock()
defer s.roomsl.Unlock()
s.serv.hub.joinRoom(&joinRequest{roomName, s})
s.rooms[roomName] = true
}
// Leave removes s from the specified room. If s
// is not a member of the room, nothing will happen. If the room is
// empty upon removal of s, the room will be closed
func (s *Socket) Leave(roomName string) {
s.roomsl.Lock()
defer s.roomsl.Unlock()
s.serv.hub.leaveRoom(&leaveRequest{roomName, s})
delete(s.rooms, roomName)
}
// LeaveAll removes s from the specified room. If s
// is not a member of the room, nothing will happen. If the room is
// empty upon removal of s, the room will be closed
func (s *Socket) LeaveAll() {
for roomName := range s.rooms {
s.Leave(roomName)
}
}
// ToRoom dispatches an event to all Sockets in the specified room.
func (s *Socket) ToRoom(roomName, eventName string, data any) {
s.serv.hub.toRoom(&RoomMsg{RoomName: roomName, EventName: eventName, Data: data})
}
// ToRoomExcept dispatches an event to all Sockets in the specified room.
func (s *Socket) ToRoomExcept(roomName string, except []string, eventName string, data any) {
s.serv.hub.toRoom(&RoomMsg{RoomName: roomName, EventName: eventName, Data: data, Except: except})
}
// Broadcast dispatches an event to all Sockets on the Server.
func (s *Socket) Broadcast(eventName string, data any) {
s.serv.hub.broadcast(&BroadcastMsg{EventName: eventName, Data: data})
}
// BroadcastExcept dispatches an event to all Sockets on the Server.
func (s *Socket) BroadcastExcept(except []string, eventName string, data any) {
s.serv.hub.broadcast(&BroadcastMsg{EventName: eventName, Data: data, Except: except})
}
// ToSocket dispatches an event to the specified socket ID.
func (s *Socket) ToSocket(socketID, eventName string, data any) {
s.serv.ToRoom("__socket_id:"+socketID, eventName, data)
}
// Emit dispatches an event to s.
func (s *Socket) Emit(eventName string, data any) error {
return s.send(emitData(eventName, data))
}
// ID returns the unique ID of s
func (s *Socket) ID() string {
return s.id
}
// emitData combines the eventName and data into a payload that is understood
// by the sac-sock protocol.
func emitData(eventName string, data any) (int, []byte) {
buf := bpool.Get()
defer bpool.Put(buf)
buf.WriteString(eventName)
buf.WriteByte(startOfHeaderByte)
switch d := data.(type) {
case string:
buf.WriteString(typeStr)
buf.WriteByte(startOfDataByte)
buf.WriteString(d)
return websocket.TextMessage, buf.Bytes()
case []byte:
buf.WriteString(typeBin)
buf.WriteByte(startOfDataByte)
buf.Write(d)
return websocket.BinaryMessage, buf.Bytes()
default:
buf.WriteString(typeJSON)
buf.WriteByte(startOfDataByte)
jsonData, err := json.Marshal(d)
if err != nil {
slog.Error(err.Error())
} else {
buf.Write(jsonData)
}
return websocket.TextMessage, buf.Bytes()
}
}
// Close closes the Socket connection and removes the Socket
// from any rooms that it was a member of
func (s *Socket) Close() error {
s.l.Lock()
isAlreadyClosed := s.closed
s.closed = true
s.l.Unlock()
if isAlreadyClosed { // can't reclose the socket
return nil
}
defer slog.Debug(s.ID(), "disconnected")
err := s.ws.Close()
if err != nil {
return err
}
rooms := s.GetRooms()
for _, room := range rooms {
s.Leave(room)
}
s.serv.l.RLock()
event := s.serv.onDisconnectFunc
s.serv.l.RUnlock()
if event != nil {
if err := event(s); err != nil {
return err
}
}
s.serv.hub.removeSocket(s)
if s.pingTicker != nil {
s.pingTicker.Stop()
s.tickerDone <- true
}
return nil
}

View File

@@ -19,6 +19,10 @@ func (b *Broker) NotifyHandler() func(context.Context, Result) error {
return b.opts.notifyResponse
}
func (b *Broker) SetNotifyHandler(callback Callback) {
b.opts.notifyResponse = callback
}
func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) {
if b.opts.callback != nil {
var result Result