mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-30 03:01:58 +08:00
Compare commits
158 Commits
v2.0.3
...
remove-ven
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9c3de2c8c | ||
|
|
787390e7ba | ||
|
|
916d022093 | ||
|
|
11e0256959 | ||
|
|
858a28ea4b | ||
|
|
82f32cf53d | ||
|
|
b4e2c61a72 | ||
|
|
2b0732b815 | ||
|
|
450b8a2817 | ||
|
|
0d880fa9bd | ||
|
|
a5310bdfdf | ||
|
|
96b96e8248 | ||
|
|
a786a0a526 | ||
|
|
c6297ee9e3 | ||
|
|
2f94751220 | ||
|
|
89a786832d | ||
|
|
8f93e35f1e | ||
|
|
ba6b12a303 | ||
|
|
a3fe5d4729 | ||
|
|
bb5fc40fe4 | ||
|
|
fbb1fb25aa | ||
|
|
fac733fd71 | ||
|
|
add87fea2e | ||
|
|
58f9fed336 | ||
|
|
1574443981 | ||
|
|
44bac0adc5 | ||
|
|
e784c755ae | ||
|
|
eafc2d91fc | ||
|
|
0df69a4a4e | ||
|
|
321a0514fe | ||
|
|
00387593c9 | ||
|
|
af78d10870 | ||
|
|
30ca94e878 | ||
|
|
ae3f72f677 | ||
|
|
9838262e66 | ||
|
|
ac812154e6 | ||
|
|
0234589152 | ||
|
|
ea97038052 | ||
|
|
48233334a5 | ||
|
|
33429451c8 | ||
|
|
050e24662f | ||
|
|
aec29e350e | ||
|
|
cb99d6f4bc | ||
|
|
c77d1c0331 | ||
|
|
6f42c3fd65 | ||
|
|
9c52292732 | ||
|
|
990f308faa | ||
|
|
fe0c1d15a6 | ||
|
|
0648e39507 | ||
|
|
233a82e448 | ||
|
|
51a8d8cb54 | ||
|
|
23c3208310 | ||
|
|
23e1092cda | ||
|
|
d498576927 | ||
|
|
7e14ce99b5 | ||
|
|
4db49a4b9d | ||
|
|
e60b8ff0c9 | ||
|
|
b9d5dcb5f0 | ||
|
|
6d394d1fe9 | ||
|
|
1ee2158711 | ||
|
|
af79b55b9f | ||
|
|
e1a9497c25 | ||
|
|
62659e17ba | ||
|
|
7ad6dd8e1a | ||
|
|
565e07747e | ||
|
|
6acd775a6b | ||
|
|
493f6c8bb0 | ||
|
|
d3785c2717 | ||
|
|
52a347169a | ||
|
|
797d75cb34 | ||
|
|
5225a357e5 | ||
|
|
a734a0dc73 | ||
|
|
6704cf7227 | ||
|
|
9233e6fd39 | ||
|
|
1ca65d9631 | ||
|
|
33229da885 | ||
|
|
c274d5fd08 | ||
|
|
10e82f41d6 | ||
|
|
e6c07b2b78 | ||
|
|
eed3ef9606 | ||
|
|
1ec880844d | ||
|
|
4b49652a8c | ||
|
|
d46e7b5bcf | ||
|
|
17fb7dadbc | ||
|
|
ed7fd836e1 | ||
|
|
605bb93c75 | ||
|
|
c73ace2ea0 | ||
|
|
aac6d699da | ||
|
|
7bd7bd5087 | ||
|
|
655bf9fdb1 | ||
|
|
b188055c7d | ||
|
|
aaf1d9d4c6 | ||
|
|
44ce819318 | ||
|
|
e4c76cc60c | ||
|
|
da79faa972 | ||
|
|
46babc89c8 | ||
|
|
9b7a943888 | ||
|
|
a909d30923 | ||
|
|
0851b09e4d | ||
|
|
a302c9dd88 | ||
|
|
1e8f922102 | ||
|
|
4c16e5593f | ||
|
|
49cada4fbc | ||
|
|
ef34510c0b | ||
|
|
e5716caad1 | ||
|
|
4b039cb35c | ||
|
|
aac245441a | ||
|
|
bb54cc68e6 | ||
|
|
7ba1352a60 | ||
|
|
ca849131eb | ||
|
|
ba7e534122 | ||
|
|
db760c34a5 | ||
|
|
ae3ee81bb4 | ||
|
|
c2ca02d149 | ||
|
|
77a64d9c87 | ||
|
|
8dec9cc962 | ||
|
|
f90e52328d | ||
|
|
50aae47618 | ||
|
|
0d79f2d63b | ||
|
|
300152413c | ||
|
|
0de1d731db | ||
|
|
80746abc52 | ||
|
|
a73cf4ca0e | ||
|
|
bc549ee7ed | ||
|
|
c464b46713 | ||
|
|
05ce56008c | ||
|
|
8254cb0cbc | ||
|
|
4ae58b79e3 | ||
|
|
b895d688e0 | ||
|
|
a600cd4ead | ||
|
|
cdb44990cf | ||
|
|
2d9c128111 | ||
|
|
a0d5bdb39f | ||
|
|
4ebcef3cb6 | ||
|
|
fb8d4720d7 | ||
|
|
4080c89127 | ||
|
|
1b67e6f3f6 | ||
|
|
1adb02e087 | ||
|
|
4d4140aa99 | ||
|
|
e31840a37d | ||
|
|
7d2e16f2d3 | ||
|
|
92cd935a16 | ||
|
|
25ce27ce2d | ||
|
|
527d084a4b | ||
|
|
bb9f937bb0 | ||
|
|
511fe88684 | ||
|
|
75504ff201 | ||
|
|
a556feb325 | ||
|
|
d06f47f4b9 | ||
|
|
8d4cc091b4 | ||
|
|
d8f28cb843 | ||
|
|
88861c219d | ||
|
|
7ba6cf28d9 | ||
|
|
c174cfdc6b | ||
|
|
4f198a99dd | ||
|
|
2a9c9fcc40 | ||
|
|
835a85c8bf | ||
|
|
fe5d9ffa61 |
86
.github/workflows/build.yml
vendored
86
.github/workflows/build.yml
vendored
@@ -3,41 +3,71 @@ name: build
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -race ./... && echo true
|
||||
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.21
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
- name: Test
|
||||
run: go test -race ./... && echo true
|
||||
|
||||
coverage:
|
||||
name: Test with Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
if: success()
|
||||
uses: actions/setup-go@v2
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Calc coverage
|
||||
go-version: '1.21'
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
go test -v -covermode=count -coverprofile=coverage.out ./...
|
||||
- name: Convert coverage.out to coverage.lcov
|
||||
uses: jandelgado/gcov2lcov-action@v1.0.6
|
||||
- name: Coveralls
|
||||
uses: coverallsapp/github-action@v1.1.2
|
||||
go mod download
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
go test -race -covermode atomic -coverprofile=covprofile ./...
|
||||
- name: Install goveralls
|
||||
run: go install github.com/mattn/goveralls@latest
|
||||
- name: Send coverage
|
||||
env:
|
||||
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: goveralls -coverprofile=covprofile -service=github
|
||||
|
||||
docker:
|
||||
if: github.repository == 'mochi-mqtt/server' && startsWith(github.ref, 'refs/tags/v')
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
github-token: ${{ secrets.github_token }}
|
||||
path-to-lcov: coverage.lcov
|
||||
images: mochimqtt/server
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=raw,value=latest,enable=${{ endsWith(github.ref, 'main') }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
cmd/mqtt
|
||||
.DS_Store
|
||||
*.db
|
||||
.idea
|
||||
vendor
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.19.0-alpine3.15 AS builder
|
||||
FROM golang:1.21.0-alpine3.18 AS builder
|
||||
|
||||
RUN apk update
|
||||
RUN apk add git
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2019, 2022 Jonathan Blake (mochi-co)
|
||||
Copyright (c) 2023 Mochi-MQTT Organisation
|
||||
Copyright (c) 2019, 2022, 2023 Jonathan Blake (mochi-co)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
505
README-CN.md
Normal file
505
README-CN.md
Normal file
@@ -0,0 +1,505 @@
|
||||
# Mochi-MQTT Server
|
||||
|
||||
<p align="center">
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-mqtt/server?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-mqtt/server/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-mqtt/server/v2)
|
||||
[](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
</p>
|
||||
|
||||
[English](README.md) | [简体中文](README-CN.md) | [招募翻译者!](https://github.com/orgs/mochi-mqtt/discussions/310)
|
||||
|
||||
|
||||
🎆 **mochi-co/mqtt 现在已经是新的 mochi-mqtt 组织的一部分。** 详细信息请[阅读公告.](https://github.com/orgs/mochi-mqtt/discussions/271)
|
||||
|
||||
|
||||
### Mochi-MQTT 是一个完全兼容的、可嵌入的高性能 Go MQTT v5(以及 v3.1.1)中间件/服务器。
|
||||
|
||||
Mochi MQTT 是一个[完全兼容 MQTT v5](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) 的可嵌入的中间件/服务器,完全使用 Go 语言编写,旨在用于遥测和物联网项目的开发。它可以作为独立的二进制文件使用,也可以嵌入到你自己的应用程序中作为库来使用,经过精心设计以实现尽可能的轻量化和快速部署,同时也极为重视代码的质量和可维护性。
|
||||
|
||||
#### 什么是 MQTT?
|
||||
MQTT 代表 MQ Telemetry Transport。它是一种发布/订阅、非常简单和轻量的消息传递协议,专为受限设备和低带宽、高延迟或不可靠网络设计而成([了解更多](https://mqtt.org/faq))。Mochi MQTT 实现了完整的 MQTT 协议的 5.0.0 版本。
|
||||
|
||||
#### Mochi-MQTT 特性
|
||||
- 完全兼容 MQTT v5 功能,与 MQTT v3.1.1 和 v3.0.0 兼容:
|
||||
- MQTT v5 用户和数据包属性
|
||||
- 主题别名(Topic Aliases)
|
||||
- 共享订阅(Shared Subscriptions)
|
||||
- 订阅选项和订阅标识符(Identifiers)
|
||||
- 消息过期(Message Expiry)
|
||||
- 客户端会话过期(Client Session Expiry)
|
||||
- 发送和接收 QoS 流量控制配额(Flow Control Quotas)
|
||||
- 服务器端的断开连接和数据包的权限验证(Auth Packets)
|
||||
- 遗愿消息延迟间隔(Will Delay Intervals)
|
||||
- 还有 Mochi MQTT v1 的所有原始 MQTT 功能,例如完全的 QoS(0,1,2)、$SYS 主题、保留消息等。
|
||||
- 面向开发者:
|
||||
- 核心代码都已开放并可访问,以便开发者完全控制。
|
||||
- 功能丰富且灵活的基于钩子(Hook)的接口系统,支持便捷的“插件(plugin)”开发。
|
||||
- 使用特殊的内联客户端(inline client)进行服务端的消息发布,也支持服务端伪装成现有的客户端。
|
||||
- 高性能且稳定:
|
||||
- 基于经典前缀树 Trie 的主题-订阅模型。
|
||||
- 客户端特定的写入缓冲区,避免因读取速度慢或客户端不规范行为而产生的问题。
|
||||
- 通过所有 [Paho互操作性测试](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability)(MQTT v5 和 MQTT v3)。
|
||||
- 超过一千多个经过仔细考虑的单元测试场景。
|
||||
- 支持 TCP、Websocket(包括 SSL/TLS)和$SYS 服务状态监控。
|
||||
- 内置 基于Redis、Badger 和 Bolt 的持久化(使用Hook钩子,你也可以自己创建)。
|
||||
- 内置基于规则的认证和 ACL 权限管理(使用Hook钩子,你也可以自己创建)。
|
||||
|
||||
### 兼容性说明(Compatibility Notes)
|
||||
由于 v5 规范与 MQTT 的早期版本存在重叠,因此服务器可以接受 v5 和 v3 客户端,但在连接了 v5 和 v3 客户端的情况下,为 v5 客户端提供的属性和功能将会对 v3 客户端进行降级处理(例如用户属性)。
|
||||
|
||||
对于 MQTT v3.0.0 和 v3.1.1 的支持被视为混合兼容性。在 v3 规范中没有明确限制的情况下,将使用更新的和以安全为首要考虑的 v5 规范 - 例如保留的消息(retained messages)的过期处理,待发送消息(inflight messages)的过期处理、客户端过期处理以及QOS消息数量的限制等。
|
||||
|
||||
#### 版本更新时间
|
||||
除非涉及关键问题,新版本通常在周末发布。
|
||||
|
||||
## 规划路线图(Roadmap)
|
||||
- 请[提出问题](https://github.com/mochi-mqtt/server/issues)来请求新功能或新的hook钩子接口!
|
||||
- 集群支持。
|
||||
- 统计度量支持。
|
||||
- 配置文件支持(支持 Docker)。
|
||||
|
||||
## 快速开始(Quick Start)
|
||||
### 使用 Go 运行服务端
|
||||
Mochi MQTT 可以作为独立的中间件使用。只需拉取此仓库代码,然后在 [cmd](cmd) 文件夹中运行 [cmd/main.go](cmd/main.go) ,默认将开启下面几个服务端口, tcp (:1883)、websocket (:1882) 和服务状态监控 (:8080) 。
|
||||
|
||||
```
|
||||
cd cmd
|
||||
go build -o mqtt && ./mqtt
|
||||
```
|
||||
|
||||
### 使用 Docker
|
||||
|
||||
你现在可以从 Docker Hub 仓库中拉取并运行Mochi MQTT[官方镜像](https://hub.docker.com/r/mochimqtt/server):
|
||||
```sh
|
||||
docker pull mochimqtt/server
|
||||
或者
|
||||
docker run mochimqtt/server
|
||||
```
|
||||
|
||||
我们还在积极完善这部分的工作,现在正在实现使用[配置文件的启动](https://github.com/orgs/mochi-mqtt/projects/2)方式。更多关于 Docker 的支持正在[这里](https://github.com/orgs/mochi-mqtt/discussions/281#discussion-5544545)和[这里](https://github.com/orgs/mochi-mqtt/discussions/209)进行讨论。如果你有在这个场景下使用 Mochi-MQTT,也可以参与到讨论中来。
|
||||
|
||||
我们提供了一个简单的 Dockerfile,用于运行 cmd/main.go 中的 Websocket(:1882)、TCP(:1883) 和服务端状态信息(:8080)这三个服务监听:
|
||||
|
||||
```sh
|
||||
docker build -t mochi:latest .
|
||||
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
```
|
||||
|
||||
|
||||
## 使用 Mochi MQTT 进行开发
|
||||
### 将Mochi MQTT作为包导入使用
|
||||
将 Mochi MQTT 作为一个包导入只需要几行代码即可开始使用。
|
||||
``` go
|
||||
import (
|
||||
"log"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建信号用于等待服务端关闭信号
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 创建新的 MQTT 服务器。
|
||||
server := mqtt.New(nil)
|
||||
|
||||
// 允许所有连接(权限)。
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// 在标1883端口上创建一个 TCP 服务端。
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 服务端等待关闭信号
|
||||
<-done
|
||||
|
||||
// 关闭服务端时需要做的一些清理工作
|
||||
}
|
||||
```
|
||||
|
||||
在 [examples](examples) 文件夹中可以找到更多使用不同配置运行服务端的示例。
|
||||
|
||||
#### 网络监听器 (Network Listeners)
|
||||
|
||||
服务端内置了一些已经实现的网络监听(Network Listeners),这些Listeners允许服务端接受不同协议的连接。当前的监听Listeners有这些:
|
||||
|
||||
| Listener | Usage |
|
||||
|------------------------------|----------------------------------------------------------------------------------------------|
|
||||
| listeners.NewTCP | 一个 TCP 监听器,接收TCP连接 |
|
||||
| listeners.NewUnixSock | 一个 Unix 套接字监听器 |
|
||||
| listeners.NewNet | 一个 net.Listener 监听 |
|
||||
| listeners.NewWebsocket | 一个 Websocket 监听器 |
|
||||
| listeners.NewHTTPStats | 一个 HTTP $SYS 服务状态监听器 |
|
||||
| listeners.NewHTTPHealthCheck | 一个 HTTP 健康检测监听器,用于为例如云基础设施提供健康检查响应 |
|
||||
|
||||
> 可以使用listeners.Listener接口开发新的监听器。如果有兴趣,你可以实现自己的Listener,如果你在此期间你有更好的建议或疑问,你可以[提交问题](https://github.com/mochi-mqtt/server/issues)给我们。
|
||||
|
||||
可以在*listeners.Config 中配置TLS,传递给Listener使其支持TLS。
|
||||
我们提供了一些示例,可以在 [示例](examples) 文件夹或 [cmd/main.go](cmd/main.go) 中找到。
|
||||
|
||||
### 服务端选项和功能(Server Options and Capabilities)
|
||||
|
||||
有许多可配置的选项(Options)可用于更改服务器的行为或限制对某些功能的访问。
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
Capabilities: mqtt.Capabilities{
|
||||
MaximumSessionExpiryInterval: 3600,
|
||||
Compatibilities: mqtt.Compatibilities{
|
||||
ObscureNotAuthorized: true,
|
||||
},
|
||||
},
|
||||
ClientNetWriteBufferSize: 4096,
|
||||
ClientNetReadBufferSize: 4096,
|
||||
SysTopicResendInterval: 10,
|
||||
InlineClient: false,
|
||||
})
|
||||
```
|
||||
请参考 mqtt.Options、mqtt.Capabilities 和 mqtt.Compatibilities 结构体,以查看完整的所有服务端选项。ClientNetWriteBufferSize 和 ClientNetReadBufferSize 可以根据你的需求配置调整每个客户端的内存使用状况。
|
||||
|
||||
### 默认配置说明(Default Configuration Notes)
|
||||
|
||||
关于决定默认配置的值,在这里进行一些说明:
|
||||
|
||||
- 默认情况下,server.Options.Capabilities.MaximumMessageExpiryInterval 的值被设置为 86400(24小时),以防止在使用默认配置时网络上暴露服务器而受到恶意DOS攻击(如果不配置到期时间将允许无限数量的保留retained/待发送inflight消息累积)。如果您在一个受信任的环境中运行,或者您有更大的保留期容量,您可以选择覆盖此设置(设置为 0 或 math.MaxInt 以取消到期限制)。
|
||||
|
||||
## 事件钩子(Event Hooks)
|
||||
|
||||
服务端有一个通用的事件钩子(Event Hooks)系统,它允许开发人员在服务器和客户端生命周期的各个阶段定制添加和修改服务端的功能。这些通用Hook钩子用于提供从认证(authentication)、持久性存储(persistent storage)到调试工具(debugging tools)等各种功能。
|
||||
|
||||
钩子(Hook)是可叠加的 - 你可以向服务器添加多个钩子(Hook),它们将按添加的顺序运行。一些钩子(Hook)修改值,这些修改后的值将在所有钩子返回之前传递给后续的钩子(Hook)。
|
||||
|
||||
| 类型 | 导入包 | 描述 |
|
||||
|----------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|
|
||||
| 访问控制 | [mochi-mqtt/server/hooks/auth . AllowHook](hooks/auth/allow_all.go) | AllowHook 允许所有客户端连接访问并读写所有主题。 |
|
||||
| 访问控制 | [mochi-mqtt/server/hooks/auth . Auth](hooks/auth/auth.go) | 基于规则的访问权限控制。 |
|
||||
| 数据持久性 | [mochi-mqtt/server/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | 使用 [BoltDB](https://dbdb.io/db/boltdb) 进行持久性存储(已弃用)。 |
|
||||
| 数据持久性 | [mochi-mqtt/server/hooks/storage/badger](hooks/storage/badger/badger.go) | 使用 [BadgerDB](https://github.com/dgraph-io/badger) 进行持久性存储。 |
|
||||
| 数据持久性 | [mochi-mqtt/server/hooks/storage/redis](hooks/storage/redis/redis.go) | 使用 [Redis](https://redis.io) 进行持久性存储。 |
|
||||
| 调试跟踪 | [mochi-mqtt/server/hooks/debug](hooks/debug/debug.go) | 调试输出以查看数据包在服务端的链路追踪。 |
|
||||
|
||||
许多内部函数都已开放给开发者,你可以参考上述示例创建自己的Hook钩子。如果你有更好的关于Hook钩子方面的建议或者疑问,你可以[提交问题](https://github.com/mochi-mqtt/server/issues)给我们。 |
|
||||
|
||||
### 访问控制(Access Control)
|
||||
|
||||
#### 允许所有(Allow Hook)
|
||||
|
||||
默认情况下,Mochi MQTT 使用拒绝所有(DENY-ALL)的访问控制规则。要允许连接,必须实现一个访问控制的钩子(Hook)来替代默认的(DENY-ALL)钩子。其中最简单的钩子(Hook)是 auth.AllowAll 钩子(Hook),它为所有连接、订阅和发布提供允许所有(ALLOW-ALL)的规则。这也是使用最简单的钩子:
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
```
|
||||
|
||||
>如果你将服务器暴露在互联网或不受信任的网络上,请不要这样做 - 它真的应该仅用于开发、测试和调试。
|
||||
|
||||
#### 权限认证(Auth Ledger)
|
||||
|
||||
权限认证钩子(Auth Ledger hook)使用结构化的定义来制定访问规则。认证规则分为两种形式:身份规则(连接时使用)和 ACL权限规则(发布订阅时使用)。
|
||||
|
||||
身份规则(Auth rules)有四个可选参数和一个是否允许参数:
|
||||
|
||||
| 参数 | 说明 |
|
||||
| -- | -- |
|
||||
| Client | 客户端的客户端 ID |
|
||||
| Username | 客户端的用户名 |
|
||||
| Password | 客户端的密码 |
|
||||
| Remote | 客户端的远程地址或 IP |
|
||||
| Allow | true(允许此用户)或 false(拒绝此用户) |
|
||||
|
||||
ACL权限规则(ACL rules)有三个可选参数和一个主题匹配参数:
|
||||
| 参数 | 说明 |
|
||||
| -- | -- |
|
||||
| Client | 客户端的客户端 ID |
|
||||
| Username | 客户端的用户名 |
|
||||
| Remote | 客户端的远程地址或 IP |
|
||||
| Filters | 用于匹配的主题数组 |
|
||||
|
||||
规则按索引顺序(0,1,2,3)处理,并在匹配到第一个规则时返回。请查看 [hooks/auth/ledger.go](hooks/auth/ledger.go) 的具体实现。
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Ledger: &auth.Ledger{
|
||||
Auth: auth.AuthRules{ // Auth 默认情况下禁止所有连接
|
||||
{Username: "peach", Password: "password1", Allow: true},
|
||||
{Username: "melon", Password: "password2", Allow: true},
|
||||
{Remote: "127.0.0.1:*", Allow: true},
|
||||
{Remote: "localhost:*", Allow: true},
|
||||
},
|
||||
ACL: auth.ACLRules{ // ACL 默认情况下允许所有连接
|
||||
{Remote: "127.0.0.1:*"}, // 本地用户允许所有连接
|
||||
{
|
||||
// 用户 melon 可以读取和写入自己的主题
|
||||
Username: "melon", Filters: auth.Filters{
|
||||
"melon/#": auth.ReadWrite,
|
||||
"updates/#": auth.WriteOnly, // 可以写入 updates,但不能从其他人那里读取 updates
|
||||
},
|
||||
},
|
||||
{
|
||||
// 其他的客户端没有发布的权限
|
||||
Filters: auth.Filters{
|
||||
"#": auth.ReadOnly,
|
||||
"updates/#": auth.Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
规则还可以存储为 JSON 或 YAML,并使用 Data 字段加载文件的二进制数据:
|
||||
|
||||
```go
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // 从字节数组(文件二进制)读取规则:yaml 或 json
|
||||
})
|
||||
```
|
||||
详细信息请参阅 [examples/auth/encoded/main.go](examples/auth/encoded/main.go)。
|
||||
|
||||
### 持久化存储(Persistent Storage)
|
||||
|
||||
#### Redis
|
||||
|
||||
我们提供了一个基本的 Redis 存储钩子(Hook),用于为服务端提供数据持久性。你可以将这个Redis的钩子(Hook)添加到服务器中,Redis的一些参数也是可以配置的。这个钩子(Hook)里使用 github.com/go-redis/redis/v8 这个库,可以通过 Options 来配置一些参数。
|
||||
|
||||
```go
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
Addr: "localhost:6379", // Redis服务端地址
|
||||
Password: "", // Redis服务端的密码
|
||||
DB: 0, // Redis数据库的index
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
有关 Redis 钩子的工作原理或如何使用它的更多信息,请参阅 [examples/persistence/redis/main.go](examples/persistence/redis/main.go) 或 [hooks/storage/redis](hooks/storage/redis) 。
|
||||
|
||||
#### Badger DB
|
||||
|
||||
如果您更喜欢基于文件的存储,还有一个 BadgerDB 存储钩子(Hook)可用。它可以以与其他钩子大致相同的方式添加和配置(具有较少的选项)。
|
||||
|
||||
```go
|
||||
err := server.AddHook(new(badger.Hook), &badger.Options{
|
||||
Path: badgerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
有关 Badger 钩子(Hook)的工作原理或如何使用它的更多信息,请参阅 [examples/persistence/badger/main.go](examples/persistence/badger/main.go) 或 [hooks/storage/badger](hooks/storage/badger)。
|
||||
|
||||
还有一个 BoltDB 钩子(Hook),已被弃用,推荐使用 Badger,但如果你想使用它,请参考 [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go)。
|
||||
|
||||
## 使用事件钩子 Event Hooks 进行开发
|
||||
|
||||
在服务端和客户端生命周期中,开发者可以使用各种钩子(Hook)增加对服务端或客户端的一些自定义的处理。
|
||||
所有的钩子都定义在mqtt.Hook这个接口中了,可以在 [hooks.go](hooks.go) 中找到这些钩子(Hook)函数。
|
||||
|
||||
> 最灵活的事件钩子是 OnPacketRead、OnPacketEncode 和 OnPacketSent - 这些钩子可以用来控制和修改所有传入和传出的数据包。
|
||||
|
||||
| 钩子函数 | 说明 |
|
||||
|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| OnStarted | 在服务器成功启动时调 |用。 |
|
||||
| OnStopped | 在服务器成功停止时调用。 |
|
||||
| OnConnectAuthenticate | 当用户尝试与服务器进行身份验证时调用。必须实现此方法来允许或拒绝对服务器的访问(请参阅 hooks/auth/allow_all 或 basic)。它可以在自定义Hook钩子中使用,以检查连接的用户是否与现有用户数据库中的用户匹配。如果允许访问,则返回 true。 |
|
||||
| OnACLCheck | 当用户尝试发布或订阅主题时调用,用来检测ACL规则。 |
|
||||
| OnSysInfoTick | 当 $SYS 主题相关的消息被发布时调用。 |
|
||||
| OnConnect | 当新客户端连接时调用,可能返回一个错误或错误码以中断客户端的连接。 |
|
||||
| OnSessionEstablish | 在新客户端连接并进行身份验证后,会立即调用此方法,并在会话建立和发送CONNACK之前立即调用。 |
|
||||
| OnSessionEstablished | 在新客户端成功建立会话(在OnConnect之后)时调用。 |
|
||||
| OnDisconnect | 当客户端因任何原因断开连接时调用。 |
|
||||
| OnAuthPacket | 当接收到认证数据包时调用。它旨在允许开发人员创建自己的 MQTT v5 认证数据包处理机制。在这里允许数据包的修改。 |
|
||||
| OnPacketRead | 当从客户端接收到数据包时调用。允许对数据包进行修改。 |
|
||||
| OnPacketEncode | 在数据包被编码并发送给客户端之前立即调用。允许修改数据包。 |
|
||||
| OnPacketSent | 在数据包已发送给客户端后调用。 |
|
||||
| OnPacketProcessed | 在数据包已接收并成功由服务端处理后调用。 |
|
||||
| OnSubscribe | 当客户端订阅一个或多个主题时调用。允许修改数据包。 |
|
||||
| OnSubscribed | 当客户端成功订阅一个或多个主题时调用。 |
|
||||
| OnSelectSubscribers | 当订阅者已被关联到一个主题中,在选择共享订阅的订阅者之前调用。允许接收者修改。 |
|
||||
| OnUnsubscribe | 当客户端取消订阅一个或多个主题时调用。允许包修改。 |
|
||||
| OnUnsubscribed | 当客户端成功取消订阅一个或多个主题时调用。 |
|
||||
| OnPublish | 当客户端发布消息时调用。允许修改数据包。 |
|
||||
| OnPublished | 当客户端向订阅者发布消息后调用。|
|
||||
| OnPublishDropped | 消息传递给客户端之前消息已被丢弃,将调用此方法。 例如当客户端响应时间过长需要丢弃消息时。 |
|
||||
| OnRetainMessage | 当消息被保留时调用。 |
|
||||
| OnRetainPublished | 当保留的消息被发布给客户端时调用。 |
|
||||
| OnQosPublish | 当发出QoS >= 1 的消息给订阅者后调用。 |
|
||||
| OnQosComplete | 在消息的QoS流程走完之后调用。 |
|
||||
| OnQosDropped | 在消息的QoS流程未完成,同时消息到期时调用。 |
|
||||
| OnPacketIDExhausted | 当packet ids已经用完后,没有可用的id可再分配时调用。 |
|
||||
| OnWill | 当客户端断开连接并打算发布遗嘱消息时调用。允许修改数据包。 |
|
||||
| OnWillSent | 遗嘱消息发送完成后被调用。 |
|
||||
| OnClientExpired | 在客户端会话已过期并应删除时调用。 |
|
||||
| OnRetainedExpired | 在保留的消息已过期并应删除时调用。| |
|
||||
| StoredClients | 这个接口需要返回客户端列表,例如从持久化数据库中获取客户端列表。 |
|
||||
| StoredSubscriptions | 返回客户端的所有订阅,例如从持久化数据库中获取客户端的订阅列表。 |
|
||||
| StoredInflightMessages | 返回待发送消息(inflight messages),例如从持久化数据库中获取到还有哪些消息未完成传输。 |
|
||||
| StoredRetainedMessages | 返回保留的消息,例如从持久化数据库获取保留的消息。 |
|
||||
| StoredSysInfo | 返回存储的系统状态信息,例如从持久化数据库获取的系统状态信息。 |
|
||||
|
||||
如果你想自己实现一个持久化存储的Hook钩子,请参考现有的持久存储Hook钩子以获取灵感和借鉴。如果您正在构建一个身份验证Hook钩子,您将需要实现OnACLCheck 和 OnConnectAuthenticate这两个函数接口。
|
||||
|
||||
### 内联客户端 (Inline Client v2.4.0+支持)
|
||||
|
||||
现在可以通过使用内联客户端功能直接在服务端上订阅主题和发布消息。内联客户端是内置在服务端中的特殊的客户端,可以在服务端的配置中启用:
|
||||
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true,
|
||||
})
|
||||
```
|
||||
启用上述配置后,你将能够使用 server.Publish、server.Subscribe 和 server.Unsubscribe 方法来在服务端中直接发布和接收消息。
|
||||
|
||||
具体如何使用请参考 [direct examples](examples/direct/main.go) 。
|
||||
|
||||
#### 内联发布(Inline Publish)
|
||||
要想在服务端中直接发布Publish一个消息,可以使用 `server.Publish`方法。
|
||||
|
||||
```go
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
```
|
||||
> 在这种情况下,QoS级别只对订阅者有效,按照 MQTT v5 规范。
|
||||
|
||||
#### 内联订阅(Inline Subscribe)
|
||||
要想在服务端中直接订阅一个主题,可以使用 `server.Subscribe`方法并提供一个处理订阅消息的回调函数。内联订阅的 QoS默认都是0。如果您希望对相同的主题有多个回调,可以使用 MQTTv5 的 subscriptionId 属性进行区分。
|
||||
|
||||
```go
|
||||
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||
}
|
||||
server.Subscribe("direct/#", 1, callbackFn)
|
||||
```
|
||||
|
||||
#### 取消内联订阅(Inline Unsubscribe)
|
||||
如果您使用内联客户端订阅了某个主题,如果需要取消订阅。您可以使用 `server.Unsubscribe` 方法取消内联订阅:
|
||||
|
||||
```go
|
||||
server.Unsubscribe("direct/#", 1)
|
||||
```
|
||||
|
||||
### 注入数据包(Packet Injection)
|
||||
|
||||
如果你想要更多的服务端控制,或者想要设置特定的MQTT v5属性或其他属性,你可以选择指定的客户端创建自己的发布包(publish packets)。这种方法允许你将MQTT数据包(packets)直接注入到运行中的服务端,相当于服务端直接自己模拟接收到了某个客户端的数据包。
|
||||
|
||||
数据包注入(Packet Injection)可用于任何MQTT数据包,包括ping请求、订阅等。你可以获取客户端的详细信息,因此你甚至可以直接在服务端模拟某个在线的客户端,发布一个数据包。
|
||||
|
||||
大多数情况下,您可能希望使用上面描述的内联客户端(Inline Client),因为它具有独特的特权:它可以绕过所有ACL和主题验证检查,这意味着它甚至可以发布到$SYS主题。你也可以自己从头开始制定一个自己的内联客户端,它将与内置的内联客户端行为相同。
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("scheduled message"),
|
||||
})
|
||||
```
|
||||
|
||||
> MQTT数据包仍然需要满足规范的结构,所以请参考[测试用例中数据包的定义](packets/tpackets.go) 和 [MQTTv5规范](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) 以获取一些帮助。
|
||||
|
||||
具体如何使用请参考 [hooks example](examples/hooks/main.go) 。
|
||||
|
||||
|
||||
### 测试(Testing)
|
||||
#### 单元测试(Unit Tests)
|
||||
|
||||
Mochi MQTT 使用精心编写的单元测试,测试了一千多种场景,以确保每个函数都表现出我们期望的行为。您可以使用以下命令运行测试:
|
||||
```
|
||||
go run --cover ./...
|
||||
```
|
||||
|
||||
#### Paho 互操作性测试(Paho Interoperability Test)
|
||||
|
||||
您可以使用 `examples/paho/main.go` 启动服务器,然后在 _interoperability_ 文件夹中运行 `python3 client_test5.py` 来检查代理是否符合 [Paho互操作性测试](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) 的要求,包括 MQTT v5 和 v3 的测试。
|
||||
|
||||
> 请注意,关于 paho 测试套件存在一些尚未解决的问题,因此在 `paho/main.go` 示例中启用了某些兼容性模式。
|
||||
|
||||
|
||||
## 基准测试(Performance Benchmarks)
|
||||
|
||||
Mochi MQTT 的性能与其他的一些主流的mqtt中间件(如 Mosquitto、EMQX 等)不相上下。
|
||||
|
||||
基准测试是使用 [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) 在 Apple Macbook Air M2 上进行的,使用 `cmd/main.go` 默认设置。考虑到高低吞吐量的突发情况,中位数分数是最有用的。数值越高越好。
|
||||
|
||||
|
||||
> 基准测试中呈现的数值不代表真实每秒消息吞吐量。它们依赖于 mqtt-stresser 的一种不寻常的计算方法,但它们在所有代理之间是一致的。性能基准测试的结果仅供参考。这些比较都是使用默认配置进行的。
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 124,772 | 125,456 | 124,614 | 314,461 | 313,186 | 311,910 |
|
||||
| [Mosquitto v2.0.15](https://github.com/eclipse/mosquitto) | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
|
||||
| [EMQX v5.0.11](https://github.com/emqx/emqx) | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
|
||||
| [Rumqtt v0.21.0](https://github.com/bytebeamio/rumqtt) | 112,208 | 108,480 | 104,753 | 135,784 | 126,446 | 117,108 |
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 41,825 | 31,663| 23,008 | 144,058 | 65,903 | 37,618 |
|
||||
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
|
||||
| Rumqtt v0.21.0 | 42,213 | 23,153 | 20,814 | 49,465 | 36,626 | 19,283 |
|
||||
|
||||
百万消息挑战(立即向服务器发送100万条消息):
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 13,532 | 4,425 | 2,344 | 52,120 | 7,274 | 2,701 |
|
||||
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
|
||||
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
|
||||
| Rumqtt v0.21.0 | 78,972 | 5,047 | 3,804 | 4,286 | 3,249 | 2,027 |
|
||||
|
||||
> 这里还不确定EMQX是不是哪里出了问题,可能是因为 Docker 的默认配置优化不对,所以要持保留意见,因为我们确实知道它是一款可靠的软件。
|
||||
|
||||
## 贡献指南(Contribution Guidelines)
|
||||
|
||||
我们欢迎代码贡献和反馈!如果你发现了漏洞(bug)或者有任何疑问,又或者是有新的需求,请[提交给我们](https://github.com/mochi-mqtt/server/issues)。如果您提交了一个PR(pull request)请求,请尽量遵循以下准则:
|
||||
|
||||
- 在合理的情况下,尽量保持测试覆盖率。
|
||||
- 清晰地说明PR(pull request)请求的作用和原因。
|
||||
- 请不要忘记在你贡献的文件中添加 SPDX FileContributor 标签。
|
||||
|
||||
[SPDX 注释] (https://spdx.dev) 用于智能的识别每个文件的许可证、版权和贡献。如果您正在向本仓库添加一个新文件,请确保它具有以下 SPDX 头部:
|
||||
```go
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt
|
||||
// SPDX-FileContributor: Your name or alias <optional@email.address>
|
||||
|
||||
package name
|
||||
```
|
||||
|
||||
请确保为文件的每位贡献者添加一个新的SPDX-FileContributor 行。可以参考其他文件的示例。请务必记得这样做,你对这个项目的贡献是有价值且受到赞赏的 - 获得认可非常重要!
|
||||
|
||||
## 给我们星星的人数(Stargazers over time) 🥰
|
||||
[](https://starchart.cc/mochi-mqtt/server)
|
||||
|
||||
您是否在项目中使用 Mochi MQTT?[请告诉我们!](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
314
README.md
314
README.md
@@ -1,28 +1,28 @@
|
||||
# Mochi-MQTT Server
|
||||
|
||||
<p align="center">
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-co/mqtt?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
|
||||
[](https://github.com/mochi-co/mqtt/issues)
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-mqtt/server?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-mqtt/server/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-mqtt/server/v2)
|
||||
[](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
</p>
|
||||
|
||||
# Mochi MQTT Broker
|
||||
## The fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker server
|
||||
[English](README.md) | [简体中文](README-CN.md) | [Translators Wanted!](https://github.com/orgs/mochi-mqtt/discussions/310)
|
||||
|
||||
🎆 **mochi-co/mqtt is now part of the new mochi-mqtt organisation.** [Read about this announcement here.](https://github.com/orgs/mochi-mqtt/discussions/271)
|
||||
|
||||
|
||||
### Mochi-MQTT is a fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker/server
|
||||
|
||||
Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) MQTT v5 broker server written in Go, designed for the development of telemetry and internet-of-things projects. The server can be used either as a standalone binary or embedded as a library in your own applications, and has been designed to be as lightweight and fast as possible, with great care taken to ensure the quality and maintainability of the project.
|
||||
|
||||
### What is MQTT?
|
||||
#### What is MQTT?
|
||||
MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol.
|
||||
|
||||
## What's new in Version 2.0.0?
|
||||
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
|
||||
|
||||
Don't forget to use the new v2 import paths:
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/v2"
|
||||
```
|
||||
#### Mochi-MQTT Features
|
||||
|
||||
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
|
||||
- User and MQTTv5 Packet Properties
|
||||
@@ -37,26 +37,27 @@ import "github.com/mochi-co/mqtt/v2"
|
||||
- Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.
|
||||
- Developer-centric:
|
||||
- Most core broker code is now exported and accessible, for total developer control.
|
||||
- Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
|
||||
- Full-featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
|
||||
- Direct Packet Injection using special inline client, or masquerade as existing clients.
|
||||
- Performant and Stable:
|
||||
- Our classic trie-based Topic-Subscription model.
|
||||
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
|
||||
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
|
||||
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
|
||||
- Over a thousand carefully considered unit test scenarios.
|
||||
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
|
||||
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
|
||||
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
|
||||
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
|
||||
|
||||
> There is no upgrade path from v1.0.0. Please review the documentation and this readme to get a sense of the changes required (e.g. the v1 events system, auth, and persistence have all been replaced with the new hooks system).
|
||||
|
||||
### Compatibility Notes
|
||||
Because of the overlap between the v5 specification and previous versions of mqtt, the server can accept both v5 and v3 clients, but note that in cases where both v5 an v3 clients are connected, properties and features provided for v5 clients will be downgraded for v3 clients (such as user properties).
|
||||
|
||||
Support for MQTT v3.0.0 and v3.1.1 is considered hybrid-compatibility. Where not specifically restricted in the v3 specification, more modern and safety-first v5 behaviours are used instead - such as expiry for inflight and retained messages, and clients - and quality-of-service flow control limits.
|
||||
|
||||
#### When is this repo updated?
|
||||
Unless it's a critical issue, new releases typically go out over the weekend.
|
||||
|
||||
## Roadmap
|
||||
- Please [open an issue](https://github.com/mochi-co/mqtt/issues) to request new features or event hooks!
|
||||
- Please [open an issue](https://github.com/mochi-mqtt/server/issues) to request new features or event hooks!
|
||||
- Cluster support.
|
||||
- Enhanced Metrics support.
|
||||
- File-based server configuration (supporting docker).
|
||||
@@ -71,6 +72,16 @@ go build -o mqtt && ./mqtt
|
||||
```
|
||||
|
||||
### Using Docker
|
||||
You can now pull and run the [official Mochi MQTT image](https://hub.docker.com/r/mochimqtt/server) from our Docker repo:
|
||||
|
||||
```sh
|
||||
docker pull mochimqtt/server
|
||||
or
|
||||
docker run mochimqtt/server
|
||||
```
|
||||
|
||||
This is a work in progress, and a [file-based configuration](https://github.com/orgs/mochi-mqtt/projects/2) is being developed to better support this implementation. _More substantial docker support is being discussed [here](https://github.com/orgs/mochi-mqtt/discussions/281#discussion-5544545) and [here](https://github.com/orgs/mochi-mqtt/discussions/209). Please join the discussion if you use Mochi-MQTT in this environment._
|
||||
|
||||
A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server:
|
||||
|
||||
```sh
|
||||
@@ -83,27 +94,48 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
Importing Mochi MQTT as a package requires just a few lines of code to get started.
|
||||
``` go
|
||||
import (
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"log"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create signals channel to run server until interrupted
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Create the new MQTT Server.
|
||||
server := mqtt.New(nil)
|
||||
|
||||
// Allow all connections.
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Create a TCP listener on a standard port.
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.Serve()
|
||||
// Allow all connections.
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Create a TCP listener on a standard port.
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Run server until interrupted
|
||||
<-done
|
||||
|
||||
// Cleanup
|
||||
}
|
||||
```
|
||||
|
||||
@@ -112,16 +144,23 @@ Examples of running the broker with various configurations can be found in the [
|
||||
#### Network Listeners
|
||||
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
|
||||
|
||||
- `listeners.NewTCP(...)` - A TCP listener.
|
||||
- `listeners.NewWebsocket(...)` A Websocket listener.
|
||||
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
|
||||
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
||||
| Listener | Usage |
|
||||
|------------------------------|----------------------------------------------------------------------------------------------|
|
||||
| listeners.NewTCP | A TCP listener |
|
||||
| listeners.NewUnixSock | A Unix Socket listener |
|
||||
| listeners.NewNet | A net.Listener listener |
|
||||
| listeners.NewWebsocket | A Websocket listener |
|
||||
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
|
||||
| listeners.NewHTTPHealthCheck | An HTTP healthcheck listener to provide health check responses for e.g. cloud infrastructure |
|
||||
|
||||
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
||||
|
||||
A `*listeners.Config` may be passed to configure TLS.
|
||||
|
||||
Examples of usage can be found in the [examples](examples) folder or [cmd/main.go](cmd/main.go).
|
||||
|
||||
### Server Options and Capabilities
|
||||
|
||||
## Server Options and Capabilities
|
||||
A number of configurable options are available which can be used to alter the behaviour or restrict access to certain features in the server.
|
||||
|
||||
```go
|
||||
@@ -132,28 +171,36 @@ server := mqtt.New(&mqtt.Options{
|
||||
ObscureNotAuthorized: true,
|
||||
},
|
||||
},
|
||||
ClientNetWriteBufferSize: 4096,
|
||||
ClientNetReadBufferSize: 4096,
|
||||
SysTopicResendInterval: 10,
|
||||
InlineClient: false,
|
||||
})
|
||||
```
|
||||
|
||||
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options.
|
||||
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
|
||||
|
||||
### Default Configuration Notes
|
||||
|
||||
Some choices were made when deciding the default configuration that need to be mentioned here:
|
||||
|
||||
- By default, the value of `server.Options.Capabilities.MaximumMessageExpiryInterval` is set to 86400 (24 hours), in order to prevent exposing the broker to DOS attacks on hostile networks when using the out-of-the-box configuration (as an infinite expiry would allow an infinite number of retained/inflight messages to accumulate). If you are operating in a trusted environment, or you have capacity for a larger retention period, uou may wish to override this (set to `0` or `math.MaxInt` for no expiry).
|
||||
|
||||
## Event Hooks
|
||||
A universal event hooks system allows developers to hook into various parts of the server and client life cycle to add and modify functionality of the broker. These universal hooks are used to provide everything from authentication, persistent storage, to debugging tools.
|
||||
|
||||
Hooks are stackable - you can add multiple hooks to a server, and they will be run in the order they were added. Some hooks modify values, and these modified values will be passed to the subsequent hooks before being returned to the runtime code.
|
||||
|
||||
| Type | Import | Info |
|
||||
| -- | -- | -- |
|
||||
| Access Control | [mochi-co/mqtt/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
|
||||
| Access Control | [mochi-co/mqtt/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
|
||||
| Persistence | [mochi-co/mqtt/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
|
||||
| Persistence | [mochi-co/mqtt/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
|
||||
| Persistence | [mochi-co/mqtt/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
|
||||
| Debugging | [mochi-co/mqtt/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
|
||||
| Type | Import | Info |
|
||||
|----------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
|
||||
| Debugging | [mochi-mqtt/server/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
|
||||
|
||||
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-co/mqtt/issues) and let everyone know!
|
||||
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-mqtt/server/issues) and let everyone know!
|
||||
|
||||
### Access Control
|
||||
#### Allow Hook
|
||||
@@ -221,7 +268,7 @@ err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
|
||||
The ledger can also be stored as JSON or YAML and loaded using the Data field:
|
||||
```go
|
||||
err = server.AddHook(new(auth.Hook), &auth.Options{
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // build ledger from byte slice: yaml or json
|
||||
})
|
||||
```
|
||||
@@ -258,60 +305,99 @@ For more information on how the badger hook works, or how to use it, see the [ex
|
||||
|
||||
There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go).
|
||||
|
||||
|
||||
|
||||
## Developing with Event Hooks
|
||||
Many hooks are available for interacting with the broker and client lifecycle.
|
||||
The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go).
|
||||
|
||||
> The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets.
|
||||
|
||||
| Function | Usage |
|
||||
| -------------------------- | -- |
|
||||
| OnStarted | Called when the server has successfully started.|
|
||||
| OnStopped | Called when the server has successfully stopped. |
|
||||
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
|
||||
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
|
||||
| OnSysInfoTick | Called when the $SYS topic values are published out. |
|
||||
| OnConnect | Called when a new client connects |
|
||||
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
|
||||
| OnDisconnect | Called when a client is disconnected for any reason. |
|
||||
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
|
||||
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
|
||||
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
|
||||
| OnPacketSent | Called when a packet has been sent to a client. |
|
||||
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
|
||||
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
|
||||
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
|
||||
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|
|
||||
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
|
||||
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
|
||||
| OnPublish | Called when a client publishes a message. Allows packet modification. |
|
||||
| OnPublished | Called when a client has published a message to subscribers. |
|
||||
| OnRetainMessage | Called then a published message is retained. |
|
||||
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
|
||||
| OnQosComplete | Called when the Qos flow for a message has been completed. |
|
||||
| OnQosDropped | Called when an inflight message expires before completion. |
|
||||
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
|
||||
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
|
||||
| OnClientExpired | Called when a client session has expired and should be deleted. |
|
||||
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
|
||||
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
|
||||
| StoredClients | Returns clients, eg. from a persistent store. |
|
||||
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
|
||||
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
|
||||
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
|
||||
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
|
||||
| Function | Usage |
|
||||
|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| OnStarted | Called when the server has successfully started. |
|
||||
| OnStopped | Called when the server has successfully stopped. |
|
||||
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
|
||||
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
|
||||
| OnSysInfoTick | Called when the $SYS topic values are published out. |
|
||||
| OnConnect | Called when a new client connects, may return an error or packet code to halt the client connection process. |
|
||||
| OnSessionEstablish | Called immediately after a new client connects and authenticates and immediately before the session is established and CONNACK is sent. |
|
||||
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
|
||||
| OnDisconnect | Called when a client is disconnected for any reason. |
|
||||
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
|
||||
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
|
||||
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
|
||||
| OnPacketSent | Called when a packet has been sent to a client. |
|
||||
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
|
||||
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
|
||||
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
|
||||
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification. |
|
||||
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
|
||||
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
|
||||
| OnPublish | Called when a client publishes a message. Allows packet modification. |
|
||||
| OnPublished | Called when a client has published a message to subscribers. |
|
||||
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
|
||||
| OnRetainMessage | Called then a published message is retained. |
|
||||
| OnRetainPublished | Called then a retained message is published to a client. |
|
||||
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
|
||||
| OnQosComplete | Called when the Qos flow for a message has been completed. |
|
||||
| OnQosDropped | Called when an inflight message expires before completion. |
|
||||
| OnPacketIDExhausted | Called when a client runs out of unused packet ids to assign. |
|
||||
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
|
||||
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
|
||||
| OnClientExpired | Called when a client session has expired and should be deleted. |
|
||||
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
|
||||
| StoredClients | Returns clients, eg. from a persistent store. |
|
||||
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
|
||||
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
|
||||
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
|
||||
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
|
||||
|
||||
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
|
||||
|
||||
### Packet Injection
|
||||
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
|
||||
### Inline Client (v2.4.0+)
|
||||
It's now possible to subscribe and publish to topics directly from the embedding code, by using the `inline client` feature. The Inline Client is an embedded client which operates as part of the server, and can be enabled in the server options:
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true,
|
||||
})
|
||||
```
|
||||
Once enabled, you will be able to use the `server.Publish`, `server.Subscribe`, and `server.Unsubscribe` methods to issue and received messages from broker-adjacent code.
|
||||
|
||||
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement).
|
||||
> See [direct examples](examples/direct/main.go) for real-life usage examples.
|
||||
|
||||
#### Inline Publish
|
||||
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
|
||||
|
||||
```go
|
||||
cl := mqtt.NewInlineClient("inline", "local")
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
```
|
||||
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
|
||||
|
||||
#### Inline Subscribe
|
||||
To subscribe to a topic filter from within the embedding application, you can use the `server.Subscribe(filter string, subscriptionId int, handler InlineSubFn) error` method with a callback function. Note that only QoS 0 is supported for inline subscriptions. If you wish to have multiple callbacks for the same filter, you can use the MQTTv5 `subscriptionId` property to differentiate.
|
||||
|
||||
```go
|
||||
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||
}
|
||||
server.Subscribe("direct/#", 1, callbackFn)
|
||||
```
|
||||
|
||||
#### Inline Unsubscribe
|
||||
You may wish to unsubscribe if you have subscribed to a filter using the inline client. You can do this easily with the `server.Unsubscribe(filter string, subscriptionId int) error` method:
|
||||
|
||||
```go
|
||||
server.Unsubscribe("direct/#", 1)
|
||||
```
|
||||
|
||||
### Packet Injection
|
||||
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client.
|
||||
|
||||
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
|
||||
|
||||
Most of the time you'll want to use the Inline Client described above, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics. In this case, you can create an inline client from scratch which will behave the same as the built-in inline client.
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
@@ -325,6 +411,7 @@ server.InjectPacket(cl, packets.Packet{
|
||||
|
||||
See the [hooks example](examples/hooks/main.go) to see this feature in action.
|
||||
|
||||
|
||||
### Testing
|
||||
#### Unit Tests
|
||||
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
|
||||
@@ -344,39 +431,54 @@ Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQ
|
||||
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better.
|
||||
|
||||
> The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers.
|
||||
> Benchmarks are provided as a general performance expectation guideline only.
|
||||
> Benchmarks are provided as a general performance expectation guideline only. Comparisons are performed using out-of-the-box default configurations.
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
|
||||
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
|
||||
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
|
||||
| Mochi v2.2.10 | 124,772 | 125,456 | 124,614 | 314,461 | 313,186 | 311,910 |
|
||||
| [Mosquitto v2.0.15](https://github.com/eclipse/mosquitto) | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
|
||||
| [EMQX v5.0.11](https://github.com/emqx/emqx) | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
|
||||
| [Rumqtt v0.21.0](https://github.com/bytebeamio/rumqtt) | 112,208 | 108,480 | 104,753 | 135,784 | 126,446 | 117,108 |
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
|
||||
| Mochi v2.2.10 | 41,825 | 31,663| 23,008 | 144,058 | 65,903 | 37,618 |
|
||||
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
|
||||
| Rumqtt v0.21.0 | 42,213 | 23,153 | 20,814 | 49,465 | 36,626 | 19,283 |
|
||||
|
||||
Million Message Challenge (hit the server with 1 million messages immediately):
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
|
||||
| Mochi v2.2.10 | 13,532 | 4,425 | 2,344 | 52,120 | 7,274 | 2,701 |
|
||||
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
|
||||
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
|
||||
| Rumqtt v0.21.0 | 78,972 | 5,047 | 3,804 | 4,286 | 3,249 | 2,027 |
|
||||
|
||||
> Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software.
|
||||
|
||||
## Contribution Guidelines
|
||||
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-mqtt/server/issues) to report a bug, ask a question, or make a feature request. If you open a pull request, please try to follow the following guidelines:
|
||||
- Try to maintain test coverage where reasonably possible.
|
||||
- Clearly state what the PR does and why.
|
||||
- Please remember to add your SPDX FileContributor tag to files where you have made a meaningful contribution.
|
||||
|
||||
[SPDX Annotations](https://spdx.dev) are used to clearly indicate the license, copyright, and contributions of each file in a machine-readable format. If you are adding a new file to the repository, please ensure it has the following SPDX header:
|
||||
```go
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt
|
||||
// SPDX-FileContributor: Your name or alias <optional@email.address>
|
||||
|
||||
package name
|
||||
```
|
||||
|
||||
Please ensure to add a new `SPDX-FileContributor` line for each contributor to the file. Refer to other files for examples. Please remember to do this, your contributions to this project are valuable and appreciated - it's important to receive credit!
|
||||
|
||||
## Stargazers over time 🥰
|
||||
[](https://starchart.cc/mochi-co/mqtt)
|
||||
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues)
|
||||
|
||||
## Contributions
|
||||
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.
|
||||
|
||||
|
||||
[](https://starchart.cc/mochi-mqtt/server)
|
||||
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
|
||||
231
clients.go
231
clients.go
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
@@ -7,6 +7,8 @@ package mqtt
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -16,12 +18,17 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds
|
||||
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds.
|
||||
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
|
||||
minimumKeepalive uint16 = 5 // the minimum recommended keepalive - values under with display a warning.
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMinimumKeepalive = errors.New("client keepalive is below minimum recommended value and may exhibit connection instability")
|
||||
)
|
||||
|
||||
// ReadFn is the function signature for the function used for reading and processing new packets.
|
||||
@@ -87,7 +94,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
defer cl.RUnlock()
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
for _, client := range cl.internal {
|
||||
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
|
||||
if client.Net.Listener == id && !client.Closed() {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
@@ -98,7 +105,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
type Client struct {
|
||||
Properties ClientProperties // client properties
|
||||
State ClientState // the operational state of the client.
|
||||
Net ClientConnection // network connection state of the clinet
|
||||
Net ClientConnection // network connection state of the client
|
||||
ID string // the client id.
|
||||
ops *ops // ops provides a reference to server ops.
|
||||
sync.RWMutex // mutex
|
||||
@@ -106,11 +113,11 @@ type Client struct {
|
||||
|
||||
// ClientConnection contains the connection transport and metadata for the client.
|
||||
type ClientConnection struct {
|
||||
conn net.Conn // the net.Conn used to establish the connection
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
Inline bool // client is an inline programmetic client
|
||||
Inline bool // if true, the client is the built-in 'inline' embedded client
|
||||
}
|
||||
|
||||
// ClientProperties contains the properties which define the client behaviour.
|
||||
@@ -133,32 +140,37 @@ type Will struct {
|
||||
Retain bool // -
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
// ClientState tracks the state of the client.
|
||||
type ClientState struct {
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
Inflight *Inflight // a map of in-flight qos messages
|
||||
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
endOnce sync.Once // only end once
|
||||
packetID uint32 // the current highest packetID
|
||||
done uint32 // atomic counter which indicates that the client has closed
|
||||
keepalive uint16 // the number of seconds the connection can wait
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
Inflight *Inflight // a map of in-flight qos messages
|
||||
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
outbound chan *packets.Packet // queue for pending outbound packets
|
||||
endOnce sync.Once // only end once
|
||||
isTakenOver uint32 // used to identify orphaned clients
|
||||
packetID uint32 // the current highest packetID
|
||||
open context.Context // indicate that the client is open for packet exchange
|
||||
cancelOpen context.CancelFunc // cancel function for open context
|
||||
outboundQty int32 // number of messages currently in the outbound queue
|
||||
Keepalive uint16 // the number of seconds the connection can wait
|
||||
ServerKeepalive bool // keepalive was set by the server
|
||||
}
|
||||
|
||||
// NewClient returns a new instance of Client.
|
||||
func NewClient(c net.Conn, o *ops) *Client {
|
||||
// newClient returns a new instance of Client. This is almost exclusively used by Server
|
||||
// for creating new clients, but it lives here because it's not dependent.
|
||||
func newClient(c net.Conn, o *ops) *Client {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cl := &Client{
|
||||
Net: ClientConnection{
|
||||
conn: c,
|
||||
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
},
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
|
||||
keepalive: defaultKeepalive,
|
||||
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
||||
open: ctx,
|
||||
cancelOpen: cancel,
|
||||
Keepalive: defaultKeepalive,
|
||||
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
@@ -166,43 +178,33 @@ func NewClient(c net.Conn, o *ops) *Client {
|
||||
ops: o,
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
Conn: c,
|
||||
bconn: bufio.NewReadWriter(
|
||||
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
|
||||
bufio.NewWriterSize(c, o.options.ClientNetWriteBufferSize),
|
||||
),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// NewInlineClient returns a client used when publishing from the embedding system.
|
||||
func NewInlineClient(id, remote string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
Net: ClientConnection{
|
||||
Remote: remote,
|
||||
Inline: true,
|
||||
},
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(0),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newClientStub returns an instance of Client with minimal initializations, such as
|
||||
// restoring client data from a db. In particular, the client is marked as offline (done).
|
||||
func newClientStub() *Client {
|
||||
return &Client{
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(0),
|
||||
done: 1,
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
|
||||
func (cl *Client) WriteLoop() {
|
||||
for {
|
||||
select {
|
||||
case pk := <-cl.State.outbound:
|
||||
if err := cl.WritePacket(*pk); err != nil {
|
||||
// TODO : Figure out what to do with error
|
||||
cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
|
||||
}
|
||||
atomic.AddInt32(&cl.State.outboundQty, -1)
|
||||
case <-cl.State.open.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,9 +217,18 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Properties.Clean = pk.Connect.Clean
|
||||
cl.Properties.Props = pk.Properties.Copy(false)
|
||||
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
if pk.Connect.Keepalive <= minimumKeepalive {
|
||||
cl.ops.log.Warn(
|
||||
ErrMinimumKeepalive.Error(),
|
||||
"client", cl.ID,
|
||||
"keepalive", pk.Connect.Keepalive,
|
||||
"recommended", minimumKeepalive,
|
||||
)
|
||||
}
|
||||
|
||||
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
|
||||
|
||||
cl.ID = pk.Connect.ClientIdentifier
|
||||
@@ -226,11 +237,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Properties.Props.AssignedClientID = cl.ID
|
||||
}
|
||||
|
||||
cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive
|
||||
if pk.Connect.Keepalive > 0 {
|
||||
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
|
||||
}
|
||||
|
||||
if pk.Connect.WillFlag {
|
||||
cl.Properties.Will = Will{
|
||||
Qos: pk.Connect.WillQos,
|
||||
@@ -248,19 +254,17 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Properties.Will.Flag = 1 // atomic for checking
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
}
|
||||
|
||||
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
if cl.Net.conn != nil {
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
|
||||
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,28 +272,30 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
// If no unused packet ids are available, an error is returned and the client
|
||||
// should be disconnected.
|
||||
func (cl *Client) NextPacketID() (i uint32, err error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
i = atomic.LoadUint32(&cl.State.packetID)
|
||||
started := i + 1
|
||||
started := i
|
||||
overflowed := false
|
||||
for {
|
||||
if i >= 65535 {
|
||||
overflowed = true
|
||||
i = 1
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
|
||||
if overflowed && i == started {
|
||||
return 0, packets.ErrQuotaExceeded
|
||||
}
|
||||
|
||||
if i >= cl.ops.options.Capabilities.maximumPacketID {
|
||||
overflowed = true
|
||||
i = 0
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
|
||||
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
|
||||
break
|
||||
atomic.StoreUint32(&cl.State.packetID, i)
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&cl.State.packetID, i)
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
|
||||
@@ -303,7 +309,7 @@ func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
|
||||
}
|
||||
|
||||
// cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends)
|
||||
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
|
||||
err := cl.WritePacket(tk)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -320,18 +326,19 @@ func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
|
||||
var deleted int64
|
||||
// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||
deleted := []uint16{}
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
deleted++
|
||||
deleted = append(deleted, tk.PacketID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
@@ -341,11 +348,11 @@ func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
var err error
|
||||
|
||||
for {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
if cl.Closed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
cl.refreshDeadline(cl.State.Keepalive)
|
||||
fh := new(packets.FixedHeader)
|
||||
err = cl.ReadFixedHeader(fh)
|
||||
if err != nil {
|
||||
@@ -366,20 +373,20 @@ func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop(err error) {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
cl.State.endOnce.Do(func() {
|
||||
if cl.Net.conn != nil {
|
||||
_ = cl.Net.conn.Close() // omit close error
|
||||
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.Close() // omit close error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cl.State.stopCause.Store(err)
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&cl.State.done, 1)
|
||||
if cl.State.cancelOpen != nil {
|
||||
cl.State.cancelOpen()
|
||||
}
|
||||
|
||||
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
||||
})
|
||||
}
|
||||
@@ -392,6 +399,11 @@ func (cl *Client) StopCause() error {
|
||||
return cl.State.stopCause.Load().(error)
|
||||
}
|
||||
|
||||
// Closed returns true if client connection is closed.
|
||||
func (cl *Client) Closed() bool {
|
||||
return cl.State.open == nil || cl.State.open.Err() != nil
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
if cl.Net.bconn == nil {
|
||||
@@ -414,6 +426,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||
return nil
|
||||
}
|
||||
@@ -481,15 +497,14 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
if cl.Closed() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if cl.Net.conn == nil {
|
||||
if cl.Net.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer cl.refreshDeadline(cl.State.keepalive)
|
||||
if pk.Expiry > 0 {
|
||||
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
|
||||
}
|
||||
@@ -503,8 +518,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
||||
}
|
||||
|
||||
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
|
||||
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
|
||||
}
|
||||
|
||||
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
||||
@@ -554,7 +569,11 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
}
|
||||
|
||||
nb := net.Buffers{buf.Bytes()}
|
||||
n, err := nb.WriteTo(cl.Net.conn)
|
||||
n, err := func() (int64, error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
return nb.WriteTo(cl.Net.Conn)
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
276
clients_test.go
276
clients_test.go
@@ -1,19 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -22,16 +27,20 @@ const pkInfo = "packet type %v, %s"
|
||||
|
||||
var errClientStop = errors.New("test stop")
|
||||
|
||||
func newClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
|
||||
cl = NewClient(w, &ops{
|
||||
cl = newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: &logger,
|
||||
capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
log: logger,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -42,6 +51,9 @@ func newClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
cl.State.Inflight.receiveQuota = 10
|
||||
cl.Properties.Props.TopicAliasMaximum = 0
|
||||
cl.Properties.Props.RequestResponseInfo = 0x1
|
||||
|
||||
go cl.WriteLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -107,8 +119,8 @@ func TestClientsDelete(t *testing.T) {
|
||||
|
||||
func TestClientsGetByListener(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
|
||||
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
@@ -119,34 +131,23 @@ func TestClientsGetByListener(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Nil(t, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestNewClientStub(t *testing.T) {
|
||||
cl := newClientStub()
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
|
||||
}
|
||||
|
||||
func TestNewInlineClient(t *testing.T) {
|
||||
cl := NewInlineClient("inline", "local")
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
|
||||
require.Equal(t, "inline", cl.ID)
|
||||
require.Equal(t, "local", cl.Net.Remote)
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.Conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.NotNil(t, cl.ops)
|
||||
require.NotNil(t, cl.ops.options.Capabilities)
|
||||
require.False(t, cl.Net.Inline)
|
||||
}
|
||||
|
||||
func TestClientParseConnect(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
@@ -168,7 +169,7 @@ func TestClientParseConnect(t *testing.T) {
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
|
||||
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
|
||||
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
|
||||
@@ -176,14 +177,14 @@ func TestClientParseConnect(t *testing.T) {
|
||||
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
|
||||
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
|
||||
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
|
||||
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
@@ -208,13 +209,34 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientParseConnectNoID(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ParseConnect("tcp1", packets.Packet{})
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientParseConnectBelowMinimumKeepalive(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
var b bytes.Buffer
|
||||
x := bufio.NewWriter(&b)
|
||||
cl.ops.log = slog.New(slog.NewTextHandler(x, nil))
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Keepalive: minimumKeepalive - 1,
|
||||
ClientIdentifier: "mochi",
|
||||
},
|
||||
}
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
err := x.Flush()
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.Contains(b.String(), ErrMinimumKeepalive.Error()))
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientNextPacketID(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
@@ -226,7 +248,7 @@ func TestClientNextPacketID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
// skip over 2
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
@@ -249,33 +271,37 @@ func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDExhausted(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
for i := 0; i <= 65535; i++ {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
|
||||
}
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
require.Equal(t, uint32(0), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
|
||||
}
|
||||
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
|
||||
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
|
||||
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
|
||||
_, err = cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
|
||||
cl.State.packetID = uint32(65534)
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(65535), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
}
|
||||
|
||||
func TestClientClearInflights(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
|
||||
@@ -285,13 +311,15 @@ func TestClientClearInflights(t *testing.T) {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
|
||||
cl.ClearInflights(n, 4)
|
||||
deleted := cl.ClearInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessages(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
@@ -300,7 +328,7 @@ func TestClientResendInflightMessages(t *testing.T) {
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
w.Close()
|
||||
_ = w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
@@ -311,8 +339,8 @@ func TestClientResendInflightMessages(t *testing.T) {
|
||||
|
||||
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
|
||||
cl, r, _ := newClient()
|
||||
r.Close()
|
||||
cl, r, _ := newTestClient()
|
||||
_ = r.Close()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
@@ -323,24 +351,24 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientRefreshDeadline(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.refreshDeadline(10)
|
||||
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
|
||||
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0x00})
|
||||
r.Close()
|
||||
_, _ = r.Write([]byte{packets.Connect << 4, 0x00})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -350,12 +378,12 @@ func TestClientReadFixedHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
|
||||
r.Close()
|
||||
_, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -363,12 +391,28 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
cl.ops.options.Capabilities.MaximumPacketSize = 2
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Close()
|
||||
_, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -378,12 +422,12 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
|
||||
r.Close()
|
||||
_, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -392,10 +436,10 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadOK(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 18, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
@@ -405,7 +449,7 @@ func TestClientReadOK(t *testing.T) {
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
var pks []packets.Packet
|
||||
@@ -446,9 +490,9 @@ func TestClientReadOK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
cl.State.done = 1
|
||||
cl.State.cancelOpen()
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
@@ -461,21 +505,29 @@ func TestClientReadDone(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientStop(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(nil)
|
||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
||||
require.Equal(t, uint32(1), cl.State.done)
|
||||
require.True(t, cl.Closed())
|
||||
require.Equal(t, nil, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestClientClosed(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
require.False(t, cl.Closed())
|
||||
cl.Stop(nil)
|
||||
require.True(t, cl.Closed())
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
cl.Net.bconn = nil
|
||||
@@ -486,16 +538,16 @@ func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
err := cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
@@ -506,16 +558,16 @@ func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadReadPacketOK(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -538,7 +590,7 @@ func TestClientReadReadPacketOK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacket(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
for _, tx := range pkTable {
|
||||
@@ -546,7 +598,7 @@ func TestClientReadPacket(t *testing.T) {
|
||||
t.Run(tt.Desc, func(t *testing.T) {
|
||||
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
|
||||
go func() {
|
||||
r.Write(tt.RawBytes)
|
||||
_, _ = r.Write(tt.RawBytes)
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -571,9 +623,17 @@ func TestClientReadPacket(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
_ = cl.Net.Conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
}
|
||||
|
||||
func TestClientWritePacket(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
|
||||
@@ -589,7 +649,7 @@ func TestClientWritePacket(t *testing.T) {
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cl.Net.conn.Close()
|
||||
_ = cl.Net.Conn.Close()
|
||||
|
||||
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
@@ -613,7 +673,7 @@ func TestClientWritePacket(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWriteClientOversizePacket(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.MaximumPacketSize = 2
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
|
||||
err := cl.WritePacket(pk)
|
||||
@@ -622,16 +682,16 @@ func TestWriteClientOversizePacket(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
@@ -642,16 +702,16 @@ func TestClientReadPacketReadingError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
@@ -661,7 +721,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(errClientStop)
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
@@ -670,15 +730,15 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl.Net.conn.Close()
|
||||
cl, _, _ := newTestClient()
|
||||
_ = cl.Net.Conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketInvalidPacket(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.WritePacket(packets.Packet{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
15
cmd/main.go
15
cmd/main.go
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -59,7 +59,8 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -77,7 +77,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -59,7 +59,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
52
examples/benchmark/main.go
Normal file
52
examples/benchmark/main.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
|
||||
flag.Parse()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
@@ -1,20 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/debug"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/debug"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -27,17 +27,21 @@ func main() {
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
err := server.AddHook(new(auth.AllowHook), nil)
|
||||
level := new(slog.LevelVar)
|
||||
server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}))
|
||||
level.Set(slog.LevelDebug)
|
||||
|
||||
err := server.AddHook(new(debug.Hook), &debug.Options{
|
||||
// ShowPacketData: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddHook(new(debug.Hook), &debug.Options{
|
||||
ShowPacketData: true,
|
||||
})
|
||||
err = server.AddHook(new(auth.AllowHook), nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -56,7 +60,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
83
examples/direct/main.go
Normal file
83
examples/direct/main.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true, // you must enable inline client to use direct publishing and subscribing.
|
||||
})
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Start the server
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Demonstration of using an inline client to directly subscribe to a topic and receive a message when
|
||||
// that subscription is activated. The inline subscription method uses the same internal subscription logic
|
||||
// as used for external (normal) clients.
|
||||
go func() {
|
||||
// Inline subscriptions can also receive retained messages on subscription.
|
||||
_ = server.Publish("direct/retained", []byte("retained message"), true, 0)
|
||||
_ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0)
|
||||
|
||||
// Subscribe to a filter and handle any received messages via a callback function.
|
||||
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||
}
|
||||
server.Log.Info("inline client subscribing")
|
||||
_ = server.Subscribe("direct/#", 1, callbackFn)
|
||||
_ = server.Subscribe("direct/#", 2, callbackFn)
|
||||
}()
|
||||
|
||||
// There is a shorthand convenience function, Publish, for easily sending publish packets if you are not
|
||||
// concerned with creating your own packets. If you want to have more control over your packets, you can
|
||||
//directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage.
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 3) {
|
||||
err := server.Publish("direct/publish", []byte("scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error("server.Publish", "error", err)
|
||||
}
|
||||
server.Log.Info("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Second * 10)
|
||||
// Unsubscribe from the same filter to stop receiving messages.
|
||||
server.Log.Info("inline client unsubscribing")
|
||||
_ = server.Unsubscribe("direct/#", 1)
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
@@ -1,21 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -52,23 +53,38 @@ func main() {
|
||||
// `server.Publish` method. Subscribe to `direct/publish` using your
|
||||
// MQTT client to see the messages.
|
||||
go func() {
|
||||
cl := mqtt.NewInlineClient("inline", "local")
|
||||
for range time.Tick(time.Second * 10) {
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
for range time.Tick(time.Second * 1) {
|
||||
err := server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("scheduled message"),
|
||||
Payload: []byte("injected scheduled message"),
|
||||
})
|
||||
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
|
||||
if err != nil {
|
||||
server.Log.Error("server.InjectPacket", "error", err)
|
||||
}
|
||||
server.Log.Info("main.go injected packet to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
// There is also a shorthand convenience function, Publish, for easily sending
|
||||
// publish packets if you are not concerned with creating your own packets.
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 5) {
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error("server.Publish", "error", err)
|
||||
}
|
||||
server.Log.Info("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
type ExampleHook struct {
|
||||
@@ -91,38 +107,44 @@ func (h *ExampleHook) Provides(b byte) bool {
|
||||
}
|
||||
|
||||
func (h *ExampleHook) Init(config any) error {
|
||||
h.Log.Info().Msg("initialised")
|
||||
h.Log.Info("initialised")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
|
||||
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
|
||||
h.Log.Info("client connected", "client", cl.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
|
||||
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
|
||||
if err != nil {
|
||||
h.Log.Info("client disconnected", "client", cl.ID, "expire", expire, "error", err)
|
||||
} else {
|
||||
h.Log.Info("client disconnected", "client", cl.ID, "expire", expire)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
|
||||
h.Log.Info(fmt.Sprintf("subscribed qos=%v", reasonCodes), "client", cl.ID, "filters", pk.Filters)
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
|
||||
h.Log.Info("unsubscribed", "client", cl.ID, "filters", pk.Filters)
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
|
||||
h.Log.Info("received from client", "client", cl.ID, "payload", string(pk.Payload))
|
||||
|
||||
pkx := pk
|
||||
if string(pk.Payload) == "hello" {
|
||||
pkx.Payload = []byte("hello world")
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
|
||||
h.Log.Info("received modified packet from client", "client", cl.ID, "payload", string(pkx.Payload))
|
||||
}
|
||||
|
||||
return pkx, nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
|
||||
h.Log.Info("published to client", "client", cl.ID, "payload", string(pk.Payload))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -26,10 +26,9 @@ func main() {
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.ServerKeepAlive = 60
|
||||
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
||||
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
|
||||
server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
|
||||
|
||||
_ = server.AddHook(new(pahoAuthHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
@@ -46,9 +45,9 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
type pahoAuthHook struct {
|
||||
@@ -62,6 +61,7 @@ func (h *pahoAuthHook) ID() string {
|
||||
func (h *pahoAuthHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnConnect,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
@@ -73,3 +73,12 @@ func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet)
|
||||
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return topic != "test/nosubscribe"
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
|
||||
// Handle paho test_server_keep_alive
|
||||
if pk.Connect.Keepalive == 120 && pk.Connect.Clean {
|
||||
cl.State.Keepalive = 60
|
||||
cl.State.ServerKeepalive = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/badger"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage/badger"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -52,8 +52,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/bolt"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage/bolt"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
@@ -30,12 +30,15 @@ func main() {
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
err := server.AddHook(new(bolt.Hook), bolt.Options{
|
||||
err := server.AddHook(new(bolt.Hook), &bolt.Options{
|
||||
Path: "bolt.db",
|
||||
Options: &bbolt.Options{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
@@ -51,7 +54,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage/redis"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
|
||||
rv8 "github.com/go-redis/redis/v8"
|
||||
)
|
||||
@@ -30,8 +30,12 @@ func main() {
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
level := new(slog.LevelVar)
|
||||
server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}))
|
||||
level.Set(slog.LevelDebug)
|
||||
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
@@ -58,8 +62,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -52,7 +52,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -97,7 +97,7 @@ func main() {
|
||||
|
||||
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
}, nil)
|
||||
}, server.Info)
|
||||
err = server.AddListener(stats)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -111,7 +111,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -41,7 +41,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
101
fanpool.go
101
fanpool.go
@@ -1,101 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
xh "github.com/cespare/xxhash/v2"
|
||||
)
|
||||
|
||||
// taskChan is a channel for incoming task functions.
|
||||
type taskChan chan func()
|
||||
|
||||
// FanPool is a fixed-sized fan-style worker pool with multiple
|
||||
// working 'columns'. Instead of a single queue channel processed by
|
||||
// many goroutines, this fan pool uses many queue channels each
|
||||
// processed by a single goroutine.
|
||||
// Very special thanks are given to the authors of HMQ in particular
|
||||
// @chowyu08 and @muXxer for their work on the fixpool worker pool
|
||||
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
|
||||
// from which this fan-pool is heavily inspired.
|
||||
type FanPool struct {
|
||||
queue []taskChan
|
||||
wg sync.WaitGroup
|
||||
capacity uint64
|
||||
perChan uint64
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
|
||||
// of the fan, whereas queueSize controls the size of each column's queue.
|
||||
func NewFanPool(fanSize, queueSize uint64) *FanPool {
|
||||
pool := &FanPool{
|
||||
capacity: fanSize,
|
||||
perChan: queueSize,
|
||||
queue: make([]taskChan, fanSize),
|
||||
}
|
||||
|
||||
pool.fillWorkers(fanSize)
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
|
||||
func (p *FanPool) fillWorkers(n uint64) {
|
||||
for i := uint64(0); i < n; i++ {
|
||||
p.queue[i] = make(taskChan, p.perChan)
|
||||
go p.worker(p.queue[i])
|
||||
p.wg.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// worker is a worker goroutine which processes tasks from a single queue.
|
||||
func (p *FanPool) worker(ch taskChan) {
|
||||
defer p.wg.Done()
|
||||
var task func()
|
||||
var ok bool
|
||||
for {
|
||||
task, ok = <-ch
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
task()
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue adds a new task to the queue to be processed.
|
||||
func (p *FanPool) Enqueue(id string, task func()) {
|
||||
if p.Size() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// We can use xh.Sum64 to get a specific queue index
|
||||
// which remains the same for a client id, giving each
|
||||
// client their own queue.
|
||||
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
|
||||
}
|
||||
|
||||
// Wait blocks until all the workers in the pool have completed.
|
||||
func (p *FanPool) Wait() {
|
||||
p.wg.Wait()
|
||||
}
|
||||
|
||||
// Close issues a shutdown signal to the workers.
|
||||
func (p *FanPool) Close() {
|
||||
for i := 0; i < int(p.Size()); i++ {
|
||||
if p.queue[i] != nil {
|
||||
close(p.queue[i])
|
||||
}
|
||||
}
|
||||
p.queue = nil
|
||||
atomic.StoreUint64(&p.capacity, 0)
|
||||
}
|
||||
|
||||
// Size returns the current number of workers in the pool.
|
||||
func (p *FanPool) Size() uint64 {
|
||||
return atomic.LoadUint64(&p.capacity)
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFanPool(t *testing.T) {
|
||||
f := NewFanPool(1, 2)
|
||||
require.NotNil(t, f)
|
||||
require.Equal(t, uint64(1), f.capacity)
|
||||
require.Equal(t, 2, cap(f.queue[0]))
|
||||
|
||||
o := make(chan bool)
|
||||
go func() {
|
||||
f.Enqueue("test", func() {
|
||||
o <- true
|
||||
})
|
||||
}()
|
||||
|
||||
require.True(t, <-o)
|
||||
f.Close()
|
||||
f.Wait()
|
||||
}
|
||||
|
||||
func TestFillWorkers(t *testing.T) {
|
||||
f := &FanPool{
|
||||
perChan: 3,
|
||||
queue: make([]taskChan, 2),
|
||||
}
|
||||
f.fillWorkers(2)
|
||||
require.Len(t, f.queue, 2)
|
||||
require.Equal(t, 3, cap(f.queue[0]))
|
||||
}
|
||||
|
||||
func TestEnqueue(t *testing.T) {
|
||||
f := &FanPool{
|
||||
capacity: 2,
|
||||
queue: []taskChan{
|
||||
make(taskChan, 2),
|
||||
make(taskChan, 2),
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
f.Enqueue("a", func() {})
|
||||
}()
|
||||
require.NotNil(t, <-f.queue[1])
|
||||
}
|
||||
|
||||
func TestEnqueueOnEmpty(t *testing.T) {
|
||||
f := &FanPool{
|
||||
queue: []taskChan{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
f.Enqueue("a", func() {})
|
||||
}()
|
||||
|
||||
require.Len(t, f.queue, 0)
|
||||
}
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
f := &FanPool{
|
||||
capacity: 10,
|
||||
}
|
||||
|
||||
require.Equal(t, uint64(10), f.Size())
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
f := &FanPool{
|
||||
capacity: 3,
|
||||
queue: []taskChan{
|
||||
make(taskChan, 2),
|
||||
make(taskChan, 2),
|
||||
make(taskChan, 2),
|
||||
},
|
||||
}
|
||||
|
||||
f.Close()
|
||||
require.Equal(t, uint64(0), f.Size())
|
||||
require.Nil(t, f.queue)
|
||||
}
|
||||
15
go.mod
15
go.mod
@@ -1,17 +1,15 @@
|
||||
module github.com/mochi-co/mqtt/v2
|
||||
module github.com/mochi-mqtt/server/v2
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.23.0
|
||||
github.com/asdine/storm v2.1.2+incompatible
|
||||
github.com/asdine/storm/v3 v3.2.1
|
||||
github.com/cespare/xxhash/v2 v2.1.2
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.3.5
|
||||
github.com/rs/xid v1.4.0
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/timshannon/badgerhold v1.0.0
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
@@ -21,6 +19,7 @@ require (
|
||||
require (
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/badger v1.6.0 // indirect
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
|
||||
@@ -28,13 +27,11 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.0 // indirect
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/golang/snappy v0.0.3 // indirect
|
||||
github.com/mattn/go-colorable v0.1.12 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/google/go-cmp v0.5.8 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
|
||||
golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
)
|
||||
|
||||
32
go.sum
32
go.sum
@@ -23,7 +23,6 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
|
||||
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -38,9 +37,9 @@ github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
|
||||
@@ -48,8 +47,9 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
@@ -62,15 +62,14 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
|
||||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
@@ -79,8 +78,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
|
||||
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
|
||||
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
@@ -109,24 +106,21 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ=
|
||||
golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
@@ -136,8 +130,10 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
260
hooks.go
260
hooks.go
@@ -1,20 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, thedevop, dgduncan
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,6 +24,7 @@ const (
|
||||
OnConnectAuthenticate
|
||||
OnACLCheck
|
||||
OnConnect
|
||||
OnSessionEstablish
|
||||
OnSessionEstablished
|
||||
OnDisconnect
|
||||
OnAuthPacket
|
||||
@@ -39,15 +39,17 @@ const (
|
||||
OnUnsubscribed
|
||||
OnPublish
|
||||
OnPublished
|
||||
OnPublishDropped
|
||||
OnRetainMessage
|
||||
OnRetainPublished
|
||||
OnQosPublish
|
||||
OnQosComplete
|
||||
OnQosDropped
|
||||
OnPacketIDExhausted
|
||||
OnWill
|
||||
OnWillSent
|
||||
OnClientExpired
|
||||
OnRetainedExpired
|
||||
OnExpireInflights
|
||||
StoredClients
|
||||
StoredSubscriptions
|
||||
StoredInflightMessages
|
||||
@@ -67,13 +69,14 @@ type Hook interface {
|
||||
Provides(b byte) bool
|
||||
Init(config any) error
|
||||
Stop() error
|
||||
SetOpts(l *zerolog.Logger, o *HookOptions)
|
||||
SetOpts(l *slog.Logger, o *HookOptions)
|
||||
OnStarted()
|
||||
OnStopped()
|
||||
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
|
||||
OnACLCheck(cl *Client, topic string, write bool) bool
|
||||
OnSysInfoTick(*system.Info)
|
||||
OnConnect(cl *Client, pk packets.Packet)
|
||||
OnConnect(cl *Client, pk packets.Packet) error
|
||||
OnSessionEstablish(cl *Client, pk packets.Packet)
|
||||
OnSessionEstablished(cl *Client, pk packets.Packet)
|
||||
OnDisconnect(cl *Client, err error, expire bool)
|
||||
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
@@ -88,15 +91,17 @@ type Hook interface {
|
||||
OnUnsubscribed(cl *Client, pk packets.Packet)
|
||||
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPublished(cl *Client, pk packets.Packet)
|
||||
OnPublishDropped(cl *Client, pk packets.Packet)
|
||||
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
|
||||
OnRetainPublished(cl *Client, pk packets.Packet)
|
||||
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
|
||||
OnQosComplete(cl *Client, pk packets.Packet)
|
||||
OnQosDropped(cl *Client, pk packets.Packet)
|
||||
OnPacketIDExhausted(cl *Client, pk packets.Packet)
|
||||
OnWill(cl *Client, will Will) (Will, error)
|
||||
OnWillSent(cl *Client, pk packets.Packet)
|
||||
OnClientExpired(cl *Client)
|
||||
OnRetainedExpired(filter string)
|
||||
OnExpireInflights(cl *Client, expiry int64)
|
||||
StoredClients() ([]storage.Client, error)
|
||||
StoredSubscriptions() ([]storage.Subscription, error)
|
||||
StoredInflightMessages() ([]storage.Message, error)
|
||||
@@ -111,11 +116,11 @@ type HookOptions struct {
|
||||
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal []Hook // a slice of hooks
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex
|
||||
Log *slog.Logger // a logger for the hook (from the server)
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
@@ -125,7 +130,7 @@ func (h *Hooks) Len() int64 {
|
||||
|
||||
// Provides returns true if any one hook provides any of the requested hook methods.
|
||||
func (h *Hooks) Provides(b ...byte) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
for _, hb := range b {
|
||||
if hook.Provides(hb) {
|
||||
return true
|
||||
@@ -140,29 +145,42 @@ func (h *Hooks) Provides(b ...byte) bool {
|
||||
func (h *Hooks) Add(hook Hook, config any) error {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
if h.internal == nil {
|
||||
h.internal = []Hook{}
|
||||
}
|
||||
|
||||
err := hook.Init(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
||||
}
|
||||
|
||||
h.internal = append(h.internal, hook)
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
i = []Hook{}
|
||||
}
|
||||
|
||||
i = append(i, hook)
|
||||
h.internal.Store(i)
|
||||
atomic.AddInt64(&h.qty, 1)
|
||||
h.wg.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a slice of all the hooks.
|
||||
func (h *Hooks) GetAll() []Hook {
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
return []Hook{}
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Stop indicates all attached hooks to gracefully end.
|
||||
func (h *Hooks) Stop() {
|
||||
go func() {
|
||||
for _, hook := range h.internal {
|
||||
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
|
||||
for _, hook := range h.GetAll() {
|
||||
h.Log.Info("stopping hook", "hook", hook.ID())
|
||||
if err := hook.Stop(); err != nil {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
|
||||
h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID())
|
||||
}
|
||||
|
||||
h.wg.Done()
|
||||
@@ -174,7 +192,7 @@ func (h *Hooks) Stop() {
|
||||
|
||||
// OnSysInfoTick is called when the $SYS topic values are published out.
|
||||
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSysInfoTick) {
|
||||
hook.OnSysInfoTick(sys)
|
||||
}
|
||||
@@ -183,7 +201,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
// OnStarted is called when the server has successfully started.
|
||||
func (h *Hooks) OnStarted() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStarted) {
|
||||
hook.OnStarted()
|
||||
}
|
||||
@@ -192,25 +210,39 @@ func (h *Hooks) OnStarted() {
|
||||
|
||||
// OnStopped is called when the server has successfully stopped.
|
||||
func (h *Hooks) OnStopped() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStopped) {
|
||||
hook.OnStopped()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
// OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection.
|
||||
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnect) {
|
||||
hook.OnConnect(cl, pk)
|
||||
err := hook.OnConnect(cl, pk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnSessionEstablish is called right after a new client connects and authenticates and right before
|
||||
// the session is established and CONNACK is sent.
|
||||
func (h *Hooks) OnSessionEstablish(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablish) {
|
||||
hook.OnSessionEstablish(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablished) {
|
||||
hook.OnSessionEstablished(cl, pk)
|
||||
}
|
||||
@@ -219,7 +251,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnDisconnect) {
|
||||
hook.OnDisconnect(cl, err, expire)
|
||||
}
|
||||
@@ -229,11 +261,11 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
// OnPacketRead is called when a packet is received from a client.
|
||||
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketRead) {
|
||||
npk, err := hook.OnPacketRead(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
|
||||
h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx)
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
@@ -250,7 +282,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
|
||||
// to create their own auth packet handling mechanisms.
|
||||
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnAuthPacket) {
|
||||
npk, err := hook.OnAuthPacket(cl, pkx)
|
||||
if err != nil {
|
||||
@@ -266,7 +298,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
|
||||
|
||||
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
|
||||
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketEncode) {
|
||||
pk = hook.OnPacketEncode(cl, pk)
|
||||
}
|
||||
@@ -277,7 +309,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
|
||||
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketProcessed) {
|
||||
hook.OnPacketProcessed(cl, pk, err)
|
||||
}
|
||||
@@ -287,7 +319,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
|
||||
// containing the bytes sent.
|
||||
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketSent) {
|
||||
hook.OnPacketSent(cl, pk, b)
|
||||
}
|
||||
@@ -299,7 +331,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribe) {
|
||||
pk = hook.OnSubscribe(cl, pk)
|
||||
}
|
||||
@@ -309,7 +341,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribed) {
|
||||
hook.OnSubscribed(cl, pk, reasonCodes)
|
||||
}
|
||||
@@ -321,7 +353,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
|
||||
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
|
||||
// group in a custom manner (such as based on client id, ip, etc).
|
||||
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSelectSubscribers) {
|
||||
subs = hook.OnSelectSubscribers(subs, pk)
|
||||
}
|
||||
@@ -334,7 +366,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribe) {
|
||||
pk = hook.OnUnsubscribe(cl, pk)
|
||||
}
|
||||
@@ -344,28 +376,35 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribed) {
|
||||
hook.OnUnsubscribed(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnMessage
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublish) {
|
||||
npk, err := hook.OnPublish(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
|
||||
if err != nil {
|
||||
if errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug("publish packet rejected",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"packet", pkx)
|
||||
return pk, err
|
||||
}
|
||||
h.Log.Error("publish packet error",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"packet", pkx)
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
@@ -375,27 +414,46 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublished) {
|
||||
hook.OnPublished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublishDropped is called when a message to a client was dropped instead of delivered
|
||||
// such as when a client is too slow to respond.
|
||||
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublishDropped) {
|
||||
hook.OnPublishDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainMessage) {
|
||||
hook.OnRetainMessage(cl, pk, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainPublished is called when a retained message is published.
|
||||
func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainPublished) {
|
||||
hook.OnRetainPublished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
|
||||
// In other words, this method is called when a new inflight message is created or resent.
|
||||
// It is typically used to store a new inflight message.
|
||||
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosPublish) {
|
||||
hook.OnQosPublish(cl, pk, sent, resends)
|
||||
}
|
||||
@@ -406,7 +464,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
|
||||
// In other words, when an inflight message is resolved.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosComplete) {
|
||||
hook.OnQosComplete(cl, pk)
|
||||
}
|
||||
@@ -414,26 +472,39 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
||||
// an inflight message expires or is abandoned.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||
// inflight message from a store.
|
||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosDropped) {
|
||||
hook.OnQosDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
|
||||
// assign to a packet.
|
||||
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketIDExhausted) {
|
||||
hook.OnPacketIDExhausted(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message. This method
|
||||
// differs from OnWillSent in that it allows you to modify the LWT message before it is
|
||||
// published. The return values of the hook methods are passed-through in the order
|
||||
// the hooks were attached.
|
||||
func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWill) {
|
||||
mlwt, err := hook.OnWill(cl, will)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
|
||||
h.Log.Error("parse will error",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"will", will)
|
||||
continue
|
||||
}
|
||||
will = mlwt
|
||||
@@ -445,7 +516,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWillSent) {
|
||||
hook.OnWillSent(cl, pk)
|
||||
}
|
||||
@@ -454,7 +525,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnClientExpired is called when a client session has expired and should be deleted.
|
||||
func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnClientExpired) {
|
||||
hook.OnClientExpired(cl)
|
||||
}
|
||||
@@ -463,7 +534,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
|
||||
// OnRetainedExpired is called when a retained message has expired and should be deleted.
|
||||
func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainedExpired) {
|
||||
hook.OnRetainedExpired(filter)
|
||||
}
|
||||
@@ -473,11 +544,11 @@ func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
// StoredClients returns all clients, e.g. from a persistent store, is used to
|
||||
// populate the server clients list before start.
|
||||
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredClients) {
|
||||
v, err := hook.StoredClients()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
|
||||
h.Log.Error("failed to load clients", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -493,11 +564,11 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
|
||||
// used to populate the server subscriptions list before start.
|
||||
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSubscriptions) {
|
||||
v, err := hook.StoredSubscriptions()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
|
||||
h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -513,11 +584,11 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
|
||||
// and is used to populate the restored clients with inflight messages before start.
|
||||
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredInflightMessages) {
|
||||
v, err := hook.StoredInflightMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
|
||||
h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -533,11 +604,11 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
|
||||
// and is used to populate the server topics with retained messages before start.
|
||||
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredRetainedMessages) {
|
||||
v, err := hook.StoredRetainedMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
|
||||
h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -552,11 +623,11 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSysInfo) {
|
||||
v, err := hook.StoredSysInfo()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
|
||||
h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -574,7 +645,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check connecting users against an existing user database.
|
||||
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnectAuthenticate) {
|
||||
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
|
||||
return true
|
||||
@@ -590,7 +661,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check publishing and subscribing users against an existing permissions or roles database.
|
||||
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnACLCheck) {
|
||||
if ok := hook.OnACLCheck(cl, topic, write); ok {
|
||||
return true
|
||||
@@ -601,24 +672,11 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnExpireInflights is called when the server issues a clear request for expired
|
||||
// inflight messages. Expiry should be the time after which the message is no longer
|
||||
// valid (usually some time in the past). A message has expired if it's created time
|
||||
// is older than time.Now() minus Inflight TTL. This method can be used to expire
|
||||
// old inflight messages in a persistent store which doesnt support per-item TTL.
|
||||
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
|
||||
for _, hook := range h.internal {
|
||||
if hook.Provides(OnExpireInflights) {
|
||||
hook.OnExpireInflights(cl, expiry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HookBase provides a set of default methods for each hook. It should be embedded in
|
||||
// all hooks.
|
||||
type HookBase struct {
|
||||
Hook
|
||||
Log *zerolog.Logger
|
||||
Log *slog.Logger
|
||||
Opts *HookOptions
|
||||
}
|
||||
|
||||
@@ -641,12 +699,12 @@ func (h *HookBase) Init(config any) error {
|
||||
|
||||
// SetOpts is called by the server to propagate internal values and generally should
|
||||
// not be called manually.
|
||||
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
|
||||
func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) {
|
||||
h.Log = l
|
||||
h.Opts = opts
|
||||
}
|
||||
|
||||
// Stop is called to gracefully shutdown the hook.
|
||||
// Stop is called to gracefully shut down the hook.
|
||||
func (h *HookBase) Stop() error {
|
||||
return nil
|
||||
}
|
||||
@@ -671,7 +729,13 @@ func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
|
||||
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnSessionEstablish is called right after a new client connects and authenticates and right before
|
||||
// the session is established and CONNACK is sent.
|
||||
func (h *HookBase) OnSessionEstablish(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
|
||||
@@ -729,9 +793,15 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
|
||||
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
|
||||
|
||||
// OnRetainPublished is called when a retained message is published.
|
||||
func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
|
||||
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
|
||||
|
||||
@@ -741,6 +811,9 @@ func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
|
||||
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message.
|
||||
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
return will, nil
|
||||
@@ -755,9 +828,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
|
||||
// OnRetainedExpired is called when a retained message for a topic has expired.
|
||||
func (h *HookBase) OnRetainedExpired(topic string) {}
|
||||
|
||||
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
|
||||
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
|
||||
|
||||
// StoredClients returns all clients from a store.
|
||||
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
||||
return
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
@@ -7,8 +7,8 @@ package auth
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
// AllowHook is an authentication hook which allows connection access
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
@@ -7,8 +7,8 @@ package auth
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
@@ -7,8 +7,8 @@ package auth
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
// Options contains the configuration/rules data for the auth ledger.
|
||||
@@ -67,10 +67,9 @@ func (h *Hook) Init(config any) error {
|
||||
}
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Int("authentication", len(h.ledger.Auth)).
|
||||
Int("acl", len(h.ledger.ACL)).
|
||||
Msg("loaded auth rules")
|
||||
h.Log.Info("loaded auth rules",
|
||||
"authentication", len(h.ledger.Auth),
|
||||
"acl", len(h.ledger.ACL))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -82,11 +81,9 @@ func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("username", string(pk.Connect.Username)).
|
||||
Str("remote", cl.Net.Remote).
|
||||
Msg("client failed authentication check")
|
||||
|
||||
h.Log.Info("client failed authentication check",
|
||||
"username", string(pk.Connect.Username),
|
||||
"remote", cl.Net.Remote)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -97,11 +94,10 @@ func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Debug().
|
||||
Str("client", cl.ID).
|
||||
Str("username", string(cl.Properties.Username)).
|
||||
Str("topic", topic).
|
||||
Msg("client failed allowed ACL check")
|
||||
h.Log.Debug("client failed allowed ACL check",
|
||||
"client", cl.ID,
|
||||
"username", string(cl.Properties.Username),
|
||||
"topic", topic)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/rs/zerolog"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
// func teardown(t *testing.T, path string, h *Hook) {
|
||||
// h.Stop()
|
||||
@@ -34,7 +34,7 @@ func TestBasicProvides(t *testing.T) {
|
||||
|
||||
func TestBasicInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -42,7 +42,7 @@ func TestBasicInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestBasicInitDefaultConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
@@ -50,7 +50,7 @@ func TestBasicInitDefaultConfig(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerPointer(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := &Ledger{
|
||||
Auth: []AuthRule{
|
||||
@@ -79,7 +79,7 @@ func TestBasicInitWithLedgerPointer(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerJSON(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
@@ -93,7 +93,7 @@ func TestBasicInitWithLedgerJSON(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerYAML(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
@@ -107,7 +107,7 @@ func TestBasicInitWithLedgerYAML(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
@@ -119,7 +119,7 @@ func TestBasicInitWithLedgerBadDAta(t *testing.T) {
|
||||
|
||||
func TestOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
@@ -158,7 +158,7 @@ func TestOnConnectAuthenticate(t *testing.T) {
|
||||
|
||||
func TestOnACL(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
@@ -9,9 +9,10 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -79,8 +80,8 @@ func (r RString) Matches(a string) bool {
|
||||
}
|
||||
|
||||
// FilterMatches returns true if a filter matches a topic rule.
|
||||
func (f RString) FilterMatches(a string) bool {
|
||||
_, ok := MatchTopic(string(f), a)
|
||||
func (r RString) FilterMatches(a string) bool {
|
||||
_, ok := MatchTopic(string(r), a)
|
||||
return ok
|
||||
}
|
||||
|
||||
@@ -160,7 +161,7 @@ func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
|
||||
}
|
||||
|
||||
// ACLOk returns true if the rules indicate the user is allowed to read or write to
|
||||
// a specific filter or topic respectively, based on the write bool.
|
||||
// a specific filter or topic respectively, based on the `write` bool.
|
||||
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
@@ -188,17 +189,31 @@ func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok boo
|
||||
return n, true
|
||||
}
|
||||
|
||||
for filter, access := range rule.Filters {
|
||||
if filter.FilterMatches(topic) {
|
||||
if !write && (access == ReadOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else if write && (access == WriteOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else {
|
||||
return n, false
|
||||
if write {
|
||||
for filter, access := range rule.Filters {
|
||||
if access == WriteOnly || access == ReadWrite {
|
||||
if filter.FilterMatches(topic) {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !write {
|
||||
for filter, access := range rule.Filters {
|
||||
if access == ReadOnly || access == ReadWrite {
|
||||
if filter.FilterMatches(topic) {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for filter := range rule.Filters {
|
||||
if filter.FilterMatches(topic) {
|
||||
return n, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
@@ -7,8 +7,8 @@ package auth
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -561,17 +561,17 @@ func TestLedgerUpdate(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
new := &Ledger{
|
||||
n := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
{Remote: "192.168.*", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
old.Update(new)
|
||||
old.Update(n)
|
||||
require.Len(t, old.Auth, 2)
|
||||
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
|
||||
require.NotSame(t, new, old)
|
||||
require.NotSame(t, n, old)
|
||||
}
|
||||
|
||||
func TestLedgerToJSON(t *testing.T) {
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
// Options contains configuration settings for the debug output.
|
||||
@@ -25,7 +25,7 @@ type Options struct {
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
Log *zerolog.Logger
|
||||
Log *slog.Logger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
@@ -54,25 +54,25 @@ func (h *Hook) Init(config any) error {
|
||||
}
|
||||
|
||||
// SetOpts is called when the hook receives inheritable server parameters.
|
||||
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
|
||||
func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) {
|
||||
h.Log = l
|
||||
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
|
||||
h.Log.Debug("", "method", "SetOpts")
|
||||
}
|
||||
|
||||
// Stop is called when the hook is stopped.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Debug().Str("method", "Stop").Send()
|
||||
h.Log.Debug("", "method", "Stop")
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *Hook) OnStarted() {
|
||||
h.Log.Debug().Str("method", "OnStarted").Send()
|
||||
h.Log.Debug("", "method", "OnStarted")
|
||||
}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *Hook) OnStopped() {
|
||||
h.Log.Debug().Str("method", "OnStopped").Send()
|
||||
h.Log.Debug("", "method", "OnStopped")
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a new packet is received from a client.
|
||||
@@ -81,8 +81,7 @@ func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet,
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
|
||||
h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
@@ -92,85 +91,72 @@ func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
|
||||
h.Log.Debug("retained message on topic", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
|
||||
h.Log.Debug("inflight out", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
|
||||
h.Log.Debug("inflight complete", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
|
||||
h.Log.Debug("inflight dropped", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnLWTSent is called when a will message has been issued from a disconnecting client.
|
||||
// OnLWTSent is called when a Will Message has been issued from a disconnecting client.
|
||||
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
|
||||
h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID)
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when the server clears expired retained messages.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
|
||||
h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter)
|
||||
}
|
||||
|
||||
// OnClientExpired is called when the server clears an expired client.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
|
||||
h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID)
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores clients from a store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
h.Log.Debug("", "method", "StoredClients")
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores subscriptions from a store.
|
||||
// StoredSubscriptions is called when the server restores subscriptions from a store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredSubscriptions").
|
||||
Send()
|
||||
|
||||
h.Log.Debug("", "method", "StoredSubscriptions")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores retained messages from a store.
|
||||
// StoredRetainedMessages is called when the server restores retained messages from a store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredRetainedMessages").
|
||||
Send()
|
||||
|
||||
h.Log.Debug("", "method", "StoredRetainedMessages")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores inflight messages from a store.
|
||||
// StoredInflightMessages is called when the server restores inflight messages from a store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredInflightMessages").
|
||||
Send()
|
||||
|
||||
h.Log.Debug("", "method", "StoredInflightMessages")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores system info from a store.
|
||||
// StoredSysInfo is called when the server restores system info from a store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
h.Log.Debug("", "method", "StoredSysInfo")
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, gsagula
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/timshannon/badgerhold"
|
||||
)
|
||||
@@ -80,7 +81,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -128,8 +128,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
@@ -137,7 +136,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -166,14 +165,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
|
||||
h.Log.Error("failed to upsert client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if their session has expired.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -183,32 +182,40 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
|
||||
h.Log.Error("failed to delete client data", "error", err, "data", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Qos: reasonCodes[i],
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
|
||||
h.Log.Error("failed to upsert subscription data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,14 +223,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
|
||||
h.Log.Error("failed to delete subscription data", "error", err, "data", subscriptionKey(cl, pk.Filters[i].Filter))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -231,14 +238,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
|
||||
h.Log.Error("failed to delete retained message data", "error", err, "data", retainedKey(pk.TopicName))
|
||||
}
|
||||
|
||||
return
|
||||
@@ -267,14 +274,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
|
||||
h.Log.Error("failed to upsert retained message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -303,27 +310,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
|
||||
h.Log.Error("failed to upsert qos inflight data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
|
||||
h.Log.Error("failed to delete inflight message data", "error", err, "data", inflightKey(cl, pk))
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
@@ -332,66 +339,52 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
Info: *sys.Clone(),
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var v []storage.Message
|
||||
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := h.db.Delete(m.ID, new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
h.Log.Error("failed to upsert $SYS data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
||||
h.Log.Error("failed to delete expired retained message data", "error", err, "id", retainedKey(filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
||||
h.Log.Error("failed to delete expired client data", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -406,7 +399,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -421,7 +414,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -436,7 +429,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -451,7 +444,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -465,20 +458,21 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
|
||||
// Errorf satisfies the badger interface for an error logger.
|
||||
func (h *Hook) Errorf(m string, v ...interface{}) {
|
||||
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
|
||||
}
|
||||
|
||||
// Warningf satisfies the badger interface for a warning logger.
|
||||
func (h *Hook) Warningf(m string, v ...interface{}) {
|
||||
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Infof satisfies the badger interface for an info logger.
|
||||
func (h *Hook) Infof(m string, v ...interface{}) {
|
||||
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Debugf satisfies the badger interface for a debug logger.
|
||||
func (h *Hook) Debugf(m string, v ...interface{}) {
|
||||
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/rs/zerolog"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/timshannon/badgerhold"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
@@ -40,8 +38,8 @@ var (
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
h.db.Badger().Close()
|
||||
_ = h.Stop()
|
||||
_ = h.db.Badger().Close()
|
||||
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -97,7 +95,7 @@ func TestProvides(t *testing.T) {
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -105,7 +103,7 @@ func TestInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -115,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -148,7 +146,7 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -170,15 +168,30 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -187,7 +200,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -206,22 +219,45 @@ func TestOnWillSent(t *testing.T) {
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -243,13 +279,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -258,13 +294,13 @@ func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -273,7 +309,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -308,7 +344,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -333,15 +369,30 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -350,7 +401,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -385,13 +436,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -400,13 +451,13 @@ func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -415,55 +466,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var v []storage.Message
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, v, 1)
|
||||
require.Equal(t, "i1", v[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -485,13 +494,13 @@ func TestOnSysInfoTick(t *testing.T) {
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -500,7 +509,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -525,7 +534,7 @@ func TestStoredClients(t *testing.T) {
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -533,7 +542,7 @@ func TestStoredClientsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -558,7 +567,7 @@ func TestStoredSubscriptions(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -566,7 +575,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -594,7 +603,7 @@ func TestStoredRetainedMessages(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -602,7 +611,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -630,7 +639,7 @@ func TestStoredInflightMessages(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -638,7 +647,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -660,7 +669,7 @@ func TestStoredSysInfo(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -669,27 +678,27 @@ func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
func TestErrorf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Errorf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestWarningf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Warningf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestInfof(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Infof("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Debugf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
|
||||
|
||||
// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
|
||||
package bolt
|
||||
|
||||
import (
|
||||
@@ -9,10 +10,10 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
sgob "github.com/asdine/storm/codec/gob"
|
||||
"github.com/asdine/storm/v3"
|
||||
@@ -85,7 +86,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -133,8 +133,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
@@ -142,7 +141,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,14 +169,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
|
||||
h.Log.Error("failed to save client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -185,34 +184,40 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Qos: reasonCodes[i],
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save subscription data")
|
||||
h.Log.Error("failed to save subscription data", "error", err, "client", cl.ID, "data", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,7 +225,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -229,9 +234,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
|
||||
Msg("failed to delete client")
|
||||
h.Log.Error("failed to delete client", "error", err, "id", subscriptionKey(cl, pk.Filters[i].Filter))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -239,7 +242,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -248,9 +251,7 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
ID: retainedKey(pk.TopicName),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", retainedKey(pk.TopicName)).
|
||||
Msg("failed to delete retained publish")
|
||||
h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(pk.TopicName))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -277,17 +278,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save retained publish data")
|
||||
h.Log.Error("failed to save retained publish data", "error", err, "client", cl.ID, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -315,17 +313,14 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save qos inflight data")
|
||||
h.Log.Error("failed to save qos inflight data", "error", err, "client", cl.ID, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -333,16 +328,14 @@ func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
ID: inflightKey(cl, pk),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", inflightKey(cl, pk)).
|
||||
Msg("failed to delete inflight data")
|
||||
h.Log.Error("failed to delete inflight data", "error", err, "id", inflightKey(cl, pk))
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
@@ -351,7 +344,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -363,57 +356,39 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Interface("data", in).
|
||||
Msg("failed to save $SYS data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the
|
||||
// provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var v []storage.Message
|
||||
err := h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
|
||||
return
|
||||
}
|
||||
}
|
||||
h.Log.Error("failed to save $SYS data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
||||
h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -428,7 +403,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -443,7 +418,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -458,7 +433,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -473,7 +448,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
@@ -39,7 +38,7 @@ var (
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
_ = h.Stop()
|
||||
err := os.Remove(path)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -95,7 +94,7 @@ func TestProvides(t *testing.T) {
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -103,7 +102,7 @@ func TestInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -114,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
|
||||
func TestInitBadPath(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Path: "..",
|
||||
})
|
||||
@@ -123,7 +122,7 @@ func TestInitBadPath(t *testing.T) {
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -156,13 +155,13 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -171,7 +170,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -190,7 +189,7 @@ func TestOnWillSent(t *testing.T) {
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -212,24 +211,62 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -251,13 +288,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -266,13 +303,13 @@ func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -281,7 +318,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -316,7 +353,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -341,15 +378,30 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -358,7 +410,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -393,13 +445,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -408,13 +460,13 @@ func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -423,55 +475,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var v []storage.Message
|
||||
err = h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, v, 1)
|
||||
require.Equal(t, "i1", v[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -493,13 +503,13 @@ func TestOnSysInfoTick(t *testing.T) {
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -508,7 +518,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -533,7 +543,7 @@ func TestStoredClients(t *testing.T) {
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -541,7 +551,7 @@ func TestStoredClientsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -552,7 +562,7 @@ func TestStoredClientsClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -577,7 +587,7 @@ func TestStoredSubscriptions(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -585,7 +595,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -596,7 +606,7 @@ func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -624,7 +634,7 @@ func TestStoredRetainedMessages(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -632,7 +642,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -643,7 +653,7 @@ func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -671,7 +681,7 @@ func TestStoredInflightMessages(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -679,7 +689,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -690,7 +700,7 @@ func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -712,7 +722,7 @@ func TestStoredSysInfo(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -720,7 +730,7 @@ func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
@@ -10,12 +10,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// defaultAddr is the default address to the redis service.
|
||||
@@ -83,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -118,12 +117,11 @@ func (h *Hook) Init(config any) error {
|
||||
h.config.HPrefix = defaultHPrefix
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("address", h.config.Options.Addr).
|
||||
Str("username", h.config.Options.Username).
|
||||
Int("password-len", len(h.config.Options.Password)).
|
||||
Int("db", h.config.Options.DB).
|
||||
Msg("connecting to redis service")
|
||||
h.Log.Info("connecting to redis service",
|
||||
"address", h.config.Options.Addr,
|
||||
"username", h.config.Options.Username,
|
||||
"password-len", len(h.config.Options.Password),
|
||||
"db", h.config.Options.DB)
|
||||
|
||||
h.db = redis.NewClient(h.config.Options)
|
||||
_, err := h.db.Ping(context.Background()).Result()
|
||||
@@ -131,14 +129,15 @@ func (h *Hook) Init(config any) error {
|
||||
return fmt.Errorf("failed to ping service: %w", err)
|
||||
}
|
||||
|
||||
h.Log.Info().Msg("connected to redis service")
|
||||
h.Log.Info("connected to redis service")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the redis connection.
|
||||
// Stop closes the redis connection.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Info().Msg("disconnecting from redis service")
|
||||
h.Log.Info("disconnecting from redis service")
|
||||
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
@@ -147,8 +146,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
@@ -156,7 +154,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -185,14 +183,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
|
||||
h.Log.Error("failed to hset client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -200,32 +198,40 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Qos: reasonCodes[i],
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
|
||||
h.Log.Error("failed to hset subscription data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,14 +239,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
|
||||
h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -248,14 +254,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
|
||||
h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName))
|
||||
}
|
||||
|
||||
return
|
||||
@@ -284,14 +290,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
|
||||
h.Log.Error("failed to hset retained message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -319,27 +325,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
|
||||
h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
||||
h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk))
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
@@ -348,7 +354,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -360,72 +366,53 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the
|
||||
// provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
||||
}
|
||||
|
||||
if d.Created < expiry || d.Created == 0 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
h.Log.Error("failed to hset server info data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
||||
h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
|
||||
h.Log.Error("failed to HGetAll client data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Client
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
|
||||
h.Log.Error("failed to unmarshal client data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -437,20 +424,20 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
|
||||
h.Log.Error("failed to HGetAll subscription data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Subscription
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
|
||||
h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -462,20 +449,20 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
|
||||
h.Log.Error("failed to HGetAll retained message data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
|
||||
h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -487,20 +474,20 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
|
||||
h.Log.Error("failed to HGetAll inflight message data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
||||
h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -512,7 +499,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -522,7 +509,7 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
}
|
||||
|
||||
if err = v.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
|
||||
h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
miniredis "github.com/alicebob/miniredis/v2"
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
@@ -41,7 +41,7 @@ var (
|
||||
|
||||
func newHook(t *testing.T, addr string) *Hook {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
@@ -87,13 +87,13 @@ func TestSysInfoKey(t *testing.T) {
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
require.Equal(t, "redis-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
@@ -116,7 +116,7 @@ func TestHKey(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
h := newHook(t, defaultAddr)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h)
|
||||
@@ -137,7 +137,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -145,7 +145,7 @@ func TestInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestInitBadAddr(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: "abc:123",
|
||||
@@ -253,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, redis.Nil, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -269,6 +285,28 @@ func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -392,6 +430,22 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -484,60 +538,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
n := time.Now().Unix()
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
|
||||
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
|
||||
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
|
||||
&storage.Message{ID: "i3", T: storage.InflightKey},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var r []storage.Message
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
err = d.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
r = append(r, d)
|
||||
}
|
||||
require.Len(t, r, 1)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,7 +25,7 @@ var (
|
||||
ErrDBFileNotOpen = errors.New("db file not open")
|
||||
)
|
||||
|
||||
// Client is a storable representation of an mqtt client.
|
||||
// Client is a storable representation of an MQTT client.
|
||||
type Client struct {
|
||||
Will ClientWill `json:"will"` // will topic and payload data if applicable
|
||||
Properties ClientProperties `json:"properties"` // the connect properties for the client
|
||||
@@ -117,7 +117,37 @@ func (d *Message) UnmarshalBinary(data []byte) error {
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// Subscription is a storable representation of an mqtt subscription.
|
||||
// ToPacket converts a storage.Message to a standard packet.
|
||||
func (d *Message) ToPacket() packets.Packet {
|
||||
pk := packets.Packet{
|
||||
FixedHeader: d.FixedHeader,
|
||||
PacketID: d.PacketID,
|
||||
TopicName: d.TopicName,
|
||||
Payload: d.Payload,
|
||||
Origin: d.Origin,
|
||||
Created: d.Created,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: d.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
|
||||
ContentType: d.Properties.ContentType,
|
||||
ResponseTopic: d.Properties.ResponseTopic,
|
||||
CorrelationData: d.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: d.Properties.TopicAlias,
|
||||
User: d.Properties.User,
|
||||
},
|
||||
}
|
||||
|
||||
// Return a deep copy of the packet data otherwise the slices will
|
||||
// continue pointing at the values from the storage packet.
|
||||
pk = pk.Copy(true)
|
||||
pk.FixedHeader.Dup = d.FixedHeader.Dup
|
||||
|
||||
return pk
|
||||
}
|
||||
|
||||
// Subscription is a storable representation of an MQTT subscription.
|
||||
type Subscription struct {
|
||||
T string `json:"t"`
|
||||
ID string `json:"id" storm:"id"`
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -104,6 +104,7 @@ var (
|
||||
ClientsMaximum: 7,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
PacketsReceived: 12,
|
||||
PacketsSent: 13,
|
||||
Retained: 15,
|
||||
@@ -111,13 +112,13 @@ var (
|
||||
InflightDropped: 17,
|
||||
},
|
||||
}
|
||||
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
|
||||
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
|
||||
)
|
||||
|
||||
func TestClientMarshalBinary(t *testing.T) {
|
||||
data, err := clientStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientJSON, data)
|
||||
require.JSONEq(t, string(clientJSON), string(data))
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinary(t *testing.T) {
|
||||
@@ -137,7 +138,7 @@ func TestClientUnmarshalBinaryEmpty(t *testing.T) {
|
||||
func TestMessageMarshalBinary(t *testing.T) {
|
||||
data, err := messageStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, messageJSON, data)
|
||||
require.JSONEq(t, string(messageJSON), string(data))
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinary(t *testing.T) {
|
||||
@@ -157,7 +158,7 @@ func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
|
||||
func TestSubscriptionMarshalBinary(t *testing.T) {
|
||||
data, err := subscriptionStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, subscriptionJSON, data)
|
||||
require.JSONEq(t, string(subscriptionJSON), string(data))
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinary(t *testing.T) {
|
||||
@@ -177,7 +178,7 @@ func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
|
||||
func TestSysInfoMarshalBinary(t *testing.T) {
|
||||
data, err := sysInfoStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sysInfoJSON, data)
|
||||
require.JSONEq(t, string(sysInfoJSON), string(data))
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinary(t *testing.T) {
|
||||
@@ -193,3 +194,35 @@ func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SystemInfo{}, d)
|
||||
}
|
||||
|
||||
func TestMessageToPacket(t *testing.T) {
|
||||
d := messageStruct
|
||||
pk := d.ToPacket()
|
||||
|
||||
require.Equal(t, packets.Packet{
|
||||
Payload: []byte("payload"),
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Remaining: d.FixedHeader.Remaining,
|
||||
Type: d.FixedHeader.Type,
|
||||
Qos: d.FixedHeader.Qos,
|
||||
Dup: d.FixedHeader.Dup,
|
||||
Retain: d.FixedHeader.Retain,
|
||||
},
|
||||
Origin: d.Origin,
|
||||
TopicName: d.TopicName,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: d.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
|
||||
ContentType: d.Properties.ContentType,
|
||||
ResponseTopic: d.Properties.ResponseTopic,
|
||||
CorrelationData: d.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: d.Properties.TopicAlias,
|
||||
User: d.Properties.User,
|
||||
},
|
||||
PacketID: 100,
|
||||
Created: d.Created,
|
||||
}, pk)
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -27,6 +27,10 @@ type modifiedHookBase struct {
|
||||
|
||||
var errTestHook = errors.New("error")
|
||||
|
||||
func (h *modifiedHookBase) ID() string {
|
||||
return "modified"
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Init(config any) error {
|
||||
if config != nil {
|
||||
return errTestHook
|
||||
@@ -46,6 +50,14 @@ func (h *modifiedHookBase) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error {
|
||||
if h.fail {
|
||||
return errTestHook
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
@@ -178,12 +190,20 @@ func TestHooksProvides(t *testing.T) {
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHooksAddAndLen(t *testing.T) {
|
||||
func TestHooksAddLenGetAll(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(1), h.Len())
|
||||
|
||||
err = h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(2), h.Len())
|
||||
|
||||
all := h.GetAll()
|
||||
require.Equal(t, "base", all[0].ID())
|
||||
require.Equal(t, "modified", all[1].ID())
|
||||
}
|
||||
|
||||
func TestHooksAddInitFailure(t *testing.T) {
|
||||
@@ -195,7 +215,7 @@ func TestHooksAddInitFailure(t *testing.T) {
|
||||
|
||||
func TestHooksStop(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
@@ -216,7 +236,7 @@ func TestHooksNonReturns(t *testing.T) {
|
||||
h.OnStarted()
|
||||
h.OnStopped()
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
h.OnConnect(cl, packets.Packet{})
|
||||
h.OnSessionEstablish(cl, packets.Packet{})
|
||||
h.OnSessionEstablished(cl, packets.Packet{})
|
||||
h.OnDisconnect(cl, nil, false)
|
||||
h.OnPacketSent(cl, packets.Packet{}, []byte{})
|
||||
@@ -224,14 +244,16 @@ func TestHooksNonReturns(t *testing.T) {
|
||||
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
|
||||
h.OnUnsubscribed(cl, packets.Packet{})
|
||||
h.OnPublished(cl, packets.Packet{})
|
||||
h.OnPublishDropped(cl, packets.Packet{})
|
||||
h.OnRetainMessage(cl, packets.Packet{}, 0)
|
||||
h.OnRetainPublished(cl, packets.Packet{})
|
||||
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
|
||||
h.OnQosComplete(cl, packets.Packet{})
|
||||
h.OnQosDropped(cl, packets.Packet{})
|
||||
h.OnPacketIDExhausted(cl, packets.Packet{})
|
||||
h.OnWillSent(cl, packets.Packet{})
|
||||
h.OnClientExpired(cl)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
h.OnExpireInflights(cl, time.Now().Unix()-1)
|
||||
|
||||
// on second iteration, check added hook methods
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
@@ -312,7 +334,7 @@ func TestHooksOnUnsubscribe(t *testing.T) {
|
||||
|
||||
func TestHooksOnPublish(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -325,7 +347,7 @@ func TestHooksOnPublish(t *testing.T) {
|
||||
// coverage: failure
|
||||
hook.fail = true
|
||||
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: reject packet
|
||||
@@ -338,7 +360,7 @@ func TestHooksOnPublish(t *testing.T) {
|
||||
|
||||
func TestHooksOnPacketRead(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -364,7 +386,7 @@ func TestHooksOnPacketRead(t *testing.T) {
|
||||
|
||||
func TestHooksOnAuthPacket(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -380,9 +402,25 @@ func TestHooksOnAuthPacket(t *testing.T) {
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnConnect(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
|
||||
hook.fail = true
|
||||
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHooksOnPacketEncode(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -394,7 +432,7 @@ func TestHooksOnPacketEncode(t *testing.T) {
|
||||
|
||||
func TestHooksOnLWT(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -411,7 +449,7 @@ func TestHooksOnLWT(t *testing.T) {
|
||||
|
||||
func TestHooksStoredClients(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
@@ -433,7 +471,7 @@ func TestHooksStoredClients(t *testing.T) {
|
||||
|
||||
func TestHooksStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
@@ -455,7 +493,7 @@ func TestHooksStoredSubscriptions(t *testing.T) {
|
||||
|
||||
func TestHooksStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
@@ -477,7 +515,7 @@ func TestHooksStoredRetainedMessages(t *testing.T) {
|
||||
|
||||
func TestHooksStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
@@ -499,7 +537,7 @@ func TestHooksStoredInflightMessages(t *testing.T) {
|
||||
|
||||
func TestHooksStoredSysInfo(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
@@ -537,7 +575,7 @@ func TestHookBaseInit(t *testing.T) {
|
||||
|
||||
func TestHookBaseSetOpts(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
h.SetOpts(&logger, new(HookOptions))
|
||||
h.SetOpts(logger, new(HookOptions))
|
||||
require.NotNil(t, h.Log)
|
||||
require.NotNil(t, h.Opts)
|
||||
}
|
||||
@@ -552,12 +590,19 @@ func TestHookBaseOnConnectAuthenticate(t *testing.T) {
|
||||
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.False(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseOnACLCheck(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v := h.OnACLCheck(new(Client), "topic", true)
|
||||
require.False(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseOnConnect(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
err := h.OnConnect(new(Client), packets.Packet{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHookBaseOnPublish(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
|
||||
28
inflight.go
28
inflight.go
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
// Inflight is a map of InflightMessage keyed on packet id.
|
||||
@@ -58,6 +58,18 @@ func (i *Inflight) Len() int {
|
||||
return len(i.internal)
|
||||
}
|
||||
|
||||
// Clone returns a new instance of Inflight with the same message data.
|
||||
// This is used when transferring inflights from a taken-over session.
|
||||
func (i *Inflight) Clone() *Inflight {
|
||||
c := NewInflights()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
for k, v := range i.internal {
|
||||
c.internal[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// GetAll returns all the inflight messages.
|
||||
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
|
||||
i.RLock()
|
||||
@@ -104,14 +116,14 @@ func (i *Inflight) Delete(id uint16) bool {
|
||||
}
|
||||
|
||||
// TakeRecieveQuota reduces the receive quota by 1.
|
||||
func (i *Inflight) TakeReceiveQuota() {
|
||||
func (i *Inflight) DecreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) > 0 {
|
||||
atomic.AddInt32(&i.receiveQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// TakeRecieveQuota increases the receive quota by 1.
|
||||
func (i *Inflight) ReturnReceiveQuota() {
|
||||
func (i *Inflight) IncreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
|
||||
atomic.AddInt32(&i.receiveQuota, 1)
|
||||
}
|
||||
@@ -123,15 +135,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) {
|
||||
atomic.StoreInt32(&i.maximumReceiveQuota, n)
|
||||
}
|
||||
|
||||
// TakeSendQuota reduces the send quota by 1.
|
||||
func (i *Inflight) TakeSendQuota() {
|
||||
// DecreaseSendQuota reduces the send quota by 1.
|
||||
func (i *Inflight) DecreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) > 0 {
|
||||
atomic.AddInt32(&i.sendQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// ReturnSendQuota increases the send quota by 1.
|
||||
func (i *Inflight) ReturnSendQuota() {
|
||||
// IncreaseSendQuota increases the send quota by 1.
|
||||
func (i *Inflight) IncreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
|
||||
atomic.AddInt32(&i.sendQuota, 1)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInflightSet(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.True(t, r)
|
||||
@@ -25,7 +25,7 @@ func TestInflightSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightGet(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
msg, ok := cl.State.Inflight.Get(2)
|
||||
@@ -34,7 +34,7 @@ func TestInflightGet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
@@ -56,13 +56,23 @@ func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightLen(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestInflightClone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
cloned := cl.State.Inflight.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.NotSame(t, cloned, cl.State.Inflight)
|
||||
}
|
||||
|
||||
func TestInflightDelete(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
|
||||
require.NotNil(t, cl.State.Inflight.internal[3])
|
||||
@@ -95,12 +105,12 @@ func TestReceiveQuota(t *testing.T) {
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Return 1
|
||||
i.ReturnReceiveQuota()
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.ReturnReceiveQuota()
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
@@ -110,12 +120,12 @@ func TestReceiveQuota(t *testing.T) {
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Take 1
|
||||
i.TakeReceiveQuota()
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.TakeReceiveQuota()
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
@@ -137,12 +147,12 @@ func TestSendQuota(t *testing.T) {
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Return 1
|
||||
i.ReturnSendQuota()
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.ReturnSendQuota()
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
@@ -152,18 +162,18 @@ func TestSendQuota(t *testing.T) {
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Take 1
|
||||
i.TakeSendQuota()
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.TakeSendQuota()
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestNextImmediate(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
|
||||
100
listeners/http_healthcheck.go
Normal file
100
listeners/http_healthcheck.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: Derek Duncan
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
|
||||
type HTTPHealthCheck struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address.
|
||||
func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck {
|
||||
if config == nil {
|
||||
config = new(Config)
|
||||
}
|
||||
return &HTTPHealthCheck{
|
||||
id: id,
|
||||
address: address,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *HTTPHealthCheck) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *HTTPHealthCheck) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *HTTPHealthCheck) Protocol() string {
|
||||
if l.listen != nil && l.listen.TLSConfig != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
return "http"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPHealthCheck) Init(_ *slog.Logger) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
l.listen = &http.Server{
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPHealthCheck) Serve(establish EstablishFn) {
|
||||
if l.listen.TLSConfig != nil {
|
||||
_ = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
_ = l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *HTTPHealthCheck) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
143
listeners/http_healthcheck_test.go
Normal file
143
listeners/http_healthcheck_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: Derek Duncan
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPHealthCheck(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
require.Equal(t, "healthcheck", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckID(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
require.Equal(t, "healthcheck", l.ID())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckAddress(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckProtocol(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
require.Equal(t, "http", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
_ = l.Init(logger)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckInit(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.listen)
|
||||
require.Equal(t, testAddr, l.listen.Addr)
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckServeAndClose(t *testing.T) {
|
||||
// setup http stats listener
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// call healthcheck
|
||||
resp, err := http.Get("http://localhost" + testAddr + "/healthcheck")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck")
|
||||
require.Error(t, err)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
|
||||
// setup http stats listener
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// make disallowed method type http request
|
||||
resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody)
|
||||
require.Error(t, err)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -8,26 +8,25 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
)
|
||||
|
||||
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
|
||||
type HTTPStats struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
log *zerolog.Logger // server logger
|
||||
sysInfo *system.Info // pointers to the server data
|
||||
end uint32 // ensure the close methods are only called once
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
sysInfo *system.Info // pointers to the server data
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
@@ -63,9 +62,8 @@ func (l *HTTPStats) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPStats) Init(log *zerolog.Logger) error {
|
||||
func (l *HTTPStats) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.jsonHandler)
|
||||
l.listen = &http.Server{
|
||||
@@ -84,10 +82,17 @@ func (l *HTTPStats) Init(log *zerolog.Logger) error {
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPStats) Serve(establish EstablishFn) {
|
||||
|
||||
var err error
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
err = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
err = l.listen.ListenAndServe()
|
||||
}
|
||||
|
||||
// After the listener has been shutdown, no need to print the http.ErrServerClosed error.
|
||||
if err != nil && atomic.LoadUint32(&l.end) == 0 {
|
||||
l.log.Error("failed to serve.", "error", err, "listener", l.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +104,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
@@ -107,33 +112,12 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
|
||||
|
||||
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
|
||||
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
|
||||
info := &system.Info{
|
||||
Version: l.sysInfo.Version,
|
||||
Started: atomic.LoadInt64(&l.sysInfo.Started),
|
||||
Time: atomic.LoadInt64(&l.sysInfo.Time),
|
||||
Uptime: atomic.LoadInt64(&l.sysInfo.Uptime),
|
||||
BytesReceived: atomic.LoadInt64(&l.sysInfo.BytesReceived),
|
||||
BytesSent: atomic.LoadInt64(&l.sysInfo.BytesSent),
|
||||
ClientsConnected: atomic.LoadInt64(&l.sysInfo.ClientsConnected),
|
||||
ClientsMaximum: atomic.LoadInt64(&l.sysInfo.ClientsMaximum),
|
||||
ClientsTotal: atomic.LoadInt64(&l.sysInfo.ClientsTotal),
|
||||
ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected),
|
||||
MessagesReceived: atomic.LoadInt64(&l.sysInfo.MessagesReceived),
|
||||
MessagesSent: atomic.LoadInt64(&l.sysInfo.MessagesSent),
|
||||
InflightDropped: atomic.LoadInt64(&l.sysInfo.InflightDropped),
|
||||
Subscriptions: atomic.LoadInt64(&l.sysInfo.Subscriptions),
|
||||
PacketsReceived: atomic.LoadInt64(&l.sysInfo.PacketsReceived),
|
||||
PacketsSent: atomic.LoadInt64(&l.sysInfo.PacketsSent),
|
||||
Retained: atomic.LoadInt64(&l.sysInfo.Retained),
|
||||
Inflight: atomic.LoadInt64(&l.sysInfo.Inflight),
|
||||
MemoryAlloc: atomic.LoadInt64(&l.sysInfo.MemoryAlloc),
|
||||
Threads: atomic.LoadInt64(&l.sysInfo.Threads),
|
||||
}
|
||||
info := *l.sysInfo.Clone()
|
||||
|
||||
out, err := json.MarshalIndent(info, "", "\t")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
_, _ = io.WriteString(w, err.Error())
|
||||
}
|
||||
|
||||
w.Write(out)
|
||||
_, _ = w.Write(out)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -42,14 +42,14 @@ func TestHTTPStatsTLSProtocol(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, nil)
|
||||
|
||||
l.Init(nil)
|
||||
_ = l.Init(logger)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsInit(t *testing.T) {
|
||||
sysInfo := new(system.Info)
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.sysInfo)
|
||||
@@ -65,7 +65,7 @@ func TestHTTPStatsServeAndClose(t *testing.T) {
|
||||
|
||||
// setup http stats listener
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -113,7 +113,7 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, sysInfo)
|
||||
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -125,3 +125,28 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
}
|
||||
|
||||
func TestHTTPStatsFailedToServe(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
// setup http stats listener
|
||||
l := NewHTTPStats("t1", "wrong_addr", nil, sysInfo)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
<-o
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
@@ -22,24 +22,24 @@ type Config struct {
|
||||
// EstablishFn is a callback function for establishing new clients.
|
||||
type EstablishFn func(id string, c net.Conn) error
|
||||
|
||||
// CloseFunc is a callback function for closing all listener clients.
|
||||
// CloseFn is a callback function for closing all listener clients.
|
||||
type CloseFn func(id string)
|
||||
|
||||
// Listener is an interface for network listeners. A network listener listens
|
||||
// for incoming client connections and adds them to the server.
|
||||
type Listener interface {
|
||||
Init(*zerolog.Logger) error // open the network address
|
||||
Serve(EstablishFn) // starting actively listening for new connections
|
||||
ID() string // return the id of the listener
|
||||
Address() string // the address of the listener
|
||||
Protocol() string // the protocol in use by the listener
|
||||
Close(CloseFn) // stop and close the listener
|
||||
Init(*slog.Logger) error // open the network address
|
||||
Serve(EstablishFn) // starting actively listening for new connections
|
||||
ID() string // return the id of the listener
|
||||
Address() string // the address of the listener
|
||||
Protocol() string // the protocol in use by the listener
|
||||
Close(CloseFn) // stop and close the listener
|
||||
}
|
||||
|
||||
// Listeners contains the network listeners for the broker.
|
||||
type Listeners struct {
|
||||
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
ClientsWg sync.WaitGroup // a waitgroup that waits for all clients in all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -86,8 +86,6 @@ func (l *Listeners) Serve(id string, establisher EstablishFn) {
|
||||
listener := l.internal[id]
|
||||
|
||||
go func(e EstablishFn) {
|
||||
defer l.wg.Done()
|
||||
l.wg.Add(1)
|
||||
listener.Serve(e)
|
||||
}(establisher)
|
||||
}
|
||||
@@ -131,5 +129,5 @@ func (l *Listeners) CloseAll(closer CloseFn) {
|
||||
for _, id := range ids {
|
||||
l.Close(id, closer)
|
||||
}
|
||||
l.wg.Wait()
|
||||
l.ClientsWg.Wait()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -11,14 +11,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testAddr = ":22222"
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// MockEstablisher is a function signature which can be used in testing.
|
||||
@@ -53,7 +53,7 @@ func (l *MockListener) Serve(establisher EstablishFn) {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *MockListener) Init(log *zerolog.Logger) error {
|
||||
func (l *MockListener) Init(log *slog.Logger) error {
|
||||
if l.ErrListen {
|
||||
return fmt.Errorf("listen failure")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -16,7 +16,7 @@ func TestMockEstablisher(t *testing.T) {
|
||||
_, w := net.Pipe()
|
||||
err := MockEstablisher("t1", w)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
_ = w.Close()
|
||||
}
|
||||
|
||||
func TestNewMockListener(t *testing.T) {
|
||||
@@ -86,7 +86,7 @@ func TestMockListenerServe(t *testing.T) {
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
|
||||
mocked.Init(nil)
|
||||
_ = mocked.Init(nil)
|
||||
}
|
||||
|
||||
func TestMockListenerClose(t *testing.T) {
|
||||
|
||||
92
listeners/net.go
Normal file
92
listeners/net.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: Jeroen Rinzema
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Net is a listener for establishing client connections on basic TCP protocol.
|
||||
type Net struct { // [MQTT-4.2.0-1]
|
||||
mu sync.Mutex
|
||||
listener net.Listener // a net.Listener which will listen for new clients
|
||||
id string // the internal id of the listener
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
|
||||
func NewNet(id string, listener net.Listener) *Net {
|
||||
return &Net{
|
||||
id: id,
|
||||
listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Net) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Net) Address() string {
|
||||
return l.listener.Addr().String()
|
||||
}
|
||||
|
||||
// Protocol returns the network of the listener.
|
||||
func (l *Net) Protocol() string {
|
||||
return l.listener.Addr().Network()
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Net) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *Net) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Net) Close(closeClients CloseFn) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listener != nil {
|
||||
err := l.listener.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
105
listeners/net_test.go
Normal file
105
listeners/net_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewNet(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.id)
|
||||
}
|
||||
|
||||
func TestNetID(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestNetAddress(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, n.Addr().String(), l.Address())
|
||||
}
|
||||
|
||||
func TestNetProtocol(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestNetInit(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetServeAndClose(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestNetEstablishThenEnd(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
_, _ = net.Dial("tcp", n.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -10,18 +10,18 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// TCP is a listener for establishing client connections on basic TCP protocol.
|
||||
type TCP struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
listen net.Listener // a net.Listener which will listen for new clients
|
||||
config *Config // configuration values for the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
listen net.Listener // a net.Listener which will listen for new clients
|
||||
config *Config // configuration values for the listener
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
@@ -53,7 +53,7 @@ func (l *TCP) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *TCP) Init(log *zerolog.Logger) error {
|
||||
func (l *TCP) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
@@ -83,7 +83,7 @@ func (l *TCP) Serve(establish EstablishFn) {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -39,21 +39,21 @@ func TestTCPProtocolTLS(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
l.Init(&logger)
|
||||
_ = l.Init(logger)
|
||||
defer l.listen.Close()
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPInit(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewTCP("t2", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err = l2.Init(&logger)
|
||||
err = l2.Init(logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l2.config.TLSConfig)
|
||||
@@ -61,7 +61,7 @@ func TestTCPInit(t *testing.T) {
|
||||
|
||||
func TestTCPServeAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -88,7 +88,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -110,7 +110,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
|
||||
func TestTCPEstablishThenEnd(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -124,7 +124,7 @@ func TestTCPEstablishThenEnd(t *testing.T) {
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", l.listen.Addr().String())
|
||||
_, _ = net.Dial("tcp", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
|
||||
98
listeners/unixsock.go
Normal file
98
listeners/unixsock.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||
type UnixSock struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
|
||||
func NewUnixSock(id, address string) *UnixSock {
|
||||
return &UnixSock{
|
||||
id: id,
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *UnixSock) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *UnixSock) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *UnixSock) Protocol() string {
|
||||
return "unix"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *UnixSock) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
_ = os.Remove(l.address)
|
||||
l.listen, err = net.Listen("unix", l.address)
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new UnixSock connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *UnixSock) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *UnixSock) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
96
listeners/unixsock_test.go
Normal file
96
listeners/unixsock_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testUnixAddr = "mochi.sock"
|
||||
|
||||
func TestNewUnixSock(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testUnixAddr, l.address)
|
||||
}
|
||||
|
||||
func TestUnixSockID(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestUnixSockAddress(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, testUnixAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestUnixSockProtocol(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "unix", l.Protocol())
|
||||
}
|
||||
|
||||
func TestUnixSockInit(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewUnixSock("t2", testUnixAddr)
|
||||
err = l2.Init(logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnixSockServeAndClose(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
_, _ = net.Dial("unix", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -7,14 +7,16 @@ package listeners
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -28,8 +30,8 @@ type Websocket struct { // [MQTT-4.2.0-1]
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // an http server for serving websocket connections
|
||||
log *zerolog.Logger // server logger
|
||||
listen *http.Server // a http server for serving websocket connections
|
||||
log *slog.Logger // server logger
|
||||
establish EstablishFn // the server's establish connection handler
|
||||
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
|
||||
end uint32 // ensure the close methods are only called once
|
||||
@@ -74,7 +76,7 @@ func (l *Websocket) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Websocket) Init(log *zerolog.Logger) error {
|
||||
func (l *Websocket) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@@ -98,21 +100,27 @@ func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c})
|
||||
err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c})
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts waiting for new Websocket connections, and calls the connection
|
||||
// establishment callback for any received.
|
||||
func (l *Websocket) Serve(establish EstablishFn) {
|
||||
var err error
|
||||
l.establish = establish
|
||||
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
err = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
err = l.listen.ListenAndServe()
|
||||
}
|
||||
|
||||
// After the listener has been shutdown, no need to print the http.ErrServerClosed error.
|
||||
if err != nil && atomic.LoadUint32(&l.end) == 0 {
|
||||
l.log.Error("failed to serve.", "error", err, "listener", l.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +132,7 @@ func (l *Websocket) Close(closeClients CloseFn) {
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
@@ -134,28 +142,54 @@ func (l *Websocket) Close(closeClients CloseFn) {
|
||||
type wsConn struct {
|
||||
net.Conn
|
||||
c *websocket.Conn
|
||||
|
||||
// reader for the current message (can be nil)
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (n int, err error) {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return
|
||||
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
if ws.r == nil {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ws.r = r
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return
|
||||
}
|
||||
var n int
|
||||
for {
|
||||
// buffer is full, return what we've read so far
|
||||
if n == len(p) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
return r.Read(p)
|
||||
br, err := ws.r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
// when ANY error occurs, we consider this the end of the current message (either because it really is, via
|
||||
// io.EOF, or because something bad happened, in which case we want to drop the remainder)
|
||||
ws.r = nil
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (n int, err error) {
|
||||
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
@@ -37,24 +37,24 @@ func TestWebsocketProtocol(t *testing.T) {
|
||||
require.Equal(t, "ws", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocoTLS(t *testing.T) {
|
||||
func TestWebsocketProtocolTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
require.Equal(t, "wss", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsockeInit(t *testing.T) {
|
||||
func TestWebsocketInit(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Nil(t, l.listen)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketServeAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
_ = l.Init(logger)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
@@ -77,7 +77,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -92,11 +92,33 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestWebsocketFailedToServe(t *testing.T) {
|
||||
l := NewWebsocket("t1", "wrong_addr", &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
<-o
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestWebsocketUpgrade(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
_ = l.Init(logger)
|
||||
|
||||
e := make(chan bool)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
@@ -110,5 +132,46 @@ func TestWebsocketUpgrade(t *testing.T) {
|
||||
require.Equal(t, true, <-e)
|
||||
|
||||
s.Close()
|
||||
ws.Close()
|
||||
_ = ws.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketConnectionReads(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
_ = l.Init(nil)
|
||||
|
||||
recv := make(chan []byte)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
var out []byte
|
||||
for {
|
||||
buf := make([]byte, 2048)
|
||||
n, err := c.Read(buf)
|
||||
require.NoError(t, err)
|
||||
out = append(out, buf[:n]...)
|
||||
if n < 2048 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
recv <- out
|
||||
return nil
|
||||
}
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(l.handler))
|
||||
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkt := make([]byte, 3000) // make sure this is >2048
|
||||
for i := 0; i < len(pkt); i++ {
|
||||
pkt[i] = byte(i % 100)
|
||||
}
|
||||
|
||||
err = ws.WriteMessage(websocket.BinaryMessage, pkt)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := <-recv
|
||||
require.Equal(t, 3000, len(got))
|
||||
require.Equal(t, pkt, got)
|
||||
|
||||
s.Close()
|
||||
_ = ws.Close()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -376,7 +376,7 @@ func TestEncodeUint16(t *testing.T) {
|
||||
result = encodeUint16(32767)
|
||||
require.Equal(t, []byte{0x7f, 0xff}, result)
|
||||
|
||||
result = encodeUint16(65535)
|
||||
result = encodeUint16(math.MaxUint16)
|
||||
require.Equal(t, []byte{0xff, 0xff}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -21,13 +21,14 @@ func (c Code) Error() string {
|
||||
}
|
||||
|
||||
var (
|
||||
// QosCodes indicicates the reason codes for each Qos byte.
|
||||
// QosCodes indicates the reason codes for each Qos byte.
|
||||
QosCodes = map[byte]Code{
|
||||
0: CodeGrantedQos0,
|
||||
1: CodeGrantedQos1,
|
||||
2: CodeGrantedQos2,
|
||||
}
|
||||
|
||||
CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"}
|
||||
CodeSuccess = Code{Code: 0x00, Reason: "success"}
|
||||
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
|
||||
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
|
||||
@@ -113,15 +114,36 @@ var (
|
||||
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
|
||||
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
|
||||
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
|
||||
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
|
||||
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
|
||||
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
|
||||
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
|
||||
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
|
||||
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
|
||||
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
|
||||
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
|
||||
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
|
||||
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
|
||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||
ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
|
||||
|
||||
// MQTTv3 specific bytes.
|
||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||
Err3ClientIdentifierNotValid = Code{Code: 0x02}
|
||||
Err3ServerUnavailable = Code{Code: 0x03}
|
||||
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
|
||||
Err3NotAuthorized = Code{Code: 0x05}
|
||||
|
||||
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
|
||||
// This is required because MQTTv3 has different return byte specification.
|
||||
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
|
||||
V5CodesToV3 = map[Code]Code{
|
||||
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
|
||||
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
|
||||
ErrServerUnavailable: Err3ServerUnavailable,
|
||||
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
|
||||
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
|
||||
ErrBadUsernameOrPassword: Err3NotAuthorized,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -19,7 +19,7 @@ func TestCodesString(t *testing.T) {
|
||||
require.Equal(t, "test", c.String())
|
||||
}
|
||||
|
||||
func TestCodesErrorr(t *testing.T) {
|
||||
func TestCodesError(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "error",
|
||||
Code: 0x1,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -8,36 +8,38 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// All of the valid packet types and their packet identifier.
|
||||
// All valid packet types and their packet identifiers.
|
||||
const (
|
||||
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
|
||||
Connect // 1
|
||||
Connack // 2
|
||||
Publish // 3
|
||||
Puback // 4
|
||||
Pubrec // 5
|
||||
Pubrel // 6
|
||||
Pubcomp // 7
|
||||
Subscribe // 8
|
||||
Suback // 9
|
||||
Unsubscribe // 10
|
||||
Unsuback // 11
|
||||
Pingreq // 12
|
||||
Pingresp // 13
|
||||
Disconnect // 14
|
||||
Auth // 15
|
||||
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
|
||||
Connect // 1
|
||||
Connack // 2
|
||||
Publish // 3
|
||||
Puback // 4
|
||||
Pubrec // 5
|
||||
Pubrel // 6
|
||||
Pubcomp // 7
|
||||
Subscribe // 8
|
||||
Suback // 9
|
||||
Unsubscribe // 10
|
||||
Unsuback // 11
|
||||
Pingreq // 12
|
||||
Pingresp // 13
|
||||
Disconnect // 14
|
||||
Auth // 15
|
||||
WillProperties byte = 99 // Special byte for validating Will Properties.
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification.
|
||||
ErrNoValidPacketAvailable error = errors.New("no valid packet available")
|
||||
ErrNoValidPacketAvailable = errors.New("no valid packet available")
|
||||
|
||||
// PacketNames is a map of packet bytes to human readable names, for easier debugging.
|
||||
// PacketNames is a map of packet bytes to human-readable names, for easier debugging.
|
||||
PacketNames = map[byte]string{
|
||||
0: "Reserved",
|
||||
1: "Connect",
|
||||
@@ -133,6 +135,7 @@ type Packet struct {
|
||||
SessionPresent bool // session existed for connack
|
||||
ReasonCode byte // reason code for a packet response (acks, etc)
|
||||
ReservedBit byte // reserved, do not use (except in testing)
|
||||
Ignore bool // if true, do not perform any message forwarding operations
|
||||
}
|
||||
|
||||
// Mods specifies certain values required for certain mqtt v5 compliance within packet encoding/decoding.
|
||||
@@ -173,6 +176,7 @@ type Subscription struct {
|
||||
Qos byte
|
||||
RetainAsPublished bool
|
||||
NoLocal bool
|
||||
FwdRetainedFlag bool // true if the subscription forms part of a publish response to a client subscription and packet is retained.
|
||||
}
|
||||
|
||||
// Copy creates a new instance of a packet, but with an empty header for inheriting new QoS flags, etc.
|
||||
@@ -208,7 +212,10 @@ func (pk *Packet) Copy(allowTransfer bool) Packet {
|
||||
Created: pk.Created,
|
||||
Expiry: pk.Expiry,
|
||||
Origin: pk.Origin,
|
||||
PacketID: pk.PacketID, // ... ? Packet ID must not be transferred (in this manner)
|
||||
}
|
||||
|
||||
if allowTransfer {
|
||||
p.PacketID = pk.PacketID
|
||||
}
|
||||
|
||||
if len(pk.Connect.ProtocolName) > 0 {
|
||||
@@ -265,28 +272,28 @@ func (s Subscription) Merge(n Subscription) Subscription {
|
||||
}
|
||||
|
||||
// encode encodes a subscription and properties into bytes.
|
||||
func (p Subscription) encode() byte {
|
||||
func (s Subscription) encode() byte {
|
||||
var flag byte
|
||||
flag |= p.Qos
|
||||
flag |= s.Qos
|
||||
|
||||
if p.NoLocal {
|
||||
if s.NoLocal {
|
||||
flag |= 1 << 2
|
||||
}
|
||||
|
||||
if p.RetainAsPublished {
|
||||
if s.RetainAsPublished {
|
||||
flag |= 1 << 3
|
||||
}
|
||||
|
||||
flag |= p.RetainHandling << 4
|
||||
flag |= s.RetainHandling << 4
|
||||
return flag
|
||||
}
|
||||
|
||||
// decode decodes subscription bytes into a subscription struct.
|
||||
func (p *Subscription) decode(b byte) {
|
||||
p.Qos = b & 3 // byte
|
||||
p.NoLocal = 1&(b>>2) > 0 // bool
|
||||
p.RetainAsPublished = 1&(b>>3) > 0 // bool
|
||||
p.RetainHandling = 3 & (b >> 4) // byte
|
||||
func (s *Subscription) decode(b byte) {
|
||||
s.Qos = b & 3 // byte
|
||||
s.NoLocal = 1&(b>>2) > 0 // bool
|
||||
s.RetainAsPublished = 1&(b>>3) > 0 // bool
|
||||
s.RetainHandling = 3 & (b >> 4) // byte
|
||||
}
|
||||
|
||||
// ConnectEncode encodes a connect packet.
|
||||
@@ -309,7 +316,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
(&pk.Properties).Encode(pk, pb, 0)
|
||||
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -318,7 +325,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
if pk.Connect.WillFlag {
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
(&pk.Connect).WillProperties.Encode(pk, pb, 0)
|
||||
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -336,7 +343,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -379,21 +386,21 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
|
||||
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) // [MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
|
||||
if err != nil {
|
||||
return ErrClientIdentifierNotValid // [MQTT-3.1.3-8]
|
||||
}
|
||||
|
||||
if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
|
||||
if pk.ProtocolVersion == 5 {
|
||||
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
|
||||
n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
|
||||
if err != nil {
|
||||
return ErrMalformedWillProperties
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.Connect.WillTopic, offset, err = decodeString(buf, offset)
|
||||
@@ -408,6 +415,10 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
}
|
||||
|
||||
if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12]
|
||||
if offset >= len(buf) { // we are at the end of the packet
|
||||
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
|
||||
}
|
||||
|
||||
pk.Connect.Username, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedUsername
|
||||
@@ -439,18 +450,14 @@ func (pk *Packet) ConnectValidate() Code {
|
||||
return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3]
|
||||
}
|
||||
|
||||
if len(pk.Connect.Password) > 65535 {
|
||||
if len(pk.Connect.Password) > math.MaxUint16 {
|
||||
return ErrProtocolViolationPasswordTooLong
|
||||
}
|
||||
|
||||
if len(pk.Connect.Username) > 65535 {
|
||||
if len(pk.Connect.Username) > math.MaxUint16 {
|
||||
return ErrProtocolViolationUsernameTooLong
|
||||
}
|
||||
|
||||
if pk.Connect.UsernameFlag && len(pk.Connect.Username) == 0 {
|
||||
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
|
||||
}
|
||||
|
||||
if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 {
|
||||
return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16]
|
||||
}
|
||||
@@ -463,7 +470,7 @@ func (pk *Packet) ConnectValidate() Code {
|
||||
return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18]
|
||||
}
|
||||
|
||||
if len(pk.Connect.ClientIdentifier) > 65535 {
|
||||
if len(pk.Connect.ClientIdentifier) > math.MaxUint16 {
|
||||
return ErrClientIdentifierNotValid
|
||||
}
|
||||
|
||||
@@ -492,13 +499,13 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -535,13 +542,13 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -604,7 +611,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload))
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -612,7 +619,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -640,7 +647,7 @@ func (pk *Packet) PublishDecode(buf []byte) error {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.Payload = buf[offset:]
|
||||
@@ -688,7 +695,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
}
|
||||
@@ -700,7 +707,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -829,7 +836,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -837,7 +844,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -857,7 +864,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.ReasonCodes = buf[offset:]
|
||||
@@ -886,7 +893,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -894,7 +901,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -914,7 +921,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
var filter string
|
||||
@@ -981,7 +988,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -989,7 +996,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1010,7 +1017,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
|
||||
offset += n + 1
|
||||
offset += n
|
||||
|
||||
pk.ReasonCodes = buf[offset:]
|
||||
}
|
||||
@@ -1034,7 +1041,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -1042,7 +1049,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1062,7 +1069,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
var filter string
|
||||
@@ -1097,12 +1104,12 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -150,7 +150,7 @@ func TestPacketEncode(t *testing.T) {
|
||||
}
|
||||
|
||||
pk := new(Packet)
|
||||
copier.Copy(pk, wanted.Packet)
|
||||
_ = copier.Copy(pk, wanted.Packet)
|
||||
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
@@ -218,7 +218,7 @@ func TestPacketDecode(t *testing.T) {
|
||||
|
||||
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
pk.FixedHeader.Decode(wanted.RawBytes[0])
|
||||
_ = pk.FixedHeader.Decode(wanted.RawBytes[0])
|
||||
if len(wanted.RawBytes) > 0 {
|
||||
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
|
||||
}
|
||||
@@ -464,6 +464,9 @@ func TestCopy(t *testing.T) {
|
||||
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
|
||||
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
|
||||
|
||||
pkcc := tt.Packet.Copy(false)
|
||||
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -42,11 +42,11 @@ const (
|
||||
|
||||
// validPacketProperties indicates which properties are valid for which packet types.
|
||||
var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropPayloadFormat: {Publish: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1},
|
||||
PropContentType: {Publish: 1},
|
||||
PropResponseTopic: {Publish: 1},
|
||||
PropCorrelationData: {Publish: 1},
|
||||
PropPayloadFormat: {Publish: 1, WillProperties: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
|
||||
PropContentType: {Publish: 1, WillProperties: 1},
|
||||
PropResponseTopic: {Publish: 1, WillProperties: 1},
|
||||
PropCorrelationData: {Publish: 1, WillProperties: 1},
|
||||
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
|
||||
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
|
||||
PropAssignedClientID: {Connack: 1},
|
||||
@@ -54,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropRequestProblemInfo: {Connect: 1},
|
||||
PropWillDelayInterval: {Connect: 1},
|
||||
PropWillDelayInterval: {WillProperties: 1},
|
||||
PropRequestResponseInfo: {Connect: 1},
|
||||
PropResponseInfo: {Connack: 1},
|
||||
PropServerReference: {Connack: 1, Disconnect: 1},
|
||||
@@ -64,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropTopicAlias: {Publish: 1},
|
||||
PropMaximumQos: {Connack: 1},
|
||||
PropRetainAvailable: {Connack: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
|
||||
PropMaximumPacketSize: {Connect: 1, Connack: 1},
|
||||
PropWildcardSubAvailable: {Connack: 1},
|
||||
PropSubIDAvailable: {Connack: 1},
|
||||
@@ -77,7 +77,7 @@ type UserProperty struct { // [MQTT-1.5.7-1]
|
||||
Val string `json:"v"`
|
||||
}
|
||||
|
||||
// Properties contains all of the mqtt v5 properties available for a packet.
|
||||
// Properties contains all mqtt v5 properties available for a packet.
|
||||
// Some properties have valid values of 0 or not-present. In this case, we opt for
|
||||
// property flags to indicate the usage of property.
|
||||
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
|
||||
@@ -194,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
|
||||
}
|
||||
|
||||
// Encode encodes properties into a bytes buffer.
|
||||
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
pkt := pk.FixedHeader.Type
|
||||
|
||||
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
|
||||
buf.WriteByte(PropPayloadFormat)
|
||||
buf.WriteByte(p.PayloadFormat)
|
||||
@@ -217,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseTopic)
|
||||
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropCorrelationData)
|
||||
buf.Write(encodeBytes(p.CorrelationData))
|
||||
}
|
||||
@@ -277,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.WriteByte(p.RequestResponseInfo)
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseInfo)
|
||||
buf.Write(encodeString(p.ResponseInfo))
|
||||
}
|
||||
@@ -289,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
|
||||
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
|
||||
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
b := encodeString(p.ReasonString)
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
|
||||
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
|
||||
buf.WriteByte(PropReasonString)
|
||||
buf.Write(b)
|
||||
}
|
||||
@@ -322,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.WriteByte(p.RetainAvailable)
|
||||
}
|
||||
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
for _, v := range p.User {
|
||||
pb.WriteByte(PropUser)
|
||||
@@ -331,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
}
|
||||
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
|
||||
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
|
||||
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
|
||||
buf.Write(pb.Bytes())
|
||||
}
|
||||
}
|
||||
@@ -357,22 +355,23 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
}
|
||||
|
||||
encodeLength(b, int64(buf.Len()))
|
||||
buf.WriteTo(b) // [MQTT-3.1.3-10]
|
||||
_, _ = buf.WriteTo(b) // [MQTT-3.1.3-10]
|
||||
}
|
||||
|
||||
// Decode decodes property bytes into a properties struct.
|
||||
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
|
||||
if p == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, _, err = DecodeLength(b)
|
||||
var bu int
|
||||
n, bu, err = DecodeLength(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return n, nil
|
||||
return n + bu, nil
|
||||
}
|
||||
|
||||
bt := b.Bytes()
|
||||
@@ -380,11 +379,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
for offset := 0; offset < n; {
|
||||
k, offset, err = decodeByte(bt, offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if _, ok := validPacketProperties[k][pk]; !ok {
|
||||
return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
|
||||
if _, ok := validPacketProperties[k][pkt]; !ok {
|
||||
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
switch k {
|
||||
@@ -406,7 +405,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
|
||||
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
|
||||
offset += bu
|
||||
@@ -452,7 +451,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
var k, v string
|
||||
k, offset, err = decodeString(bt, offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
v, offset, err = decodeString(bt, offset)
|
||||
p.User = append(p.User, UserProperty{Key: k, Val: v})
|
||||
@@ -470,9 +469,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
return n + bu, nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -202,14 +202,14 @@ func init() {
|
||||
func TestEncodeProperties(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, propertiesBytes, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
|
||||
@@ -219,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
|
||||
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
|
||||
@@ -232,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
|
||||
|
||||
pr := tmp{}
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
|
||||
pr.p.Encode(Reserved, Mods{}, b, 0)
|
||||
require.Equal(t, []byte{}, b.Bytes())
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
|
||||
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
|
||||
props := new(Properties)
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, []byte{0x00}, b.Bytes())
|
||||
}
|
||||
|
||||
@@ -250,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
|
||||
props := new(Properties)
|
||||
n, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 172, n)
|
||||
require.Equal(t, 172+2, n)
|
||||
require.EqualValues(t, propertiesStruct, *props)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
@@ -40,7 +40,6 @@ const (
|
||||
TConnectMqtt5
|
||||
TConnectMqtt5LWT
|
||||
TConnectClean
|
||||
TConnectCleanLWT
|
||||
TConnectUserPass
|
||||
TConnectUserPassLWT
|
||||
TConnectMalProtocolName
|
||||
@@ -61,7 +60,6 @@ const (
|
||||
TConnectInvalidProtocolVersion2
|
||||
TConnectInvalidReservedBit
|
||||
TConnectInvalidClientIDTooLong
|
||||
TConnectInvalidPasswordNoUsername
|
||||
TConnectInvalidFlagNoUsername
|
||||
TConnectInvalidFlagNoPassword
|
||||
TConnectInvalidUsernameNoFlag
|
||||
@@ -71,7 +69,7 @@ const (
|
||||
TConnectInvalidWillFlagNoPayload
|
||||
TConnectInvalidWillFlagQosOutOfRange
|
||||
TConnectInvalidWillSurplusRetain
|
||||
TConnectNotCleanNoClientID
|
||||
TConnectZeroByteUsername
|
||||
TConnectSpecInvalidUTF8D800
|
||||
TConnectSpecInvalidUTF8DFFF
|
||||
TConnectSpecInvalidUTF80000
|
||||
@@ -82,6 +80,7 @@ const (
|
||||
TConnackAcceptedAdjustedExpiryInterval
|
||||
TConnackMinMqtt5
|
||||
TConnackMinCleanMqtt5
|
||||
TConnackServerKeepalive
|
||||
TConnackInvalidMinMqtt5
|
||||
TConnackBadProtocolVersion
|
||||
TConnackProtocolViolationNoSession
|
||||
@@ -89,6 +88,7 @@ const (
|
||||
TConnackServerUnavailable
|
||||
TConnackBadUsernamePassword
|
||||
TConnackBadUsernamePasswordNoSession
|
||||
TConnackMqtt5BadUsernamePasswordNoSession
|
||||
TConnackNotAuthorised
|
||||
TConnackMalSessionPresent
|
||||
TConnackMalReturnCode
|
||||
@@ -101,6 +101,7 @@ const (
|
||||
TPublishBasicMqtt5
|
||||
TPublishMqtt5
|
||||
TPublishQos1
|
||||
TPublishQos1Mqtt5
|
||||
TPublishQos1NoPayload
|
||||
TPublishQos1Dup
|
||||
TPublishQos2
|
||||
@@ -128,11 +129,14 @@ const (
|
||||
TPublishSpecDenySysTopic
|
||||
TPuback
|
||||
TPubackMqtt5
|
||||
TPubackMqtt5NotAuthorized
|
||||
TPubackMalPacketID
|
||||
TPubackMalProperties
|
||||
TPubackUnexpectedError
|
||||
TPubrec
|
||||
TPubrecMqtt5
|
||||
TPubrecMqtt5IDInUse
|
||||
TPubrecMqtt5NotAuthorized
|
||||
TPubrecMalPacketID
|
||||
TPubrecMalProperties
|
||||
TPubrecMalReasonCode
|
||||
@@ -180,7 +184,6 @@ const (
|
||||
TUnsubscribe
|
||||
TUnsubscribeMany
|
||||
TUnsubscribeMqtt5
|
||||
TUnsubscribeDropProperties
|
||||
TUnsubscribeMalPacketID
|
||||
TUnsubscribeMalTopicName
|
||||
TUnsubscribeMalProperties
|
||||
@@ -198,7 +201,6 @@ const (
|
||||
TDisconnect
|
||||
TDisconnectTakeover
|
||||
TDisconnectMqtt5
|
||||
TDisconnectNormalMqtt5
|
||||
TDisconnectSecondConnect
|
||||
TDisconnectReceiveMaximum
|
||||
TDisconnectDropProperties
|
||||
@@ -249,26 +251,26 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "mqtt v3.1.1",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Connect << 4, 16, // Fixed header
|
||||
Connect << 4, 15, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 60, // Keepalive
|
||||
0, 4, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', '3', // Client ID "zen"
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connect,
|
||||
Remaining: 16,
|
||||
Remaining: 15,
|
||||
},
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: false,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "zen3",
|
||||
ClientIdentifier: "zen",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -425,9 +427,9 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Connect << 4, 28, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
194, // Packet Flags
|
||||
0, 20, // Keepalive
|
||||
4, // Protocol Version
|
||||
0 | 1<<6 | 1<<7, // Packet Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 5, // Username MSB+LSB
|
||||
@@ -443,7 +445,7 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: true,
|
||||
Clean: false,
|
||||
Keepalive: 20,
|
||||
ClientIdentifier: "zen",
|
||||
UsernameFlag: true,
|
||||
@@ -497,6 +499,43 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectZeroByteUsername,
|
||||
Desc: "username flag but 0 byte username",
|
||||
Group: "decode",
|
||||
RawBytes: []byte{
|
||||
Connect << 4, 23, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
5, // Protocol Version
|
||||
130, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
5, // length
|
||||
17, 0, 0, 0, 120, // Session Expiry Interval (17)
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 0, // Username MSB+LSB
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connect,
|
||||
Remaining: 23,
|
||||
},
|
||||
ProtocolVersion: 5,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: true,
|
||||
Keepalive: 30,
|
||||
ClientIdentifier: "zen",
|
||||
Username: []byte{},
|
||||
UsernameFlag: true,
|
||||
},
|
||||
Properties: Properties{
|
||||
SessionExpiryInterval: uint32(120),
|
||||
SessionExpiryIntervalFlag: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Fail States
|
||||
{
|
||||
@@ -623,6 +662,24 @@ var TPacketData = map[byte]TPacketCases{
|
||||
'm', 'o', 'c',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TConnectInvalidFlagNoUsername,
|
||||
Desc: "username flag with no username bytes",
|
||||
Group: "decode",
|
||||
FailFirst: ErrProtocolViolationFlagNoUsername,
|
||||
RawBytes: []byte{
|
||||
Connect << 4, 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
5, // Protocol Version
|
||||
130, // Flags
|
||||
0, 20, // Keepalive
|
||||
0,
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectMalPassword,
|
||||
Desc: "malformed password",
|
||||
@@ -783,20 +840,6 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectInvalidFlagNoUsername,
|
||||
Desc: "has username flag but no username",
|
||||
Group: "validate",
|
||||
Expect: ErrProtocolViolationFlagNoUsername,
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{Type: Connect},
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
UsernameFlag: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectInvalidUsernameNoFlag,
|
||||
Desc: "has username but no flag",
|
||||
@@ -1043,25 +1086,22 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "accepted, no session, adjusted expiry interval mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 11, // fixed header
|
||||
Connack << 4, 8, // fixed header
|
||||
0, // Session present
|
||||
CodeSuccess.Code,
|
||||
8, // length
|
||||
5, // length
|
||||
17, 0, 0, 0, 120, // Session Expiry Interval (17)
|
||||
19, 0, 10, // Server Keep Alive (19)
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 11,
|
||||
Remaining: 8,
|
||||
},
|
||||
ReasonCode: CodeSuccess.Code,
|
||||
Properties: Properties{
|
||||
SessionExpiryInterval: uint32(120),
|
||||
SessionExpiryIntervalFlag: true,
|
||||
ServerKeepAlive: uint16(10),
|
||||
ServerKeepAliveFlag: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1148,28 +1188,25 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "accepted min properties mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 16, // fixed header
|
||||
Connack << 4, 13, // fixed header
|
||||
1, // existing session
|
||||
CodeSuccess.Code,
|
||||
13, // Properties length
|
||||
10, // Properties length
|
||||
18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18)
|
||||
19, 0, 20, // Server Keep Alive (19)
|
||||
36, 1, // Maximum Qos (36)
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 16,
|
||||
Remaining: 13,
|
||||
},
|
||||
SessionPresent: true,
|
||||
ReasonCode: CodeSuccess.Code,
|
||||
Properties: Properties{
|
||||
ServerKeepAlive: uint16(20),
|
||||
ServerKeepAliveFlag: true,
|
||||
AssignedClientID: "mochi",
|
||||
MaximumQos: byte(1),
|
||||
MaximumQosFlag: true,
|
||||
AssignedClientID: "mochi",
|
||||
MaximumQos: byte(1),
|
||||
MaximumQosFlag: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1178,11 +1215,10 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "accepted min properties mqtt5b",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 6, // fixed header
|
||||
Connack << 4, 3, // fixed header
|
||||
0, // existing session
|
||||
CodeSuccess.Code,
|
||||
3, // Properties length
|
||||
19, 0, 10, // server keepalive
|
||||
0, // Properties length
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
@@ -1192,6 +1228,27 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
SessionPresent: false,
|
||||
ReasonCode: CodeSuccess.Code,
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnackServerKeepalive,
|
||||
Desc: "server set keepalive",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 6, // fixed header
|
||||
1, // existing session
|
||||
CodeSuccess.Code,
|
||||
3, // Properties length
|
||||
19, 0, 10, // server keepalive
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 6,
|
||||
},
|
||||
SessionPresent: true,
|
||||
ReasonCode: CodeSuccess.Code,
|
||||
Properties: Properties{
|
||||
ServerKeepAlive: uint16(10),
|
||||
ServerKeepAliveFlag: true,
|
||||
@@ -1203,26 +1260,23 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "failure min properties mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: append([]byte{
|
||||
Connack << 4, 26, // fixed header
|
||||
Connack << 4, 23, // fixed header
|
||||
0, // No existing session
|
||||
ErrUnspecifiedError.Code,
|
||||
// Properties
|
||||
23, // length
|
||||
19, 0, 20, // Server Keep Alive (19)
|
||||
20, // length
|
||||
31, 0, 17, // Reason String (31)
|
||||
}, []byte(ErrUnspecifiedError.Reason)...),
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 25,
|
||||
Remaining: 23,
|
||||
},
|
||||
SessionPresent: false,
|
||||
ReasonCode: ErrUnspecifiedError.Code,
|
||||
Properties: Properties{
|
||||
ServerKeepAlive: uint16(20),
|
||||
ServerKeepAliveFlag: true,
|
||||
ReasonString: ErrUnspecifiedError.Reason,
|
||||
ReasonString: ErrUnspecifiedError.Reason,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1316,10 +1370,28 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "bad username or password no session",
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 2, // fixed header
|
||||
0, // No session present
|
||||
ErrBadUsernameOrPassword.Code,
|
||||
0, // No session present
|
||||
Err3NotAuthorized.Code, // use v3 remapping
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 2,
|
||||
},
|
||||
ReasonCode: Err3NotAuthorized.Code,
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnackMqtt5BadUsernamePasswordNoSession,
|
||||
Desc: "mqtt5 bad username or password no session",
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 3, // fixed header
|
||||
0, // No session present
|
||||
ErrBadUsernameOrPassword.Code,
|
||||
0,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 2,
|
||||
@@ -1327,6 +1399,7 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ReasonCode: ErrBadUsernameOrPassword.Code,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TConnackNotAuthorised,
|
||||
Desc: "not authorised",
|
||||
@@ -1631,6 +1704,43 @@ var TPacketData = map[byte]TPacketCases{
|
||||
PacketID: 7,
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPublishQos1Mqtt5,
|
||||
Desc: "mqtt v5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Publish<<4 | 1<<1, 37, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
// Properties
|
||||
16, // length
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Remaining: 37,
|
||||
Qos: 1,
|
||||
},
|
||||
PacketID: 7,
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TPublishQos1Dup,
|
||||
Desc: "qos:1, dup:true, packet id",
|
||||
@@ -1804,13 +1914,10 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Case: TPublishRetainMqtt5,
|
||||
Desc: "retain mqtt5",
|
||||
RawBytes: []byte{
|
||||
Publish<<4 | 1<<0, 35, // Fixed header
|
||||
Publish<<4 | 1<<0, 19, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
16, // properties length
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
0, // properties length
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
||||
},
|
||||
Packet: &Packet{
|
||||
@@ -1818,18 +1925,11 @@ var TPacketData = map[byte]TPacketCases{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Retain: true,
|
||||
Remaining: 35,
|
||||
Remaining: 19,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
Payload: []byte("hello mochi"),
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{},
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -2172,6 +2272,66 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubackMqtt5NotAuthorized,
|
||||
Desc: "QOS 1 publish not authorized mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Puback << 4, 37, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrNotAuthorized.Code, // Reason Code
|
||||
33, // Properties Length
|
||||
31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
|
||||
't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Puback,
|
||||
Remaining: 31,
|
||||
},
|
||||
PacketID: 7,
|
||||
ReasonCode: ErrNotAuthorized.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrNotAuthorized.Reason,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubackUnexpectedError,
|
||||
Desc: "unexpected error",
|
||||
Group: "decode",
|
||||
RawBytes: []byte{
|
||||
Puback << 4, 29, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrPayloadFormatInvalid.Code, // Reason Code
|
||||
25, // Properties Length
|
||||
31, 0, 22, 'p', 'a', 'y', 'l', 'o', 'a', 'd',
|
||||
' ', 'f', 'o', 'r', 'm', 'a', 't',
|
||||
' ', 'i', 'n', 'v', 'a', 'l', 'i', 'd', // Reason String (31)
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Puback,
|
||||
Remaining: 28,
|
||||
},
|
||||
PacketID: 7,
|
||||
ReasonCode: ErrPayloadFormatInvalid.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrPayloadFormatInvalid.Reason,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Fail states
|
||||
{
|
||||
@@ -2253,14 +2413,17 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "packet id in use mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Pubrec << 4, 31, // Fixed header
|
||||
Pubrec << 4, 47, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrPacketIdentifierInUse.Code, // Reason Code
|
||||
27, // Properties Length
|
||||
43, // Properties Length
|
||||
31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't',
|
||||
' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r',
|
||||
' ', 'i', 'n',
|
||||
' ', 'u', 's', 'e', // Reason String (31)
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
@@ -2272,6 +2435,46 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ReasonCode: ErrPacketIdentifierInUse.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrPacketIdentifierInUse.Reason,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubrecMqtt5NotAuthorized,
|
||||
Desc: "QOS 2 publish not authorized mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Pubrec << 4, 37, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrNotAuthorized.Code, // Reason Code
|
||||
33, // Properties Length
|
||||
31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
|
||||
't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Pubrec,
|
||||
Remaining: 31,
|
||||
},
|
||||
PacketID: 7,
|
||||
ReasonCode: ErrNotAuthorized.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrNotAuthorized.Reason,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
1650
server_test.go
1650
server_test.go
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package system
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Info contains atomic counters and values for various server statistics
|
||||
// commonly found in $SYS topics (and others).
|
||||
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
|
||||
@@ -20,6 +22,7 @@ type Info struct {
|
||||
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
|
||||
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
|
||||
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
|
||||
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
|
||||
Retained int64 `json:"retained"` // total number of retained messages active on the broker
|
||||
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
|
||||
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
|
||||
@@ -29,3 +32,30 @@ type Info struct {
|
||||
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
|
||||
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
|
||||
}
|
||||
|
||||
// Clone makes a copy of Info using atomic operation
|
||||
func (i *Info) Clone() *Info {
|
||||
return &Info{
|
||||
Version: i.Version,
|
||||
Started: atomic.LoadInt64(&i.Started),
|
||||
Time: atomic.LoadInt64(&i.Time),
|
||||
Uptime: atomic.LoadInt64(&i.Uptime),
|
||||
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
|
||||
BytesSent: atomic.LoadInt64(&i.BytesSent),
|
||||
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
|
||||
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
|
||||
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
|
||||
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
|
||||
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
|
||||
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
|
||||
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
|
||||
Retained: atomic.LoadInt64(&i.Retained),
|
||||
Inflight: atomic.LoadInt64(&i.Inflight),
|
||||
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
|
||||
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
|
||||
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
|
||||
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
|
||||
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
|
||||
Threads: atomic.LoadInt64(&i.Threads),
|
||||
}
|
||||
}
|
||||
|
||||
37
system/system_test.go
Normal file
37
system/system_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
o := &Info{
|
||||
Version: "version",
|
||||
Started: 1,
|
||||
Time: 2,
|
||||
Uptime: 3,
|
||||
BytesReceived: 4,
|
||||
BytesSent: 5,
|
||||
ClientsConnected: 6,
|
||||
ClientsMaximum: 7,
|
||||
ClientsTotal: 8,
|
||||
ClientsDisconnected: 9,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
Retained: 12,
|
||||
Inflight: 13,
|
||||
InflightDropped: 14,
|
||||
Subscriptions: 15,
|
||||
PacketsReceived: 16,
|
||||
PacketsSent: 17,
|
||||
MemoryAlloc: 18,
|
||||
Threads: 19,
|
||||
}
|
||||
|
||||
n := o.Clone()
|
||||
|
||||
require.Equal(t, o, n)
|
||||
}
|
||||
196
topics.go
196
topics.go
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -186,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio
|
||||
return m
|
||||
}
|
||||
|
||||
// InlineSubFn is the signature for a callback function which will be called
|
||||
// when an inline client receives a message on a topic it is subscribed to.
|
||||
// The sub argument contains information about the subscription that was matched for any filters.
|
||||
type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
|
||||
|
||||
// InlineSubscriptions represents a map of internal subscriptions keyed on client.
|
||||
type InlineSubscriptions struct {
|
||||
internal map[int]InlineSubscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
|
||||
func NewInlineSubscriptions() *InlineSubscriptions {
|
||||
return &InlineSubscriptions{
|
||||
internal: map[int]InlineSubscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new internal subscription for a client id.
|
||||
func (s *InlineSubscriptions) Add(val InlineSubscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.internal[val.Identifier] = val
|
||||
}
|
||||
|
||||
// GetAll returns all internal subscriptions.
|
||||
func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[int]InlineSubscription{}
|
||||
for k, v := range s.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns an internal subscription for a client id.
|
||||
func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val, ok = s.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the number of internal subscriptions.
|
||||
func (s *InlineSubscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes an internal subscription by the client id.
|
||||
func (s *InlineSubscriptions) Delete(id int) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal, id)
|
||||
}
|
||||
|
||||
// Subscriptions is a map of subscriptions keyed on client.
|
||||
type Subscriptions struct {
|
||||
internal map[string]packets.Subscription
|
||||
@@ -244,11 +303,17 @@ func (s *Subscriptions) Delete(id string) {
|
||||
// ClientSubscriptions is a map of aggregated subscriptions for a client.
|
||||
type ClientSubscriptions map[string]packets.Subscription
|
||||
|
||||
type InlineSubscription struct {
|
||||
packets.Subscription
|
||||
Handler InlineSubFn
|
||||
}
|
||||
|
||||
// Subscribers contains the shared and non-shared subscribers matching a topic.
|
||||
type Subscribers struct {
|
||||
Shared map[string]map[string]packets.Subscription
|
||||
SharedSelected map[string]packets.Subscription
|
||||
Subscriptions map[string]packets.Subscription
|
||||
Shared map[string]map[string]packets.Subscription
|
||||
SharedSelected map[string]packets.Subscription
|
||||
Subscriptions map[string]packets.Subscription
|
||||
InlineSubscriptions map[int]InlineSubscription
|
||||
}
|
||||
|
||||
// SelectShared returns one subscriber for each shared subscription group.
|
||||
@@ -298,9 +363,45 @@ func NewTopicsIndex() *TopicsIndex {
|
||||
}
|
||||
}
|
||||
|
||||
// InlineSubscribe adds a new internal subscription for a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
n := x.set(subscription.Filter, 0)
|
||||
_, existed = n.inlineSubscriptions.Get(subscription.Identifier)
|
||||
n.inlineSubscriptions.Add(subscription)
|
||||
|
||||
return !existed
|
||||
}
|
||||
|
||||
// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
|
||||
// returning true if the subscription existed.
|
||||
func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
particle := x.seek(filter, 0)
|
||||
if particle == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
particle.inlineSubscriptions.Delete(id)
|
||||
|
||||
if particle.inlineSubscriptions.Len() == 0 {
|
||||
x.trim(particle)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Subscribe adds a new subscription for a client to a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
prefix, _ := isolateParticle(subscription.Filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
@@ -320,8 +421,13 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription
|
||||
// Unsubscribe removes a subscription filter for a client, returning true if the
|
||||
// subscription existed.
|
||||
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var d int
|
||||
if strings.HasPrefix(filter, SharePrefix) {
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
shareSub := strings.EqualFold(prefix, SharePrefix)
|
||||
if shareSub {
|
||||
d = 2
|
||||
}
|
||||
|
||||
@@ -330,8 +436,7 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
if shareSub {
|
||||
group, _ := isolateParticle(filter, 1)
|
||||
particle.shared.Delete(group, client)
|
||||
} else {
|
||||
@@ -346,7 +451,12 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
// 1 if a retained message was added, and -1 if the retained message was removed.
|
||||
// 0 is returned if sequential empty payloads are received.
|
||||
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
n := x.set(pk.TopicName, 0)
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if len(pk.Payload) > 0 {
|
||||
n.retainPath = pk.TopicName
|
||||
x.Retained.Add(pk.TopicName, pk)
|
||||
@@ -361,6 +471,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
|
||||
n.retainPath = ""
|
||||
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
|
||||
x.trim(n)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -471,9 +582,10 @@ func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []pack
|
||||
// their subscription ids and highest qos.
|
||||
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
|
||||
return x.scanSubscribers(topic, 0, nil, &Subscribers{
|
||||
Shared: map[string]map[string]packets.Subscription{},
|
||||
SharedSelected: map[string]packets.Subscription{},
|
||||
Subscriptions: map[string]packets.Subscription{},
|
||||
Shared: map[string]map[string]packets.Subscription{},
|
||||
SharedSelected: map[string]packets.Subscription{},
|
||||
Subscriptions: map[string]packets.Subscription{},
|
||||
InlineSubscriptions: map[int]InlineSubscription{},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -488,20 +600,30 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
|
||||
}
|
||||
|
||||
key, hasNext := isolateParticle(topic, d)
|
||||
for _, partKey := range []string{key, "+", "#"} {
|
||||
for _, partKey := range []string{key, "+"} {
|
||||
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" {
|
||||
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
||||
}
|
||||
|
||||
if hasNext {
|
||||
x.scanSubscribers(topic, d+1, particle, subs)
|
||||
} else {
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
|
||||
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
|
||||
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
||||
x.gatherSharedSubscriptions(wild, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if particle := n.particles.get("#"); particle != nil {
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
}
|
||||
|
||||
return subs
|
||||
}
|
||||
|
||||
@@ -542,6 +664,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr
|
||||
}
|
||||
}
|
||||
|
||||
// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
|
||||
func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
|
||||
if subs.InlineSubscriptions == nil {
|
||||
subs.InlineSubscriptions = map[int]InlineSubscription{}
|
||||
}
|
||||
|
||||
for id, inline := range particle.inlineSubscriptions.GetAll() {
|
||||
subs.InlineSubscriptions[id] = inline
|
||||
}
|
||||
}
|
||||
|
||||
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
||||
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
||||
var next, end int
|
||||
@@ -572,7 +705,7 @@ func IsSharedFilter(filter string) bool {
|
||||
|
||||
// IsValidFilter returns true if the filter is valid.
|
||||
func IsValidFilter(filter string, forPublish bool) bool {
|
||||
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
|
||||
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish.
|
||||
return false // [MQTT-4.7.3-1]
|
||||
}
|
||||
|
||||
@@ -613,22 +746,25 @@ func IsValidFilter(filter string, forPublish bool) bool {
|
||||
|
||||
// particle is a child node on the tree.
|
||||
type particle struct {
|
||||
key string // the key of the particle
|
||||
parent *particle // a pointer to the parent of the particle
|
||||
particles particles // a map of child particles
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
retainPath string // path of a retained message
|
||||
key string // the key of the particle
|
||||
parent *particle // a pointer to the parent of the particle
|
||||
particles particles // a map of child particles
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
|
||||
retainPath string // path of a retained message
|
||||
sync.Mutex // mutex for when making changes to the particle
|
||||
}
|
||||
|
||||
// newParticle returns a pointer to a new instance of particle.
|
||||
func newParticle(key string, parent *particle) *particle {
|
||||
return &particle{
|
||||
key: key,
|
||||
parent: parent,
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
shared: NewSharedSubscriptions(),
|
||||
key: key,
|
||||
parent: parent,
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
shared: NewSharedSubscriptions(),
|
||||
inlineSubscriptions: NewInlineSubscriptions(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
242
topics_test.go
242
topics_test.go
@@ -1,13 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -319,7 +320,7 @@ func TestUnsubscribeShared(t *testing.T) {
|
||||
require.True(t, exists)
|
||||
require.Equal(t, byte(2), client.Qos)
|
||||
|
||||
require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
|
||||
require.True(t, index.Unsubscribe("$share/tmp/a/b/c", "cl1"))
|
||||
_, exists = final.shared.Get("tmp", "cl1")
|
||||
require.False(t, exists)
|
||||
}
|
||||
@@ -501,28 +502,40 @@ func TestScanSubscribers(t *testing.T) {
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
|
||||
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 4, len(subs.Subscriptions))
|
||||
require.Equal(t, 3, len(subs.Subscriptions))
|
||||
require.Contains(t, subs.Subscriptions, "cl1")
|
||||
require.Contains(t, subs.Subscriptions, "cl2")
|
||||
require.Contains(t, subs.Subscriptions, "cl3")
|
||||
require.Contains(t, subs.Subscriptions, "cl4")
|
||||
|
||||
require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
|
||||
require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
|
||||
require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
|
||||
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
|
||||
|
||||
require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
|
||||
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
|
||||
require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
|
||||
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
|
||||
require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
|
||||
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
|
||||
|
||||
subs = index.scanSubscribers("d/e/f/g", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 1, len(subs.Subscriptions))
|
||||
require.Contains(t, subs.Subscriptions, "cl4")
|
||||
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
|
||||
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
|
||||
|
||||
subs = index.scanSubscribers("", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 0, len(subs.Subscriptions))
|
||||
}
|
||||
|
||||
func TestScanSubscribersTopicInheritanceBug(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 0, Filter: "a/b/c"})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/b"})
|
||||
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 1, len(subs.Subscriptions))
|
||||
}
|
||||
|
||||
func TestScanSubscribersShared(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
|
||||
@@ -531,8 +544,9 @@ func TestScanSubscribersShared(t *testing.T) {
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
|
||||
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
|
||||
index.Subscribe("cl5", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c/#"})
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 3, len(subs.Shared))
|
||||
require.Equal(t, 4, len(subs.Shared))
|
||||
}
|
||||
|
||||
func TestSelectSharedSubscriber(t *testing.T) {
|
||||
@@ -840,3 +854,215 @@ func TestNewTopicAliases(t *testing.T) {
|
||||
require.NotNil(t, a.Outbound)
|
||||
require.Equal(t, uint16(5), a.Outbound.maximum)
|
||||
}
|
||||
|
||||
func TestNewInlineSubscriptions(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
require.NotNil(t, subscriptions)
|
||||
require.NotNil(t, subscriptions.internal)
|
||||
require.Equal(t, 0, subscriptions.Len())
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionAdd(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
subscription := InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
Handler: handler,
|
||||
}
|
||||
subscriptions.Add(subscription)
|
||||
|
||||
sub, ok := subscriptions.Get(1)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "a/b/c", sub.Filter)
|
||||
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionGet(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
subscription := InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
Handler: handler,
|
||||
}
|
||||
subscriptions.Add(subscription)
|
||||
|
||||
sub, ok := subscriptions.Get(1)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "a/b/c", sub.Filter)
|
||||
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
|
||||
|
||||
_, ok = subscriptions.Get(999)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionsGetAll(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
})
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
})
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2},
|
||||
})
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3},
|
||||
})
|
||||
|
||||
allSubs := subscriptions.GetAll()
|
||||
require.Len(t, allSubs, 3)
|
||||
require.Contains(t, allSubs, 1)
|
||||
require.Contains(t, allSubs, 2)
|
||||
require.Contains(t, allSubs, 3)
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionDelete(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
subscription := InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
Handler: handler,
|
||||
}
|
||||
subscriptions.Add(subscription)
|
||||
|
||||
subscriptions.Delete(1)
|
||||
_, ok := subscriptions.Get(1)
|
||||
require.False(t, ok)
|
||||
require.Empty(t, subscriptions.GetAll())
|
||||
require.Zero(t, subscriptions.Len())
|
||||
}
|
||||
|
||||
func TestInlineSubscribe(t *testing.T) {
|
||||
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
desc string
|
||||
filter string
|
||||
subscription InlineSubscription
|
||||
wasNew bool
|
||||
}{
|
||||
{
|
||||
desc: "subscribe",
|
||||
filter: "a/b/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "subscribe existed",
|
||||
filter: "a/b/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
|
||||
wasNew: false,
|
||||
},
|
||||
{
|
||||
desc: "subscribe different identifier",
|
||||
filter: "a/b/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "subscribe case sensitive didnt exist",
|
||||
filter: "A/B/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard+ sub",
|
||||
filter: "d/+",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard# sub",
|
||||
filter: "d/e/#",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
for _, tx := range tt {
|
||||
t.Run(tx.desc, func(t *testing.T) {
|
||||
require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription))
|
||||
})
|
||||
}
|
||||
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
}
|
||||
|
||||
func TestInlineUnsubscribe(t *testing.T) {
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
|
||||
sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index = NewTopicsIndex()
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
ok := index.InlineUnsubscribe(1, "a/b/c/d")
|
||||
require.True(t, ok)
|
||||
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
|
||||
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
ok = index.InlineUnsubscribe(1, "d/e/f")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f"))
|
||||
|
||||
ok = index.InlineUnsubscribe(1, "not/exist")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
1
vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
1
vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
@@ -1 +0,0 @@
|
||||
language: go
|
||||
35
vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
35
vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
@@ -1,35 +0,0 @@
|
||||
bbloom.go
|
||||
|
||||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
siphash.go
|
||||
|
||||
// https://github.com/dchest/siphash
|
||||
//
|
||||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
131
vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
131
vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
@@ -1,131 +0,0 @@
|
||||
## bbloom: a bitset Bloom filter for go/golang
|
||||
===
|
||||
|
||||
[](http://travis-ci.org/AndreasBriese/bbloom)
|
||||
|
||||
package implements a fast bloom filter with real 'bitset' and JSONMarshal/JSONUnmarshal to store/reload the Bloom filter.
|
||||
|
||||
NOTE: the package uses unsafe.Pointer to set and read the bits from the bitset. If you're uncomfortable with using the unsafe package, please consider using my bloom filter package at github.com/AndreasBriese/bloom
|
||||
|
||||
===
|
||||
|
||||
changelog 11/2015: new thread safe methods AddTS(), HasTS(), AddIfNotHasTS() following a suggestion from Srdjan Marinovic (github @a-little-srdjan), who used this to code a bloomfilter cache.
|
||||
|
||||
This bloom filter was developed to strengthen a website-log database and was tested and optimized for this log-entry mask: "2014/%02i/%02i %02i:%02i:%02i /info.html".
|
||||
Nonetheless bbloom should work with any other form of entries.
|
||||
|
||||
~~Hash function is a modified Berkeley DB sdbm hash (to optimize for smaller strings). sdbm http://www.cse.yorku.ca/~oz/hash.html~~
|
||||
|
||||
Found sipHash (SipHash-2-4, a fast short-input PRF created by Jean-Philippe Aumasson and Daniel J. Bernstein.) to be about as fast. sipHash had been ported by Dimtry Chestnyk to Go (github.com/dchest/siphash )
|
||||
|
||||
Minimum hashset size is: 512 ([4]uint64; will be set automatically).
|
||||
|
||||
###install
|
||||
|
||||
```sh
|
||||
go get github.com/AndreasBriese/bbloom
|
||||
```
|
||||
|
||||
###test
|
||||
+ change to folder ../bbloom
|
||||
+ create wordlist in file "words.txt" (you might use `python permut.py`)
|
||||
+ run 'go test -bench=.' within the folder
|
||||
|
||||
```go
|
||||
go test -bench=.
|
||||
```
|
||||
|
||||
~~If you've installed the GOCONVEY TDD-framework http://goconvey.co/ you can run the tests automatically.~~
|
||||
|
||||
using go's testing framework now (have in mind that the op timing is related to 65536 operations of Add, Has, AddIfNotHas respectively)
|
||||
|
||||
### usage
|
||||
|
||||
after installation add
|
||||
|
||||
```go
|
||||
import (
|
||||
...
|
||||
"github.com/AndreasBriese/bbloom"
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
at your header. In the program use
|
||||
|
||||
```go
|
||||
// create a bloom filter for 65536 items and 1 % wrong-positive ratio
|
||||
bf := bbloom.New(float64(1<<16), float64(0.01))
|
||||
|
||||
// or
|
||||
// create a bloom filter with 650000 for 65536 items and 7 locs per hash explicitly
|
||||
// bf = bbloom.New(float64(650000), float64(7))
|
||||
// or
|
||||
bf = bbloom.New(650000.0, 7.0)
|
||||
|
||||
// add one item
|
||||
bf.Add([]byte("butter"))
|
||||
|
||||
// Number of elements added is exposed now
|
||||
// Note: ElemNum will not be included in JSON export (for compatability to older version)
|
||||
nOfElementsInFilter := bf.ElemNum
|
||||
|
||||
// check if item is in the filter
|
||||
isIn := bf.Has([]byte("butter")) // should be true
|
||||
isNotIn := bf.Has([]byte("Butter")) // should be false
|
||||
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added := bf.AddIfNotHas([]byte("butter")) // should be false because 'butter' is already in the set
|
||||
added = bf.AddIfNotHas([]byte("buTTer")) // should be true because 'buTTer' is new
|
||||
|
||||
// thread safe versions for concurrent use: AddTS, HasTS, AddIfNotHasTS
|
||||
// add one item
|
||||
bf.AddTS([]byte("peanutbutter"))
|
||||
// check if item is in the filter
|
||||
isIn = bf.HasTS([]byte("peanutbutter")) // should be true
|
||||
isNotIn = bf.HasTS([]byte("peanutButter")) // should be false
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added = bf.AddIfNotHasTS([]byte("butter")) // should be false because 'peanutbutter' is already in the set
|
||||
added = bf.AddIfNotHasTS([]byte("peanutbuTTer")) // should be true because 'penutbuTTer' is new
|
||||
|
||||
// convert to JSON ([]byte)
|
||||
Json := bf.JSONMarshal()
|
||||
|
||||
// bloomfilters Mutex is exposed for external un-/locking
|
||||
// i.e. mutex lock while doing JSON conversion
|
||||
bf.Mtx.Lock()
|
||||
Json = bf.JSONMarshal()
|
||||
bf.Mtx.Unlock()
|
||||
|
||||
// restore a bloom filter from storage
|
||||
bfNew := bbloom.JSONUnmarshal(Json)
|
||||
|
||||
isInNew := bfNew.Has([]byte("butter")) // should be true
|
||||
isNotInNew := bfNew.Has([]byte("Butter")) // should be false
|
||||
|
||||
```
|
||||
|
||||
to work with the bloom filter.
|
||||
|
||||
### why 'fast'?
|
||||
|
||||
It's about 3 times faster than William Fitzgeralds bitset bloom filter https://github.com/willf/bloom . And it is about so fast as my []bool set variant for Boom filters (see https://github.com/AndreasBriese/bloom ) but having a 8times smaller memory footprint:
|
||||
|
||||
|
||||
Bloom filter (filter size 524288, 7 hashlocs)
|
||||
github.com/AndreasBriese/bbloom 'Add' 65536 items (10 repetitions): 6595800 ns (100 ns/op)
|
||||
github.com/AndreasBriese/bbloom 'Has' 65536 items (10 repetitions): 5986600 ns (91 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Add' 65536 items (10 repetitions): 6304684 ns (96 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Has' 65536 items (10 repetitions): 6568663 ns (100 ns/op)
|
||||
|
||||
github.com/willf/bloom 'Add' 65536 items (10 repetitions): 24367224 ns (371 ns/op)
|
||||
github.com/willf/bloom 'Test' 65536 items (10 repetitions): 21881142 ns (333 ns/op)
|
||||
github.com/dataence/bloom/standard 'Add' 65536 items (10 repetitions): 23041644 ns (351 ns/op)
|
||||
github.com/dataence/bloom/standard 'Check' 65536 items (10 repetitions): 19153133 ns (292 ns/op)
|
||||
github.com/cabello/bloom 'Add' 65536 items (10 repetitions): 131921507 ns (2012 ns/op)
|
||||
github.com/cabello/bloom 'Contains' 65536 items (10 repetitions): 131108962 ns (2000 ns/op)
|
||||
|
||||
(on MBPro15 OSX10.8.5 i7 4Core 2.4Ghz)
|
||||
|
||||
|
||||
With 32bit bloom filters (bloom32) using modified sdbm, bloom32 does hashing with only 2 bit shifts, one xor and one substraction per byte. smdb is about as fast as fnv64a but gives less collisions with the dataset (see mask above). bloom.New(float64(10 * 1<<16),float64(7)) populated with 1<<16 random items from the dataset (see above) and tested against the rest results in less than 0.05% collisions.
|
||||
284
vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
284
vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
@@ -1,284 +0,0 @@
|
||||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
// 2019/08/25 code revision to reduce unsafe use
|
||||
// Parts are adopted from the fork at ipfs/bbloom after performance rev by
|
||||
// Steve Allen (https://github.com/Stebalien)
|
||||
// (see https://github.com/ipfs/bbloom/blob/master/bbloom.go)
|
||||
// -> func Has
|
||||
// -> func set
|
||||
// -> func add
|
||||
|
||||
package bbloom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// helper
|
||||
// not needed anymore by Set
|
||||
// var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128}
|
||||
|
||||
func getSize(ui64 uint64) (size uint64, exponent uint64) {
|
||||
if ui64 < uint64(512) {
|
||||
ui64 = uint64(512)
|
||||
}
|
||||
size = uint64(1)
|
||||
for size < ui64 {
|
||||
size <<= 1
|
||||
exponent++
|
||||
}
|
||||
return size, exponent
|
||||
}
|
||||
|
||||
func calcSizeByWrongPositives(numEntries, wrongs float64) (uint64, uint64) {
|
||||
size := -1 * numEntries * math.Log(wrongs) / math.Pow(float64(0.69314718056), 2)
|
||||
locs := math.Ceil(float64(0.69314718056) * size / numEntries)
|
||||
return uint64(size), uint64(locs)
|
||||
}
|
||||
|
||||
// New
|
||||
// returns a new bloomfilter
|
||||
func New(params ...float64) (bloomfilter Bloom) {
|
||||
var entries, locs uint64
|
||||
if len(params) == 2 {
|
||||
if params[1] < 1 {
|
||||
entries, locs = calcSizeByWrongPositives(params[0], params[1])
|
||||
} else {
|
||||
entries, locs = uint64(params[0]), uint64(params[1])
|
||||
}
|
||||
} else {
|
||||
log.Fatal("usage: New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(3)) or New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(0.03))")
|
||||
}
|
||||
size, exponent := getSize(uint64(entries))
|
||||
bloomfilter = Bloom{
|
||||
Mtx: &sync.Mutex{},
|
||||
sizeExp: exponent,
|
||||
size: size - 1,
|
||||
setLocs: locs,
|
||||
shift: 64 - exponent,
|
||||
}
|
||||
bloomfilter.Size(size)
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// NewWithBoolset
|
||||
// takes a []byte slice and number of locs per entry
|
||||
// returns the bloomfilter with a bitset populated according to the input []byte
|
||||
func NewWithBoolset(bs *[]byte, locs uint64) (bloomfilter Bloom) {
|
||||
bloomfilter = New(float64(len(*bs)<<3), float64(locs))
|
||||
for i, b := range *bs {
|
||||
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bloomfilter.bitset[0])) + uintptr(i))) = b
|
||||
}
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// bloomJSONImExport
|
||||
// Im/Export structure used by JSONMarshal / JSONUnmarshal
|
||||
type bloomJSONImExport struct {
|
||||
FilterSet []byte
|
||||
SetLocs uint64
|
||||
}
|
||||
|
||||
// JSONUnmarshal
|
||||
// takes JSON-Object (type bloomJSONImExport) as []bytes
|
||||
// returns Bloom object
|
||||
func JSONUnmarshal(dbData []byte) Bloom {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
json.Unmarshal(dbData, &bloomImEx)
|
||||
buf := bytes.NewBuffer(bloomImEx.FilterSet)
|
||||
bs := buf.Bytes()
|
||||
bf := NewWithBoolset(&bs, bloomImEx.SetLocs)
|
||||
return bf
|
||||
}
|
||||
|
||||
//
|
||||
// Bloom filter
|
||||
type Bloom struct {
|
||||
Mtx *sync.Mutex
|
||||
ElemNum uint64
|
||||
bitset []uint64
|
||||
sizeExp uint64
|
||||
size uint64
|
||||
setLocs uint64
|
||||
shift uint64
|
||||
}
|
||||
|
||||
// <--- http://www.cse.yorku.ca/~oz/hash.html
|
||||
// modified Berkeley DB Hash (32bit)
|
||||
// hash is casted to l, h = 16bit fragments
|
||||
// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.shift >> bl.shift
|
||||
// return l, h
|
||||
// }
|
||||
|
||||
// Update: found sipHash of Jean-Philippe Aumasson & Daniel J. Bernstein to be even faster than absdbm()
|
||||
// https://131002.net/siphash/
|
||||
// siphash was implemented for Go by Dmitry Chestnykh https://github.com/dchest/siphash
|
||||
|
||||
// Add
|
||||
// set the bit(s) for entry; Adds an entry to the Bloom filter
|
||||
func (bl *Bloom) Add(entry []byte) {
|
||||
l, h := bl.sipHash(entry)
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
bl.set((h + i*l) & bl.size)
|
||||
bl.ElemNum++
|
||||
}
|
||||
}
|
||||
|
||||
// AddTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) AddTS(entry []byte) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
bl.Add(entry)
|
||||
}
|
||||
|
||||
// Has
|
||||
// check if bit(s) for entry is/are set
|
||||
// returns true if the entry was added to the Bloom Filter
|
||||
func (bl Bloom) Has(entry []byte) bool {
|
||||
l, h := bl.sipHash(entry)
|
||||
res := true
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
res = res && bl.isSet((h+i*l)&bl.size)
|
||||
// https://github.com/ipfs/bbloom/commit/84e8303a9bfb37b2658b85982921d15bbb0fecff
|
||||
// // Branching here (early escape) is not worth it
|
||||
// // This is my conclusion from benchmarks
|
||||
// // (prevents loop unrolling)
|
||||
// switch bl.IsSet((h + i*l) & bl.size) {
|
||||
// case false:
|
||||
// return false
|
||||
// }
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// HasTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) HasTS(entry []byte) bool {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.Has(entry)
|
||||
}
|
||||
|
||||
// AddIfNotHas
|
||||
// Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl Bloom) AddIfNotHas(entry []byte) (added bool) {
|
||||
if bl.Has(entry) {
|
||||
return added
|
||||
}
|
||||
bl.Add(entry)
|
||||
return true
|
||||
}
|
||||
|
||||
// AddIfNotHasTS
|
||||
// Tread safe: Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl *Bloom) AddIfNotHasTS(entry []byte) (added bool) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.AddIfNotHas(entry)
|
||||
}
|
||||
|
||||
// Size
|
||||
// make Bloom filter with as bitset of size sz
|
||||
func (bl *Bloom) Size(sz uint64) {
|
||||
bl.bitset = make([]uint64, sz>>6)
|
||||
}
|
||||
|
||||
// Clear
|
||||
// resets the Bloom filter
|
||||
func (bl *Bloom) Clear() {
|
||||
bs := bl.bitset
|
||||
for i := range bs {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Set
|
||||
// set the bit[idx] of bitsit
|
||||
func (bl *Bloom) set(idx uint64) {
|
||||
// ommit unsafe
|
||||
// *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3))) |= mask[idx%8]
|
||||
bl.bitset[idx>>6] |= 1 << (idx % 64)
|
||||
}
|
||||
|
||||
// IsSet
|
||||
// check if bit[idx] of bitset is set
|
||||
// returns true/false
|
||||
func (bl *Bloom) isSet(idx uint64) bool {
|
||||
// ommit unsafe
|
||||
// return (((*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)))) >> (idx % 8)) & 1) == 1
|
||||
return bl.bitset[idx>>6]&(1<<(idx%64)) != 0
|
||||
}
|
||||
|
||||
// JSONMarshal
|
||||
// returns JSON-object (type bloomJSONImExport) as []byte
|
||||
func (bl Bloom) JSONMarshal() []byte {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
bloomImEx.SetLocs = uint64(bl.setLocs)
|
||||
bloomImEx.FilterSet = make([]byte, len(bl.bitset)<<3)
|
||||
for i := range bloomImEx.FilterSet {
|
||||
bloomImEx.FilterSet[i] = *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[0])) + uintptr(i)))
|
||||
}
|
||||
data, err := json.Marshal(bloomImEx)
|
||||
if err != nil {
|
||||
log.Fatal("json.Marshal failed: ", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// // alternative hashFn
|
||||
// func (bl Bloom) fnv64a(b *[]byte) (l, h uint64) {
|
||||
// h64 := fnv.New64a()
|
||||
// h64.Write(*b)
|
||||
// hash := h64.Sum64()
|
||||
// h = hash >> 32
|
||||
// l = hash << 32 >> 32
|
||||
// return l, h
|
||||
// }
|
||||
//
|
||||
// // <-- http://partow.net/programming/hashfunctions/index.html
|
||||
// // citation: An algorithm proposed by Donald E. Knuth in The Art Of Computer Programming Volume 3,
|
||||
// // under the topic of sorting and search chapter 6.4.
|
||||
// // modified to fit with boolset-length
|
||||
// func (bl Bloom) DEKHash(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = ((hash << 5) ^ (hash >> bl.shift)) ^ uint64(c)
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.sizeExp >> bl.sizeExp
|
||||
// return l, h
|
||||
// }
|
||||
225
vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
225
vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
@@ -1,225 +0,0 @@
|
||||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
|
||||
package bbloom
|
||||
|
||||
// Hash returns the 64-bit SipHash-2-4 of the given byte slice with two 64-bit
|
||||
// parts of 128-bit key: k0 and k1.
|
||||
func (bl Bloom) sipHash(p []byte) (l, h uint64) {
|
||||
// Initialization.
|
||||
v0 := uint64(8317987320269560794) // k0 ^ 0x736f6d6570736575
|
||||
v1 := uint64(7237128889637516672) // k1 ^ 0x646f72616e646f6d
|
||||
v2 := uint64(7816392314733513934) // k0 ^ 0x6c7967656e657261
|
||||
v3 := uint64(8387220255325274014) // k1 ^ 0x7465646279746573
|
||||
t := uint64(len(p)) << 56
|
||||
|
||||
// Compression.
|
||||
for len(p) >= 8 {
|
||||
|
||||
m := uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 |
|
||||
uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
|
||||
|
||||
v3 ^= m
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= m
|
||||
p = p[8:]
|
||||
}
|
||||
|
||||
// Compress last block.
|
||||
switch len(p) {
|
||||
case 7:
|
||||
t |= uint64(p[6]) << 48
|
||||
fallthrough
|
||||
case 6:
|
||||
t |= uint64(p[5]) << 40
|
||||
fallthrough
|
||||
case 5:
|
||||
t |= uint64(p[4]) << 32
|
||||
fallthrough
|
||||
case 4:
|
||||
t |= uint64(p[3]) << 24
|
||||
fallthrough
|
||||
case 3:
|
||||
t |= uint64(p[2]) << 16
|
||||
fallthrough
|
||||
case 2:
|
||||
t |= uint64(p[1]) << 8
|
||||
fallthrough
|
||||
case 1:
|
||||
t |= uint64(p[0])
|
||||
}
|
||||
|
||||
v3 ^= t
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= t
|
||||
|
||||
// Finalization.
|
||||
v2 ^= 0xff
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 3.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 4.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// return v0 ^ v1 ^ v2 ^ v3
|
||||
|
||||
hash := v0 ^ v1 ^ v2 ^ v3
|
||||
h = hash >> bl.shift
|
||||
l = hash << bl.shift >> bl.shift
|
||||
return l, h
|
||||
|
||||
}
|
||||
140
vendor/github.com/AndreasBriese/bbloom/words.txt
generated
vendored
140
vendor/github.com/AndreasBriese/bbloom/words.txt
generated
vendored
@@ -1,140 +0,0 @@
|
||||
2014/01/01 00:00:00 /info.html
|
||||
2014/01/01 00:00:00 /info.html
|
||||
2014/01/01 00:00:01 /info.html
|
||||
2014/01/01 00:00:02 /info.html
|
||||
2014/01/01 00:00:03 /info.html
|
||||
2014/01/01 00:00:04 /info.html
|
||||
2014/01/01 00:00:05 /info.html
|
||||
2014/01/01 00:00:06 /info.html
|
||||
2014/01/01 00:00:07 /info.html
|
||||
2014/01/01 00:00:08 /info.html
|
||||
2014/01/01 00:00:09 /info.html
|
||||
2014/01/01 00:00:10 /info.html
|
||||
2014/01/01 00:00:11 /info.html
|
||||
2014/01/01 00:00:12 /info.html
|
||||
2014/01/01 00:00:13 /info.html
|
||||
2014/01/01 00:00:14 /info.html
|
||||
2014/01/01 00:00:15 /info.html
|
||||
2014/01/01 00:00:16 /info.html
|
||||
2014/01/01 00:00:17 /info.html
|
||||
2014/01/01 00:00:18 /info.html
|
||||
2014/01/01 00:00:19 /info.html
|
||||
2014/01/01 00:00:20 /info.html
|
||||
2014/01/01 00:00:21 /info.html
|
||||
2014/01/01 00:00:22 /info.html
|
||||
2014/01/01 00:00:23 /info.html
|
||||
2014/01/01 00:00:24 /info.html
|
||||
2014/01/01 00:00:25 /info.html
|
||||
2014/01/01 00:00:26 /info.html
|
||||
2014/01/01 00:00:27 /info.html
|
||||
2014/01/01 00:00:28 /info.html
|
||||
2014/01/01 00:00:29 /info.html
|
||||
2014/01/01 00:00:30 /info.html
|
||||
2014/01/01 00:00:31 /info.html
|
||||
2014/01/01 00:00:32 /info.html
|
||||
2014/01/01 00:00:33 /info.html
|
||||
2014/01/01 00:00:34 /info.html
|
||||
2014/01/01 00:00:35 /info.html
|
||||
2014/01/01 00:00:36 /info.html
|
||||
2014/01/01 00:00:37 /info.html
|
||||
2014/01/01 00:00:38 /info.html
|
||||
2014/01/01 00:00:39 /info.html
|
||||
2014/01/01 00:00:40 /info.html
|
||||
2014/01/01 00:00:41 /info.html
|
||||
2014/01/01 00:00:42 /info.html
|
||||
2014/01/01 00:00:43 /info.html
|
||||
2014/01/01 00:00:44 /info.html
|
||||
2014/01/01 00:00:45 /info.html
|
||||
2014/01/01 00:00:46 /info.html
|
||||
2014/01/01 00:00:47 /info.html
|
||||
2014/01/01 00:00:48 /info.html
|
||||
2014/01/01 00:00:49 /info.html
|
||||
2014/01/01 00:00:50 /info.html
|
||||
2014/01/01 00:00:51 /info.html
|
||||
2014/01/01 00:00:52 /info.html
|
||||
2014/01/01 00:00:53 /info.html
|
||||
2014/01/01 00:00:54 /info.html
|
||||
2014/01/01 00:00:55 /info.html
|
||||
2014/01/01 00:00:56 /info.html
|
||||
2014/01/01 00:00:57 /info.html
|
||||
2014/01/01 00:00:58 /info.html
|
||||
2014/01/01 00:00:59 /info.html
|
||||
2014/01/01 00:01:00 /info.html
|
||||
2014/01/01 00:01:01 /info.html
|
||||
2014/01/01 00:01:02 /info.html
|
||||
2014/01/01 00:01:03 /info.html
|
||||
2014/01/01 00:01:04 /info.html
|
||||
2014/01/01 00:01:05 /info.html
|
||||
2014/01/01 00:01:06 /info.html
|
||||
2014/01/01 00:01:07 /info.html
|
||||
2014/01/01 00:01:08 /info.html
|
||||
2014/01/01 00:01:09 /info.html
|
||||
2014/01/01 00:01:10 /info.html
|
||||
2014/01/01 00:01:11 /info.html
|
||||
2014/01/01 00:01:12 /info.html
|
||||
2014/01/01 00:01:13 /info.html
|
||||
2014/01/01 00:01:14 /info.html
|
||||
2014/01/01 00:01:15 /info.html
|
||||
2014/01/01 00:01:16 /info.html
|
||||
2014/01/01 00:01:17 /info.html
|
||||
2014/01/01 00:01:18 /info.html
|
||||
2014/01/01 00:01:19 /info.html
|
||||
2014/01/01 00:01:20 /info.html
|
||||
2014/01/01 00:01:21 /info.html
|
||||
2014/01/01 00:01:22 /info.html
|
||||
2014/01/01 00:01:23 /info.html
|
||||
2014/01/01 00:01:24 /info.html
|
||||
2014/01/01 00:01:25 /info.html
|
||||
2014/01/01 00:01:26 /info.html
|
||||
2014/01/01 00:01:27 /info.html
|
||||
2014/01/01 00:01:28 /info.html
|
||||
2014/01/01 00:01:29 /info.html
|
||||
2014/01/01 00:01:30 /info.html
|
||||
2014/01/01 00:01:31 /info.html
|
||||
2014/01/01 00:01:32 /info.html
|
||||
2014/01/01 00:01:33 /info.html
|
||||
2014/01/01 00:01:34 /info.html
|
||||
2014/01/01 00:01:35 /info.html
|
||||
2014/01/01 00:01:36 /info.html
|
||||
2014/01/01 00:01:37 /info.html
|
||||
2014/01/01 00:01:38 /info.html
|
||||
2014/01/01 00:01:39 /info.html
|
||||
2014/01/01 00:01:40 /info.html
|
||||
2014/01/01 00:01:41 /info.html
|
||||
2014/01/01 00:01:42 /info.html
|
||||
2014/01/01 00:01:43 /info.html
|
||||
2014/01/01 00:01:44 /info.html
|
||||
2014/01/01 00:01:45 /info.html
|
||||
2014/01/01 00:01:46 /info.html
|
||||
2014/01/01 00:01:47 /info.html
|
||||
2014/01/01 00:01:48 /info.html
|
||||
2014/01/01 00:01:49 /info.html
|
||||
2014/01/01 00:01:50 /info.html
|
||||
2014/01/01 00:01:51 /info.html
|
||||
2014/01/01 00:01:52 /info.html
|
||||
2014/01/01 00:01:53 /info.html
|
||||
2014/01/01 00:01:54 /info.html
|
||||
2014/01/01 00:01:55 /info.html
|
||||
2014/01/01 00:01:56 /info.html
|
||||
2014/01/01 00:01:57 /info.html
|
||||
2014/01/01 00:01:58 /info.html
|
||||
2014/01/01 00:01:59 /info.html
|
||||
2014/01/01 00:02:00 /info.html
|
||||
2014/01/01 00:02:01 /info.html
|
||||
2014/01/01 00:02:02 /info.html
|
||||
2014/01/01 00:02:03 /info.html
|
||||
2014/01/01 00:02:04 /info.html
|
||||
2014/01/01 00:02:05 /info.html
|
||||
2014/01/01 00:02:06 /info.html
|
||||
2014/01/01 00:02:07 /info.html
|
||||
2014/01/01 00:02:08 /info.html
|
||||
2014/01/01 00:02:09 /info.html
|
||||
2014/01/01 00:02:10 /info.html
|
||||
2014/01/01 00:02:11 /info.html
|
||||
2014/01/01 00:02:12 /info.html
|
||||
2014/01/01 00:02:13 /info.html
|
||||
2014/01/01 00:02:14 /info.html
|
||||
2014/01/01 00:02:15 /info.html
|
||||
2014/01/01 00:02:16 /info.html
|
||||
2014/01/01 00:02:17 /info.html
|
||||
2014/01/01 00:02:18 /info.html
|
||||
24
vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
24
vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
@@ -1,24 +0,0 @@
|
||||
This is free and unencumbered software released into the public domain.
|
||||
|
||||
Anyone is free to copy, modify, publish, use, compile, sell, or
|
||||
distribute this software, either in source code form or as a compiled
|
||||
binary, for any purpose, commercial or non-commercial, and by any
|
||||
means.
|
||||
|
||||
In jurisdictions that recognize copyright laws, the author or authors
|
||||
of this software dedicate any and all copyright interest in the
|
||||
software to the public domain. We make this dedication for the benefit
|
||||
of the public at large and to the detriment of our heirs and
|
||||
successors. We intend this dedication to be an overt act of
|
||||
relinquishment in perpetuity of all present and future rights to this
|
||||
software under copyright law.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
For more information, please refer to <http://unlicense.org/>
|
||||
7
vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
7
vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
@@ -1,7 +0,0 @@
|
||||
# gopher-json [](https://godoc.org/layeh.com/gopher-json)
|
||||
|
||||
Package json is a simple JSON encoder/decoder for [gopher-lua](https://github.com/yuin/gopher-lua).
|
||||
|
||||
## License
|
||||
|
||||
Public domain
|
||||
33
vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
33
vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
@@ -1,33 +0,0 @@
|
||||
// Package json is a simple JSON encoder/decoder for gopher-lua.
|
||||
//
|
||||
// Documentation
|
||||
//
|
||||
// The following functions are exposed by the library:
|
||||
// decode(string): Decodes a JSON string. Returns nil and an error string if
|
||||
// the string could not be decoded.
|
||||
// encode(value): Encodes a value into a JSON string. Returns nil and an error
|
||||
// string if the value could not be encoded.
|
||||
//
|
||||
// The following types are supported:
|
||||
//
|
||||
// Lua | JSON
|
||||
// ---------+-----
|
||||
// nil | null
|
||||
// number | number
|
||||
// string | string
|
||||
// table | object: when table is non-empty and has only string keys
|
||||
// | array: when table is empty, or has only sequential numeric keys
|
||||
// | starting from 1
|
||||
//
|
||||
// Attempting to encode any other Lua type will result in an error.
|
||||
//
|
||||
// Example
|
||||
//
|
||||
// Below is an example usage of the library:
|
||||
// import (
|
||||
// luajson "layeh.com/gopher-json"
|
||||
// )
|
||||
//
|
||||
// L := lua.NewState()
|
||||
// luajson.Preload(s)
|
||||
package json
|
||||
189
vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
189
vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
@@ -1,189 +0,0 @@
|
||||
package json
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// Preload adds json to the given Lua state's package.preload table. After it
|
||||
// has been preloaded, it can be loaded using require:
|
||||
//
|
||||
// local json = require("json")
|
||||
func Preload(L *lua.LState) {
|
||||
L.PreloadModule("json", Loader)
|
||||
}
|
||||
|
||||
// Loader is the module loader function.
|
||||
func Loader(L *lua.LState) int {
|
||||
t := L.NewTable()
|
||||
L.SetFuncs(t, api)
|
||||
L.Push(t)
|
||||
return 1
|
||||
}
|
||||
|
||||
var api = map[string]lua.LGFunction{
|
||||
"decode": apiDecode,
|
||||
"encode": apiEncode,
|
||||
}
|
||||
|
||||
func apiDecode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to decode"), 1)
|
||||
return 0
|
||||
}
|
||||
str := L.CheckString(1)
|
||||
|
||||
value, err := Decode(L, []byte(str))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(value)
|
||||
return 1
|
||||
}
|
||||
|
||||
func apiEncode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to encode"), 1)
|
||||
return 0
|
||||
}
|
||||
value := L.CheckAny(1)
|
||||
|
||||
data, err := Encode(value)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(lua.LString(string(data)))
|
||||
return 1
|
||||
}
|
||||
|
||||
var (
|
||||
errNested = errors.New("cannot encode recursively nested tables to JSON")
|
||||
errSparseArray = errors.New("cannot encode sparse array")
|
||||
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
|
||||
)
|
||||
|
||||
type invalidTypeError lua.LValueType
|
||||
|
||||
func (i invalidTypeError) Error() string {
|
||||
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
|
||||
}
|
||||
|
||||
// Encode returns the JSON encoding of value.
|
||||
func Encode(value lua.LValue) ([]byte, error) {
|
||||
return json.Marshal(jsonValue{
|
||||
LValue: value,
|
||||
visited: make(map[*lua.LTable]bool),
|
||||
})
|
||||
}
|
||||
|
||||
type jsonValue struct {
|
||||
lua.LValue
|
||||
visited map[*lua.LTable]bool
|
||||
}
|
||||
|
||||
func (j jsonValue) MarshalJSON() (data []byte, err error) {
|
||||
switch converted := j.LValue.(type) {
|
||||
case lua.LBool:
|
||||
data, err = json.Marshal(bool(converted))
|
||||
case lua.LNumber:
|
||||
data, err = json.Marshal(float64(converted))
|
||||
case *lua.LNilType:
|
||||
data = []byte(`null`)
|
||||
case lua.LString:
|
||||
data, err = json.Marshal(string(converted))
|
||||
case *lua.LTable:
|
||||
if j.visited[converted] {
|
||||
return nil, errNested
|
||||
}
|
||||
j.visited[converted] = true
|
||||
|
||||
key, value := converted.Next(lua.LNil)
|
||||
|
||||
switch key.Type() {
|
||||
case lua.LTNil: // empty table
|
||||
data = []byte(`[]`)
|
||||
case lua.LTNumber:
|
||||
arr := make([]jsonValue, 0, converted.Len())
|
||||
expectedKey := lua.LNumber(1)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTNumber {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
if expectedKey != key {
|
||||
err = errSparseArray
|
||||
return
|
||||
}
|
||||
arr = append(arr, jsonValue{value, j.visited})
|
||||
expectedKey++
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(arr)
|
||||
case lua.LTString:
|
||||
obj := make(map[string]jsonValue)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTString {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
obj[key.String()] = jsonValue{value, j.visited}
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(obj)
|
||||
default:
|
||||
err = errInvalidKeys
|
||||
}
|
||||
default:
|
||||
err = invalidTypeError(j.LValue.Type())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Decode converts the JSON encoded data to Lua values.
|
||||
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
|
||||
var value interface{}
|
||||
err := json.Unmarshal(data, &value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeValue(L, value), nil
|
||||
}
|
||||
|
||||
// DecodeValue converts the value to a Lua value.
|
||||
//
|
||||
// This function only converts values that the encoding/json package decodes to.
|
||||
// All other values will return lua.LNil.
|
||||
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
|
||||
switch converted := value.(type) {
|
||||
case bool:
|
||||
return lua.LBool(converted)
|
||||
case float64:
|
||||
return lua.LNumber(converted)
|
||||
case string:
|
||||
return lua.LString(converted)
|
||||
case json.Number:
|
||||
return lua.LString(converted)
|
||||
case []interface{}:
|
||||
arr := L.CreateTable(len(converted), 0)
|
||||
for _, item := range converted {
|
||||
arr.Append(DecodeValue(L, item))
|
||||
}
|
||||
return arr
|
||||
case map[string]interface{}:
|
||||
tbl := L.CreateTable(0, len(converted))
|
||||
for key, item := range converted {
|
||||
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
|
||||
}
|
||||
return tbl
|
||||
case nil:
|
||||
return lua.LNil
|
||||
}
|
||||
|
||||
return lua.LNil
|
||||
}
|
||||
6
vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
6
vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
@@ -1,6 +0,0 @@
|
||||
/integration/redis_src/
|
||||
/integration/dump.rdb
|
||||
*.swp
|
||||
/integration/nodes.conf
|
||||
.idea/
|
||||
miniredis.iml
|
||||
225
vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
225
vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
@@ -1,225 +0,0 @@
|
||||
## Changelog
|
||||
|
||||
|
||||
### v2.23.0
|
||||
|
||||
- basic INFO support (thanks @kirill-a-belov)
|
||||
- support COUNT in SSCAN (thanks @Abdi-dd)
|
||||
- test and support Go 1.19
|
||||
- support LPOS (thanks @ianstarz)
|
||||
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
|
||||
|
||||
|
||||
### v2.22.0
|
||||
|
||||
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
|
||||
- fix invalid resposne of COMMAND (thanks @zsh1995)
|
||||
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
|
||||
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
|
||||
|
||||
|
||||
### v2.21.0
|
||||
|
||||
- support for GETEX (thanks @dntj)
|
||||
- support for GT and LT in ZADD (thanks @lsgndln)
|
||||
- support for XAUTOCLAIM (thanks @randall-fulton)
|
||||
|
||||
|
||||
### v2.20.0
|
||||
|
||||
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
|
||||
|
||||
|
||||
### v2.19.0
|
||||
|
||||
- support for TYPE in SCAN (thanks @0xDiddi)
|
||||
- update BITPOS (thanks @dirkm)
|
||||
- fix a lua redis.call() return value (thanks @mpetronic)
|
||||
- update ZRANGE (thanks @valdemarpereira)
|
||||
|
||||
|
||||
### v2.18.0
|
||||
|
||||
- support for ZUNION (thanks @propan)
|
||||
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
|
||||
- support for LMOVE (thanks @btwear)
|
||||
|
||||
|
||||
### v2.17.0
|
||||
|
||||
- added miniredis.RunT(t)
|
||||
|
||||
|
||||
### v2.16.1
|
||||
|
||||
- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang)
|
||||
- fix exclusive ranges in XRANGE (thanks @joseotoro)
|
||||
|
||||
|
||||
### v2.16.0
|
||||
|
||||
- simplify some code (thanks @zonque)
|
||||
- support for EXAT/PXAT in SET
|
||||
- support for XTRIM (thanks @joseotoro)
|
||||
- support for ZRANDMEMBER
|
||||
- support for redis.log() in lua (thanks @dirkm)
|
||||
|
||||
|
||||
### v2.15.2
|
||||
|
||||
- Fix race condition in blocking code (thanks @zonque and @robx)
|
||||
- XREAD accepts '$' as ID (thanks @bradengroom)
|
||||
|
||||
|
||||
### v2.15.1
|
||||
|
||||
- EVAL should cache the script (thanks @guoshimin)
|
||||
|
||||
|
||||
### v2.15.0
|
||||
|
||||
- target redis 6.2 and added new args to various commands
|
||||
- support for all hyperlog commands (thanks @ilbaktin)
|
||||
- support for GETDEL (thanks @wszaranski)
|
||||
|
||||
|
||||
### v2.14.5
|
||||
|
||||
- added XPENDING
|
||||
- support for BLOCK option in XREAD and XREADGROUP
|
||||
|
||||
|
||||
### v2.14.4
|
||||
|
||||
- fix BITPOS error (thanks @xiaoyuzdy)
|
||||
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
|
||||
- fix empty EXEC return type (thanks @ashanbrown)
|
||||
- fix XDEL (thanks @svakili and @yvesf)
|
||||
- fix FLUSHALL for streams (thanks @svakili)
|
||||
|
||||
|
||||
### v2.14.3
|
||||
|
||||
- fix problem where Lua code didn't set the selected DB
|
||||
- update to redis 6.0.10 (thanks @lazappa)
|
||||
|
||||
|
||||
### v2.14.2
|
||||
|
||||
- update LUA dependency
|
||||
- deal with (p)unsubscribe when there are no channels
|
||||
|
||||
|
||||
### v2.14.1
|
||||
|
||||
- mod tidy
|
||||
|
||||
|
||||
### v2.14.0
|
||||
|
||||
- support for HELLO and the RESP3 protocol
|
||||
- KEEPTTL in SET (thanks @johnpena)
|
||||
|
||||
|
||||
### v2.13.3
|
||||
|
||||
- support Go 1.14 and 1.15
|
||||
- update the `Check...()` methods
|
||||
- support for XREAD (thanks @pieterlexis)
|
||||
|
||||
|
||||
### v2.13.2
|
||||
|
||||
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
|
||||
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
|
||||
- changed unit and integration tests to compare raw payloads, not parsed payloads
|
||||
- remove "redigo" dependency
|
||||
|
||||
|
||||
### v2.13.1
|
||||
|
||||
- added HSTRLEN
|
||||
- minimal support for ACL users in AUTH
|
||||
|
||||
|
||||
### v2.13.0
|
||||
|
||||
- added RunTLS(...)
|
||||
- added SetError(...)
|
||||
|
||||
|
||||
### v2.12.0
|
||||
|
||||
- redis 6
|
||||
- Lua json update (thanks @gsmith85)
|
||||
- CLUSTER commands (thanks @kratisto)
|
||||
- fix TOUCH
|
||||
- fix a shutdown race condition
|
||||
|
||||
|
||||
### v2.11.4
|
||||
|
||||
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
|
||||
|
||||
|
||||
### v2.11.3
|
||||
|
||||
- support for TOUCH (thanks @cleroux)
|
||||
- support for cluster and stream commands (thanks @kak-tus)
|
||||
|
||||
|
||||
### v2.11.2
|
||||
|
||||
- make sure Lua code is executed concurrently
|
||||
- add command GEORADIUSBYMEMBER (thanks @kyeett)
|
||||
|
||||
|
||||
### v2.11.1
|
||||
|
||||
- globals protection for Lua code (thanks @vk-outreach)
|
||||
- HSET update (thanks @carlgreen)
|
||||
- fix BLPOP block on shutdown (thanks @Asalle)
|
||||
|
||||
|
||||
### v2.11.0
|
||||
|
||||
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
|
||||
- added GEODIST
|
||||
- improved precision for geohashes, closer to what real redis does
|
||||
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
|
||||
|
||||
|
||||
### v2.10.1
|
||||
|
||||
- added m.Server()
|
||||
|
||||
|
||||
### v2.10.0
|
||||
|
||||
- added UNLINK
|
||||
- fix DEL zero-argument case
|
||||
- cleanup some direct access commands
|
||||
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
|
||||
|
||||
|
||||
### v2.9.1
|
||||
|
||||
- fix issue with ZRANGEBYLEX
|
||||
- fix issue with BRPOPLPUSH and direct access
|
||||
|
||||
|
||||
### v2.9.0
|
||||
|
||||
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
|
||||
- fix messages generated by PSUBSCRIBE
|
||||
- optional internal seed (thanks @zikaeroh)
|
||||
|
||||
|
||||
### v2.8.0
|
||||
|
||||
Proper `v2` in go.mod.
|
||||
|
||||
|
||||
### older
|
||||
|
||||
See https://github.com/alicebob/miniredis/releases for the full changelog
|
||||
21
vendor/github.com/alicebob/miniredis/v2/LICENSE
generated
vendored
21
vendor/github.com/alicebob/miniredis/v2/LICENSE
generated
vendored
@@ -1,21 +0,0 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Harmen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
12
vendor/github.com/alicebob/miniredis/v2/Makefile
generated
vendored
12
vendor/github.com/alicebob/miniredis/v2/Makefile
generated
vendored
@@ -1,12 +0,0 @@
|
||||
.PHONY: all test testrace int
|
||||
|
||||
all: test
|
||||
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
testrace:
|
||||
go test -race ./...
|
||||
|
||||
int:
|
||||
${MAKE} -C integration all
|
||||
333
vendor/github.com/alicebob/miniredis/v2/README.md
generated
vendored
333
vendor/github.com/alicebob/miniredis/v2/README.md
generated
vendored
@@ -1,333 +0,0 @@
|
||||
# Miniredis
|
||||
|
||||
Pure Go Redis test server, used in Go unittests.
|
||||
|
||||
|
||||
##
|
||||
|
||||
Sometimes you want to test code which uses Redis, without making it a full-blown
|
||||
integration test.
|
||||
Miniredis implements (parts of) the Redis server, to be used in unittests. It
|
||||
enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`.
|
||||
|
||||
It saves you from using mock code, and since the redis server lives in the
|
||||
test process you can query for values directly, without going through the server
|
||||
stack.
|
||||
|
||||
There are no dependencies on external binaries, so you can easily integrate it in automated build processes.
|
||||
|
||||
Be sure to import v2:
|
||||
```
|
||||
import "github.com/alicebob/miniredis/v2"
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
Implemented commands:
|
||||
|
||||
- Connection (complete)
|
||||
- AUTH -- see RequireAuth()
|
||||
- ECHO
|
||||
- HELLO -- see RequireUserAuth()
|
||||
- PING
|
||||
- SELECT
|
||||
- SWAPDB
|
||||
- QUIT
|
||||
- Key
|
||||
- COPY
|
||||
- DEL
|
||||
- EXISTS
|
||||
- EXPIRE
|
||||
- EXPIREAT
|
||||
- KEYS
|
||||
- MOVE
|
||||
- PERSIST
|
||||
- PEXPIRE
|
||||
- PEXPIREAT
|
||||
- PTTL
|
||||
- RENAME
|
||||
- RENAMENX
|
||||
- RANDOMKEY -- see m.Seed(...)
|
||||
- SCAN
|
||||
- TOUCH
|
||||
- TTL
|
||||
- TYPE
|
||||
- UNLINK
|
||||
- Transactions (complete)
|
||||
- DISCARD
|
||||
- EXEC
|
||||
- MULTI
|
||||
- UNWATCH
|
||||
- WATCH
|
||||
- Server
|
||||
- DBSIZE
|
||||
- FLUSHALL
|
||||
- FLUSHDB
|
||||
- TIME -- returns time.Now() or value set by SetTime()
|
||||
- COMMAND -- partly
|
||||
- INFO -- partly, returns only "clients" section with one field "connected_clients"
|
||||
- String keys (complete)
|
||||
- APPEND
|
||||
- BITCOUNT
|
||||
- BITOP
|
||||
- BITPOS
|
||||
- DECR
|
||||
- DECRBY
|
||||
- GET
|
||||
- GETBIT
|
||||
- GETRANGE
|
||||
- GETSET
|
||||
- GETDEL
|
||||
- GETEX
|
||||
- INCR
|
||||
- INCRBY
|
||||
- INCRBYFLOAT
|
||||
- MGET
|
||||
- MSET
|
||||
- MSETNX
|
||||
- PSETEX
|
||||
- SET
|
||||
- SETBIT
|
||||
- SETEX
|
||||
- SETNX
|
||||
- SETRANGE
|
||||
- STRLEN
|
||||
- Hash keys (complete)
|
||||
- HDEL
|
||||
- HEXISTS
|
||||
- HGET
|
||||
- HGETALL
|
||||
- HINCRBY
|
||||
- HINCRBYFLOAT
|
||||
- HKEYS
|
||||
- HLEN
|
||||
- HMGET
|
||||
- HMSET
|
||||
- HSET
|
||||
- HSETNX
|
||||
- HSTRLEN
|
||||
- HVALS
|
||||
- HSCAN
|
||||
- List keys (complete)
|
||||
- BLPOP
|
||||
- BRPOP
|
||||
- BRPOPLPUSH
|
||||
- LINDEX
|
||||
- LINSERT
|
||||
- LLEN
|
||||
- LPOP
|
||||
- LPUSH
|
||||
- LPUSHX
|
||||
- LRANGE
|
||||
- LREM
|
||||
- LSET
|
||||
- LTRIM
|
||||
- RPOP
|
||||
- RPOPLPUSH
|
||||
- RPUSH
|
||||
- RPUSHX
|
||||
- LMOVE
|
||||
- Pub/Sub (complete)
|
||||
- PSUBSCRIBE
|
||||
- PUBLISH
|
||||
- PUBSUB
|
||||
- PUNSUBSCRIBE
|
||||
- SUBSCRIBE
|
||||
- UNSUBSCRIBE
|
||||
- Set keys (complete)
|
||||
- SADD
|
||||
- SCARD
|
||||
- SDIFF
|
||||
- SDIFFSTORE
|
||||
- SINTER
|
||||
- SINTERSTORE
|
||||
- SISMEMBER
|
||||
- SMEMBERS
|
||||
- SMOVE
|
||||
- SPOP -- see m.Seed(...)
|
||||
- SRANDMEMBER -- see m.Seed(...)
|
||||
- SREM
|
||||
- SUNION
|
||||
- SUNIONSTORE
|
||||
- SSCAN
|
||||
- Sorted Set keys (complete)
|
||||
- ZADD
|
||||
- ZCARD
|
||||
- ZCOUNT
|
||||
- ZINCRBY
|
||||
- ZINTERSTORE
|
||||
- ZLEXCOUNT
|
||||
- ZPOPMIN
|
||||
- ZPOPMAX
|
||||
- ZRANDMEMBER
|
||||
- ZRANGE
|
||||
- ZRANGEBYLEX
|
||||
- ZRANGEBYSCORE
|
||||
- ZRANK
|
||||
- ZREM
|
||||
- ZREMRANGEBYLEX
|
||||
- ZREMRANGEBYRANK
|
||||
- ZREMRANGEBYSCORE
|
||||
- ZREVRANGE
|
||||
- ZREVRANGEBYLEX
|
||||
- ZREVRANGEBYSCORE
|
||||
- ZREVRANK
|
||||
- ZSCORE
|
||||
- ZUNION
|
||||
- ZUNIONSTORE
|
||||
- ZSCAN
|
||||
- Stream keys
|
||||
- XACK
|
||||
- XADD
|
||||
- XAUTOCLAIM
|
||||
- XCLAIM
|
||||
- XDEL
|
||||
- XGROUP CREATE
|
||||
- XGROUP CREATECONSUMER
|
||||
- XGROUP DESTROY
|
||||
- XGROUP DELCONSUMER
|
||||
- XINFO STREAM -- partly
|
||||
- XINFO GROUPS
|
||||
- XINFO CONSUMERS -- partly
|
||||
- XLEN
|
||||
- XRANGE
|
||||
- XREAD
|
||||
- XREADGROUP
|
||||
- XREVRANGE
|
||||
- XPENDING
|
||||
- XTRIM
|
||||
- Scripting
|
||||
- EVAL
|
||||
- EVALSHA
|
||||
- SCRIPT LOAD
|
||||
- SCRIPT EXISTS
|
||||
- SCRIPT FLUSH
|
||||
- GEO
|
||||
- GEOADD
|
||||
- GEODIST
|
||||
- ~~GEOHASH~~
|
||||
- GEOPOS
|
||||
- GEORADIUS
|
||||
- GEORADIUS_RO
|
||||
- GEORADIUSBYMEMBER
|
||||
- GEORADIUSBYMEMBER_RO
|
||||
- Cluster
|
||||
- CLUSTER SLOTS
|
||||
- CLUSTER KEYSLOT
|
||||
- CLUSTER NODES
|
||||
- HyperLogLog (complete)
|
||||
- PFADD
|
||||
- PFCOUNT
|
||||
- PFMERGE
|
||||
|
||||
|
||||
## TTLs, key expiration, and time
|
||||
|
||||
Since miniredis is intended to be used in unittests TTLs don't decrease
|
||||
automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a
|
||||
key. It will return 0 when no TTL is set.
|
||||
|
||||
`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <=
|
||||
0 will be removed.
|
||||
|
||||
EXPIREAT and PEXPIREAT values will be
|
||||
converted to a duration. For that you can either set m.SetTime(t) to use that
|
||||
time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in
|
||||
which case time.Now() will be used.
|
||||
|
||||
SetTime() also sets the value returned by TIME, which defaults to time.Now().
|
||||
It is not updated by FastForward, only by SetTime.
|
||||
|
||||
## Randomness and Seed()
|
||||
|
||||
Miniredis will use `math/rand`'s global RNG for randomness unless a seed is
|
||||
provided by calling `m.Seed(...)`. If a seed is provided, then miniredis will
|
||||
use its own RNG based on that seed.
|
||||
|
||||
Commands which use randomness are: RANDOMKEY, SPOP, and SRANDMEMBER.
|
||||
|
||||
## Example
|
||||
|
||||
``` Go
|
||||
|
||||
import (
|
||||
...
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
...
|
||||
)
|
||||
|
||||
func TestSomething(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
|
||||
// Optionally set some keys your code expects:
|
||||
s.Set("foo", "bar")
|
||||
s.HSet("some", "other", "key")
|
||||
|
||||
// Run your code and see if it behaves.
|
||||
// An example using the redigo library from "github.com/gomodule/redigo/redis":
|
||||
c, err := redis.Dial("tcp", s.Addr())
|
||||
_, err = c.Do("SET", "foo", "bar")
|
||||
|
||||
// Optionally check values in redis...
|
||||
if got, err := s.Get("foo"); err != nil || got != "bar" {
|
||||
t.Error("'foo' has the wrong value")
|
||||
}
|
||||
// ... or use a helper for that:
|
||||
s.CheckGet(t, "foo", "bar")
|
||||
|
||||
// TTL and expiration:
|
||||
s.Set("foo", "bar")
|
||||
s.SetTTL("foo", 10*time.Second)
|
||||
s.FastForward(11 * time.Second)
|
||||
if s.Exists("foo") {
|
||||
t.Fatal("'foo' should not have existed anymore")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Not supported
|
||||
|
||||
Commands which will probably not be implemented:
|
||||
|
||||
- CLUSTER (all)
|
||||
- ~~CLUSTER *~~
|
||||
- ~~READONLY~~
|
||||
- ~~READWRITE~~
|
||||
- Key
|
||||
- ~~DUMP~~
|
||||
- ~~MIGRATE~~
|
||||
- ~~OBJECT~~
|
||||
- ~~RESTORE~~
|
||||
- ~~WAIT~~
|
||||
- Scripting
|
||||
- ~~SCRIPT DEBUG~~
|
||||
- ~~SCRIPT KILL~~
|
||||
- Server
|
||||
- ~~BGSAVE~~
|
||||
- ~~BGWRITEAOF~~
|
||||
- ~~CLIENT *~~
|
||||
- ~~CONFIG *~~
|
||||
- ~~DEBUG *~~
|
||||
- ~~LASTSAVE~~
|
||||
- ~~MONITOR~~
|
||||
- ~~ROLE~~
|
||||
- ~~SAVE~~
|
||||
- ~~SHUTDOWN~~
|
||||
- ~~SLAVEOF~~
|
||||
- ~~SLOWLOG~~
|
||||
- ~~SYNC~~
|
||||
|
||||
|
||||
## &c.
|
||||
|
||||
Integration tests are run against Redis 6.2.6. The [./integration](./integration/) subdir
|
||||
compares miniredis against a real redis instance.
|
||||
|
||||
The Redis 6 RESP3 protocol is supported. If there are problems, please open
|
||||
an issue.
|
||||
|
||||
If you want to test Redis Sentinel have a look at [minisentinel](https://github.com/Bose/minisentinel).
|
||||
|
||||
A changelog is kept at [CHANGELOG.md](https://github.com/alicebob/miniredis/blob/master/CHANGELOG.md).
|
||||
|
||||
[](https://pkg.go.dev/github.com/alicebob/miniredis/v2)
|
||||
63
vendor/github.com/alicebob/miniredis/v2/check.go
generated
vendored
63
vendor/github.com/alicebob/miniredis/v2/check.go
generated
vendored
@@ -1,63 +0,0 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// T is implemented by Testing.T
|
||||
type T interface {
|
||||
Helper()
|
||||
Errorf(string, ...interface{})
|
||||
}
|
||||
|
||||
// CheckGet does not call Errorf() iff there is a string key with the
|
||||
// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`.
|
||||
func (m *Miniredis) CheckGet(t T, key, expected string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Get(key)
|
||||
if err != nil {
|
||||
t.Errorf("GET error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if found != expected {
|
||||
t.Errorf("GET error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckList does not call Errorf() iff there is a list key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`.
|
||||
func (m *Miniredis) CheckList(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.List(key)
|
||||
if err != nil {
|
||||
t.Errorf("List error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("List error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckSet does not call Errorf() iff there is a set key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`.
|
||||
func (m *Miniredis) CheckSet(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Members(key)
|
||||
if err != nil {
|
||||
t.Errorf("Set error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
sort.Strings(expected)
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("Set error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
67
vendor/github.com/alicebob/miniredis/v2/cmd_cluster.go
generated
vendored
67
vendor/github.com/alicebob/miniredis/v2/cmd_cluster.go
generated
vendored
@@ -1,67 +0,0 @@
|
||||
// Commands from https://redis.io/commands#cluster
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsCluster handles some cluster operations.
|
||||
func commandsCluster(m *Miniredis) {
|
||||
m.srv.Register("CLUSTER", m.cmdCluster)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdCluster(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "SLOTS":
|
||||
m.cmdClusterSlots(c, cmd, args)
|
||||
case "KEYSLOT":
|
||||
m.cmdClusterKeySlot(c, cmd, args)
|
||||
case "NODES":
|
||||
m.cmdClusterNodes(c, cmd, args)
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("ERR 'CLUSTER %s' not supported", strings.Join(args, " ")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CLUSTER SLOTS
|
||||
func (m *Miniredis) cmdClusterSlots(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteLen(1)
|
||||
c.WriteLen(3)
|
||||
c.WriteInt(0)
|
||||
c.WriteInt(16383)
|
||||
c.WriteLen(3)
|
||||
c.WriteBulk(m.srv.Addr().IP.String())
|
||||
c.WriteInt(m.srv.Addr().Port)
|
||||
c.WriteBulk("09dbe9720cda62f7865eabc5fd8857c5d2678366")
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER KEYSLOT
|
||||
func (m *Miniredis) cmdClusterKeySlot(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteInt(163)
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER NODES
|
||||
func (m *Miniredis) cmdClusterNodes(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk("e7d1eecce10fd6bb5eb35b9f99a514335d9ba9ca 127.0.0.1:7000@7000 myself,master - 0 0 1 connected 0-16383")
|
||||
})
|
||||
}
|
||||
2045
vendor/github.com/alicebob/miniredis/v2/cmd_command.go
generated
vendored
2045
vendor/github.com/alicebob/miniredis/v2/cmd_command.go
generated
vendored
File diff suppressed because it is too large
Load Diff
284
vendor/github.com/alicebob/miniredis/v2/cmd_connection.go
generated
vendored
284
vendor/github.com/alicebob/miniredis/v2/cmd_connection.go
generated
vendored
@@ -1,284 +0,0 @@
|
||||
// Commands from https://redis.io/commands#connection
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
func commandsConnection(m *Miniredis) {
|
||||
m.srv.Register("AUTH", m.cmdAuth)
|
||||
m.srv.Register("ECHO", m.cmdEcho)
|
||||
m.srv.Register("HELLO", m.cmdHello)
|
||||
m.srv.Register("PING", m.cmdPing)
|
||||
m.srv.Register("QUIT", m.cmdQuit)
|
||||
m.srv.Register("SELECT", m.cmdSelect)
|
||||
m.srv.Register("SWAPDB", m.cmdSwapdb)
|
||||
}
|
||||
|
||||
// PING
|
||||
func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
payload := ""
|
||||
if len(args) > 0 {
|
||||
payload = args[0]
|
||||
}
|
||||
|
||||
// PING is allowed in subscribed state
|
||||
if sub := getCtx(c).subscriber; sub != nil {
|
||||
c.Block(func(c *server.Writer) {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("pong")
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if payload == "" {
|
||||
c.WriteInline("PONG")
|
||||
return
|
||||
}
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
}
|
||||
|
||||
// AUTH
|
||||
func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 2 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
var opts = struct {
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
username: "default",
|
||||
password: args[0],
|
||||
}
|
||||
if len(args) == 2 {
|
||||
opts.username, opts.password = args[0], args[1]
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
c.WriteError("ERR AUTH <password> called without any password configured for the default user. Are you sure your configuration is correct?")
|
||||
return
|
||||
}
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.authenticated = true
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// HELLO
|
||||
func (m *Miniredis) cmdHello(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
version int
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.version, "ERR Protocol version is not an integer or out of range"); !ok {
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
switch opts.version {
|
||||
case 2, 3:
|
||||
default:
|
||||
c.WriteError("NOPROTO unsupported protocol version")
|
||||
return
|
||||
}
|
||||
|
||||
var checkAuth bool
|
||||
for len(args) > 0 {
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "AUTH":
|
||||
if len(args) < 3 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
opts.username, opts.password, args = args[1], args[2], args[3:]
|
||||
checkAuth = true
|
||||
case "SETNAME":
|
||||
if len(args) < 2 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
_, args = args[1], args[2:]
|
||||
default:
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
// redis ignores legacy "AUTH" if it's not enabled.
|
||||
checkAuth = false
|
||||
}
|
||||
if checkAuth {
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
getCtx(c).authenticated = true
|
||||
}
|
||||
|
||||
c.Resp3 = opts.version == 3
|
||||
|
||||
c.WriteMapLen(7)
|
||||
c.WriteBulk("server")
|
||||
c.WriteBulk("miniredis")
|
||||
c.WriteBulk("version")
|
||||
c.WriteBulk("6.0.5")
|
||||
c.WriteBulk("proto")
|
||||
c.WriteInt(opts.version)
|
||||
c.WriteBulk("id")
|
||||
c.WriteInt(42)
|
||||
c.WriteBulk("mode")
|
||||
c.WriteBulk("standalone")
|
||||
c.WriteBulk("role")
|
||||
c.WriteBulk("master")
|
||||
c.WriteBulk("modules")
|
||||
c.WriteLen(0)
|
||||
}
|
||||
|
||||
// ECHO
|
||||
func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
msg := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk(msg)
|
||||
})
|
||||
}
|
||||
|
||||
// SELECT
|
||||
func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.isValidCMD(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id int
|
||||
}
|
||||
if ok := optInt(c, args[0], &opts.id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.selectedDB = opts.id
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// SWAPDB
|
||||
func (m *Miniredis) cmdSwapdb(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id1 int
|
||||
id2 int
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.id1, "ERR invalid first DB index"); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optIntErr(c, args[1], &opts.id2, "ERR invalid second DB index"); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id1 < 0 || opts.id2 < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
m.swapDB(opts.id1, opts.id2)
|
||||
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// QUIT
|
||||
func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) {
|
||||
// QUIT isn't transactionfied and accepts any arguments.
|
||||
c.WriteOK()
|
||||
c.Close()
|
||||
}
|
||||
669
vendor/github.com/alicebob/miniredis/v2/cmd_generic.go
generated
vendored
669
vendor/github.com/alicebob/miniredis/v2/cmd_generic.go
generated
vendored
@@ -1,669 +0,0 @@
|
||||
// Commands from https://redis.io/commands#generic
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsGeneric handles EXPIRE, TTL, PERSIST, &c.
|
||||
func commandsGeneric(m *Miniredis) {
|
||||
m.srv.Register("COPY", m.cmdCopy)
|
||||
m.srv.Register("DEL", m.cmdDel)
|
||||
// DUMP
|
||||
m.srv.Register("EXISTS", m.cmdExists)
|
||||
m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second))
|
||||
m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second))
|
||||
m.srv.Register("KEYS", m.cmdKeys)
|
||||
// MIGRATE
|
||||
m.srv.Register("MOVE", m.cmdMove)
|
||||
// OBJECT
|
||||
m.srv.Register("PERSIST", m.cmdPersist)
|
||||
m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond))
|
||||
m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond))
|
||||
m.srv.Register("PTTL", m.cmdPTTL)
|
||||
m.srv.Register("RANDOMKEY", m.cmdRandomkey)
|
||||
m.srv.Register("RENAME", m.cmdRename)
|
||||
m.srv.Register("RENAMENX", m.cmdRenamenx)
|
||||
// RESTORE
|
||||
m.srv.Register("TOUCH", m.cmdTouch)
|
||||
m.srv.Register("TTL", m.cmdTTL)
|
||||
m.srv.Register("TYPE", m.cmdType)
|
||||
m.srv.Register("SCAN", m.cmdScan)
|
||||
// SORT
|
||||
m.srv.Register("UNLINK", m.cmdDel)
|
||||
}
|
||||
|
||||
// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
|
||||
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
|
||||
// converted to a duration.
|
||||
func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) {
|
||||
return func(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
value int
|
||||
}
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.value); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
// Key must be present.
|
||||
if _, ok := db.keys[opts.key]; !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if unix {
|
||||
db.ttl[opts.key] = m.at(opts.value, d)
|
||||
} else {
|
||||
db.ttl[opts.key] = time.Duration(opts.value) * d
|
||||
}
|
||||
db.keyVersion[opts.key]++
|
||||
db.checkTTL(opts.key)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TOUCH
|
||||
func (m *Miniredis) cmdTouch(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TTL
|
||||
func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// No such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Seconds()))
|
||||
})
|
||||
}
|
||||
|
||||
// PTTL
|
||||
func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Nanoseconds() / 1000000))
|
||||
})
|
||||
}
|
||||
|
||||
// PERSIST
|
||||
func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.ttl[key]; !ok {
|
||||
// no expire value
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
delete(db.ttl, key)
|
||||
db.keyVersion[key]++
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// DEL and UNLINK
|
||||
func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
db.del(key, true) // delete expire
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TYPE
|
||||
func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError("usage error")
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteInline("none")
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInline(t)
|
||||
})
|
||||
}
|
||||
|
||||
// EXISTS
|
||||
func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
found := 0
|
||||
for _, k := range args {
|
||||
if db.exists(k) {
|
||||
found++
|
||||
}
|
||||
}
|
||||
c.WriteInt(found)
|
||||
})
|
||||
}
|
||||
|
||||
// MOVE
|
||||
func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
targetDB int
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
opts.targetDB, _ = strconv.Atoi(args[1])
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if ctx.selectedDB == opts.targetDB {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
db := m.db(ctx.selectedDB)
|
||||
targetDB := m.db(opts.targetDB)
|
||||
|
||||
if !db.move(opts.key, targetDB) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// KEYS
|
||||
func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
keys, _ := matchKeys(db.allKeys(), key)
|
||||
c.WriteLen(len(keys))
|
||||
for _, s := range keys {
|
||||
c.WriteBulk(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RANDOMKEY
|
||||
func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if len(db.keys) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
nr := m.randIntn(len(db.keys))
|
||||
for k := range db.keys {
|
||||
if nr == 0 {
|
||||
c.WriteBulk(k)
|
||||
return
|
||||
}
|
||||
nr--
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RENAME
|
||||
func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// RENAMENX
|
||||
func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if db.exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// SCAN
|
||||
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
cursor int
|
||||
withMatch bool
|
||||
match string
|
||||
withType bool
|
||||
_type string
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
// MATCH, COUNT and TYPE options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
// we do nothing with count
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if _, err := strconv.Atoi(args[1]); err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match, args = args[1], args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "type" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withType = true
|
||||
opts._type, args = strings.ToLower(args[1]), args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// We return _all_ (matched) keys every time.
|
||||
|
||||
if opts.cursor != 0 {
|
||||
// Invalid cursor.
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
|
||||
var keys []string
|
||||
|
||||
if opts.withType {
|
||||
keys = make([]string, 0)
|
||||
for k, t := range db.keys {
|
||||
// type must be given exactly; no pattern matching is performed
|
||||
if t == opts._type {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys) // To make things deterministic.
|
||||
} else {
|
||||
keys = db.allKeys()
|
||||
}
|
||||
|
||||
if opts.withMatch {
|
||||
keys, _ = matchKeys(keys, opts.match)
|
||||
}
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(len(keys))
|
||||
for _, k := range keys {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// COPY
|
||||
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts = struct {
|
||||
from string
|
||||
to string
|
||||
destinationDB int
|
||||
replace bool
|
||||
}{
|
||||
destinationDB: -1,
|
||||
}
|
||||
|
||||
opts.from, opts.to, args = args[0], args[1], args[2:]
|
||||
for len(args) > 0 {
|
||||
switch strings.ToLower(args[0]) {
|
||||
case "db":
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
db, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if db < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
return
|
||||
}
|
||||
opts.destinationDB = db
|
||||
args = args[2:]
|
||||
case "replace":
|
||||
opts.replace = true
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
fromDB, toDB := ctx.selectedDB, opts.destinationDB
|
||||
if toDB == -1 {
|
||||
toDB = fromDB
|
||||
}
|
||||
|
||||
if fromDB == toDB && opts.from == opts.to {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
|
||||
if !m.db(fromDB).exists(opts.from) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if !opts.replace {
|
||||
if m.db(toDB).exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.copy(m.db(fromDB), opts.from, m.db(toDB), opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
609
vendor/github.com/alicebob/miniredis/v2/cmd_geo.go
generated
vendored
609
vendor/github.com/alicebob/miniredis/v2/cmd_geo.go
generated
vendored
@@ -1,609 +0,0 @@
|
||||
// Commands from https://redis.io/commands#geo
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsGeo handles GEOADD, GEORADIUS etc.
|
||||
func commandsGeo(m *Miniredis) {
|
||||
m.srv.Register("GEOADD", m.cmdGeoadd)
|
||||
m.srv.Register("GEODIST", m.cmdGeodist)
|
||||
m.srv.Register("GEOPOS", m.cmdGeopos)
|
||||
m.srv.Register("GEORADIUS", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUS_RO", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUSBYMEMBER", m.cmdGeoradiusbymember)
|
||||
m.srv.Register("GEORADIUSBYMEMBER_RO", m.cmdGeoradiusbymember)
|
||||
}
|
||||
|
||||
// GEOADD
|
||||
func (m *Miniredis) cmdGeoadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 || len(args[1:])%3 != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
toSet := map[string]float64{}
|
||||
for len(args) > 2 {
|
||||
rawLong, rawLat, name := args[0], args[1], args[2]
|
||||
args = args[3:]
|
||||
longitude, err := strconv.ParseFloat(rawLong, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(rawLat, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
|
||||
if latitude < -85.05112878 ||
|
||||
latitude > 85.05112878 ||
|
||||
longitude < -180 ||
|
||||
longitude > 180 {
|
||||
c.WriteError(fmt.Sprintf("ERR invalid longitude,latitude pair %.6f,%.6f", longitude, latitude))
|
||||
return
|
||||
}
|
||||
|
||||
toSet[name] = float64(toGeohash(longitude, latitude))
|
||||
}
|
||||
|
||||
set := 0
|
||||
for name, score := range toSet {
|
||||
if db.ssetAdd(key, score, name) {
|
||||
set++
|
||||
}
|
||||
}
|
||||
c.WriteInt(set)
|
||||
})
|
||||
}
|
||||
|
||||
// GEODIST
|
||||
func (m *Miniredis) cmdGeodist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, from, to, args := args[0], args[1], args[2], args[3:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
unit := "m"
|
||||
if len(args) > 0 {
|
||||
unit, args = args[0], args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
toMeter := parseUnit(unit)
|
||||
if toMeter == 0 {
|
||||
c.WriteError(msgUnsupportedUnit)
|
||||
return
|
||||
}
|
||||
|
||||
members := db.sortedsetKeys[key]
|
||||
fromD, okFrom := members.get(from)
|
||||
toD, okTo := members.get(to)
|
||||
if !okFrom || !okTo {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
fromLo, fromLat := fromGeohash(uint64(fromD))
|
||||
toLo, toLat := fromGeohash(uint64(toD))
|
||||
|
||||
dist := distance(fromLat, fromLo, toLat, toLo) / toMeter
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", dist))
|
||||
})
|
||||
}
|
||||
|
||||
// GEOPOS
|
||||
func (m *Miniredis) cmdGeopos(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(args))
|
||||
for _, l := range args {
|
||||
if !db.ssetExists(key, l) {
|
||||
c.WriteLen(-1)
|
||||
continue
|
||||
}
|
||||
score := db.ssetScore(key, l)
|
||||
c.WriteLen(2)
|
||||
long, lat := fromGeohash(uint64(score))
|
||||
c.WriteBulk(fmt.Sprintf("%f", long))
|
||||
c.WriteBulk(fmt.Sprintf("%f", lat))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type geoDistance struct {
|
||||
Name string
|
||||
Score float64
|
||||
Distance float64
|
||||
Longitude float64
|
||||
Latitude float64
|
||||
}
|
||||
|
||||
// GEORADIUS and GEORADIUS_RO
|
||||
func (m *Miniredis) cmdGeoradius(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 5 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
longitude, err := strconv.ParseFloat(args[1], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
radius, err := strconv.ParseFloat(args[3], 64)
|
||||
if err != nil || radius < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
toMeter := parseUnit(args[4])
|
||||
if toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[5:]
|
||||
|
||||
var opts struct {
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUS_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
members := db.ssetElements(key)
|
||||
|
||||
matches := withinRadius(members, longitude, latitude, radius*toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GEORADIUSBYMEMBER and GEORADIUSBYMEMBER_RO
|
||||
func (m *Miniredis) cmdGeoradiusbymember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
member string
|
||||
radius float64
|
||||
toMeter float64
|
||||
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}{
|
||||
key: args[0],
|
||||
member: args[1],
|
||||
}
|
||||
|
||||
r, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil || r < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
opts.radius = r
|
||||
|
||||
opts.toMeter = parseUnit(args[3])
|
||||
if opts.toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[4:]
|
||||
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUSBYMEMBER_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(opts.key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// get position of member
|
||||
if !db.ssetExists(opts.key, opts.member) {
|
||||
c.WriteError("ERR could not decode requested zset member")
|
||||
return
|
||||
}
|
||||
score := db.ssetScore(opts.key, opts.member)
|
||||
longitude, latitude := fromGeohash(uint64(score))
|
||||
|
||||
members := db.ssetElements(opts.key)
|
||||
matches := withinRadius(members, longitude, latitude, opts.radius*opts.toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/opts.toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/opts.toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func withinRadius(members []ssElem, longitude, latitude, radius float64) []geoDistance {
|
||||
matches := []geoDistance{}
|
||||
for _, el := range members {
|
||||
elLo, elLat := fromGeohash(uint64(el.score))
|
||||
distanceInMeter := distance(latitude, longitude, elLat, elLo)
|
||||
|
||||
if distanceInMeter <= radius {
|
||||
matches = append(matches, geoDistance{
|
||||
Name: el.member,
|
||||
Score: el.score,
|
||||
Distance: distanceInMeter,
|
||||
Longitude: elLo,
|
||||
Latitude: elLat,
|
||||
})
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
func parseUnit(u string) float64 {
|
||||
switch u {
|
||||
case "m":
|
||||
return 1
|
||||
case "km":
|
||||
return 1000
|
||||
case "mi":
|
||||
return 1609.34
|
||||
case "ft":
|
||||
return 0.3048
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user