mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-30 03:01:58 +08:00
Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40e9cdb383 | ||
|
|
e9f72154b6 | ||
|
|
69412dd23c | ||
|
|
686c35ac0c | ||
|
|
65c78534dc | ||
|
|
83db7fff56 | ||
|
|
10a02ab3c7 | ||
|
|
5058333f36 | ||
|
|
b2ab984949 | ||
|
|
5523d15a9b | ||
|
|
c6c7c296f6 | ||
|
|
4c682384c5 | ||
|
|
624dde0986 | ||
|
|
dc4eecdfb7 | ||
|
|
e8f151bf1f | ||
|
|
4983b6b977 | ||
|
|
99e50ae74e | ||
|
|
8e52e49b94 | ||
|
|
4c0c862dcd | ||
|
|
2f2d867170 | ||
|
|
916d022093 | ||
|
|
11e0256959 | ||
|
|
858a28ea4b | ||
|
|
82f32cf53d | ||
|
|
b4e2c61a72 | ||
|
|
2b0732b815 | ||
|
|
450b8a2817 | ||
|
|
0d880fa9bd | ||
|
|
a5310bdfdf | ||
|
|
96b96e8248 | ||
|
|
a786a0a526 | ||
|
|
c6297ee9e3 | ||
|
|
2f94751220 | ||
|
|
89a786832d | ||
|
|
8f93e35f1e | ||
|
|
ba6b12a303 | ||
|
|
a3fe5d4729 | ||
|
|
bb5fc40fe4 | ||
|
|
fbb1fb25aa | ||
|
|
fac733fd71 | ||
|
|
add87fea2e | ||
|
|
58f9fed336 | ||
|
|
1574443981 | ||
|
|
44bac0adc5 | ||
|
|
e784c755ae | ||
|
|
eafc2d91fc | ||
|
|
0df69a4a4e | ||
|
|
321a0514fe | ||
|
|
00387593c9 | ||
|
|
af78d10870 |
92
.github/workflows/build.yml
vendored
92
.github/workflows/build.yml
vendored
@@ -3,43 +3,71 @@ name: build
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -race ./... && echo true
|
||||
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.21
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
- name: Test
|
||||
run: go test -race ./... && echo true
|
||||
|
||||
coverage:
|
||||
name: Test with Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.19'
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
go mod download
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
go test -race -covermode atomic -coverprofile=covprofile ./...
|
||||
- name: Install goveralls
|
||||
run: go install github.com/mattn/goveralls@latest
|
||||
- name: Send coverage
|
||||
env:
|
||||
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: goveralls -coverprofile=covprofile -service=github
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.21'
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
go mod download
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
go test -race -covermode atomic -coverprofile=covprofile ./...
|
||||
- name: Install goveralls
|
||||
run: go install github.com/mattn/goveralls@latest
|
||||
- name: Send coverage
|
||||
env:
|
||||
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: goveralls -coverprofile=covprofile -service=github
|
||||
|
||||
docker:
|
||||
if: github.repository == 'mochi-mqtt/server' && startsWith(github.ref, 'refs/tags/v')
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: mochimqtt/server
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=raw,value=latest,enable=${{ endsWith(github.ref, 'main') }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
cmd/mqtt
|
||||
.DS_Store
|
||||
*.db
|
||||
.idea
|
||||
.idea
|
||||
vendor
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.19.0-alpine3.15 AS builder
|
||||
FROM golang:1.21.0-alpine3.18 AS builder
|
||||
|
||||
RUN apk update
|
||||
RUN apk add git
|
||||
|
||||
505
README-CN.md
Normal file
505
README-CN.md
Normal file
@@ -0,0 +1,505 @@
|
||||
# Mochi-MQTT Server
|
||||
|
||||
<p align="center">
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-mqtt/server?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-mqtt/server/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-mqtt/server/v2)
|
||||
[](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
</p>
|
||||
|
||||
[English](README.md) | [简体中文](README-CN.md) | [日本語](README-JP.md) | [招募翻译者!](https://github.com/orgs/mochi-mqtt/discussions/310)
|
||||
|
||||
|
||||
🎆 **mochi-co/mqtt 现在已经是新的 mochi-mqtt 组织的一部分。** 详细信息请[阅读公告.](https://github.com/orgs/mochi-mqtt/discussions/271)
|
||||
|
||||
|
||||
### Mochi-MQTT 是一个完全兼容的、可嵌入的高性能 Go MQTT v5(以及 v3.1.1)中间件/服务器。
|
||||
|
||||
Mochi MQTT 是一个[完全兼容 MQTT v5](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) 的可嵌入的中间件/服务器,完全使用 Go 语言编写,旨在用于遥测和物联网项目的开发。它可以作为独立的二进制文件使用,也可以嵌入到你自己的应用程序中作为库来使用,经过精心设计以实现尽可能的轻量化和快速部署,同时也极为重视代码的质量和可维护性。
|
||||
|
||||
#### 什么是 MQTT?
|
||||
MQTT 代表 MQ Telemetry Transport。它是一种发布/订阅、非常简单和轻量的消息传递协议,专为受限设备和低带宽、高延迟或不可靠网络设计而成([了解更多](https://mqtt.org/faq))。Mochi MQTT 实现了完整的 MQTT 协议的 5.0.0 版本。
|
||||
|
||||
#### Mochi-MQTT 特性
|
||||
- 完全兼容 MQTT v5 功能,与 MQTT v3.1.1 和 v3.0.0 兼容:
|
||||
- MQTT v5 用户和数据包属性
|
||||
- 主题别名(Topic Aliases)
|
||||
- 共享订阅(Shared Subscriptions)
|
||||
- 订阅选项和订阅标识符(Identifiers)
|
||||
- 消息过期(Message Expiry)
|
||||
- 客户端会话过期(Client Session Expiry)
|
||||
- 发送和接收 QoS 流量控制配额(Flow Control Quotas)
|
||||
- 服务器端的断开连接和数据包的权限验证(Auth Packets)
|
||||
- 遗愿消息延迟间隔(Will Delay Intervals)
|
||||
- 还有 Mochi MQTT v1 的所有原始 MQTT 功能,例如完全的 QoS(0,1,2)、$SYS 主题、保留消息等。
|
||||
- 面向开发者:
|
||||
- 核心代码都已开放并可访问,以便开发者完全控制。
|
||||
- 功能丰富且灵活的基于钩子(Hook)的接口系统,支持便捷的“插件(plugin)”开发。
|
||||
- 使用特殊的内联客户端(inline client)进行服务端的消息发布,也支持服务端伪装成现有的客户端。
|
||||
- 高性能且稳定:
|
||||
- 基于经典前缀树 Trie 的主题-订阅模型。
|
||||
- 客户端特定的写入缓冲区,避免因读取速度慢或客户端不规范行为而产生的问题。
|
||||
- 通过所有 [Paho互操作性测试](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability)(MQTT v5 和 MQTT v3)。
|
||||
- 超过一千多个经过仔细考虑的单元测试场景。
|
||||
- 支持 TCP、Websocket(包括 SSL/TLS)和$SYS 服务状态监控。
|
||||
- 内置 基于Redis、Badger 和 Bolt 的持久化(使用Hook钩子,你也可以自己创建)。
|
||||
- 内置基于规则的认证和 ACL 权限管理(使用Hook钩子,你也可以自己创建)。
|
||||
|
||||
### 兼容性说明(Compatibility Notes)
|
||||
由于 v5 规范与 MQTT 的早期版本存在重叠,因此服务器可以接受 v5 和 v3 客户端,但在连接了 v5 和 v3 客户端的情况下,为 v5 客户端提供的属性和功能将会对 v3 客户端进行降级处理(例如用户属性)。
|
||||
|
||||
对于 MQTT v3.0.0 和 v3.1.1 的支持被视为混合兼容性。在 v3 规范中没有明确限制的情况下,将使用更新的和以安全为首要考虑的 v5 规范 - 例如保留的消息(retained messages)的过期处理,待发送消息(inflight messages)的过期处理、客户端过期处理以及QOS消息数量的限制等。
|
||||
|
||||
#### 版本更新时间
|
||||
除非涉及关键问题,新版本通常在周末发布。
|
||||
|
||||
## 规划路线图(Roadmap)
|
||||
- 请[提出问题](https://github.com/mochi-mqtt/server/issues)来请求新功能或新的hook钩子接口!
|
||||
- 集群支持。
|
||||
- 统计度量支持。
|
||||
- 配置文件支持(支持 Docker)。
|
||||
|
||||
## 快速开始(Quick Start)
|
||||
### 使用 Go 运行服务端
|
||||
Mochi MQTT 可以作为独立的中间件使用。只需拉取此仓库代码,然后在 [cmd](cmd) 文件夹中运行 [cmd/main.go](cmd/main.go) ,默认将开启下面几个服务端口, tcp (:1883)、websocket (:1882) 和服务状态监控 (:8080) 。
|
||||
|
||||
```
|
||||
cd cmd
|
||||
go build -o mqtt && ./mqtt
|
||||
```
|
||||
|
||||
### 使用 Docker
|
||||
|
||||
你现在可以从 Docker Hub 仓库中拉取并运行Mochi MQTT[官方镜像](https://hub.docker.com/r/mochimqtt/server):
|
||||
```sh
|
||||
docker pull mochimqtt/server
|
||||
或者
|
||||
docker run mochimqtt/server
|
||||
```
|
||||
|
||||
我们还在积极完善这部分的工作,现在正在实现使用[配置文件的启动](https://github.com/orgs/mochi-mqtt/projects/2)方式。更多关于 Docker 的支持正在[这里](https://github.com/orgs/mochi-mqtt/discussions/281#discussion-5544545)和[这里](https://github.com/orgs/mochi-mqtt/discussions/209)进行讨论。如果你有在这个场景下使用 Mochi-MQTT,也可以参与到讨论中来。
|
||||
|
||||
我们提供了一个简单的 Dockerfile,用于运行 cmd/main.go 中的 Websocket(:1882)、TCP(:1883) 和服务端状态信息(:8080)这三个服务监听:
|
||||
|
||||
```sh
|
||||
docker build -t mochi:latest .
|
||||
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
```
|
||||
|
||||
|
||||
## 使用 Mochi MQTT 进行开发
|
||||
### 将Mochi MQTT作为包导入使用
|
||||
将 Mochi MQTT 作为一个包导入只需要几行代码即可开始使用。
|
||||
``` go
|
||||
import (
|
||||
"log"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建信号用于等待服务端关闭信号
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 创建新的 MQTT 服务器。
|
||||
server := mqtt.New(nil)
|
||||
|
||||
// 允许所有连接(权限)。
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// 在标1883端口上创建一个 TCP 服务端。
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 服务端等待关闭信号
|
||||
<-done
|
||||
|
||||
// 关闭服务端时需要做的一些清理工作
|
||||
}
|
||||
```
|
||||
|
||||
在 [examples](examples) 文件夹中可以找到更多使用不同配置运行服务端的示例。
|
||||
|
||||
#### 网络监听器 (Network Listeners)
|
||||
|
||||
服务端内置了一些已经实现的网络监听(Network Listeners),这些Listeners允许服务端接受不同协议的连接。当前的监听Listeners有这些:
|
||||
|
||||
| Listener | Usage |
|
||||
|------------------------------|----------------------------------------------------------------------------------------------|
|
||||
| listeners.NewTCP | 一个 TCP 监听器,接收TCP连接 |
|
||||
| listeners.NewUnixSock | 一个 Unix 套接字监听器 |
|
||||
| listeners.NewNet | 一个 net.Listener 监听 |
|
||||
| listeners.NewWebsocket | 一个 Websocket 监听器 |
|
||||
| listeners.NewHTTPStats | 一个 HTTP $SYS 服务状态监听器 |
|
||||
| listeners.NewHTTPHealthCheck | 一个 HTTP 健康检测监听器,用于为例如云基础设施提供健康检查响应 |
|
||||
|
||||
> 可以使用listeners.Listener接口开发新的监听器。如果有兴趣,你可以实现自己的Listener,如果你在此期间你有更好的建议或疑问,你可以[提交问题](https://github.com/mochi-mqtt/server/issues)给我们。
|
||||
|
||||
可以在*listeners.Config 中配置TLS,传递给Listener使其支持TLS。
|
||||
我们提供了一些示例,可以在 [示例](examples) 文件夹或 [cmd/main.go](cmd/main.go) 中找到。
|
||||
|
||||
### 服务端选项和功能(Server Options and Capabilities)
|
||||
|
||||
有许多可配置的选项(Options)可用于更改服务器的行为或限制对某些功能的访问。
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
Capabilities: mqtt.Capabilities{
|
||||
MaximumSessionExpiryInterval: 3600,
|
||||
Compatibilities: mqtt.Compatibilities{
|
||||
ObscureNotAuthorized: true,
|
||||
},
|
||||
},
|
||||
ClientNetWriteBufferSize: 4096,
|
||||
ClientNetReadBufferSize: 4096,
|
||||
SysTopicResendInterval: 10,
|
||||
InlineClient: false,
|
||||
})
|
||||
```
|
||||
请参考 mqtt.Options、mqtt.Capabilities 和 mqtt.Compatibilities 结构体,以查看完整的所有服务端选项。ClientNetWriteBufferSize 和 ClientNetReadBufferSize 可以根据你的需求配置调整每个客户端的内存使用状况。
|
||||
|
||||
### 默认配置说明(Default Configuration Notes)
|
||||
|
||||
关于决定默认配置的值,在这里进行一些说明:
|
||||
|
||||
- 默认情况下,server.Options.Capabilities.MaximumMessageExpiryInterval 的值被设置为 86400(24小时),以防止在使用默认配置时网络上暴露服务器而受到恶意DOS攻击(如果不配置到期时间将允许无限数量的保留retained/待发送inflight消息累积)。如果您在一个受信任的环境中运行,或者您有更大的保留期容量,您可以选择覆盖此设置(设置为0 以取消到期限制)。
|
||||
|
||||
## 事件钩子(Event Hooks)
|
||||
|
||||
服务端有一个通用的事件钩子(Event Hooks)系统,它允许开发人员在服务器和客户端生命周期的各个阶段定制添加和修改服务端的功能。这些通用Hook钩子用于提供从认证(authentication)、持久性存储(persistent storage)到调试工具(debugging tools)等各种功能。
|
||||
|
||||
钩子(Hook)是可叠加的 - 你可以向服务器添加多个钩子(Hook),它们将按添加的顺序运行。一些钩子(Hook)修改值,这些修改后的值将在所有钩子返回之前传递给后续的钩子(Hook)。
|
||||
|
||||
| 类型 | 导入包 | 描述 |
|
||||
|----------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|
|
||||
| 访问控制 | [mochi-mqtt/server/hooks/auth . AllowHook](hooks/auth/allow_all.go) | AllowHook 允许所有客户端连接访问并读写所有主题。 |
|
||||
| 访问控制 | [mochi-mqtt/server/hooks/auth . Auth](hooks/auth/auth.go) | 基于规则的访问权限控制。 |
|
||||
| 数据持久性 | [mochi-mqtt/server/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | 使用 [BoltDB](https://dbdb.io/db/boltdb) 进行持久性存储(已弃用)。 |
|
||||
| 数据持久性 | [mochi-mqtt/server/hooks/storage/badger](hooks/storage/badger/badger.go) | 使用 [BadgerDB](https://github.com/dgraph-io/badger) 进行持久性存储。 |
|
||||
| 数据持久性 | [mochi-mqtt/server/hooks/storage/redis](hooks/storage/redis/redis.go) | 使用 [Redis](https://redis.io) 进行持久性存储。 |
|
||||
| 调试跟踪 | [mochi-mqtt/server/hooks/debug](hooks/debug/debug.go) | 调试输出以查看数据包在服务端的链路追踪。 |
|
||||
|
||||
许多内部函数都已开放给开发者,你可以参考上述示例创建自己的Hook钩子。如果你有更好的关于Hook钩子方面的建议或者疑问,你可以[提交问题](https://github.com/mochi-mqtt/server/issues)给我们。
|
||||
|
||||
### 访问控制(Access Control)
|
||||
|
||||
#### 允许所有(Allow Hook)
|
||||
|
||||
默认情况下,Mochi MQTT 使用拒绝所有(DENY-ALL)的访问控制规则。要允许连接,必须实现一个访问控制的钩子(Hook)来替代默认的(DENY-ALL)钩子。其中最简单的钩子(Hook)是 auth.AllowAll 钩子(Hook),它为所有连接、订阅和发布提供允许所有(ALLOW-ALL)的规则。这也是使用最简单的钩子:
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
```
|
||||
|
||||
>如果你将服务器暴露在互联网或不受信任的网络上,请不要这样做 - 它真的应该仅用于开发、测试和调试。
|
||||
|
||||
#### 权限认证(Auth Ledger)
|
||||
|
||||
权限认证钩子(Auth Ledger hook)使用结构化的定义来制定访问规则。认证规则分为两种形式:身份规则(连接时使用)和 ACL权限规则(发布订阅时使用)。
|
||||
|
||||
身份规则(Auth rules)有四个可选参数和一个是否允许参数:
|
||||
|
||||
| 参数 | 说明 |
|
||||
| -- | -- |
|
||||
| Client | 客户端的客户端 ID |
|
||||
| Username | 客户端的用户名 |
|
||||
| Password | 客户端的密码 |
|
||||
| Remote | 客户端的远程地址或 IP |
|
||||
| Allow | true(允许此用户)或 false(拒绝此用户) |
|
||||
|
||||
ACL权限规则(ACL rules)有三个可选参数和一个主题匹配参数:
|
||||
| 参数 | 说明 |
|
||||
| -- | -- |
|
||||
| Client | 客户端的客户端 ID |
|
||||
| Username | 客户端的用户名 |
|
||||
| Remote | 客户端的远程地址或 IP |
|
||||
| Filters | 用于匹配的主题数组 |
|
||||
|
||||
规则按索引顺序(0,1,2,3)处理,并在匹配到第一个规则时返回。请查看 [hooks/auth/ledger.go](hooks/auth/ledger.go) 的具体实现。
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Ledger: &auth.Ledger{
|
||||
Auth: auth.AuthRules{ // Auth 默认情况下禁止所有连接
|
||||
{Username: "peach", Password: "password1", Allow: true},
|
||||
{Username: "melon", Password: "password2", Allow: true},
|
||||
{Remote: "127.0.0.1:*", Allow: true},
|
||||
{Remote: "localhost:*", Allow: true},
|
||||
},
|
||||
ACL: auth.ACLRules{ // ACL 默认情况下允许所有连接
|
||||
{Remote: "127.0.0.1:*"}, // 本地用户允许所有连接
|
||||
{
|
||||
// 用户 melon 可以读取和写入自己的主题
|
||||
Username: "melon", Filters: auth.Filters{
|
||||
"melon/#": auth.ReadWrite,
|
||||
"updates/#": auth.WriteOnly, // 可以写入 updates,但不能从其他人那里读取 updates
|
||||
},
|
||||
},
|
||||
{
|
||||
// 其他的客户端没有发布的权限
|
||||
Filters: auth.Filters{
|
||||
"#": auth.ReadOnly,
|
||||
"updates/#": auth.Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
规则还可以存储为 JSON 或 YAML,并使用 Data 字段加载文件的二进制数据:
|
||||
|
||||
```go
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // 从字节数组(文件二进制)读取规则:yaml 或 json
|
||||
})
|
||||
```
|
||||
详细信息请参阅 [examples/auth/encoded/main.go](examples/auth/encoded/main.go)。
|
||||
|
||||
### 持久化存储(Persistent Storage)
|
||||
|
||||
#### Redis
|
||||
|
||||
我们提供了一个基本的 Redis 存储钩子(Hook),用于为服务端提供数据持久性。你可以将这个Redis的钩子(Hook)添加到服务器中,Redis的一些参数也是可以配置的。这个钩子(Hook)里使用 github.com/go-redis/redis/v8 这个库,可以通过 Options 来配置一些参数。
|
||||
|
||||
```go
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
Addr: "localhost:6379", // Redis服务端地址
|
||||
Password: "", // Redis服务端的密码
|
||||
DB: 0, // Redis数据库的index
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
有关 Redis 钩子的工作原理或如何使用它的更多信息,请参阅 [examples/persistence/redis/main.go](examples/persistence/redis/main.go) 或 [hooks/storage/redis](hooks/storage/redis) 。
|
||||
|
||||
#### Badger DB
|
||||
|
||||
如果您更喜欢基于文件的存储,还有一个 BadgerDB 存储钩子(Hook)可用。它可以以与其他钩子大致相同的方式添加和配置(具有较少的选项)。
|
||||
|
||||
```go
|
||||
err := server.AddHook(new(badger.Hook), &badger.Options{
|
||||
Path: badgerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
有关 Badger 钩子(Hook)的工作原理或如何使用它的更多信息,请参阅 [examples/persistence/badger/main.go](examples/persistence/badger/main.go) 或 [hooks/storage/badger](hooks/storage/badger)。
|
||||
|
||||
还有一个 BoltDB 钩子(Hook),已被弃用,推荐使用 Badger,但如果你想使用它,请参考 [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go)。
|
||||
|
||||
## 使用事件钩子 Event Hooks 进行开发
|
||||
|
||||
在服务端和客户端生命周期中,开发者可以使用各种钩子(Hook)增加对服务端或客户端的一些自定义的处理。
|
||||
所有的钩子都定义在mqtt.Hook这个接口中了,可以在 [hooks.go](hooks.go) 中找到这些钩子(Hook)函数。
|
||||
|
||||
> 最灵活的事件钩子是 OnPacketRead、OnPacketEncode 和 OnPacketSent - 这些钩子可以用来控制和修改所有传入和传出的数据包。
|
||||
|
||||
| 钩子函数 | 说明 |
|
||||
|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| OnStarted | 在服务器成功启动时调 |用。 |
|
||||
| OnStopped | 在服务器成功停止时调用。 |
|
||||
| OnConnectAuthenticate | 当用户尝试与服务器进行身份验证时调用。必须实现此方法来允许或拒绝对服务器的访问(请参阅 hooks/auth/allow_all 或 basic)。它可以在自定义Hook钩子中使用,以检查连接的用户是否与现有用户数据库中的用户匹配。如果允许访问,则返回 true。 |
|
||||
| OnACLCheck | 当用户尝试发布或订阅主题时调用,用来检测ACL规则。 |
|
||||
| OnSysInfoTick | 当 $SYS 主题相关的消息被发布时调用。 |
|
||||
| OnConnect | 当新客户端连接时调用,可能返回一个错误或错误码以中断客户端的连接。 |
|
||||
| OnSessionEstablish | 在新客户端连接并进行身份验证后,会立即调用此方法,并在会话建立和发送CONNACK之前立即调用。 |
|
||||
| OnSessionEstablished | 在新客户端成功建立会话(在OnConnect之后)时调用。 |
|
||||
| OnDisconnect | 当客户端因任何原因断开连接时调用。 |
|
||||
| OnAuthPacket | 当接收到认证数据包时调用。它旨在允许开发人员创建自己的 MQTT v5 认证数据包处理机制。在这里允许数据包的修改。 |
|
||||
| OnPacketRead | 当从客户端接收到数据包时调用。允许对数据包进行修改。 |
|
||||
| OnPacketEncode | 在数据包被编码并发送给客户端之前立即调用。允许修改数据包。 |
|
||||
| OnPacketSent | 在数据包已发送给客户端后调用。 |
|
||||
| OnPacketProcessed | 在数据包已接收并成功由服务端处理后调用。 |
|
||||
| OnSubscribe | 当客户端订阅一个或多个主题时调用。允许修改数据包。 |
|
||||
| OnSubscribed | 当客户端成功订阅一个或多个主题时调用。 |
|
||||
| OnSelectSubscribers | 当订阅者已被关联到一个主题中,在选择共享订阅的订阅者之前调用。允许接收者修改。 |
|
||||
| OnUnsubscribe | 当客户端取消订阅一个或多个主题时调用。允许包修改。 |
|
||||
| OnUnsubscribed | 当客户端成功取消订阅一个或多个主题时调用。 |
|
||||
| OnPublish | 当客户端发布消息时调用。允许修改数据包。 |
|
||||
| OnPublished | 当客户端向订阅者发布消息后调用。|
|
||||
| OnPublishDropped | 消息传递给客户端之前消息已被丢弃,将调用此方法。 例如当客户端响应时间过长需要丢弃消息时。 |
|
||||
| OnRetainMessage | 当消息被保留时调用。 |
|
||||
| OnRetainPublished | 当保留的消息被发布给客户端时调用。 |
|
||||
| OnQosPublish | 当发出QoS >= 1 的消息给订阅者后调用。 |
|
||||
| OnQosComplete | 在消息的QoS流程走完之后调用。 |
|
||||
| OnQosDropped | 在消息的QoS流程未完成,同时消息到期时调用。 |
|
||||
| OnPacketIDExhausted | 当packet ids已经用完后,没有可用的id可再分配时调用。 |
|
||||
| OnWill | 当客户端断开连接并打算发布遗嘱消息时调用。允许修改数据包。 |
|
||||
| OnWillSent | 遗嘱消息发送完成后被调用。 |
|
||||
| OnClientExpired | 在客户端会话已过期并应删除时调用。 |
|
||||
| OnRetainedExpired | 在保留的消息已过期并应删除时调用。| |
|
||||
| StoredClients | 这个接口需要返回客户端列表,例如从持久化数据库中获取客户端列表。 |
|
||||
| StoredSubscriptions | 返回客户端的所有订阅,例如从持久化数据库中获取客户端的订阅列表。 |
|
||||
| StoredInflightMessages | 返回待发送消息(inflight messages),例如从持久化数据库中获取到还有哪些消息未完成传输。 |
|
||||
| StoredRetainedMessages | 返回保留的消息,例如从持久化数据库获取保留的消息。 |
|
||||
| StoredSysInfo | 返回存储的系统状态信息,例如从持久化数据库获取的系统状态信息。 |
|
||||
|
||||
如果你想自己实现一个持久化存储的Hook钩子,请参考现有的持久存储Hook钩子以获取灵感和借鉴。如果您正在构建一个身份验证Hook钩子,您将需要实现OnACLCheck 和 OnConnectAuthenticate这两个函数接口。
|
||||
|
||||
### 内联客户端 (Inline Client v2.4.0+支持)
|
||||
|
||||
现在可以通过使用内联客户端功能直接在服务端上订阅主题和发布消息。内联客户端是内置在服务端中的特殊的客户端,可以在服务端的配置中启用:
|
||||
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true,
|
||||
})
|
||||
```
|
||||
启用上述配置后,你将能够使用 server.Publish、server.Subscribe 和 server.Unsubscribe 方法来在服务端中直接发布和接收消息。
|
||||
|
||||
具体如何使用请参考 [direct examples](examples/direct/main.go) 。
|
||||
|
||||
#### 内联发布(Inline Publish)
|
||||
要想在服务端中直接发布Publish一个消息,可以使用 `server.Publish`方法。
|
||||
|
||||
```go
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
```
|
||||
> 在这种情况下,QoS级别只对订阅者有效,按照 MQTT v5 规范。
|
||||
|
||||
#### 内联订阅(Inline Subscribe)
|
||||
要想在服务端中直接订阅一个主题,可以使用 `server.Subscribe`方法并提供一个处理订阅消息的回调函数。内联订阅的 QoS默认都是0。如果您希望对相同的主题有多个回调,可以使用 MQTTv5 的 subscriptionId 属性进行区分。
|
||||
|
||||
```go
|
||||
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||
}
|
||||
server.Subscribe("direct/#", 1, callbackFn)
|
||||
```
|
||||
|
||||
#### 取消内联订阅(Inline Unsubscribe)
|
||||
如果您使用内联客户端订阅了某个主题,如果需要取消订阅。您可以使用 `server.Unsubscribe` 方法取消内联订阅:
|
||||
|
||||
```go
|
||||
server.Unsubscribe("direct/#", 1)
|
||||
```
|
||||
|
||||
### 注入数据包(Packet Injection)
|
||||
|
||||
如果你想要更多的服务端控制,或者想要设置特定的MQTT v5属性或其他属性,你可以选择指定的客户端创建自己的发布包(publish packets)。这种方法允许你将MQTT数据包(packets)直接注入到运行中的服务端,相当于服务端直接自己模拟接收到了某个客户端的数据包。
|
||||
|
||||
数据包注入(Packet Injection)可用于任何MQTT数据包,包括ping请求、订阅等。你可以获取客户端的详细信息,因此你甚至可以直接在服务端模拟某个在线的客户端,发布一个数据包。
|
||||
|
||||
大多数情况下,您可能希望使用上面描述的内联客户端(Inline Client),因为它具有独特的特权:它可以绕过所有ACL和主题验证检查,这意味着它甚至可以发布到$SYS主题。你也可以自己从头开始制定一个自己的内联客户端,它将与内置的内联客户端行为相同。
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("scheduled message"),
|
||||
})
|
||||
```
|
||||
|
||||
> MQTT数据包仍然需要满足规范的结构,所以请参考[测试用例中数据包的定义](packets/tpackets.go) 和 [MQTTv5规范](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) 以获取一些帮助。
|
||||
|
||||
具体如何使用请参考 [hooks example](examples/hooks/main.go) 。
|
||||
|
||||
|
||||
### 测试(Testing)
|
||||
#### 单元测试(Unit Tests)
|
||||
|
||||
Mochi MQTT 使用精心编写的单元测试,测试了一千多种场景,以确保每个函数都表现出我们期望的行为。您可以使用以下命令运行测试:
|
||||
```
|
||||
go run --cover ./...
|
||||
```
|
||||
|
||||
#### Paho 互操作性测试(Paho Interoperability Test)
|
||||
|
||||
您可以使用 `examples/paho/main.go` 启动服务器,然后在 _interoperability_ 文件夹中运行 `python3 client_test5.py` 来检查代理是否符合 [Paho互操作性测试](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) 的要求,包括 MQTT v5 和 v3 的测试。
|
||||
|
||||
> 请注意,关于 paho 测试套件存在一些尚未解决的问题,因此在 `paho/main.go` 示例中启用了某些兼容性模式。
|
||||
|
||||
|
||||
## 基准测试(Performance Benchmarks)
|
||||
|
||||
Mochi MQTT 的性能与其他的一些主流的mqtt中间件(如 Mosquitto、EMQX 等)不相上下。
|
||||
|
||||
基准测试是使用 [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) 在 Apple Macbook Air M2 上进行的,使用 `cmd/main.go` 默认设置。考虑到高低吞吐量的突发情况,中位数分数是最有用的。数值越高越好。
|
||||
|
||||
|
||||
> 基准测试中呈现的数值不代表真实每秒消息吞吐量。它们依赖于 mqtt-stresser 的一种不寻常的计算方法,但它们在所有代理之间是一致的。性能基准测试的结果仅供参考。这些比较都是使用默认配置进行的。
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 124,772 | 125,456 | 124,614 | 314,461 | 313,186 | 311,910 |
|
||||
| [Mosquitto v2.0.15](https://github.com/eclipse/mosquitto) | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
|
||||
| [EMQX v5.0.11](https://github.com/emqx/emqx) | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
|
||||
| [Rumqtt v0.21.0](https://github.com/bytebeamio/rumqtt) | 112,208 | 108,480 | 104,753 | 135,784 | 126,446 | 117,108 |
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 41,825 | 31,663| 23,008 | 144,058 | 65,903 | 37,618 |
|
||||
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
|
||||
| Rumqtt v0.21.0 | 42,213 | 23,153 | 20,814 | 49,465 | 36,626 | 19,283 |
|
||||
|
||||
百万消息挑战(立即向服务器发送100万条消息):
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 13,532 | 4,425 | 2,344 | 52,120 | 7,274 | 2,701 |
|
||||
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
|
||||
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
|
||||
| Rumqtt v0.21.0 | 78,972 | 5,047 | 3,804 | 4,286 | 3,249 | 2,027 |
|
||||
|
||||
> 这里还不确定EMQX是不是哪里出了问题,可能是因为 Docker 的默认配置优化不对,所以要持保留意见,因为我们确实知道它是一款可靠的软件。
|
||||
|
||||
## 贡献指南(Contribution Guidelines)
|
||||
|
||||
我们欢迎代码贡献和反馈!如果你发现了漏洞(bug)或者有任何疑问,又或者是有新的需求,请[提交给我们](https://github.com/mochi-mqtt/server/issues)。如果您提交了一个PR(pull request)请求,请尽量遵循以下准则:
|
||||
|
||||
- 在合理的情况下,尽量保持测试覆盖率。
|
||||
- 清晰地说明PR(pull request)请求的作用和原因。
|
||||
- 请不要忘记在你贡献的文件中添加 SPDX FileContributor 标签。
|
||||
|
||||
[SPDX 注释] (https://spdx.dev) 用于智能的识别每个文件的许可证、版权和贡献。如果您正在向本仓库添加一个新文件,请确保它具有以下 SPDX 头部:
|
||||
```go
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt
|
||||
// SPDX-FileContributor: Your name or alias <optional@email.address>
|
||||
|
||||
package name
|
||||
```
|
||||
|
||||
请确保为文件的每位贡献者添加一个新的SPDX-FileContributor 行。可以参考其他文件的示例。请务必记得这样做,你对这个项目的贡献是有价值且受到赞赏的 - 获得认可非常重要!
|
||||
|
||||
## 给我们星星的人数(Stargazers over time) 🥰
|
||||
[](https://starchart.cc/mochi-mqtt/server)
|
||||
|
||||
您是否在项目中使用 Mochi MQTT?[请告诉我们!](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
494
README-JP.md
Normal file
494
README-JP.md
Normal file
@@ -0,0 +1,494 @@
|
||||
# Mochi-MQTT Server
|
||||
|
||||
<p align="center">
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-mqtt/server?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-mqtt/server/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-mqtt/server/v2)
|
||||
[](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
</p>
|
||||
|
||||
[English](README.md) | [简体中文](README-CN.md) | [日本語](README-JP.md) | [Translators Wanted!](https://github.com/orgs/mochi-mqtt/discussions/310)
|
||||
|
||||
🎆 **mochi-co/mqtt は新しい mochi-mqtt organisation の一部です.** [このページをお読みください](https://github.com/orgs/mochi-mqtt/discussions/271)
|
||||
|
||||
|
||||
### Mochi-MQTTは MQTT v5 (と v3.1.1)に完全に準拠しているアプリケーションに組み込み可能なハイパフォーマンスなbroker/serverです.
|
||||
|
||||
Mochi MQTT は Goで書かれたMQTT v5に完全に[準拠](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html)しているMQTTブローカーで、IoTプロジェクトやテレメトリの開発プロジェクト向けに設計されています。 スタンドアロンのバイナリで使ったり、アプリケーションにライブラリとして組み込むことができ、プロジェクトのメンテナンス性と品質を確保できるように配慮しながら、 軽量で可能な限り速く動作するように設計されています。
|
||||
|
||||
#### MQTTとは?
|
||||
MQTT は [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT)を意味します。 Pub/Sub型のシンプルで軽量なメッセージプロトコルで、低帯域、高遅延、不安定なネットワーク下での制約を考慮して設計されています([MQTTについて詳しくはこちら](https://mqtt.org/faq))。 Mochi MQTTはMQTTプロトコルv5.0.0に完全準拠した実装をしています。
|
||||
|
||||
#### Mochi-MQTTのもつ機能
|
||||
|
||||
- MQTTv5への完全な準拠とMQTT v3.1.1 および v3.0.0 との互換性:
|
||||
- MQTT v5で拡張されたユーザープロパティ
|
||||
- トピック・エイリアス
|
||||
- 共有サブスクリプション
|
||||
- サブスクリプションオプションとサブスクリプションID
|
||||
- メッセージの有効期限
|
||||
- クライアントセッション
|
||||
- 送受信QoSフロー制御クォータ
|
||||
- サーバサイド切断と認証パケット
|
||||
- Will遅延間隔
|
||||
- 上記に加えてQoS(0,1,2)、$SYSトピック、retain機能などすべてのMQTT v1の特徴を持ちます
|
||||
- Developer-centric:
|
||||
- 開発者が制御できるように、ほとんどのコアブローカーのコードをエクスポートにしてアクセスできるようにしました。
|
||||
- フル機能で柔軟なフックベースのインターフェイスにすることで簡単に'プラグイン'を開発できるようにしました。
|
||||
- 特別なインラインクライアントを利用することでパケットインジェクションを行うか、既存のクライアントとしてマスカレードすることができます。
|
||||
- パフォーマンスと安定性:
|
||||
- 古典的なツリーベースのトピックサブスクリプションモデル
|
||||
- クライアント固有に書き込みバッファーをもたせることにより、読み込みの遅さや不規則なクライアントの挙動の問題を回避しています。
|
||||
- MQTT v5 and MQTT v3のすべての[Paho互換性テスト](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability)をpassしています。
|
||||
- 慎重に検討された多くのユニットテストシナリオでテストされています。
|
||||
- TCP, Websocket (SSL/TLSを含む), $SYSのダッシュボードリスナー
|
||||
- フックを利用した保存機能としてRedis, Badger, Boltを使うことができます(自作のHookも可能です)。
|
||||
- フックを利用したルールベース認証機能とアクセス制御リストLedgerを使うことができます(自作のHookも可能です)。
|
||||
|
||||
### 互換性に関する注意事項
|
||||
MQTTv5とそれ以前との互換性から、サーバーはv5とv3両方のクライアントを受け入れることができますが、v5とv3のクライアントが接続された場合はv5でクライアント向けの特徴と機能はv3クライアントにダウングレードされます(ユーザープロパティなど)。
|
||||
MQTT v3.0.0 と v3.1.1 のサポートはハイブリッド互換性があるとみなされます。それはv3と仕様に制限されていない場合、例えば、送信メッセージ、保持メッセージの有効期限とQoSフロー制御制限などについては、よりモダンで安全なv5の動作が使用されます
|
||||
|
||||
#### リリースされる時期について
|
||||
クリティカルなイシュー出ない限り、新しいリリースがされるのは週末です。
|
||||
|
||||
## Roadmap
|
||||
- 新しい特徴やイベントフックのリクエストは [open an issue](https://github.com/mochi-mqtt/server/issues) へ!
|
||||
- クラスターのサポート
|
||||
- メトリックスサポートの強化
|
||||
- ファイルベースの設定(Dockerイメージのサポート)
|
||||
|
||||
## Quick Start
|
||||
### GoでのBrokerの動かし方
|
||||
Mochi MQTTはスタンドアロンのブローカーとして使うことができます。単純にこのレポジトリーをチェックアウトして、[cmd/main.go](cmd/main.go) を起動すると内部の [cmd](cmd) フォルダのエントリポイントにしてtcp (:1883), websocket (:1882), dashboard (:8080)のポートを外部にEXPOSEします。
|
||||
|
||||
```
|
||||
cd cmd
|
||||
go build -o mqtt && ./mqtt
|
||||
```
|
||||
|
||||
### Dockerで利用する
|
||||
Dockerレポジトリの [official Mochi MQTT image](https://hub.docker.com/r/mochimqtt/server) から Pullして起動することができます。
|
||||
|
||||
```sh
|
||||
docker pull mochimqtt/server
|
||||
or
|
||||
docker run mochimqtt/server
|
||||
```
|
||||
|
||||
これは実装途中です。[file-based configuration](https://github.com/orgs/mochi-mqtt/projects/2) は、この実装をよりよくサポートするために開発中です。
|
||||
より実質的なdockerのサポートが議論されています。_Docker環境で使っている方は是非この議論に参加してください。_ [ここ](https://github.com/orgs/mochi-mqtt/discussions/281#discussion-5544545) や [ここ](https://github.com/orgs/mochi-mqtt/discussions/209)。
|
||||
|
||||
[cmd/main.go](cmd/main.go)の Websocket, TCP, Statsサーバを実行するために、シンプルなDockerfileが提供されます。
|
||||
|
||||
|
||||
```sh
|
||||
docker build -t mochi:latest .
|
||||
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
```
|
||||
|
||||
## Mochi MQTTを使って開発するには
|
||||
### パッケージをインポート
|
||||
Mochi MQTTをパッケージとしてインポートするにはほんの数行のコードで始めることができます。
|
||||
``` go
|
||||
import (
|
||||
"log"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create signals channel to run server until interrupted
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Create the new MQTT Server.
|
||||
server := mqtt.New(nil)
|
||||
|
||||
// Allow all connections.
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Create a TCP listener on a standard port.
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Run server until interrupted
|
||||
<-done
|
||||
|
||||
// Cleanup
|
||||
}
|
||||
```
|
||||
|
||||
ブローカーの動作例は [examples](examples)フォルダにあります。
|
||||
|
||||
#### Network Listeners
|
||||
サーバは様々なプロトコルのコネクションのリスナーに対応しています。現在の対応リスナーは、
|
||||
|
||||
| Listener | Usage |
|
||||
|------------------------------|----------------------------------------------------------------------------------------------|
|
||||
| listeners.NewTCP | TCPリスナー |
|
||||
| listeners.NewUnixSock | Unixソケットリスナー |
|
||||
| listeners.NewNet | net.Listenerリスナー |
|
||||
| listeners.NewWebsocket | Websocketリスナー |
|
||||
| listeners.NewHTTPStats | HTTP $SYSダッシュボード |
|
||||
| listeners.NewHTTPHealthCheck | ヘルスチェック応答を提供するためのHTTPヘルスチェックリスナー(クラウドインフラ) |
|
||||
|
||||
> 新しいリスナーを開発するためには `listeners.Listener` を使ってください。使ったら是非教えてください!
|
||||
|
||||
TLSを設定するには`*listeners.Config`を渡すことができます。
|
||||
|
||||
[examples](examples) フォルダと [cmd/main.go](cmd/main.go)に使用例があります。
|
||||
|
||||
|
||||
## 設定できるオプションと機能
|
||||
たくさんのオプションが利用可能です。サーバーの動作を変更したり、特定の機能へのアクセスを制限することができます。
|
||||
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
Capabilities: mqtt.Capabilities{
|
||||
MaximumSessionExpiryInterval: 3600,
|
||||
Compatibilities: mqtt.Compatibilities{
|
||||
ObscureNotAuthorized: true,
|
||||
},
|
||||
},
|
||||
ClientNetWriteBufferSize: 4096,
|
||||
ClientNetReadBufferSize: 4096,
|
||||
SysTopicResendInterval: 10,
|
||||
InlineClient: false,
|
||||
})
|
||||
```
|
||||
|
||||
mqtt.Options、mqtt.Capabilities、mqtt.Compatibilitiesの構造体はオプションの理解に役立ちます。
|
||||
必要に応じて`ClientNetWriteBufferSize`と`ClientNetReadBufferSize`はクライアントの使用するメモリに合わせて設定できます。
|
||||
|
||||
### デフォルト設定に関する注意事項
|
||||
|
||||
いくつかのデフォルトの設定を決める際にいくつかの決定がなされましたのでここに記しておきます:
|
||||
- デフォルトとして、敵対的なネットワーク上のDoSアタックにさらされるのを防ぐために `server.Options.Capabilities.MaximumMessageExpiryInterval`は86400 (24時間)に、とセットされています。有効期限を無限にすると、保持、送信メッセージが無限に蓄積されるからです。もし信頼できる環境であったり、より大きな保存期間が可能であれば、この設定はオーバーライドできます(`0` を設定すると有効期限はなくなります。)
|
||||
|
||||
## Event Hooks
|
||||
ユニバーサルイベントフックシステムは、開発者にサーバとクライアントの様々なライフサイクルをフックすることができ、ブローカーの機能を追加/変更することができます。それらのユニバーサルフックは認証、永続ストレージ、デバッグツールなど、あらゆるものに使用されています。
|
||||
フックは複数重ねることができ、サーバに複数のフックを設定することができます。それらは追加した順番に動作します。いくつかのフックは値を変えて、その値は動作コードに返される前にあとに続くフックに渡されます。
|
||||
|
||||
|
||||
| Type | Import | Info |
|
||||
|----------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . AllowHook](hooks/auth/allow_all.go) | すべてのトピックに対しての読み書きをすべてのクライアントに対して許可します。 |
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . Auth](hooks/auth/auth.go) | ルールベースのアクセスコントロール台帳です。 |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | [BoltDB](https://dbdb.io/db/boltdb) を使った永続ストレージ (非推奨). |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/badger](hooks/storage/badger/badger.go) | [BadgerDB](https://github.com/dgraph-io/badger)を使った永続ストレージ |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/redis](hooks/storage/redis/redis.go) | [Redis](https://redis.io)を使った永続ストレージ |
|
||||
| Debugging | [mochi-mqtt/server/hooks/debug](hooks/debug/debug.go) | パケットフローを可視化するデバッグ用のフック |
|
||||
|
||||
たくさんの内部関数が開発者に公開されています、なので、上記の例を使って自分でフックを作ることができます。もし作ったら是非[Open an issue](https://github.com/mochi-mqtt/server/issues)に投稿して教えてください!
|
||||
|
||||
### アクセスコントロール
|
||||
#### Allow Hook
|
||||
デフォルトで、Mochi MQTTはアクセスコントロールルールにDENY-ALLを使用しています。コネクションを許可するためには、アクセスコントロールフックを上書きする必要があります。一番単純なのは`auth.AllowAll`フックで、ALLOW-ALLルールがすべてのコネクション、サブスクリプション、パブリッシュに適用されます。使い方は下記のようにするだけです:
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
```
|
||||
|
||||
> もしインターネットや信頼できないネットワークにさらされる場合は行わないでください。これは開発・テスト・デバッグ用途のみであるべきです。
|
||||
|
||||
#### Auth Ledger
|
||||
Auth Ledgerは構造体で定義したアクセスルールの洗練された仕組みを提供します。Auth Ledgerルール2つの形式から成ります、認証ルール(コネクション)とACLルール(パブリッシュ、サブスクライブ)です。
|
||||
|
||||
認証ルールは4つのクライテリアとアサーションフラグがあります:
|
||||
| Criteria | Usage |
|
||||
| -- | -- |
|
||||
| Client | 接続クライアントのID |
|
||||
| Username | 接続クライアントのユーザー名 |
|
||||
| Password | 接続クライアントのパスワード |
|
||||
| Remote | クライアントのリモートアドレスもしくはIP |
|
||||
| Allow | true(このユーザーを許可する)もしくはfalse(このユーザを拒否する) |
|
||||
|
||||
アクセスコントロールルールは3つのクライテリアとフィルターマッチがあります:
|
||||
| Criteria | Usage |
|
||||
| -- | -- |
|
||||
| Client | 接続クライアントのID |
|
||||
| Username | 接続クライアントのユーザー名 |
|
||||
| Remote | クライアントのリモートアドレスもしくはIP |
|
||||
| Filters | 合致するフィルターの配列 |
|
||||
|
||||
ルールはインデックス順(0,1,2,3)に処理され、はじめに合致したルールが適用されます。 [hooks/auth/ledger.go](hooks/auth/ledger.go) の構造体を見てください。
|
||||
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Ledger: &auth.Ledger{
|
||||
Auth: auth.AuthRules{ // Auth disallows all by default
|
||||
{Username: "peach", Password: "password1", Allow: true},
|
||||
{Username: "melon", Password: "password2", Allow: true},
|
||||
{Remote: "127.0.0.1:*", Allow: true},
|
||||
{Remote: "localhost:*", Allow: true},
|
||||
},
|
||||
ACL: auth.ACLRules{ // ACL allows all by default
|
||||
{Remote: "127.0.0.1:*"}, // local superuser allow all
|
||||
{
|
||||
// user melon can read and write to their own topic
|
||||
Username: "melon", Filters: auth.Filters{
|
||||
"melon/#": auth.ReadWrite,
|
||||
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
|
||||
},
|
||||
},
|
||||
{
|
||||
// Otherwise, no clients have publishing permissions
|
||||
Filters: auth.Filters{
|
||||
"#": auth.ReadOnly,
|
||||
"updates/#": auth.Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
ledgeはデータフィールドを使用してJSONもしくはYAML形式で保存したものを使用することもできます。
|
||||
```go
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // build ledger from byte slice: yaml or json
|
||||
})
|
||||
```
|
||||
より詳しくは[examples/auth/encoded/main.go](examples/auth/encoded/main.go)を見てください。
|
||||
|
||||
### 永続ストレージ
|
||||
#### Redis
|
||||
ブローカーに永続性を提供する基本的な Redis ストレージフックが利用可能です。他のフックと同じ方法で、いくつかのオプションを使用してサーバーに追加できます。それはフック内部で github.com/go-redis/redis/v8 を使用し、Optionsの値で詳しい設定を行うことができます。
|
||||
```go
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
Addr: "localhost:6379", // default redis address
|
||||
Password: "", // your password
|
||||
DB: 0, // your redis db
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
Redisフックがどのように動くか、どのように使用するかについての詳しくは、[examples/persistence/redis/main.go](examples/persistence/redis/main.go) か [hooks/storage/redis](hooks/storage/redis) のソースコードを見てください。
|
||||
|
||||
#### Badger DB
|
||||
もしファイルベースのストレージのほうが適しているのであれば、BadgerDBストレージも使用することができます。それもまた、他のフックと同様に追加、設定することができます(オプションは若干少ないです)。
|
||||
|
||||
```go
|
||||
err := server.AddHook(new(badger.Hook), &badger.Options{
|
||||
Path: badgerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
badgerフックがどのように動くか、どのように使用するかについての詳しくは、[examples/persistence/badger/main.go](examples/persistence/badger/main.go) か [hooks/storage/badger](hooks/storage/badger) のソースコードを見てください。
|
||||
|
||||
BoltDBフックはBadgerに代わって非推奨となりましたが、もし必要ならば [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go)をチェックしてください。
|
||||
|
||||
## イベントフックを利用した開発
|
||||
|
||||
ブローカーとクライアントのライフサイクルに関わるたくさんのフックが利用できます。
|
||||
そのすべてのフックと`mqtt.Hook`インターフェイスの関数シグネチャは[hooks.go](hooks.go)に記載されています。
|
||||
|
||||
> もっと柔軟なイベントフックはOnPacketRead、OnPacketEncodeとOnPacketSentです。それらは、すべての流入パケットと流出パケットをコントロール及び変更に使用されるフックです。
|
||||
|
||||
|
||||
| Function | Usage |
|
||||
|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| OnStarted | サーバーが正常にスタートした際に呼ばれます。 |
|
||||
| OnStopped | サーバーが正常に終了した際に呼ばれます。 |
|
||||
| OnConnectAuthenticate | ユーザーがサーバと認証を試みた際に呼ばれます。このメソッドはサーバーへのアクセス許可もしくは拒否するためには必ず使用する必要があります(hooks/auth/allow_all or basicを見てください)。これは、データベースにユーザーが存在するか照合してチェックするカスタムフックに利用できます。許可する場合はtrueを返す実装をします。|
|
||||
| OnACLCheck | ユーザーがあるトピックフィルタにpublishかsubscribeした際に呼ばれます。上と同様です |
|
||||
| OnSysInfoTick | $SYSトピック値がpublishされた場合に呼ばれます。 |
|
||||
| OnConnect | 新しいクライアントが接続した際によばれます、エラーかパケットコードを返して切断する場合があります。 |
|
||||
| OnSessionEstablish | 新しいクライアントが接続された後すぐ、セッションが確立されてCONNACKが送信される前に呼ばれます。 |
|
||||
| OnSessionEstablished | 新しいクライアントがセッションを確立した際(OnConnectの後)に呼ばれます。 |
|
||||
| OnDisconnect | クライアントが何らかの理由で切断された場合に呼ばれます。 |
|
||||
| OnAuthPacket | 認証パケットを受け取ったときに呼ばれます。これは開発者にmqtt v5の認証パケットを取り扱う仕組みを作成すること意図しています。パケットを変更することができます。 |
|
||||
| OnPacketRead | クライアントからパケットを受け取った際に呼ばれます。パケットを変更することができます。 |
|
||||
| OnPacketEncode | エンコードされたパケットがクライアントに送信する直前に呼ばれます。パケットを変更することができます。 |
|
||||
| OnPacketSent | クライアントにパケットが送信された際に呼ばれます。 |
|
||||
| OnPacketProcessed | パケットが届いてブローカーが正しく処理できた場合に呼ばれます。 |
|
||||
| OnSubscribe | クライアントが1つ以上のフィルタをsubscribeした場合に呼ばれます。パケットの変更ができます。 |
|
||||
| OnSubscribed | クライアントが1つ以上のフィルタをsubscribeに成功した場合に呼ばれます。 |
|
||||
| OnSelectSubscribers | サブスクライバーがトピックに収集されたとき、共有サブスクライバーが選択される前に呼ばれる。受信者は変更可能。 |
|
||||
| OnUnsubscribe | 1つ以上のあんサブスクライブが呼ばれた場合。パケットの変更は可能。 |
|
||||
| OnUnsubscribed | クライアントが正常に1つ以上のトピックフィルタをサブスクライブ解除した場合。 |
|
||||
| OnPublish | クライアントがメッセージをパブリッシュした場合。パケットの変更は可能。 |
|
||||
| OnPublished | クライアントがサブスクライバーにメッセージをパブリッシュし終わった場合。 |
|
||||
| OnPublishDropped | あるクライアントが反応に時間がかかった場合等のようにクライアントに到達する前にメッセージが失われた場合に呼ばれる。 |
|
||||
| OnRetainMessage | パブリッシュされたメッセージが保持された場合に呼ばれる。 |
|
||||
| OnRetainPublished | 保持されたメッセージがクライアントに到達した場合に呼ばれる。 |
|
||||
| OnQosPublish | QoSが1以上のパケットがサブスクライバーに発行された場合。 |
|
||||
| OnQosComplete | そのメッセージQoSフローが完了した場合に呼ばれる。 |
|
||||
| OnQosDropped | インフライトメッセージが完了前に期限切れになった場合に呼ばれる。 |
|
||||
| OnPacketIDExhausted | クライアントがパケットに割り当てるIDが枯渇した場合に呼ばれる。 |
|
||||
| OnWill | クライアントが切断し、WILLメッセージを発行しようとした場合に呼ばれる。パケットの変更が可能。 |
|
||||
| OnWillSent | LWTメッセージが切断されたクライアントから発行された場合に呼ばれる |
|
||||
| OnClientExpired | クライアントセッションが期限切れで削除するべき場合に呼ばれる。 |
|
||||
| OnRetainedExpired | 保持メッセージが期限切れで削除すべき場合に呼ばれる。 |
|
||||
| StoredClients | クライアントを返す。例えば永続ストレージから。 |
|
||||
| StoredSubscriptions | クライアントのサブスクリプションを返す。例えば永続ストレージから。 |
|
||||
| StoredInflightMessages | インフライトメッセージを返す。例えば永続ストレージから。 |
|
||||
| StoredRetainedMessages | 保持されたメッセージを返す。例えば永続ストレージから。 |
|
||||
| StoredSysInfo | システム情報の値を返す。例えば永続ストレージから。 |
|
||||
|
||||
もし永続ストレージフックを作成しようとしているのであれば、すでに存在する永続的なフックを見てインスピレーションとどのようなパターンがあるか見てみてください。もし認証フックを作成しようとしているのであれば、`OnACLCheck`と`OnConnectAuthenticate`が役立つでしょう。
|
||||
|
||||
### Inline Client (v2.4.0+)
|
||||
トピックに対して埋め込まれたコードから直接サブスクライブとパブリッシュできます。そうするには`inline client`機能を使うことができます。インラインクライアント機能はサーバの一部として組み込まれているクライアントでサーバーのオプションとしてEnableにできます。
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true,
|
||||
})
|
||||
```
|
||||
Enableにすると、`server.Publish`, `server.Subscribe`, `server.Unsubscribe`のメソッドを利用できて、ブローカーから直接メッセージを送受信できます。
|
||||
> 実際の使用例は[direct examples](examples/direct/main.go)を見てください。
|
||||
|
||||
#### Inline Publish
|
||||
組み込まれたアプリケーションからメッセージをパブリッシュするには`server.Publish(topic string, payload []byte, retain bool, qos byte) error`メソッドを利用します。
|
||||
|
||||
```go
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
```
|
||||
> このケースでのQoSはサブスクライバーに設定できる上限でしか使用されません。これはMQTTv5の仕様に従っています。
|
||||
|
||||
#### Inline Subscribe
|
||||
組み込まれたアプリケーション内部からトピックフィルタをサブスクライブするには、`server.Subscribe(filter string, subscriptionId int, handler InlineSubFn) error`メソッドがコールバックも含めて使用できます。
|
||||
インラインサブスクリプションではQoS0のみが適用されます。もし複数のコールバックを同じフィルタに設定したい場合は、MQTTv5の`subscriptionId`のプロパティがその区別に使用できます。
|
||||
|
||||
```go
|
||||
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||
}
|
||||
server.Subscribe("direct/#", 1, callbackFn)
|
||||
```
|
||||
|
||||
#### Inline Unsubscribe
|
||||
インラインクライアントでサブスクリプション解除をしたい場合は、`server.Unsubscribe(filter string, subscriptionId int) error` メソッドで行うことができます。
|
||||
|
||||
```go
|
||||
server.Unsubscribe("direct/#", 1)
|
||||
```
|
||||
|
||||
### Packet Injection
|
||||
もし、より制御したい場合や、特定のMQTTv5のプロパティやその他の値をセットしたい場合は、クライアントからのパブリッシュパケットを自ら作成することができます。この方法は単なるパブリッシュではなく、MQTTパケットをまるで特定のクライアントから受け取ったかのようにランタイムに直接インジェクションすることができます。
|
||||
|
||||
このパケットインジェクションは例えばPING ReqやサブスクリプションなどのどんなMQTTパケットでも使用できます。そしてクライアントの構造体とメソッドはエクスポートされているので、(もし、非常にカスタマイズ性の高い要求がある場合には)まるで接続されたクライアントに代わってパケットをインジェクションすることさえできます。
|
||||
|
||||
たいていの場合は上記のインラインクライアントを使用するのが良いでしょう、それはACLとトピックバリデーションをバイパスできる特権があるからです。これは$SYSトピックにさえパブリッシュできることも意味します。ビルトインのクライアントと同様に振る舞うインラインクライアントを作成できます。
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("scheduled message"),
|
||||
})
|
||||
```
|
||||
|
||||
> MQTTのパケットは正しく構成する必要があり、なので[the test packets catalogue](packets/tpackets.go)と[MQTTv5 Specification](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html)を参照してください。
|
||||
|
||||
この機能の動作を確認するには[hooks example](examples/hooks/main.go) を見てください。
|
||||
|
||||
|
||||
### Testing
|
||||
#### ユニットテスト
|
||||
それぞれの関数が期待通りの動作をするように考えられてMochi MQTTテストが作成されています。テストを走らせるには:
|
||||
```
|
||||
go run --cover ./...
|
||||
```
|
||||
|
||||
#### Paho相互運用性テスト
|
||||
`examples/paho/main.go`を使用してブローカーを起動し、_interoperability_フォルダの`python3 client_test5.py`のmqttv5とv3のテストを実行することで、[Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability)を確認することができます。
|
||||
|
||||
> pahoスイートには現在は何個かの偽陰性に関わるissueがあるので、`paho/main.go`の例ではいくつかの互換性モードがオンになっていることに注意してください。
|
||||
|
||||
|
||||
|
||||
## ベンチマーク
|
||||
Mochi MQTTのパフォーマンスはMosquitto、EMQX、その他などの有名なブローカーに匹敵します。
|
||||
|
||||
ベンチマークはApple Macbook Air M2上で[MQTT-Stresser](https://github.com/inovex/mqtt-stresser)、セッティングとして`cmd/main.go`のデフォルト設定を使用しています。高スループットと低スループットのバーストを考慮すると、中央値のスコアが最も信頼できます。この値は高いほど良いです。
|
||||
|
||||
> ベンチマークの値は1秒あたりのメッセージ数のスループットのそのものを表しているわけではありません。これは、mqtt-stresserによる固有の計算に依存するものではありますが、すべてのブローカーに渡って一貫性のある値として利用しています。
|
||||
> ベンチマークは一般的なパフォーマンス予測ガイドラインとしてのみ提供されます。比較はそのまま使用したデフォルトの設定値で実行しています。
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 124,772 | 125,456 | 124,614 | 314,461 | 313,186 | 311,910 |
|
||||
| [Mosquitto v2.0.15](https://github.com/eclipse/mosquitto) | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
|
||||
| [EMQX v5.0.11](https://github.com/emqx/emqx) | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
|
||||
| [Rumqtt v0.21.0](https://github.com/bytebeamio/rumqtt) | 112,208 | 108,480 | 104,753 | 135,784 | 126,446 | 117,108 |
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 41,825 | 31,663| 23,008 | 144,058 | 65,903 | 37,618 |
|
||||
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
|
||||
| Rumqtt v0.21.0 | 42,213 | 23,153 | 20,814 | 49,465 | 36,626 | 19,283 |
|
||||
|
||||
100万メッセージ試験 (100 万メッセージを一斉にサーバーに送信します):
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.10 | 13,532 | 4,425 | 2,344 | 52,120 | 7,274 | 2,701 |
|
||||
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
|
||||
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
|
||||
| Rumqtt v0.21.0 | 78,972 | 5,047 | 3,804 | 4,286 | 3,249 | 2,027 |
|
||||
|
||||
> EMQXのここでの結果は何が起きているのかわかりませんが、おそらくDockerのそのままの設定が最適ではなかったのでしょう、なので、この結果はソフトウェアのひとつの側面にしか過ぎないと捉えてください。
|
||||
|
||||
|
||||
## Contribution Guidelines
|
||||
コントリビューションとフィードバックは両方とも歓迎していますでバグを報告したり、質問したり、新機能のリクエストをしてください。もしプルリクエストするならば下記のガイドラインに従うようにしてください。
|
||||
- 合理的で可能な限りテストカバレッジを維持してください
|
||||
- なぜPRをしたのかとそのPRの内容について明確にしてください。
|
||||
- 有意義な貢献をした場合はSPDX FileContributorタグをファイルにつけてください。
|
||||
|
||||
[SPDX Annotations](https://spdx.dev)はそのライセンス、著作権表記、コントリビューターについて明確するのために、それぞれのファイルに機械可読な形式で記されています。もし、新しいファイルをレポジトリに追加した場合は、下記のようなSPDXヘッダーを付与していることを確かめてください。
|
||||
|
||||
```go
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2023 mochi-mqtt
|
||||
// SPDX-FileContributor: Your name or alias <optional@email.address>
|
||||
|
||||
package name
|
||||
```
|
||||
|
||||
ファイルにそれぞれのコントリビューターの`SPDX-FileContributor`が追加されていることを確認してください、他のファイルを参考にしてください。あなたのこのプロジェクトへのコントリビュートは価値があり、高く評価されます!
|
||||
|
||||
|
||||
## Stargazers over time 🥰
|
||||
[](https://starchart.cc/mochi-mqtt/server)
|
||||
Mochi MQTTをプロジェクトで使用していますか? [是非私達に教えてください!](https://github.com/mochi-mqtt/server/issues)
|
||||
|
||||
132
README.md
132
README.md
@@ -1,6 +1,7 @@
|
||||
# Mochi-MQTT Server
|
||||
|
||||
<p align="center">
|
||||
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-mqtt/server?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-mqtt/server/v2)
|
||||
@@ -9,23 +10,19 @@
|
||||
|
||||
</p>
|
||||
|
||||
# Mochi MQTT Broker
|
||||
## The fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker server
|
||||
[English](README.md) | [简体中文](README-CN.md) | [日本語](README-JP.md) | [Translators Wanted!](https://github.com/orgs/mochi-mqtt/discussions/310)
|
||||
|
||||
🎆 **mochi-co/mqtt is now part of the new mochi-mqtt organisation.** [Read about this announcement here.](https://github.com/orgs/mochi-mqtt/discussions/271)
|
||||
|
||||
|
||||
### Mochi-MQTT is a fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker/server
|
||||
|
||||
Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) MQTT v5 broker server written in Go, designed for the development of telemetry and internet-of-things projects. The server can be used either as a standalone binary or embedded as a library in your own applications, and has been designed to be as lightweight and fast as possible, with great care taken to ensure the quality and maintainability of the project.
|
||||
|
||||
### What is MQTT?
|
||||
#### What is MQTT?
|
||||
MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol.
|
||||
|
||||
### When is this repo updated?
|
||||
Unless it's a critical issue, new releases typically go out over the weekend.
|
||||
|
||||
## What's new in Version 2?
|
||||
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
|
||||
|
||||
Don't forget to use the new v2 import paths:
|
||||
```go
|
||||
import "github.com/mochi-mqtt/server/v2"
|
||||
```
|
||||
#### Mochi-MQTT Features
|
||||
|
||||
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
|
||||
- User and MQTTv5 Packet Properties
|
||||
@@ -51,13 +48,14 @@ import "github.com/mochi-mqtt/server/v2"
|
||||
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
|
||||
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
|
||||
|
||||
> There is no upgrade path from v1.0.0. Please review the documentation and this readme to get a sense of the changes required (e.g. the v1 events system, auth, and persistence have all been replaced with the new hooks system).
|
||||
|
||||
### Compatibility Notes
|
||||
Because of the overlap between the v5 specification and previous versions of mqtt, the server can accept both v5 and v3 clients, but note that in cases where both v5 an v3 clients are connected, properties and features provided for v5 clients will be downgraded for v3 clients (such as user properties).
|
||||
|
||||
Support for MQTT v3.0.0 and v3.1.1 is considered hybrid-compatibility. Where not specifically restricted in the v3 specification, more modern and safety-first v5 behaviours are used instead - such as expiry for inflight and retained messages, and clients - and quality-of-service flow control limits.
|
||||
|
||||
#### When is this repo updated?
|
||||
Unless it's a critical issue, new releases typically go out over the weekend.
|
||||
|
||||
## Roadmap
|
||||
- Please [open an issue](https://github.com/mochi-mqtt/server/issues) to request new features or event hooks!
|
||||
- Cluster support.
|
||||
@@ -74,6 +72,16 @@ go build -o mqtt && ./mqtt
|
||||
```
|
||||
|
||||
### Using Docker
|
||||
You can now pull and run the [official Mochi MQTT image](https://hub.docker.com/r/mochimqtt/server) from our Docker repo:
|
||||
|
||||
```sh
|
||||
docker pull mochimqtt/server
|
||||
or
|
||||
docker run mochimqtt/server
|
||||
```
|
||||
|
||||
This is a work in progress, and a [file-based configuration](https://github.com/orgs/mochi-mqtt/projects/2) is being developed to better support this implementation. _More substantial docker support is being discussed [here](https://github.com/orgs/mochi-mqtt/discussions/281#discussion-5544545) and [here](https://github.com/orgs/mochi-mqtt/discussions/209). Please join the discussion if you use Mochi-MQTT in this environment._
|
||||
|
||||
A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server:
|
||||
|
||||
```sh
|
||||
@@ -88,12 +96,21 @@ Importing Mochi MQTT as a package requires just a few lines of code to get start
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create signals channel to run server until interrupted
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Create the new MQTT Server.
|
||||
server := mqtt.New(nil)
|
||||
|
||||
@@ -107,10 +124,18 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Run server until interrupted
|
||||
<-done
|
||||
|
||||
// Cleanup
|
||||
}
|
||||
```
|
||||
|
||||
@@ -134,7 +159,8 @@ A `*listeners.Config` may be passed to configure TLS.
|
||||
|
||||
Examples of usage can be found in the [examples](examples) folder or [cmd/main.go](cmd/main.go).
|
||||
|
||||
### Server Options and Capabilities
|
||||
|
||||
## Server Options and Capabilities
|
||||
A number of configurable options are available which can be used to alter the behaviour or restrict access to certain features in the server.
|
||||
|
||||
```go
|
||||
@@ -148,25 +174,31 @@ server := mqtt.New(&mqtt.Options{
|
||||
ClientNetWriteBufferSize: 4096,
|
||||
ClientNetReadBufferSize: 4096,
|
||||
SysTopicResendInterval: 10,
|
||||
InlineClient: false,
|
||||
})
|
||||
```
|
||||
|
||||
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
|
||||
|
||||
### Default Configuration Notes
|
||||
|
||||
Some choices were made when deciding the default configuration that need to be mentioned here:
|
||||
|
||||
- By default, the value of `server.Options.Capabilities.MaximumMessageExpiryInterval` is set to 86400 (24 hours), in order to prevent exposing the broker to DOS attacks on hostile networks when using the out-of-the-box configuration (as an infinite expiry would allow an infinite number of retained/inflight messages to accumulate). If you are operating in a trusted environment, or you have capacity for a larger retention period, you may wish to override this (set to `0` for no expiry).
|
||||
|
||||
## Event Hooks
|
||||
A universal event hooks system allows developers to hook into various parts of the server and client life cycle to add and modify functionality of the broker. These universal hooks are used to provide everything from authentication, persistent storage, to debugging tools.
|
||||
|
||||
Hooks are stackable - you can add multiple hooks to a server, and they will be run in the order they were added. Some hooks modify values, and these modified values will be passed to the subsequent hooks before being returned to the runtime code.
|
||||
|
||||
| Type | Import | Info |
|
||||
| -- | -- | -- |
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
|
||||
| Debugging | [mochi-mqtt/server/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
|
||||
| Type | Import | Info |
|
||||
|----------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
|
||||
| Access Control | [mochi-mqtt/server/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
|
||||
| Persistence | [mochi-mqtt/server/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
|
||||
| Debugging | [mochi-mqtt/server/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
|
||||
|
||||
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-mqtt/server/issues) and let everyone know!
|
||||
|
||||
@@ -236,7 +268,7 @@ err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
|
||||
The ledger can also be stored as JSON or YAML and loaded using the Data field:
|
||||
```go
|
||||
err = server.AddHook(new(auth.Hook), &auth.Options{
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // build ledger from byte slice: yaml or json
|
||||
})
|
||||
```
|
||||
@@ -287,7 +319,7 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
|
||||
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
|
||||
| OnSysInfoTick | Called when the $SYS topic values are published out. |
|
||||
| OnConnect | Called when a new client connects, may return an error or packet code to halt the client connection process. |
|
||||
| OnSessionEstablish | Called immediately after a new client connects and authenticates and immediately before the session is established and CONNACK is sent.
|
||||
| OnSessionEstablish | Called immediately after a new client connects and authenticates and immediately before the session is established and CONNACK is sent. |
|
||||
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
|
||||
| OnDisconnect | Called when a client is disconnected for any reason. |
|
||||
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
|
||||
@@ -321,8 +353,18 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
|
||||
|
||||
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
|
||||
|
||||
### Inline Client (v2.4.0+)
|
||||
It's now possible to subscribe and publish to topics directly from the embedding code, by using the `inline client` feature. The Inline Client is an embedded client which operates as part of the server, and can be enabled in the server options:
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true,
|
||||
})
|
||||
```
|
||||
Once enabled, you will be able to use the `server.Publish`, `server.Subscribe`, and `server.Unsubscribe` methods to issue and received messages from broker-adjacent code.
|
||||
|
||||
### Direct Publish
|
||||
> See [direct examples](examples/direct/main.go) for real-life usage examples.
|
||||
|
||||
#### Inline Publish
|
||||
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
|
||||
|
||||
```go
|
||||
@@ -330,11 +372,30 @@ err := server.Publish("direct/publish", []byte("packet scheduled message"), fals
|
||||
```
|
||||
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
|
||||
|
||||
#### Inline Subscribe
|
||||
To subscribe to a topic filter from within the embedding application, you can use the `server.Subscribe(filter string, subscriptionId int, handler InlineSubFn) error` method with a callback function. Note that only QoS 0 is supported for inline subscriptions. If you wish to have multiple callbacks for the same filter, you can use the MQTTv5 `subscriptionId` property to differentiate.
|
||||
|
||||
```go
|
||||
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||
}
|
||||
server.Subscribe("direct/#", 1, callbackFn)
|
||||
```
|
||||
|
||||
#### Inline Unsubscribe
|
||||
You may wish to unsubscribe if you have subscribed to a filter using the inline client. You can do this easily with the `server.Unsubscribe(filter string, subscriptionId int) error` method:
|
||||
|
||||
```go
|
||||
server.Unsubscribe("direct/#", 1)
|
||||
```
|
||||
|
||||
### Packet Injection
|
||||
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
|
||||
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client.
|
||||
|
||||
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
|
||||
|
||||
Most of the time you'll want to use the Inline Client described above, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics. In this case, you can create an inline client from scratch which will behave the same as the built-in inline client.
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
@@ -351,7 +412,6 @@ server.InjectPacket(cl, packets.Packet{
|
||||
See the [hooks example](examples/hooks/main.go) to see this feature in action.
|
||||
|
||||
|
||||
|
||||
### Testing
|
||||
#### Unit Tests
|
||||
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
|
||||
@@ -405,7 +465,7 @@ Million Message Challenge (hit the server with 1 million messages immediately):
|
||||
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-mqtt/server/issues) to report a bug, ask a question, or make a feature request. If you open a pull request, please try to follow the following guidelines:
|
||||
- Try to maintain test coverage where reasonably possible.
|
||||
- Clearly state what the PR does and why.
|
||||
- Remember to add your SPDX FileContributor tag to files where you have made a meaningful contribution.
|
||||
- Please remember to add your SPDX FileContributor tag to files where you have made a meaningful contribution.
|
||||
|
||||
[SPDX Annotations](https://spdx.dev) are used to clearly indicate the license, copyright, and contributions of each file in a machine-readable format. If you are adding a new file to the repository, please ensure it has the following SPDX header:
|
||||
```go
|
||||
|
||||
108
clients.go
108
clients.go
@@ -8,6 +8,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -21,8 +22,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds
|
||||
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds.
|
||||
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
|
||||
minimumKeepalive uint16 = 5 // the minimum recommended keepalive - values under with display a warning.
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMinimumKeepalive = errors.New("client keepalive is below minimum recommended value and may exhibit connection instability")
|
||||
)
|
||||
|
||||
// ReadFn is the function signature for the function used for reading and processing new packets.
|
||||
@@ -99,7 +105,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
type Client struct {
|
||||
Properties ClientProperties // client properties
|
||||
State ClientState // the operational state of the client.
|
||||
Net ClientConnection // network connection state of the clinet
|
||||
Net ClientConnection // network connection state of the client
|
||||
ID string // the client id.
|
||||
ops *ops // ops provides a reference to server ops.
|
||||
sync.RWMutex // mutex
|
||||
@@ -107,11 +113,12 @@ type Client struct {
|
||||
|
||||
// ClientConnection contains the connection transport and metadata for the client.
|
||||
type ClientConnection struct {
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
Inline bool // client is an inline programmetic client
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.Reader // a buffered net.Conn for reading packets
|
||||
outbuf *bytes.Buffer // a buffer for writing packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
Inline bool // if true, the client is the built-in 'inline' embedded client
|
||||
}
|
||||
|
||||
// ClientProperties contains the properties which define the client behaviour.
|
||||
@@ -134,7 +141,7 @@ type Will struct {
|
||||
Retain bool // -
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
// ClientState tracks the state of the client.
|
||||
type ClientState struct {
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
@@ -174,11 +181,8 @@ func newClient(c net.Conn, o *ops) *Client {
|
||||
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
Conn: c,
|
||||
bconn: bufio.NewReadWriter(
|
||||
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
|
||||
bufio.NewWriterSize(c, o.options.ClientNetWriteBufferSize),
|
||||
),
|
||||
Conn: c,
|
||||
bconn: bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
@@ -192,7 +196,8 @@ func (cl *Client) WriteLoop() {
|
||||
select {
|
||||
case pk := <-cl.State.outbound:
|
||||
if err := cl.WritePacket(*pk); err != nil {
|
||||
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
|
||||
// TODO : Figure out what to do with error
|
||||
cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
|
||||
}
|
||||
atomic.AddInt32(&cl.State.outboundQty, -1)
|
||||
case <-cl.State.open.Done():
|
||||
@@ -210,6 +215,19 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Properties.Clean = pk.Connect.Clean
|
||||
cl.Properties.Props = pk.Properties.Copy(false)
|
||||
|
||||
if cl.Properties.Props.ReceiveMaximum > cl.ops.options.Capabilities.MaximumInflight { // 3.3.4 Non-normative
|
||||
cl.Properties.Props.ReceiveMaximum = cl.ops.options.Capabilities.MaximumInflight
|
||||
}
|
||||
|
||||
if pk.Connect.Keepalive <= minimumKeepalive {
|
||||
cl.ops.log.Warn(
|
||||
ErrMinimumKeepalive.Error(),
|
||||
"client", cl.ID,
|
||||
"keepalive", pk.Connect.Keepalive,
|
||||
"recommended", minimumKeepalive,
|
||||
)
|
||||
}
|
||||
|
||||
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
@@ -310,11 +328,27 @@ func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||
// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights() {
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearExpiredInflights deletes any inflight messages which have expired.
|
||||
func (cl *Client) ClearExpiredInflights(now, maximumExpiry int64) []uint16 {
|
||||
deleted := []uint16{}
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
||||
expired := tk.ProtocolVersion == 5 && tk.Expiry > 0 && tk.Expiry < now // [MQTT-3.3.2-5]
|
||||
|
||||
// If the maximum message expiry interval is set (greater than 0), and the message
|
||||
// retention period exceeds the maximum expiry, the message will be forcibly removed.
|
||||
enforced := maximumExpiry > 0 && now-tk.Created > maximumExpiry
|
||||
|
||||
if expired || enforced {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
@@ -552,11 +586,35 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
|
||||
}
|
||||
|
||||
nb := net.Buffers{buf.Bytes()}
|
||||
n, err := func() (int64, error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
return nb.WriteTo(cl.Net.Conn)
|
||||
if len(cl.State.outbound) == 0 {
|
||||
if cl.Net.outbuf == nil {
|
||||
return buf.WriteTo(cl.Net.Conn)
|
||||
}
|
||||
|
||||
// first write to buffer, then flush buffer
|
||||
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
|
||||
err = cl.flushOutbuf()
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// there are more writes in the queue
|
||||
if cl.Net.outbuf == nil {
|
||||
if buf.Len() >= cl.ops.options.ClientNetWriteBufferSize {
|
||||
return buf.WriteTo(cl.Net.Conn)
|
||||
}
|
||||
cl.Net.outbuf = new(bytes.Buffer)
|
||||
}
|
||||
|
||||
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
|
||||
if cl.Net.outbuf.Len() < cl.ops.options.ClientNetWriteBufferSize {
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
err = cl.flushOutbuf()
|
||||
return int64(n), err
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -572,3 +630,15 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (cl *Client) flushOutbuf() (err error) {
|
||||
if cl.Net.outbuf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = cl.Net.outbuf.WriteTo(cl.Net.Conn)
|
||||
if err == nil {
|
||||
cl.Net.outbuf = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
252
clients_test.go
252
clients_test.go
@@ -5,10 +5,14 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -29,10 +33,11 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
cl = newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: &logger,
|
||||
log: logger,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
MaximumInflight: 5,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
@@ -179,6 +184,45 @@ func TestClientParseConnect(t *testing.T) {
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectReceiveMaxExceedMaxInflight(t *testing.T) {
|
||||
const MaxInflight uint16 = 1
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ops.options.Capabilities.MaximumInflight = MaxInflight
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillTopic: "lwt",
|
||||
WillPayload: []byte("lol gg"),
|
||||
WillQos: 1,
|
||||
WillRetain: false,
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
ReceiveMaximum: uint16(5),
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
|
||||
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
|
||||
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
|
||||
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
|
||||
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
|
||||
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(MaxInflight), cl.State.Inflight.sendQuota)
|
||||
require.Equal(t, int32(MaxInflight), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
@@ -210,6 +254,27 @@ func TestClientParseConnectNoID(t *testing.T) {
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientParseConnectBelowMinimumKeepalive(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
var b bytes.Buffer
|
||||
x := bufio.NewWriter(&b)
|
||||
cl.ops.log = slog.New(slog.NewTextHandler(x, nil))
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Keepalive: minimumKeepalive - 1,
|
||||
ClientIdentifier: "mochi",
|
||||
},
|
||||
}
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
err := x.Flush()
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.Contains(b.String(), ErrMinimumKeepalive.Error()))
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientNextPacketID(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
@@ -263,7 +328,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
|
||||
}
|
||||
|
||||
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
|
||||
@@ -277,19 +342,56 @@ func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
|
||||
func TestClientClearInflights(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
n := time.Now().Unix()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
|
||||
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
cl.ClearInflights()
|
||||
require.Equal(t, 0, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientClearExpiredInflights(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 1, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 2, Expiry: n - 2})
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 3, Created: n - 3}) // within bounds
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 5, Created: n - 5}) // over max server expiry limit
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 7, Created: n})
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
|
||||
deleted := cl.ClearInflights(n, 4)
|
||||
deleted := cl.ClearExpiredInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 11, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 12, Expiry: n - 2}) // expiry is ineffective for v3.
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 13, Created: n - 3}) // within bounds for v3
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 15, Created: n - 5}) // over max server expiry limit
|
||||
require.Equal(t, 6, cl.State.Inflight.Len())
|
||||
|
||||
deleted = cl.ClearExpiredInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{11, 12, 15}, deleted)
|
||||
require.Equal(t, 3, cl.State.Inflight.Len())
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 17, Created: n - 1})
|
||||
deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not process abandon messages
|
||||
require.Len(t, deleted, 0)
|
||||
require.Equal(t, 4, cl.State.Inflight.Len())
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{ProtocolVersion: 5, PacketID: 18, Expiry: n - 1})
|
||||
deleted = cl.ClearExpiredInflights(n, 0) // maximumExpiry = 0 do not abandon messages
|
||||
require.ElementsMatch(t, []uint16{18}, deleted) // expiry is still effective for v5.
|
||||
require.Len(t, deleted, 1)
|
||||
require.Equal(t, 4, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessages(t *testing.T) {
|
||||
@@ -303,7 +405,7 @@ func TestClientResendInflightMessages(t *testing.T) {
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
w.Close()
|
||||
_ = w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
@@ -315,7 +417,7 @@ func TestClientResendInflightMessages(t *testing.T) {
|
||||
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
|
||||
cl, r, _ := newTestClient()
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
@@ -342,8 +444,8 @@ func TestClientReadFixedHeader(t *testing.T) {
|
||||
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0x00})
|
||||
r.Close()
|
||||
_, _ = r.Write([]byte{packets.Connect << 4, 0x00})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -357,8 +459,8 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
|
||||
r.Close()
|
||||
_, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -372,8 +474,8 @@ func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
r.Close()
|
||||
_, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -387,7 +489,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -401,8 +503,8 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
|
||||
r.Close()
|
||||
_, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -414,7 +516,7 @@ func TestClientReadOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 18, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
@@ -424,7 +526,7 @@ func TestClientReadOK(t *testing.T) {
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
var pks []packets.Packet
|
||||
@@ -499,10 +601,10 @@ func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
cl.Net.bconn = nil
|
||||
@@ -516,13 +618,13 @@ func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
err := cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
@@ -536,13 +638,13 @@ func TestClientReadReadPacketOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -573,7 +675,7 @@ func TestClientReadPacket(t *testing.T) {
|
||||
t.Run(tt.Desc, func(t *testing.T) {
|
||||
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
|
||||
go func() {
|
||||
r.Write(tt.RawBytes)
|
||||
_, _ = r.Write(tt.RawBytes)
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -600,7 +702,7 @@ func TestClientReadPacket(t *testing.T) {
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
_ = cl.Net.Conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
@@ -624,7 +726,7 @@ func TestClientWritePacket(t *testing.T) {
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cl.Net.Conn.Close()
|
||||
_ = cl.Net.Conn.Close()
|
||||
|
||||
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
@@ -647,6 +749,86 @@ func TestClientWritePacket(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWritePacketBuffer(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
|
||||
cl := newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: logger,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
cl.ID = "mochi"
|
||||
cl.State.Inflight.maximumSendQuota = 5
|
||||
cl.State.Inflight.sendQuota = 5
|
||||
cl.State.Inflight.maximumReceiveQuota = 10
|
||||
cl.State.Inflight.receiveQuota = 10
|
||||
cl.Properties.Props.TopicAliasMaximum = 0
|
||||
cl.Properties.Props.RequestResponseInfo = 0x1
|
||||
|
||||
cl.ops.options.ClientNetWriteBufferSize = 10
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
small := packets.TPacketData[packets.Publish].Get(packets.TPublishNoPayload).Packet
|
||||
large := packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||
|
||||
cl.State.outbound <- small
|
||||
|
||||
tt := []struct {
|
||||
pks []*packets.Packet
|
||||
size int
|
||||
}{
|
||||
{
|
||||
pks: []*packets.Packet{small, small},
|
||||
size: 18,
|
||||
},
|
||||
{
|
||||
pks: []*packets.Packet{large},
|
||||
size: 20,
|
||||
},
|
||||
{
|
||||
pks: []*packets.Packet{small},
|
||||
size: 0,
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
for i, tx := range tt {
|
||||
for _, pk := range tx.pks {
|
||||
cl.Properties.ProtocolVersion = pk.ProtocolVersion
|
||||
err := cl.WritePacket(*pk)
|
||||
require.NoError(t, err, "index: %d", i)
|
||||
if i == len(tt)-1 {
|
||||
cl.Net.Conn.Close()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var n int
|
||||
var err error
|
||||
for i, tx := range tt {
|
||||
buf := make([]byte, 100)
|
||||
if i == len(tt)-1 {
|
||||
buf, err = io.ReadAll(r)
|
||||
n = len(buf)
|
||||
} else {
|
||||
n, err = io.ReadAtLeast(r, buf, 1)
|
||||
}
|
||||
require.NoError(t, err, "index: %d", i)
|
||||
require.Equal(t, tx.size, n, "index: %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClientOversizePacket(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.MaximumPacketSize = 2
|
||||
@@ -660,13 +842,13 @@ func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
@@ -680,13 +862,13 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
_, _ = r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
_ = r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
@@ -706,7 +888,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
_ = cl.Net.Conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
@@ -59,7 +59,8 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
@@ -77,7 +77,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
@@ -59,7 +59,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
@@ -46,7 +46,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -6,15 +6,15 @@ package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/debug"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -27,8 +27,12 @@ func main() {
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
level := new(slog.LevelVar)
|
||||
server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}))
|
||||
level.Set(slog.LevelDebug)
|
||||
|
||||
err := server.AddHook(new(debug.Hook), &debug.Options{
|
||||
// ShowPacketData: true,
|
||||
@@ -56,7 +60,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
83
examples/direct/main.go
Normal file
83
examples/direct/main.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true, // you must enable inline client to use direct publishing and subscribing.
|
||||
})
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Start the server
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Demonstration of using an inline client to directly subscribe to a topic and receive a message when
|
||||
// that subscription is activated. The inline subscription method uses the same internal subscription logic
|
||||
// as used for external (normal) clients.
|
||||
go func() {
|
||||
// Inline subscriptions can also receive retained messages on subscription.
|
||||
_ = server.Publish("direct/retained", []byte("retained message"), true, 0)
|
||||
_ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0)
|
||||
|
||||
// Subscribe to a filter and handle any received messages via a callback function.
|
||||
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||
}
|
||||
server.Log.Info("inline client subscribing")
|
||||
_ = server.Subscribe("direct/#", 1, callbackFn)
|
||||
_ = server.Subscribe("direct/#", 2, callbackFn)
|
||||
}()
|
||||
|
||||
// There is a shorthand convenience function, Publish, for easily sending publish packets if you are not
|
||||
// concerned with creating your own packets. If you want to have more control over your packets, you can
|
||||
//directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage.
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 3) {
|
||||
err := server.Publish("direct/publish", []byte("scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error("server.Publish", "error", err)
|
||||
}
|
||||
server.Log.Info("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Second * 10)
|
||||
// Unsubscribe from the same filter to stop receiving messages.
|
||||
server.Log.Info("inline client unsubscribing")
|
||||
_ = server.Unsubscribe("direct/#", 1)
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
@@ -6,13 +6,14 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
@@ -27,7 +28,10 @@ func main() {
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
InlineClient: true, // you must enable inline client to use direct publishing and subscribing.
|
||||
})
|
||||
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
@@ -35,7 +39,11 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddHook(new(ExampleHook), map[string]any{})
|
||||
// Add custom hook (ExampleHook) to the server
|
||||
err = server.AddHook(new(ExampleHook), &ExampleHookOptions{
|
||||
Server: server,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -62,9 +70,9 @@ func main() {
|
||||
Payload: []byte("injected scheduled message"),
|
||||
})
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.InjectPacket")
|
||||
server.Log.Error("server.InjectPacket", "error", err)
|
||||
}
|
||||
server.Log.Info().Msgf("main.go injected packet to direct/publish")
|
||||
server.Log.Info("main.go injected packet to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -74,20 +82,26 @@ func main() {
|
||||
for range time.Tick(time.Second * 5) {
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.Publish")
|
||||
server.Log.Error("server.Publish", "error", err)
|
||||
}
|
||||
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
|
||||
server.Log.Info("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the hook.
|
||||
type ExampleHookOptions struct {
|
||||
Server *mqtt.Server
|
||||
}
|
||||
|
||||
type ExampleHook struct {
|
||||
mqtt.HookBase
|
||||
config *ExampleHookOptions
|
||||
}
|
||||
|
||||
func (h *ExampleHook) ID() string {
|
||||
@@ -106,39 +120,67 @@ func (h *ExampleHook) Provides(b byte) bool {
|
||||
}
|
||||
|
||||
func (h *ExampleHook) Init(config any) error {
|
||||
h.Log.Info().Msg("initialised")
|
||||
h.Log.Info("initialised")
|
||||
if _, ok := config.(*ExampleHookOptions); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
h.config = config.(*ExampleHookOptions)
|
||||
if h.config.Server == nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// subscribeCallback handles messages for subscribed topics
|
||||
func (h *ExampleHook) subscribeCallback(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||
h.Log.Info("hook subscribed message", "client", cl.ID, "topic", pk.TopicName)
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
|
||||
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
|
||||
h.Log.Info("client connected", "client", cl.ID)
|
||||
|
||||
// Example demonstrating how to subscribe to a topic within the hook.
|
||||
h.config.Server.Subscribe("hook/direct/publish", 1, h.subscribeCallback)
|
||||
|
||||
// Example demonstrating how to publish a message within the hook
|
||||
err := h.config.Server.Publish("hook/direct/publish", []byte("packet hook message"), false, 0)
|
||||
if err != nil {
|
||||
h.Log.Error("hook.publish", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
|
||||
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
|
||||
if err != nil {
|
||||
h.Log.Info("client disconnected", "client", cl.ID, "expire", expire, "error", err)
|
||||
} else {
|
||||
h.Log.Info("client disconnected", "client", cl.ID, "expire", expire)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
|
||||
h.Log.Info(fmt.Sprintf("subscribed qos=%v", reasonCodes), "client", cl.ID, "filters", pk.Filters)
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
|
||||
h.Log.Info("unsubscribed", "client", cl.ID, "filters", pk.Filters)
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
|
||||
h.Log.Info("received from client", "client", cl.ID, "payload", string(pk.Payload))
|
||||
|
||||
pkx := pk
|
||||
if string(pk.Payload) == "hello" {
|
||||
pkx.Payload = []byte("hello world")
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
|
||||
h.Log.Info("received modified packet from client", "client", cl.ID, "payload", string(pkx.Payload))
|
||||
}
|
||||
|
||||
return pkx, nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
|
||||
h.Log.Info("published to client", "client", cl.ID, "payload", string(pk.Payload))
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
@@ -45,9 +45,9 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
type pahoAuthHook struct {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage/badger"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
@@ -52,8 +52,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage/bolt"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
@@ -54,7 +54,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -6,15 +6,15 @@ package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage/redis"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
rv8 "github.com/go-redis/redis/v8"
|
||||
)
|
||||
@@ -30,8 +30,12 @@ func main() {
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
level := new(slog.LevelVar)
|
||||
server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}))
|
||||
level.Set(slog.LevelDebug)
|
||||
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
@@ -58,8 +62,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
@@ -52,7 +52,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
@@ -111,7 +111,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
)
|
||||
@@ -41,7 +41,7 @@ func main() {
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
server.Log.Warn("caught signal, stopping...")
|
||||
_ = server.Close()
|
||||
server.Log.Info("main.go finished")
|
||||
}
|
||||
|
||||
11
go.mod
11
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/mochi-mqtt/server/v2
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.23.0
|
||||
@@ -10,7 +10,6 @@ require (
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.3.5
|
||||
github.com/rs/xid v1.4.0
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/timshannon/badgerhold v1.0.0
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
@@ -28,13 +27,11 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.0 // indirect
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/golang/snappy v0.0.3 // indirect
|
||||
github.com/mattn/go-colorable v0.1.12 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/google/go-cmp v0.5.8 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
|
||||
golang.org/x/net v0.7.0 // indirect
|
||||
golang.org/x/sys v0.5.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
)
|
||||
|
||||
32
go.sum
32
go.sum
@@ -23,7 +23,6 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
|
||||
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -38,9 +37,9 @@ github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
|
||||
@@ -48,8 +47,9 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
@@ -62,15 +62,14 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
|
||||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
@@ -79,8 +78,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
|
||||
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
|
||||
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
@@ -109,24 +106,21 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
@@ -136,8 +130,10 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
52
hooks.go
52
hooks.go
@@ -1,20 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co, thedevop
|
||||
// SPDX-FileContributor: mochi-co, thedevop, dgduncan
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -70,7 +69,7 @@ type Hook interface {
|
||||
Provides(b byte) bool
|
||||
Init(config any) error
|
||||
Stop() error
|
||||
SetOpts(l *zerolog.Logger, o *HookOptions)
|
||||
SetOpts(l *slog.Logger, o *HookOptions)
|
||||
OnStarted()
|
||||
OnStopped()
|
||||
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
|
||||
@@ -117,11 +116,11 @@ type HookOptions struct {
|
||||
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
Log *slog.Logger // a logger for the hook (from the server)
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
@@ -179,9 +178,9 @@ func (h *Hooks) GetAll() []Hook {
|
||||
func (h *Hooks) Stop() {
|
||||
go func() {
|
||||
for _, hook := range h.GetAll() {
|
||||
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
|
||||
h.Log.Info("stopping hook", "hook", hook.ID())
|
||||
if err := hook.Stop(); err != nil {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
|
||||
h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID())
|
||||
}
|
||||
|
||||
h.wg.Done()
|
||||
@@ -266,7 +265,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
|
||||
if hook.Provides(OnPacketRead) {
|
||||
npk, err := hook.OnPacketRead(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
|
||||
h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx)
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
@@ -394,10 +393,16 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
|
||||
npk, err := hook.OnPublish(cl, pkx)
|
||||
if err != nil {
|
||||
if errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
|
||||
h.Log.Debug("publish packet rejected",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"packet", pkx)
|
||||
return pk, err
|
||||
}
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet error")
|
||||
h.Log.Error("publish packet error",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"packet", pkx)
|
||||
return pk, err
|
||||
}
|
||||
pkx = npk
|
||||
@@ -496,7 +501,10 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
if hook.Provides(OnWill) {
|
||||
mlwt, err := hook.OnWill(cl, will)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
|
||||
h.Log.Error("parse will error",
|
||||
"error", err,
|
||||
"hook", hook.ID(),
|
||||
"will", will)
|
||||
continue
|
||||
}
|
||||
will = mlwt
|
||||
@@ -540,7 +548,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
if hook.Provides(StoredClients) {
|
||||
v, err := hook.StoredClients()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
|
||||
h.Log.Error("failed to load clients", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -560,7 +568,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if hook.Provides(StoredSubscriptions) {
|
||||
v, err := hook.StoredSubscriptions()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
|
||||
h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -580,7 +588,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if hook.Provides(StoredInflightMessages) {
|
||||
v, err := hook.StoredInflightMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
|
||||
h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -600,7 +608,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if hook.Provides(StoredRetainedMessages) {
|
||||
v, err := hook.StoredRetainedMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
|
||||
h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -619,7 +627,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if hook.Provides(StoredSysInfo) {
|
||||
v, err := hook.StoredSysInfo()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
|
||||
h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID())
|
||||
return v, err
|
||||
}
|
||||
|
||||
@@ -668,7 +676,7 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
// all hooks.
|
||||
type HookBase struct {
|
||||
Hook
|
||||
Log *zerolog.Logger
|
||||
Log *slog.Logger
|
||||
Opts *HookOptions
|
||||
}
|
||||
|
||||
@@ -691,7 +699,7 @@ func (h *HookBase) Init(config any) error {
|
||||
|
||||
// SetOpts is called by the server to propagate internal values and generally should
|
||||
// not be called manually.
|
||||
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
|
||||
func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) {
|
||||
h.Log = l
|
||||
h.Opts = opts
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ package auth
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
)
|
||||
|
||||
@@ -67,10 +67,9 @@ func (h *Hook) Init(config any) error {
|
||||
}
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Int("authentication", len(h.ledger.Auth)).
|
||||
Int("acl", len(h.ledger.ACL)).
|
||||
Msg("loaded auth rules")
|
||||
h.Log.Info("loaded auth rules",
|
||||
"authentication", len(h.ledger.Auth),
|
||||
"acl", len(h.ledger.ACL))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -82,11 +81,9 @@ func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("username", string(pk.Connect.Username)).
|
||||
Str("remote", cl.Net.Remote).
|
||||
Msg("client failed authentication check")
|
||||
|
||||
h.Log.Info("client failed authentication check",
|
||||
"username", string(pk.Connect.Username),
|
||||
"remote", cl.Net.Remote)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -97,11 +94,10 @@ func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Debug().
|
||||
Str("client", cl.ID).
|
||||
Str("username", string(cl.Properties.Username)).
|
||||
Str("topic", topic).
|
||||
Msg("client failed allowed ACL check")
|
||||
h.Log.Debug("client failed allowed ACL check",
|
||||
"client", cl.ID,
|
||||
"username", string(cl.Properties.Username),
|
||||
"topic", topic)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -5,16 +5,16 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
// func teardown(t *testing.T, path string, h *Hook) {
|
||||
// h.Stop()
|
||||
@@ -34,7 +34,7 @@ func TestBasicProvides(t *testing.T) {
|
||||
|
||||
func TestBasicInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -42,7 +42,7 @@ func TestBasicInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestBasicInitDefaultConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
@@ -50,7 +50,7 @@ func TestBasicInitDefaultConfig(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerPointer(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := &Ledger{
|
||||
Auth: []AuthRule{
|
||||
@@ -79,7 +79,7 @@ func TestBasicInitWithLedgerPointer(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerJSON(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
@@ -93,7 +93,7 @@ func TestBasicInitWithLedgerJSON(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerYAML(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
@@ -107,7 +107,7 @@ func TestBasicInitWithLedgerYAML(t *testing.T) {
|
||||
|
||||
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
@@ -119,7 +119,7 @@ func TestBasicInitWithLedgerBadDAta(t *testing.T) {
|
||||
|
||||
func TestOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
@@ -158,7 +158,7 @@ func TestOnConnectAuthenticate(t *testing.T) {
|
||||
|
||||
func TestOnACL(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
|
||||
@@ -80,8 +80,8 @@ func (r RString) Matches(a string) bool {
|
||||
}
|
||||
|
||||
// FilterMatches returns true if a filter matches a topic rule.
|
||||
func (f RString) FilterMatches(a string) bool {
|
||||
_, ok := MatchTopic(string(f), a)
|
||||
func (r RString) FilterMatches(a string) bool {
|
||||
_, ok := MatchTopic(string(r), a)
|
||||
return ok
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
|
||||
}
|
||||
|
||||
// ACLOk returns true if the rules indicate the user is allowed to read or write to
|
||||
// a specific filter or topic respectively, based on the write bool.
|
||||
// a specific filter or topic respectively, based on the `write` bool.
|
||||
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
@@ -209,7 +209,7 @@ func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok boo
|
||||
}
|
||||
}
|
||||
|
||||
for filter, _ := range rule.Filters {
|
||||
for filter := range rule.Filters {
|
||||
if filter.FilterMatches(topic) {
|
||||
return n, false
|
||||
}
|
||||
|
||||
@@ -561,17 +561,17 @@ func TestLedgerUpdate(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
new := &Ledger{
|
||||
n := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
{Remote: "192.168.*", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
old.Update(new)
|
||||
old.Update(n)
|
||||
require.Len(t, old.Auth, 2)
|
||||
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
|
||||
require.NotSame(t, new, old)
|
||||
require.NotSame(t, n, old)
|
||||
}
|
||||
|
||||
func TestLedgerToJSON(t *testing.T) {
|
||||
|
||||
@@ -5,13 +5,13 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Options contains configuration settings for the debug output.
|
||||
@@ -25,7 +25,7 @@ type Options struct {
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
Log *zerolog.Logger
|
||||
Log *slog.Logger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
@@ -54,25 +54,25 @@ func (h *Hook) Init(config any) error {
|
||||
}
|
||||
|
||||
// SetOpts is called when the hook receives inheritable server parameters.
|
||||
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
|
||||
func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) {
|
||||
h.Log = l
|
||||
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
|
||||
h.Log.Debug("", "method", "SetOpts")
|
||||
}
|
||||
|
||||
// Stop is called when the hook is stopped.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Debug().Str("method", "Stop").Send()
|
||||
h.Log.Debug("", "method", "Stop")
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *Hook) OnStarted() {
|
||||
h.Log.Debug().Str("method", "OnStarted").Send()
|
||||
h.Log.Debug("", "method", "OnStarted")
|
||||
}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *Hook) OnStopped() {
|
||||
h.Log.Debug().Str("method", "OnStopped").Send()
|
||||
h.Log.Debug("", "method", "OnStopped")
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a new packet is received from a client.
|
||||
@@ -81,8 +81,7 @@ func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet,
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
|
||||
h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
@@ -92,85 +91,72 @@ func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
|
||||
h.Log.Debug("retained message on topic", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
|
||||
h.Log.Debug("inflight out", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
|
||||
h.Log.Debug("inflight complete", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
|
||||
h.Log.Debug("inflight dropped", "m", h.packetMeta(pk))
|
||||
}
|
||||
|
||||
// OnLWTSent is called when a will message has been issued from a disconnecting client.
|
||||
// OnLWTSent is called when a Will Message has been issued from a disconnecting client.
|
||||
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
|
||||
h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID)
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when the server clears expired retained messages.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
|
||||
h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter)
|
||||
}
|
||||
|
||||
// OnClientExpired is called when the server clears an expired client.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
|
||||
h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID)
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores clients from a store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
h.Log.Debug("", "method", "StoredClients")
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores subscriptions from a store.
|
||||
// StoredSubscriptions is called when the server restores subscriptions from a store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredSubscriptions").
|
||||
Send()
|
||||
|
||||
h.Log.Debug("", "method", "StoredSubscriptions")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores retained messages from a store.
|
||||
// StoredRetainedMessages is called when the server restores retained messages from a store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredRetainedMessages").
|
||||
Send()
|
||||
|
||||
h.Log.Debug("", "method", "StoredRetainedMessages")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores inflight messages from a store.
|
||||
// StoredInflightMessages is called when the server restores inflight messages from a store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredInflightMessages").
|
||||
Send()
|
||||
|
||||
h.Log.Debug("", "method", "StoredInflightMessages")
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores system info from a store.
|
||||
// StoredSysInfo is called when the server restores system info from a store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
h.Log.Debug("", "method", "StoredSysInfo")
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ package badger
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
@@ -127,8 +128,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -165,14 +165,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
|
||||
h.Log.Error("failed to upsert client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if their session has expired.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -188,14 +188,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
|
||||
h.Log.Error("failed to delete client data", "error", err, "data", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
|
||||
h.Log.Error("failed to upsert subscription data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,14 +223,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
|
||||
h.Log.Error("failed to delete subscription data", "error", err, "data", subscriptionKey(cl, pk.Filters[i].Filter))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -238,14 +238,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
|
||||
h.Log.Error("failed to delete retained message data", "error", err, "data", retainedKey(pk.TopicName))
|
||||
}
|
||||
|
||||
return
|
||||
@@ -274,14 +274,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
|
||||
h.Log.Error("failed to upsert retained message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -310,27 +310,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
|
||||
h.Log.Error("failed to upsert qos inflight data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
|
||||
h.Log.Error("failed to delete inflight message data", "error", err, "data", inflightKey(cl, pk))
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
@@ -339,7 +339,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -351,40 +351,40 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
|
||||
h.Log.Error("failed to upsert $SYS data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
||||
h.Log.Error("failed to delete expired retained message data", "error", err, "id", retainedKey(filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
||||
h.Log.Error("failed to delete expired client data", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -429,7 +429,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -444,7 +444,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -458,20 +458,21 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
|
||||
// Errorf satisfies the badger interface for an error logger.
|
||||
func (h *Hook) Errorf(m string, v ...interface{}) {
|
||||
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
|
||||
}
|
||||
|
||||
// Warningf satisfies the badger interface for a warning logger.
|
||||
func (h *Hook) Warningf(m string, v ...interface{}) {
|
||||
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Infof satisfies the badger interface for an info logger.
|
||||
func (h *Hook) Infof(m string, v ...interface{}) {
|
||||
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
// Debugf satisfies the badger interface for a debug logger.
|
||||
func (h *Hook) Debugf(m string, v ...interface{}) {
|
||||
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
|
||||
}
|
||||
|
||||
@@ -5,22 +5,22 @@
|
||||
package badger
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/timshannon/badgerhold"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
@@ -38,8 +38,8 @@ var (
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
h.db.Badger().Close()
|
||||
_ = h.Stop()
|
||||
_ = h.db.Badger().Close()
|
||||
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func TestProvides(t *testing.T) {
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -103,7 +103,7 @@ func TestInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -113,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -146,7 +146,7 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -170,13 +170,13 @@ func TestOnClientExpired(t *testing.T) {
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -185,13 +185,13 @@ func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -200,7 +200,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -219,13 +219,13 @@ func TestOnWillSent(t *testing.T) {
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -234,7 +234,7 @@ func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -257,7 +257,7 @@ func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -279,13 +279,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -294,13 +294,13 @@ func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -309,7 +309,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -344,7 +344,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -371,13 +371,13 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -386,13 +386,13 @@ func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -401,7 +401,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -436,13 +436,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -451,13 +451,13 @@ func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -466,13 +466,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -494,13 +494,13 @@ func TestOnSysInfoTick(t *testing.T) {
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -509,7 +509,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -534,7 +534,7 @@ func TestStoredClients(t *testing.T) {
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -542,7 +542,7 @@ func TestStoredClientsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -567,7 +567,7 @@ func TestStoredSubscriptions(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -575,7 +575,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -603,7 +603,7 @@ func TestStoredRetainedMessages(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -611,7 +611,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -639,7 +639,7 @@ func TestStoredInflightMessages(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -647,7 +647,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -669,7 +669,7 @@ func TestStoredSysInfo(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -678,27 +678,27 @@ func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
func TestErrorf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Errorf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestWarningf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Warningf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestInfof(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Infof("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.Debugf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
|
||||
|
||||
// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
|
||||
package bolt
|
||||
|
||||
import (
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
@@ -132,8 +133,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
@@ -141,7 +141,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -169,14 +169,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
|
||||
h.Log.Error("failed to save client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -190,14 +190,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -217,10 +217,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save subscription data")
|
||||
h.Log.Error("failed to save subscription data", "error", err, "client", cl.ID, "data", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,7 +225,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -237,9 +234,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
|
||||
Msg("failed to delete client")
|
||||
h.Log.Error("failed to delete client", "error", err, "id", subscriptionKey(cl, pk.Filters[i].Filter))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -247,7 +242,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -256,9 +251,7 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
ID: retainedKey(pk.TopicName),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", retainedKey(pk.TopicName)).
|
||||
Msg("failed to delete retained publish")
|
||||
h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(pk.TopicName))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -285,17 +278,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save retained publish data")
|
||||
h.Log.Error("failed to save retained publish data", "error", err, "client", cl.ID, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -323,17 +313,14 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save qos inflight data")
|
||||
h.Log.Error("failed to save qos inflight data", "error", err, "client", cl.ID, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -341,16 +328,14 @@ func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
ID: inflightKey(cl, pk),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", inflightKey(cl, pk)).
|
||||
Msg("failed to delete inflight data")
|
||||
h.Log.Error("failed to delete inflight data", "error", err, "id", inflightKey(cl, pk))
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
@@ -359,7 +344,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -371,41 +356,39 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Interface("data", in).
|
||||
Msg("failed to save $SYS data")
|
||||
h.Log.Error("failed to save $SYS data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
||||
h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -420,7 +403,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -435,7 +418,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -450,7 +433,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -465,7 +448,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -5,22 +5,22 @@
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
@@ -38,7 +38,7 @@ var (
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
_ = h.Stop()
|
||||
err := os.Remove(path)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func TestProvides(t *testing.T) {
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -102,7 +102,7 @@ func TestInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -113,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
|
||||
func TestInitBadPath(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Path: "..",
|
||||
})
|
||||
@@ -122,7 +122,7 @@ func TestInitBadPath(t *testing.T) {
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -155,13 +155,13 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -170,7 +170,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -189,7 +189,7 @@ func TestOnWillSent(t *testing.T) {
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -213,7 +213,7 @@ func TestOnClientExpired(t *testing.T) {
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -222,19 +222,19 @@ func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -243,7 +243,7 @@ func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -266,7 +266,7 @@ func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -288,13 +288,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -303,13 +303,13 @@ func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -318,7 +318,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -353,7 +353,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -380,7 +380,7 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -389,19 +389,19 @@ func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -410,7 +410,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -445,13 +445,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -460,13 +460,13 @@ func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -475,13 +475,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -503,13 +503,13 @@ func TestOnSysInfoTick(t *testing.T) {
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -518,7 +518,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -543,7 +543,7 @@ func TestStoredClients(t *testing.T) {
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -551,7 +551,7 @@ func TestStoredClientsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -562,7 +562,7 @@ func TestStoredClientsClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -587,7 +587,7 @@ func TestStoredSubscriptions(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -595,7 +595,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -606,7 +606,7 @@ func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -634,7 +634,7 @@ func TestStoredRetainedMessages(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -642,7 +642,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -653,7 +653,7 @@ func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -681,7 +681,7 @@ func TestStoredInflightMessages(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -689,7 +689,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
@@ -700,7 +700,7 @@ func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
@@ -722,7 +722,7 @@ func TestStoredSysInfo(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
@@ -730,7 +730,7 @@ func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
|
||||
@@ -10,12 +10,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// defaultAddr is the default address to the redis service.
|
||||
@@ -117,12 +117,11 @@ func (h *Hook) Init(config any) error {
|
||||
h.config.HPrefix = defaultHPrefix
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("address", h.config.Options.Addr).
|
||||
Str("username", h.config.Options.Username).
|
||||
Int("password-len", len(h.config.Options.Password)).
|
||||
Int("db", h.config.Options.DB).
|
||||
Msg("connecting to redis service")
|
||||
h.Log.Info("connecting to redis service",
|
||||
"address", h.config.Options.Addr,
|
||||
"username", h.config.Options.Username,
|
||||
"password-len", len(h.config.Options.Password),
|
||||
"db", h.config.Options.DB)
|
||||
|
||||
h.db = redis.NewClient(h.config.Options)
|
||||
_, err := h.db.Ping(context.Background()).Result()
|
||||
@@ -130,14 +129,15 @@ func (h *Hook) Init(config any) error {
|
||||
return fmt.Errorf("failed to ping service: %w", err)
|
||||
}
|
||||
|
||||
h.Log.Info().Msg("connected to redis service")
|
||||
h.Log.Info("connected to redis service")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the redis connection.
|
||||
// Stop closes the redis connection.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Info().Msg("disconnecting from redis service")
|
||||
h.Log.Info("disconnecting from redis service")
|
||||
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
@@ -146,8 +146,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
@@ -155,7 +154,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -184,14 +183,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
|
||||
h.Log.Error("failed to hset client data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -205,14 +204,14 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -232,7 +231,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
|
||||
h.Log.Error("failed to hset subscription data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -240,14 +239,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
|
||||
h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -255,14 +254,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
|
||||
h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName))
|
||||
}
|
||||
|
||||
return
|
||||
@@ -291,14 +290,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
|
||||
h.Log.Error("failed to hset retained message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -326,27 +325,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
|
||||
h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
||||
h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk))
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
@@ -355,7 +354,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -367,53 +366,53 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
|
||||
h.Log.Error("failed to hset server info data", "error", err, "data", in)
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
||||
h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter))
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
|
||||
h.Log.Error("failed to HGetAll client data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Client
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
|
||||
h.Log.Error("failed to unmarshal client data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -425,20 +424,20 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
|
||||
h.Log.Error("failed to HGetAll subscription data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Subscription
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
|
||||
h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -450,20 +449,20 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
|
||||
h.Log.Error("failed to HGetAll retained message data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
|
||||
h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -475,20 +474,20 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
|
||||
h.Log.Error("failed to HGetAll inflight message data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
||||
h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
@@ -500,7 +499,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -510,7 +509,7 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
}
|
||||
|
||||
if err = v.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
|
||||
h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
|
||||
@@ -5,24 +5,24 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2"
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/hooks/storage"
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
miniredis "github.com/alicebob/miniredis/v2"
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
@@ -41,7 +41,7 @@ var (
|
||||
|
||||
func newHook(t *testing.T, addr string) *Hook {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
@@ -87,13 +87,13 @@ func TestSysInfoKey(t *testing.T) {
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
require.Equal(t, "redis-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
@@ -116,7 +116,7 @@ func TestHKey(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
h := newHook(t, defaultAddr)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h)
|
||||
@@ -137,7 +137,7 @@ func TestInitUseDefaults(t *testing.T) {
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
@@ -145,7 +145,7 @@ func TestInitBadConfig(t *testing.T) {
|
||||
|
||||
func TestInitBadAddr(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.SetOpts(logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: "abc:123",
|
||||
|
||||
@@ -25,7 +25,7 @@ var (
|
||||
ErrDBFileNotOpen = errors.New("db file not open")
|
||||
)
|
||||
|
||||
// Client is a storable representation of an mqtt client.
|
||||
// Client is a storable representation of an MQTT client.
|
||||
type Client struct {
|
||||
Will ClientWill `json:"will"` // will topic and payload data if applicable
|
||||
Properties ClientProperties `json:"properties"` // the connect properties for the client
|
||||
@@ -147,7 +147,7 @@ func (d *Message) ToPacket() packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// Subscription is a storable representation of an mqtt subscription.
|
||||
// Subscription is a storable representation of an MQTT subscription.
|
||||
type Subscription struct {
|
||||
T string `json:"t"`
|
||||
ID string `json:"id" storm:"id"`
|
||||
|
||||
@@ -118,7 +118,7 @@ var (
|
||||
func TestClientMarshalBinary(t *testing.T) {
|
||||
data, err := clientStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientJSON, data)
|
||||
require.JSONEq(t, string(clientJSON), string(data))
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinary(t *testing.T) {
|
||||
@@ -138,7 +138,7 @@ func TestClientUnmarshalBinaryEmpty(t *testing.T) {
|
||||
func TestMessageMarshalBinary(t *testing.T) {
|
||||
data, err := messageStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, messageJSON, data)
|
||||
require.JSONEq(t, string(messageJSON), string(data))
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinary(t *testing.T) {
|
||||
@@ -158,7 +158,7 @@ func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
|
||||
func TestSubscriptionMarshalBinary(t *testing.T) {
|
||||
data, err := subscriptionStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, subscriptionJSON, data)
|
||||
require.JSONEq(t, string(subscriptionJSON), string(data))
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinary(t *testing.T) {
|
||||
@@ -178,7 +178,7 @@ func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
|
||||
func TestSysInfoMarshalBinary(t *testing.T) {
|
||||
data, err := sysInfoStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sysInfoJSON, data)
|
||||
require.JSONEq(t, string(sysInfoJSON), string(data))
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinary(t *testing.T) {
|
||||
|
||||
@@ -215,7 +215,7 @@ func TestHooksAddInitFailure(t *testing.T) {
|
||||
|
||||
func TestHooksStop(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
@@ -334,7 +334,7 @@ func TestHooksOnUnsubscribe(t *testing.T) {
|
||||
|
||||
func TestHooksOnPublish(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -360,7 +360,7 @@ func TestHooksOnPublish(t *testing.T) {
|
||||
|
||||
func TestHooksOnPacketRead(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -386,7 +386,7 @@ func TestHooksOnPacketRead(t *testing.T) {
|
||||
|
||||
func TestHooksOnAuthPacket(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -404,7 +404,7 @@ func TestHooksOnAuthPacket(t *testing.T) {
|
||||
|
||||
func TestHooksOnConnect(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -420,7 +420,7 @@ func TestHooksOnConnect(t *testing.T) {
|
||||
|
||||
func TestHooksOnPacketEncode(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -432,7 +432,7 @@ func TestHooksOnPacketEncode(t *testing.T) {
|
||||
|
||||
func TestHooksOnLWT(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
@@ -449,7 +449,7 @@ func TestHooksOnLWT(t *testing.T) {
|
||||
|
||||
func TestHooksStoredClients(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
@@ -471,7 +471,7 @@ func TestHooksStoredClients(t *testing.T) {
|
||||
|
||||
func TestHooksStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
@@ -493,7 +493,7 @@ func TestHooksStoredSubscriptions(t *testing.T) {
|
||||
|
||||
func TestHooksStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
@@ -515,7 +515,7 @@ func TestHooksStoredRetainedMessages(t *testing.T) {
|
||||
|
||||
func TestHooksStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
@@ -537,7 +537,7 @@ func TestHooksStoredInflightMessages(t *testing.T) {
|
||||
|
||||
func TestHooksStoredSysInfo(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
h.Log = logger
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
@@ -575,7 +575,7 @@ func TestHookBaseInit(t *testing.T) {
|
||||
|
||||
func TestHookBaseSetOpts(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
h.SetOpts(&logger, new(HookOptions))
|
||||
h.SetOpts(logger, new(HookOptions))
|
||||
require.NotNil(t, h.Log)
|
||||
require.NotNil(t, h.Opts)
|
||||
}
|
||||
|
||||
@@ -6,23 +6,21 @@ package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
|
||||
type HTTPHealthCheck struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address.
|
||||
@@ -57,9 +55,7 @@ func (l *HTTPHealthCheck) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPHealthCheck) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
func (l *HTTPHealthCheck) Init(_ *slog.Logger) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
@@ -83,9 +79,9 @@ func (l *HTTPHealthCheck) Init(log *zerolog.Logger) error {
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPHealthCheck) Serve(establish EstablishFn) {
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
_ = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
_ = l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +93,7 @@ func (l *HTTPHealthCheck) Close(closeClients CloseFn) {
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
|
||||
@@ -39,13 +39,13 @@ func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
l.Init(nil)
|
||||
_ = l.Init(logger)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPHealthCheckInit(t *testing.T) {
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.listen)
|
||||
@@ -55,7 +55,7 @@ func TestHTTPHealthCheckInit(t *testing.T) {
|
||||
func TestHTTPHealthCheckServeAndClose(t *testing.T) {
|
||||
// setup http stats listener
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -91,7 +91,7 @@ func TestHTTPHealthCheckServeAndClose(t *testing.T) {
|
||||
func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
|
||||
// setup http stats listener
|
||||
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -129,7 +129,7 @@ func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
|
||||
@@ -8,26 +8,25 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
|
||||
type HTTPStats struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
log *zerolog.Logger // server logger
|
||||
sysInfo *system.Info // pointers to the server data
|
||||
end uint32 // ensure the close methods are only called once
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
sysInfo *system.Info // pointers to the server data
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
@@ -63,9 +62,8 @@ func (l *HTTPStats) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPStats) Init(log *zerolog.Logger) error {
|
||||
func (l *HTTPStats) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.jsonHandler)
|
||||
l.listen = &http.Server{
|
||||
@@ -84,10 +82,17 @@ func (l *HTTPStats) Init(log *zerolog.Logger) error {
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPStats) Serve(establish EstablishFn) {
|
||||
|
||||
var err error
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
err = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
err = l.listen.ListenAndServe()
|
||||
}
|
||||
|
||||
// After the listener has been shutdown, no need to print the http.ErrServerClosed error.
|
||||
if err != nil && atomic.LoadUint32(&l.end) == 0 {
|
||||
l.log.Error("failed to serve.", "error", err, "listener", l.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +104,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
@@ -111,8 +116,8 @@ func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
out, err := json.MarshalIndent(info, "", "\t")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
_, _ = io.WriteString(w, err.Error())
|
||||
}
|
||||
|
||||
w.Write(out)
|
||||
_, _ = w.Write(out)
|
||||
}
|
||||
|
||||
@@ -42,14 +42,14 @@ func TestHTTPStatsTLSProtocol(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, nil)
|
||||
|
||||
l.Init(nil)
|
||||
_ = l.Init(logger)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsInit(t *testing.T) {
|
||||
sysInfo := new(system.Info)
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.sysInfo)
|
||||
@@ -65,7 +65,7 @@ func TestHTTPStatsServeAndClose(t *testing.T) {
|
||||
|
||||
// setup http stats listener
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -113,7 +113,7 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, sysInfo)
|
||||
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -125,3 +125,28 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
}
|
||||
|
||||
func TestHTTPStatsFailedToServe(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
// setup http stats listener
|
||||
l := NewHTTPStats("t1", "wrong_addr", nil, sysInfo)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
<-o
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
@@ -22,24 +22,24 @@ type Config struct {
|
||||
// EstablishFn is a callback function for establishing new clients.
|
||||
type EstablishFn func(id string, c net.Conn) error
|
||||
|
||||
// CloseFunc is a callback function for closing all listener clients.
|
||||
// CloseFn is a callback function for closing all listener clients.
|
||||
type CloseFn func(id string)
|
||||
|
||||
// Listener is an interface for network listeners. A network listener listens
|
||||
// for incoming client connections and adds them to the server.
|
||||
type Listener interface {
|
||||
Init(*zerolog.Logger) error // open the network address
|
||||
Serve(EstablishFn) // starting actively listening for new connections
|
||||
ID() string // return the id of the listener
|
||||
Address() string // the address of the listener
|
||||
Protocol() string // the protocol in use by the listener
|
||||
Close(CloseFn) // stop and close the listener
|
||||
Init(*slog.Logger) error // open the network address
|
||||
Serve(EstablishFn) // starting actively listening for new connections
|
||||
ID() string // return the id of the listener
|
||||
Address() string // the address of the listener
|
||||
Protocol() string // the protocol in use by the listener
|
||||
Close(CloseFn) // stop and close the listener
|
||||
}
|
||||
|
||||
// Listeners contains the network listeners for the broker.
|
||||
type Listeners struct {
|
||||
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
ClientsWg sync.WaitGroup // a waitgroup that waits for all clients in all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -86,8 +86,6 @@ func (l *Listeners) Serve(id string, establisher EstablishFn) {
|
||||
listener := l.internal[id]
|
||||
|
||||
go func(e EstablishFn) {
|
||||
defer l.wg.Done()
|
||||
l.wg.Add(1)
|
||||
listener.Serve(e)
|
||||
}(establisher)
|
||||
}
|
||||
@@ -131,5 +129,5 @@ func (l *Listeners) CloseAll(closer CloseFn) {
|
||||
for _, id := range ids {
|
||||
l.Close(id, closer)
|
||||
}
|
||||
l.wg.Wait()
|
||||
l.ClientsWg.Wait()
|
||||
}
|
||||
|
||||
@@ -11,14 +11,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testAddr = ":22222"
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// MockEstablisher is a function signature which can be used in testing.
|
||||
@@ -53,7 +53,7 @@ func (l *MockListener) Serve(establisher EstablishFn) {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *MockListener) Init(log *zerolog.Logger) error {
|
||||
func (l *MockListener) Init(log *slog.Logger) error {
|
||||
if l.ErrListen {
|
||||
return fmt.Errorf("listen failure")
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestMockEstablisher(t *testing.T) {
|
||||
_, w := net.Pipe()
|
||||
err := MockEstablisher("t1", w)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
_ = w.Close()
|
||||
}
|
||||
|
||||
func TestNewMockListener(t *testing.T) {
|
||||
@@ -86,7 +86,7 @@ func TestMockListenerServe(t *testing.T) {
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
|
||||
mocked.Init(nil)
|
||||
_ = mocked.Init(nil)
|
||||
}
|
||||
|
||||
func TestMockListenerClose(t *testing.T) {
|
||||
|
||||
@@ -9,16 +9,16 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Net is a listener for establishing client connections on basic TCP protocol.
|
||||
type Net struct { // [MQTT-4.2.0-1]
|
||||
mu sync.Mutex
|
||||
listener net.Listener // a net.Listener which will listen for new clients
|
||||
id string // the internal id of the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
listener net.Listener // a net.Listener which will listen for new clients
|
||||
id string // the internal id of the listener
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
|
||||
@@ -45,7 +45,7 @@ func (l *Net) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Net) Init(log *zerolog.Logger) error {
|
||||
func (l *Net) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
return nil
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func (l *Net) Serve(establish EstablishFn) {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestNetInit(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
err = l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func TestNetServeAndClose(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
err = l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -84,7 +84,7 @@ func TestNetEstablishThenEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
err = l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -98,7 +98,7 @@ func TestNetEstablishThenEnd(t *testing.T) {
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", n.Addr().String())
|
||||
_, _ = net.Dial("tcp", n.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
|
||||
@@ -10,18 +10,18 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// TCP is a listener for establishing client connections on basic TCP protocol.
|
||||
type TCP struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
listen net.Listener // a net.Listener which will listen for new clients
|
||||
config *Config // configuration values for the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
listen net.Listener // a net.Listener which will listen for new clients
|
||||
config *Config // configuration values for the listener
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
@@ -44,6 +44,9 @@ func (l *TCP) ID() string {
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *TCP) Address() string {
|
||||
if l.listen != nil {
|
||||
return l.listen.Addr().String()
|
||||
}
|
||||
return l.address
|
||||
}
|
||||
|
||||
@@ -53,7 +56,7 @@ func (l *TCP) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *TCP) Init(log *zerolog.Logger) error {
|
||||
func (l *TCP) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
@@ -83,7 +86,7 @@ func (l *TCP) Serve(establish EstablishFn) {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -39,21 +39,21 @@ func TestTCPProtocolTLS(t *testing.T) {
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
l.Init(&logger)
|
||||
_ = l.Init(logger)
|
||||
defer l.listen.Close()
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPInit(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewTCP("t2", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err = l2.Init(&logger)
|
||||
err = l2.Init(logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l2.config.TLSConfig)
|
||||
@@ -61,7 +61,7 @@ func TestTCPInit(t *testing.T) {
|
||||
|
||||
func TestTCPServeAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -88,7 +88,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -110,7 +110,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
|
||||
func TestTCPEstablishThenEnd(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -124,7 +124,7 @@ func TestTCPEstablishThenEnd(t *testing.T) {
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", l.listen.Addr().String())
|
||||
_, _ = net.Dial("tcp", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
|
||||
@@ -10,17 +10,17 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||
type UnixSock struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *slog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
|
||||
@@ -47,7 +47,7 @@ func (l *UnixSock) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *UnixSock) Init(log *zerolog.Logger) error {
|
||||
func (l *UnixSock) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
@@ -73,7 +73,7 @@ func (l *UnixSock) Serve(establish EstablishFn) {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -38,19 +38,19 @@ func TestUnixSockProtocol(t *testing.T) {
|
||||
|
||||
func TestUnixSockInit(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewUnixSock("t2", testUnixAddr)
|
||||
err = l2.Init(&logger)
|
||||
err = l2.Init(logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnixSockServeAndClose(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -75,7 +75,7 @@ func TestUnixSockServeAndClose(t *testing.T) {
|
||||
|
||||
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -89,7 +89,7 @@ func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("unix", l.listen.Addr().String())
|
||||
_, _ = net.Dial("unix", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
|
||||
@@ -14,8 +14,9 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -29,8 +30,8 @@ type Websocket struct { // [MQTT-4.2.0-1]
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // an http server for serving websocket connections
|
||||
log *zerolog.Logger // server logger
|
||||
listen *http.Server // a http server for serving websocket connections
|
||||
log *slog.Logger // server logger
|
||||
establish EstablishFn // the server's establish connection handler
|
||||
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
|
||||
end uint32 // ensure the close methods are only called once
|
||||
@@ -75,7 +76,7 @@ func (l *Websocket) Protocol() string {
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Websocket) Init(log *zerolog.Logger) error {
|
||||
func (l *Websocket) Init(log *slog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@@ -101,19 +102,25 @@ func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c})
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
l.log.Warn("", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts waiting for new Websocket connections, and calls the connection
|
||||
// establishment callback for any received.
|
||||
func (l *Websocket) Serve(establish EstablishFn) {
|
||||
var err error
|
||||
l.establish = establish
|
||||
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
err = l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
err = l.listen.ListenAndServe()
|
||||
}
|
||||
|
||||
// After the listener has been shutdown, no need to print the http.ErrServerClosed error.
|
||||
if err != nil && atomic.LoadUint32(&l.end) == 0 {
|
||||
l.log.Error("failed to serve.", "error", err, "listener", l.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +132,7 @@ func (l *Websocket) Close(closeClients CloseFn) {
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
_ = l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
@@ -136,7 +143,7 @@ type wsConn struct {
|
||||
net.Conn
|
||||
c *websocket.Conn
|
||||
|
||||
// reader for the current message (may be nil)
|
||||
// reader for the current message (can be nil)
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
|
||||
@@ -37,24 +37,24 @@ func TestWebsocketProtocol(t *testing.T) {
|
||||
require.Equal(t, "ws", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocoTLS(t *testing.T) {
|
||||
func TestWebsocketProtocolTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
require.Equal(t, "wss", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsockeInit(t *testing.T) {
|
||||
func TestWebsocketInit(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Nil(t, l.listen)
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketServeAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
_ = l.Init(logger)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
@@ -77,7 +77,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(nil)
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
@@ -92,11 +92,33 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestWebsocketFailedToServe(t *testing.T) {
|
||||
l := NewWebsocket("t1", "wrong_addr", &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
<-o
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestWebsocketUpgrade(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
_ = l.Init(logger)
|
||||
|
||||
e := make(chan bool)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
@@ -110,12 +132,12 @@ func TestWebsocketUpgrade(t *testing.T) {
|
||||
require.Equal(t, true, <-e)
|
||||
|
||||
s.Close()
|
||||
ws.Close()
|
||||
_ = ws.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketConnectionReads(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
_ = l.Init(nil)
|
||||
|
||||
recv := make(chan []byte)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
@@ -151,5 +173,5 @@ func TestWebsocketConnectionReads(t *testing.T) {
|
||||
require.Equal(t, pkt, got)
|
||||
|
||||
s.Close()
|
||||
ws.Close()
|
||||
_ = ws.Close()
|
||||
}
|
||||
|
||||
81
mempool/bufpool.go
Normal file
81
mempool/bufpool.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bufPool = NewBuffer(0)
|
||||
|
||||
// GetBuffer takes a Buffer from the default buffer pool
|
||||
func GetBuffer() *bytes.Buffer { return bufPool.Get() }
|
||||
|
||||
// PutBuffer returns Buffer to the default buffer pool
|
||||
func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) }
|
||||
|
||||
type BufferPool interface {
|
||||
Get() *bytes.Buffer
|
||||
Put(x *bytes.Buffer)
|
||||
}
|
||||
|
||||
// NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will
|
||||
// return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If
|
||||
// max <= 0, no limit will be enforced.
|
||||
func NewBuffer(max int) BufferPool {
|
||||
if max > 0 {
|
||||
return newBufferWithCap(max)
|
||||
}
|
||||
|
||||
return newBuffer()
|
||||
}
|
||||
|
||||
// Buffer is a Buffer pool.
|
||||
type Buffer struct {
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
func newBuffer() *Buffer {
|
||||
return &Buffer{
|
||||
pool: &sync.Pool{
|
||||
New: func() any { return new(bytes.Buffer) },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get a Buffer from the pool.
|
||||
func (b *Buffer) Get() *bytes.Buffer {
|
||||
return b.pool.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
// Put the Buffer back into pool. It resets the Buffer for reuse.
|
||||
func (b *Buffer) Put(x *bytes.Buffer) {
|
||||
x.Reset()
|
||||
b.pool.Put(x)
|
||||
}
|
||||
|
||||
// BufferWithCap is a Buffer pool that
|
||||
type BufferWithCap struct {
|
||||
bp *Buffer
|
||||
max int
|
||||
}
|
||||
|
||||
func newBufferWithCap(max int) *BufferWithCap {
|
||||
return &BufferWithCap{
|
||||
bp: newBuffer(),
|
||||
max: max,
|
||||
}
|
||||
}
|
||||
|
||||
// Get a Buffer from the pool.
|
||||
func (b *BufferWithCap) Get() *bytes.Buffer {
|
||||
return b.bp.Get()
|
||||
}
|
||||
|
||||
// Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer
|
||||
// for reuse.
|
||||
func (b *BufferWithCap) Put(x *bytes.Buffer) {
|
||||
if x.Cap() > b.max {
|
||||
return
|
||||
}
|
||||
b.bp.Put(x)
|
||||
}
|
||||
96
mempool/bufpool_test.go
Normal file
96
mempool/bufpool_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBuffer(t *testing.T) {
|
||||
defer debug.SetGCPercent(debug.SetGCPercent(-1))
|
||||
bp := NewBuffer(1000)
|
||||
require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String())
|
||||
|
||||
bp = NewBuffer(0)
|
||||
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
|
||||
|
||||
bp = NewBuffer(-1)
|
||||
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
|
||||
}
|
||||
|
||||
func TestBuffer(t *testing.T) {
|
||||
defer debug.SetGCPercent(debug.SetGCPercent(-1))
|
||||
Size := 101
|
||||
|
||||
bp := NewBuffer(0)
|
||||
buf := bp.Get()
|
||||
|
||||
for i := 0; i < Size; i++ {
|
||||
buf.WriteByte('a')
|
||||
}
|
||||
|
||||
bp.Put(buf)
|
||||
buf = bp.Get()
|
||||
require.Equal(t, 0, buf.Len())
|
||||
}
|
||||
|
||||
func TestBufferWithCap(t *testing.T) {
|
||||
defer debug.SetGCPercent(debug.SetGCPercent(-1))
|
||||
Size := 101
|
||||
bp := NewBuffer(100)
|
||||
buf := bp.Get()
|
||||
|
||||
for i := 0; i < Size; i++ {
|
||||
buf.WriteByte('a')
|
||||
}
|
||||
|
||||
bp.Put(buf)
|
||||
buf = bp.Get()
|
||||
require.Equal(t, 0, buf.Len())
|
||||
require.Equal(t, 0, buf.Cap())
|
||||
}
|
||||
|
||||
func BenchmarkBufferPool(b *testing.B) {
|
||||
bp := NewBuffer(0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := bp.Get()
|
||||
b.WriteString("this is a test")
|
||||
bp.Put(b)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBufferPoolWithCapLarger(b *testing.B) {
|
||||
bp := NewBuffer(64 * 1024)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := bp.Get()
|
||||
b.WriteString("this is a test")
|
||||
bp.Put(b)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBufferPoolWithCapLesser(b *testing.B) {
|
||||
bp := NewBuffer(10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := bp.Get()
|
||||
b.WriteString("this is a test")
|
||||
bp.Put(b)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBufferWithoutPool(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("this is a test")
|
||||
_ = b
|
||||
}
|
||||
}
|
||||
@@ -126,6 +126,7 @@ var (
|
||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||
ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
|
||||
|
||||
// MQTTv3 specific bytes.
|
||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestCodesString(t *testing.T) {
|
||||
require.Equal(t, "test", c.String())
|
||||
}
|
||||
|
||||
func TestCodesErrorr(t *testing.T) {
|
||||
func TestCodesError(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "error",
|
||||
Code: 0x1,
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/mempool"
|
||||
)
|
||||
|
||||
// All of the valid packet types and their packet identifier.
|
||||
// All valid packet types and their packet identifiers.
|
||||
const (
|
||||
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
|
||||
Connect // 1
|
||||
@@ -37,9 +39,9 @@ const (
|
||||
|
||||
var (
|
||||
// ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification.
|
||||
ErrNoValidPacketAvailable error = errors.New("no valid packet available")
|
||||
ErrNoValidPacketAvailable = errors.New("no valid packet available")
|
||||
|
||||
// PacketNames is a map of packet bytes to human readable names, for easier debugging.
|
||||
// PacketNames is a map of packet bytes to human-readable names, for easier debugging.
|
||||
PacketNames = map[byte]string{
|
||||
0: "Reserved",
|
||||
1: "Connect",
|
||||
@@ -272,33 +274,34 @@ func (s Subscription) Merge(n Subscription) Subscription {
|
||||
}
|
||||
|
||||
// encode encodes a subscription and properties into bytes.
|
||||
func (p Subscription) encode() byte {
|
||||
func (s Subscription) encode() byte {
|
||||
var flag byte
|
||||
flag |= p.Qos
|
||||
flag |= s.Qos
|
||||
|
||||
if p.NoLocal {
|
||||
if s.NoLocal {
|
||||
flag |= 1 << 2
|
||||
}
|
||||
|
||||
if p.RetainAsPublished {
|
||||
if s.RetainAsPublished {
|
||||
flag |= 1 << 3
|
||||
}
|
||||
|
||||
flag |= p.RetainHandling << 4
|
||||
flag |= s.RetainHandling << 4
|
||||
return flag
|
||||
}
|
||||
|
||||
// decode decodes subscription bytes into a subscription struct.
|
||||
func (p *Subscription) decode(b byte) {
|
||||
p.Qos = b & 3 // byte
|
||||
p.NoLocal = 1&(b>>2) > 0 // bool
|
||||
p.RetainAsPublished = 1&(b>>3) > 0 // bool
|
||||
p.RetainHandling = 3 & (b >> 4) // byte
|
||||
func (s *Subscription) decode(b byte) {
|
||||
s.Qos = b & 3 // byte
|
||||
s.NoLocal = 1&(b>>2) > 0 // bool
|
||||
s.RetainAsPublished = 1&(b>>3) > 0 // bool
|
||||
s.RetainHandling = 3 & (b >> 4) // byte
|
||||
}
|
||||
|
||||
// ConnectEncode encodes a connect packet.
|
||||
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeBytes(pk.Connect.ProtocolName))
|
||||
nb.WriteByte(pk.ProtocolVersion)
|
||||
|
||||
@@ -315,7 +318,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
nb.Write(encodeUint16(pk.Connect.Keepalive))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -324,7 +328,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.Connect.WillFlag {
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -343,7 +348,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -493,19 +498,22 @@ func (pk *Packet) ConnectValidate() Code {
|
||||
|
||||
// ConnackEncode encodes a Connack packet.
|
||||
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.WriteByte(encodeBool(pk.SessionPresent))
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -536,19 +544,21 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
|
||||
// DisconnectEncode encodes a Disconnect packet.
|
||||
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -598,7 +608,8 @@ func (pk *Packet) PingrespDecode(buf []byte) error {
|
||||
|
||||
// PublishEncode encodes a Publish packet.
|
||||
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
|
||||
nb.Write(encodeString(pk.TopicName)) // [MQTT-3.3.2-1]
|
||||
|
||||
@@ -610,16 +621,16 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
}
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
nb.Write(pk.Payload)
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Remaining = nb.Len() + len(pk.Payload)
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
buf.Write(pk.Payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -690,11 +701,13 @@ func (pk *Packet) PublishValidate(topicAliasMaximum uint16) Code {
|
||||
|
||||
// encodePubAckRelRecComp encodes a Puback, Pubrel, Pubrec, or Pubcomp packet.
|
||||
func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
@@ -707,7 +720,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -831,11 +844,13 @@ func (pk *Packet) ReasonCodeValid() bool {
|
||||
|
||||
// SubackEncode encodes a Suback packet.
|
||||
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -844,7 +859,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -878,10 +893,12 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
return ErrProtocolViolationNoPacketID
|
||||
}
|
||||
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
xb := bytes.NewBuffer([]byte{}) // capture and write filters after length checks
|
||||
xb := mempool.GetBuffer() // capture and write filters after length checks
|
||||
defer mempool.PutBuffer(xb)
|
||||
for _, opts := range pk.Filters {
|
||||
xb.Write(encodeString(opts.Filter)) // [MQTT-3.8.3-1]
|
||||
if pk.ProtocolVersion == 5 {
|
||||
@@ -892,7 +909,8 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
}
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -901,7 +919,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -983,20 +1001,21 @@ func (pk *Packet) SubscribeValidate() Code {
|
||||
|
||||
// UnsubackEncode encodes an Unsuback packet.
|
||||
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
nb.Write(pk.ReasonCodes)
|
||||
}
|
||||
|
||||
nb.Write(pk.ReasonCodes)
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1031,16 +1050,19 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
return ErrProtocolViolationNoPacketID
|
||||
}
|
||||
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
xb := bytes.NewBuffer([]byte{}) // capture filters and write after length checks
|
||||
xb := mempool.GetBuffer() // capture filters and write after length checks
|
||||
defer mempool.PutBuffer(xb)
|
||||
for _, sub := range pk.Filters {
|
||||
xb.Write(encodeString(sub.Filter)) // [MQTT-3.10.3-1]
|
||||
}
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -1049,7 +1071,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1100,16 +1122,18 @@ func (pk *Packet) UnsubscribeValidate() Code {
|
||||
|
||||
// AuthEncode encodes an Auth packet.
|
||||
func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Encode(buf)
|
||||
nb.WriteTo(buf)
|
||||
buf.Write(nb.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ func TestPacketEncode(t *testing.T) {
|
||||
}
|
||||
|
||||
pk := new(Packet)
|
||||
copier.Copy(pk, wanted.Packet)
|
||||
_ = copier.Copy(pk, wanted.Packet)
|
||||
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
@@ -218,7 +218,7 @@ func TestPacketDecode(t *testing.T) {
|
||||
|
||||
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
pk.FixedHeader.Decode(wanted.RawBytes[0])
|
||||
_ = pk.FixedHeader.Decode(wanted.RawBytes[0])
|
||||
if len(wanted.RawBytes) > 0 {
|
||||
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/mempool"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -77,7 +79,7 @@ type UserProperty struct { // [MQTT-1.5.7-1]
|
||||
Val string `json:"v"`
|
||||
}
|
||||
|
||||
// Properties contains all of the mqtt v5 properties available for a packet.
|
||||
// Properties contains all mqtt v5 properties available for a packet.
|
||||
// Some properties have valid values of 0 or not-present. In this case, we opt for
|
||||
// property flags to indicate the usage of property.
|
||||
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
|
||||
@@ -199,7 +201,8 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
return
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(buf)
|
||||
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
|
||||
buf.WriteByte(PropPayloadFormat)
|
||||
buf.WriteByte(p.PayloadFormat)
|
||||
@@ -230,7 +233,7 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
for _, v := range p.SubscriptionIdentifier {
|
||||
if v > 0 {
|
||||
buf.WriteByte(PropSubscriptionIdentifier)
|
||||
encodeLength(&buf, int64(v))
|
||||
encodeLength(buf, int64(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -321,7 +324,8 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
}
|
||||
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
for _, v := range p.User {
|
||||
pb.WriteByte(PropUser)
|
||||
pb.Write(encodeString(v.Key))
|
||||
@@ -355,7 +359,7 @@ func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
}
|
||||
|
||||
encodeLength(b, int64(buf.Len()))
|
||||
buf.WriteTo(b) // [MQTT-3.1.3-10]
|
||||
b.Write(buf.Bytes()) // [MQTT-3.1.3-10]
|
||||
}
|
||||
|
||||
// Decode decodes property bytes into a properties struct.
|
||||
|
||||
@@ -40,7 +40,6 @@ const (
|
||||
TConnectMqtt5
|
||||
TConnectMqtt5LWT
|
||||
TConnectClean
|
||||
TConnectCleanLWT
|
||||
TConnectUserPass
|
||||
TConnectUserPassLWT
|
||||
TConnectMalProtocolName
|
||||
@@ -61,7 +60,6 @@ const (
|
||||
TConnectInvalidProtocolVersion2
|
||||
TConnectInvalidReservedBit
|
||||
TConnectInvalidClientIDTooLong
|
||||
TConnectInvalidPasswordNoUsername
|
||||
TConnectInvalidFlagNoUsername
|
||||
TConnectInvalidFlagNoPassword
|
||||
TConnectInvalidUsernameNoFlag
|
||||
@@ -131,12 +129,14 @@ const (
|
||||
TPublishSpecDenySysTopic
|
||||
TPuback
|
||||
TPubackMqtt5
|
||||
TPubackMqtt5NotAuthorized
|
||||
TPubackMalPacketID
|
||||
TPubackMalProperties
|
||||
TPubackUnexpectedError
|
||||
TPubrec
|
||||
TPubrecMqtt5
|
||||
TPubrecMqtt5IDInUse
|
||||
TPubrecMqtt5NotAuthorized
|
||||
TPubrecMalPacketID
|
||||
TPubrecMalProperties
|
||||
TPubrecMalReasonCode
|
||||
@@ -184,7 +184,6 @@ const (
|
||||
TUnsubscribe
|
||||
TUnsubscribeMany
|
||||
TUnsubscribeMqtt5
|
||||
TUnsubscribeDropProperties
|
||||
TUnsubscribeMalPacketID
|
||||
TUnsubscribeMalTopicName
|
||||
TUnsubscribeMalProperties
|
||||
@@ -202,7 +201,6 @@ const (
|
||||
TDisconnect
|
||||
TDisconnectTakeover
|
||||
TDisconnectMqtt5
|
||||
TDisconnectNormalMqtt5
|
||||
TDisconnectSecondConnect
|
||||
TDisconnectReceiveMaximum
|
||||
TDisconnectDropProperties
|
||||
@@ -1707,41 +1705,41 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPublishQos1Mqtt5,
|
||||
Desc: "mqtt v5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Publish<<4 | 1<<1, 37, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
// Properties
|
||||
16, // length
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Remaining: 37,
|
||||
Qos: 1,
|
||||
},
|
||||
PacketID: 7,
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
},
|
||||
Case: TPublishQos1Mqtt5,
|
||||
Desc: "mqtt v5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Publish<<4 | 1<<1, 37, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
// Properties
|
||||
16, // length
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Remaining: 37,
|
||||
Qos: 1,
|
||||
},
|
||||
PacketID: 7,
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TPublishQos1Dup,
|
||||
@@ -2274,6 +2272,40 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubackMqtt5NotAuthorized,
|
||||
Desc: "QOS 1 publish not authorized mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Puback << 4, 37, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrNotAuthorized.Code, // Reason Code
|
||||
33, // Properties Length
|
||||
31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
|
||||
't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Puback,
|
||||
Remaining: 31,
|
||||
},
|
||||
PacketID: 7,
|
||||
ReasonCode: ErrNotAuthorized.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrNotAuthorized.Reason,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubackUnexpectedError,
|
||||
Desc: "unexpected error",
|
||||
@@ -2412,6 +2444,40 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubrecMqtt5NotAuthorized,
|
||||
Desc: "QOS 2 publish not authorized mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Pubrec << 4, 37, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrNotAuthorized.Code, // Reason Code
|
||||
33, // Properties Length
|
||||
31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
|
||||
't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Pubrec,
|
||||
Remaining: 31,
|
||||
},
|
||||
PacketID: 7,
|
||||
ReasonCode: ErrNotAuthorized.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrNotAuthorized.Reason,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubrecMalReasonCode,
|
||||
Desc: "malformed reason code",
|
||||
|
||||
404
server.go
404
server.go
@@ -2,7 +2,7 @@
|
||||
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
// package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
|
||||
// Package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -22,51 +22,62 @@ import (
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
"github.com/mochi-mqtt/server/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "2.3.0" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
Version = "2.4.6" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
LocalListener = "local"
|
||||
InlineClientId = "inline"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultServerCapabilities defines the default features and capabilities provided by the server.
|
||||
DefaultServerCapabilities = &Capabilities{
|
||||
MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions
|
||||
MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over
|
||||
ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client
|
||||
MaximumQos: 2, // maxmimum qos value available to clients
|
||||
RetainAvailable: 1, // retain messages is available
|
||||
MaximumPacketSize: 0, // no maximum packet size
|
||||
TopicAliasMaximum: math.MaxUint16, // maximum topic alias value
|
||||
WildcardSubAvailable: 1, // wildcard subscriptions are available
|
||||
SubIDAvailable: 1, // subscription identifiers are available
|
||||
SharedSubAvailable: 1, // shared subscriptions are available
|
||||
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
|
||||
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
|
||||
}
|
||||
// Deprecated: Use NewDefaultServerCapabilities to avoid data race issue.
|
||||
DefaultServerCapabilities = NewDefaultServerCapabilities()
|
||||
|
||||
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists.
|
||||
ErrConnectionClosed = errors.New("connection not open") // connection is closed
|
||||
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists
|
||||
ErrConnectionClosed = errors.New("connection not open") // connection is closed
|
||||
ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default
|
||||
)
|
||||
|
||||
// Capabilities indicates the capabilities and features provided by the server.
|
||||
type Capabilities struct {
|
||||
MaximumMessageExpiryInterval int64
|
||||
MaximumClientWritesPending int32
|
||||
MaximumSessionExpiryInterval uint32
|
||||
MaximumPacketSize uint32
|
||||
MaximumMessageExpiryInterval int64 // maximum message expiry if message expiry is 0 or over
|
||||
MaximumClientWritesPending int32 // maximum number of pending message writes for a client
|
||||
MaximumSessionExpiryInterval uint32 // maximum number of seconds to keep disconnected sessions
|
||||
MaximumPacketSize uint32 // maximum packet size, no limit if 0
|
||||
maximumPacketID uint32 // unexported, used for testing only
|
||||
ReceiveMaximum uint16
|
||||
TopicAliasMaximum uint16
|
||||
SharedSubAvailable byte
|
||||
MinimumProtocolVersion byte
|
||||
ReceiveMaximum uint16 // maximum number of concurrent qos messages per client
|
||||
MaximumInflight uint16 // maximum number of qos > 0 messages can be stored, 0(=8192)-65535
|
||||
TopicAliasMaximum uint16 // maximum topic alias value
|
||||
SharedSubAvailable byte // support of shared subscriptions
|
||||
MinimumProtocolVersion byte // minimum supported mqtt version
|
||||
Compatibilities Compatibilities
|
||||
MaximumQos byte
|
||||
RetainAvailable byte
|
||||
WildcardSubAvailable byte
|
||||
SubIDAvailable byte
|
||||
MaximumQos byte // maximum qos value available to clients
|
||||
RetainAvailable byte // support of retain messages
|
||||
WildcardSubAvailable byte // support of wildcard subscriptions
|
||||
SubIDAvailable byte // support of subscription identifiers
|
||||
}
|
||||
|
||||
// NewDefaultServerCapabilities defines the default features and capabilities provided by the server.
|
||||
func NewDefaultServerCapabilities() *Capabilities {
|
||||
return &Capabilities{
|
||||
MaximumMessageExpiryInterval: 60 * 60 * 24, // maximum message expiry if message expiry is 0 or over
|
||||
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
|
||||
MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions
|
||||
MaximumPacketSize: 0, // no maximum packet size
|
||||
maximumPacketID: math.MaxUint16,
|
||||
ReceiveMaximum: 1024, // maximum number of concurrent qos messages per client
|
||||
MaximumInflight: 1024 * 8, // maximum number of qos > 0 messages can be stored
|
||||
TopicAliasMaximum: math.MaxUint16, // maximum topic alias value
|
||||
SharedSubAvailable: 1, // shared subscriptions are available
|
||||
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
|
||||
MaximumQos: 2, // maximum qos value available to clients
|
||||
RetainAvailable: 1, // retain messages is available
|
||||
WildcardSubAvailable: 1, // wildcard subscriptions are available
|
||||
SubIDAvailable: 1, // subscription identifiers are available
|
||||
}
|
||||
}
|
||||
|
||||
// Compatibilities provides flags for using compatibility modes.
|
||||
@@ -95,26 +106,34 @@ type Options struct {
|
||||
// the servers default logger configuration. If you wish to change the log level,
|
||||
// of the default logger, you can do so by setting
|
||||
// server := mqtt.New(nil)
|
||||
// l := server.Log.Level(zerolog.DebugLevel)
|
||||
// server.Log = &l
|
||||
Logger *zerolog.Logger
|
||||
// level := new(slog.LevelVar)
|
||||
// server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
// Level: level,
|
||||
// }))
|
||||
// level.Set(slog.LevelDebug)
|
||||
Logger *slog.Logger
|
||||
|
||||
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
|
||||
SysTopicResendInterval int64
|
||||
|
||||
// Enable Inline client to allow direct subscribing and publishing from the parent codebase,
|
||||
// with negligible performance difference (disabled by default to prevent confusion in statistics).
|
||||
InlineClient bool
|
||||
}
|
||||
|
||||
// Server is an MQTT broker server. It should be created with server.New()
|
||||
// in order to ensure all the internal fields are correctly populated.
|
||||
type Server struct {
|
||||
Options *Options // configurable server options
|
||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections
|
||||
Clients *Clients // clients known to the broker
|
||||
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
|
||||
Info *system.Info // values about the server commonly known as $SYS topics
|
||||
loop *loop // loop contains tickers for the system event loop
|
||||
done chan bool // indicate that the server is ending
|
||||
Log *zerolog.Logger // minimal no-alloc logger
|
||||
hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage.
|
||||
Options *Options // configurable server options
|
||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections
|
||||
Clients *Clients // clients known to the broker
|
||||
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
|
||||
Info *system.Info // values about the server commonly known as $SYS topics
|
||||
loop *loop // loop contains tickers for the system event loop
|
||||
done chan bool // indicate that the server is ending
|
||||
Log *slog.Logger // minimal no-alloc logger
|
||||
hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage
|
||||
inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish
|
||||
}
|
||||
|
||||
// loop contains interval tickers for the system events loop.
|
||||
@@ -123,16 +142,16 @@ type loop struct {
|
||||
clientExpiry *time.Ticker // interval ticker for cleaning expired clients
|
||||
inflightExpiry *time.Ticker // interval ticker for cleaning up expired inflight messages
|
||||
retainedExpiry *time.Ticker // interval ticker for cleaning retained messages
|
||||
willDelaySend *time.Ticker // interval ticker for sending will messages with a delay
|
||||
willDelaySend *time.Ticker // interval ticker for sending Will Messages with a delay
|
||||
willDelayed *packets.Packets // activate LWT packets which will be sent after a delay
|
||||
}
|
||||
|
||||
// ops contains server values which can be propagated to other structs.
|
||||
type ops struct {
|
||||
options *Options // a pointer to the server options and capabilities, for referencing in clients
|
||||
info *system.Info // pointers to server system info
|
||||
hooks *Hooks // pointer to the server hooks
|
||||
log *zerolog.Logger // a structured logger for the client
|
||||
options *Options // a pointer to the server options and capabilities, for referencing in clients
|
||||
info *system.Info // pointers to server system info
|
||||
hooks *Hooks // pointer to the server hooks
|
||||
log *slog.Logger // a structured logger for the client
|
||||
}
|
||||
|
||||
// New returns a new instance of mochi mqtt broker. Optional parameters
|
||||
@@ -168,17 +187,26 @@ func New(opts *Options) *Server {
|
||||
},
|
||||
}
|
||||
|
||||
if s.Options.InlineClient {
|
||||
s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true)
|
||||
s.Clients.Add(s.inlineClient)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// ensureDefaults ensures that the server starts with sane default values, if none are provided.
|
||||
func (o *Options) ensureDefaults() {
|
||||
if o.Capabilities == nil {
|
||||
o.Capabilities = DefaultServerCapabilities
|
||||
o.Capabilities = NewDefaultServerCapabilities()
|
||||
}
|
||||
|
||||
o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535
|
||||
|
||||
if o.Capabilities.MaximumInflight == 0 {
|
||||
o.Capabilities.MaximumInflight = 1024 * 8
|
||||
}
|
||||
|
||||
if o.SysTopicResendInterval == 0 {
|
||||
o.SysTopicResendInterval = defaultSysTopicInterval
|
||||
}
|
||||
@@ -192,8 +220,8 @@ func (o *Options) ensureDefaults() {
|
||||
}
|
||||
|
||||
if o.Logger == nil {
|
||||
log := zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.InfoLevel).Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||
o.Logger = &log
|
||||
log := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
o.Logger = log
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,8 +245,6 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool)
|
||||
// By default, we don't want to restrict developer publishes,
|
||||
// but if you do, reset this after creating inline client.
|
||||
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
|
||||
} else {
|
||||
go cl.WriteLoop() // can only write to real clients
|
||||
}
|
||||
|
||||
return cl
|
||||
@@ -227,12 +253,12 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool)
|
||||
// AddHook attaches a new Hook to the server. Ideally, this should be called
|
||||
// before the server is started with s.Serve().
|
||||
func (s *Server) AddHook(hook Hook, config any) error {
|
||||
nl := s.Log.With().Str("hook", hook.ID()).Logger()
|
||||
hook.SetOpts(&nl, &HookOptions{
|
||||
nl := s.Log.With("hook", hook.ID())
|
||||
hook.SetOpts(nl, &HookOptions{
|
||||
Capabilities: s.Options.Capabilities,
|
||||
})
|
||||
|
||||
s.Log.Info().Str("hook", hook.ID()).Msg("added hook")
|
||||
s.Log.Info("added hook", "hook", hook.ID())
|
||||
return s.hooks.Add(hook, config)
|
||||
}
|
||||
|
||||
@@ -242,23 +268,23 @@ func (s *Server) AddListener(l listeners.Listener) error {
|
||||
return ErrListenerIDExists
|
||||
}
|
||||
|
||||
nl := s.Log.With().Str("listener", l.ID()).Logger()
|
||||
err := l.Init(&nl)
|
||||
nl := s.Log.With(slog.String("listener", l.ID()))
|
||||
err := l.Init(nl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.Listeners.Add(l)
|
||||
|
||||
s.Log.Info().Str("id", l.ID()).Str("protocol", l.Protocol()).Str("address", l.Address()).Msg("attached listener")
|
||||
s.Log.Info("attached listener", "id", l.ID(), "protocol", l.Protocol(), "address", l.Address())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts the event loops responsible for establishing client connections
|
||||
// on all attached listeners, publishing the system topics, and starting all hooks.
|
||||
func (s *Server) Serve() error {
|
||||
s.Log.Info().Str("version", Version).Msg("mochi mqtt starting")
|
||||
defer s.Log.Info().Msg("mochi mqtt server started")
|
||||
s.Log.Info("mochi mqtt starting", "version", Version)
|
||||
defer s.Log.Info("mochi mqtt server started")
|
||||
|
||||
if s.hooks.Provides(
|
||||
StoredClients,
|
||||
@@ -283,8 +309,8 @@ func (s *Server) Serve() error {
|
||||
|
||||
// eventLoop loops forever, running various server housekeeping methods at different intervals.
|
||||
func (s *Server) eventLoop() {
|
||||
s.Log.Debug().Msg("system event loop started")
|
||||
defer s.Log.Debug().Msg("system event loop halted")
|
||||
s.Log.Debug("system event loop started")
|
||||
defer s.Log.Debug("system event loop halted")
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -314,7 +340,12 @@ func (s *Server) EstablishConnection(listener string, c net.Conn) error {
|
||||
// attachClient validates an incoming client connection and if viable, attaches the client
|
||||
// to the server, performs session housekeeping, and reads incoming packets.
|
||||
func (s *Server) attachClient(cl *Client, listener string) error {
|
||||
defer s.Listeners.ClientsWg.Done()
|
||||
s.Listeners.ClientsWg.Add(1)
|
||||
|
||||
go cl.WriteLoop()
|
||||
defer cl.Stop(nil)
|
||||
|
||||
pk, err := s.readConnectionPacket(cl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read connection: %w", err)
|
||||
@@ -375,13 +406,13 @@ func (s *Server) attachClient(cl *Client, listener string) error {
|
||||
} else {
|
||||
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
|
||||
}
|
||||
s.Log.Debug("client disconnected", "error", err, "client", cl.ID, "remote", cl.Net.Remote, "listener", listener)
|
||||
|
||||
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
|
||||
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
||||
s.hooks.OnDisconnect(cl, err, expire)
|
||||
|
||||
if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
|
||||
cl.ClearInflights(math.MaxInt64, 0)
|
||||
cl.ClearInflights()
|
||||
s.UnsubscribeClient(cl)
|
||||
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
||||
}
|
||||
@@ -418,10 +449,10 @@ func (s *Server) receivePacket(cl *Client, pk packets.Packet) error {
|
||||
if code, ok := err.(packets.Code); ok &&
|
||||
cl.Properties.ProtocolVersion == 5 &&
|
||||
code.Code >= packets.ErrUnspecifiedError.Code {
|
||||
s.DisconnectClient(cl, code)
|
||||
_ = s.DisconnectClient(cl, code)
|
||||
}
|
||||
|
||||
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("pk", pk).Msg("error processing packet")
|
||||
s.Log.Warn("error processing packet", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "pk", pk)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -456,10 +487,10 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
|
||||
// session is abandoned.
|
||||
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
||||
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
|
||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
_ = s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
existing.ClearInflights()
|
||||
atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred
|
||||
return false // [MQTT-3.2.2-3]
|
||||
}
|
||||
@@ -484,11 +515,9 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
||||
// Clean the state of the existing client to prevent sequential take-overs
|
||||
// from increasing memory usage by inflights + subs * client-id.
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
s.Log.Debug().Str("client", cl.ID).
|
||||
Str("old_remote", existing.Net.Remote).
|
||||
Str("new_remote", cl.Net.Remote).
|
||||
Msg("session taken over")
|
||||
existing.ClearInflights()
|
||||
|
||||
s.Log.Debug("session taken over", "client", cl.ID, "old_remote", existing.Net.Remote, "new_remote", cl.Net.Remote)
|
||||
|
||||
return true // [MQTT-3.2.2-3]
|
||||
}
|
||||
@@ -643,13 +672,16 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Publish publishes a publish packet into the broker as if it were sent from the speicfied client.
|
||||
// Publish publishes a publish packet into the broker as if it were sent from the specified client.
|
||||
// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
|
||||
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
|
||||
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
|
||||
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
|
||||
cl := s.NewClient(nil, "local", "inline", true)
|
||||
return s.InjectPacket(cl, packets.Packet{
|
||||
if !s.Options.InlineClient {
|
||||
return ErrInlineClientNotEnabled
|
||||
}
|
||||
|
||||
return s.InjectPacket(s.inlineClient, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Qos: qos,
|
||||
@@ -661,6 +693,75 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er
|
||||
})
|
||||
}
|
||||
|
||||
// Subscribe adds an inline subscription for the specified topic filter and subscription identifier
|
||||
// with the provided handler function.
|
||||
func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error {
|
||||
if !s.Options.InlineClient {
|
||||
return ErrInlineClientNotEnabled
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
return packets.ErrInlineSubscriptionHandlerInvalid
|
||||
}
|
||||
|
||||
if !IsValidFilter(filter, false) {
|
||||
return packets.ErrTopicFilterInvalid
|
||||
}
|
||||
|
||||
subscription := packets.Subscription{
|
||||
Identifier: subscriptionId,
|
||||
Filter: filter,
|
||||
}
|
||||
|
||||
pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client.
|
||||
Origin: s.inlineClient.ID,
|
||||
FixedHeader: packets.FixedHeader{Type: packets.Subscribe},
|
||||
Filters: packets.Subscriptions{subscription},
|
||||
})
|
||||
|
||||
inlineSubscription := InlineSubscription{
|
||||
Subscription: subscription,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
s.Topics.InlineSubscribe(inlineSubscription)
|
||||
s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code})
|
||||
|
||||
// Handling retained messages.
|
||||
for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4]
|
||||
handler(s.inlineClient, inlineSubscription.Subscription, pkv)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe removes an inline subscription for the specified subscription and topic filter.
|
||||
// It allows you to unsubscribe a specific subscription from the internal subscription
|
||||
// associated with the given topic filter.
|
||||
func (s *Server) Unsubscribe(filter string, subscriptionId int) error {
|
||||
if !s.Options.InlineClient {
|
||||
return ErrInlineClientNotEnabled
|
||||
}
|
||||
|
||||
if !IsValidFilter(filter, false) {
|
||||
return packets.ErrTopicFilterInvalid
|
||||
}
|
||||
|
||||
pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{
|
||||
Origin: s.inlineClient.ID,
|
||||
FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe},
|
||||
Filters: packets.Subscriptions{
|
||||
{
|
||||
Identifier: subscriptionId,
|
||||
Filter: filter,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.Topics.InlineUnsubscribe(subscriptionId, filter)
|
||||
s.hooks.OnUnsubscribed(s.inlineClient, pk)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InjectPacket injects a packet into the broker as if it were sent from the specified client.
|
||||
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
|
||||
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
|
||||
@@ -690,7 +791,21 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
}
|
||||
|
||||
if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) {
|
||||
return nil
|
||||
if pk.FixedHeader.Qos == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cl.Properties.ProtocolVersion != 5 {
|
||||
return s.DisconnectClient(cl, packets.ErrNotAuthorized)
|
||||
}
|
||||
|
||||
ackType := packets.Puback
|
||||
if pk.FixedHeader.Qos == 2 {
|
||||
ackType = packets.Pubrec
|
||||
}
|
||||
|
||||
ack := s.buildAck(pk.PacketID, ackType, 0, pk.Properties, packets.ErrNotAuthorized)
|
||||
return cl.WritePacket(ack)
|
||||
}
|
||||
|
||||
pk.Origin = cl.ID
|
||||
@@ -713,7 +828,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Qos > s.Options.Capabilities.MaximumQos {
|
||||
pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce Qos based on server max qos capability
|
||||
pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce qos based on server max qos capability
|
||||
}
|
||||
|
||||
pkx, err := s.hooks.OnPublish(cl, pk)
|
||||
@@ -735,7 +850,10 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
s.retainMessage(cl, pk)
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Qos == 0 {
|
||||
// If it's inlineClient, it can't handle PUBREC and PUBREL.
|
||||
// When it publishes a package with a qos > 0, the server treats
|
||||
// the package as qos=0, and the client receives it as qos=1 or 2.
|
||||
if pk.FixedHeader.Qos == 0 || cl.Net.Inline {
|
||||
s.publishToSubscribers(pk)
|
||||
s.hooks.OnPublished(cl, pk)
|
||||
return nil
|
||||
@@ -808,11 +926,15 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
subscribers.MergeSharedSelected()
|
||||
}
|
||||
|
||||
for _, inlineSubscription := range subscribers.InlineSubscriptions {
|
||||
inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk)
|
||||
}
|
||||
|
||||
for id, subs := range subscribers.Subscriptions {
|
||||
if cl, ok := s.Clients.Get(id); ok {
|
||||
_, err := s.publishToClient(cl, subs, pk)
|
||||
if err != nil {
|
||||
s.Log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
|
||||
s.Log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -824,6 +946,9 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
||||
}
|
||||
|
||||
out := pk.Copy(false)
|
||||
if !s.hooks.OnACLCheck(cl, pk.TopicName, false) {
|
||||
return out, packets.ErrNotAuthorized
|
||||
}
|
||||
if !sub.FwdRetainedFlag && ((cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished) || cl.Properties.ProtocolVersion < 5) { // ![MQTT-3.3.1-13] [v3 MQTT-3.3.1-9]
|
||||
out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
|
||||
}
|
||||
@@ -856,10 +981,18 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
||||
}
|
||||
|
||||
if out.FixedHeader.Qos > 0 {
|
||||
if cl.State.Inflight.Len() >= int(s.Options.Capabilities.MaximumInflight) {
|
||||
// add hook?
|
||||
atomic.AddInt64(&s.Info.InflightDropped, 1)
|
||||
s.Log.Warn("client store quota reached", "client", cl.ID, "listener", cl.Net.Listener)
|
||||
return out, packets.ErrQuotaExceeded
|
||||
}
|
||||
|
||||
i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1]
|
||||
if err != nil {
|
||||
s.hooks.OnPacketIDExhausted(cl, pk)
|
||||
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted")
|
||||
atomic.AddInt64(&s.Info.InflightDropped, 1)
|
||||
s.Log.Warn("packet ids exhausted", "error", err, "client", cl.ID, "listener", cl.Net.Listener)
|
||||
return out, packets.ErrQuotaExceeded
|
||||
}
|
||||
|
||||
@@ -889,8 +1022,10 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
||||
default:
|
||||
atomic.AddInt64(&s.Info.MessagesDropped, 1)
|
||||
cl.ops.hooks.OnPublishDropped(cl, pk)
|
||||
cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
|
||||
cl.State.Inflight.IncreaseSendQuota()
|
||||
if out.FixedHeader.Qos > 0 {
|
||||
cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
|
||||
cl.State.Inflight.IncreaseSendQuota()
|
||||
}
|
||||
return out, packets.ErrPendingClientWritesExceeded
|
||||
}
|
||||
|
||||
@@ -910,7 +1045,7 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e
|
||||
for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4]
|
||||
_, err := s.publishToClient(cl, sub, pkv)
|
||||
if err != nil {
|
||||
s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
|
||||
s.Log.Debug("failed to publish retained message", "error", err, "client", cl.ID, "listener", cl.Net.Listener, "packet", pkv)
|
||||
continue
|
||||
}
|
||||
s.hooks.OnRetainPublished(cl, pkv)
|
||||
@@ -1238,27 +1373,28 @@ func (s *Server) publishSysTopics() {
|
||||
atomic.StoreInt64(&s.Info.ClientsTotal, int64(s.Clients.Len()))
|
||||
atomic.StoreInt64(&s.Info.ClientsDisconnected, atomic.LoadInt64(&s.Info.ClientsTotal)-atomic.LoadInt64(&s.Info.ClientsConnected))
|
||||
|
||||
info := s.Info.Clone()
|
||||
topics := map[string]string{
|
||||
SysPrefix + "/broker/version": s.Info.Version,
|
||||
SysPrefix + "/broker/time": AtomicItoa(&s.Info.Time),
|
||||
SysPrefix + "/broker/uptime": AtomicItoa(&s.Info.Uptime),
|
||||
SysPrefix + "/broker/started": AtomicItoa(&s.Info.Started),
|
||||
SysPrefix + "/broker/load/bytes/received": AtomicItoa(&s.Info.BytesReceived),
|
||||
SysPrefix + "/broker/load/bytes/sent": AtomicItoa(&s.Info.BytesSent),
|
||||
SysPrefix + "/broker/clients/connected": AtomicItoa(&s.Info.ClientsConnected),
|
||||
SysPrefix + "/broker/clients/disconnected": AtomicItoa(&s.Info.ClientsDisconnected),
|
||||
SysPrefix + "/broker/clients/maximum": AtomicItoa(&s.Info.ClientsMaximum),
|
||||
SysPrefix + "/broker/clients/total": AtomicItoa(&s.Info.ClientsTotal),
|
||||
SysPrefix + "/broker/packets/received": AtomicItoa(&s.Info.PacketsReceived),
|
||||
SysPrefix + "/broker/packets/sent": AtomicItoa(&s.Info.PacketsSent),
|
||||
SysPrefix + "/broker/messages/received": AtomicItoa(&s.Info.MessagesReceived),
|
||||
SysPrefix + "/broker/messages/sent": AtomicItoa(&s.Info.MessagesSent),
|
||||
SysPrefix + "/broker/messages/dropped": AtomicItoa(&s.Info.MessagesDropped),
|
||||
SysPrefix + "/broker/messages/inflight": AtomicItoa(&s.Info.Inflight),
|
||||
SysPrefix + "/broker/retained": AtomicItoa(&s.Info.Retained),
|
||||
SysPrefix + "/broker/subscriptions": AtomicItoa(&s.Info.Subscriptions),
|
||||
SysPrefix + "/broker/system/memory": AtomicItoa(&s.Info.MemoryAlloc),
|
||||
SysPrefix + "/broker/system/threads": AtomicItoa(&s.Info.Threads),
|
||||
SysPrefix + "/broker/time": Int64toa(info.Time),
|
||||
SysPrefix + "/broker/uptime": Int64toa(info.Uptime),
|
||||
SysPrefix + "/broker/started": Int64toa(info.Started),
|
||||
SysPrefix + "/broker/load/bytes/received": Int64toa(info.BytesReceived),
|
||||
SysPrefix + "/broker/load/bytes/sent": Int64toa(info.BytesSent),
|
||||
SysPrefix + "/broker/clients/connected": Int64toa(info.ClientsConnected),
|
||||
SysPrefix + "/broker/clients/disconnected": Int64toa(info.ClientsDisconnected),
|
||||
SysPrefix + "/broker/clients/maximum": Int64toa(info.ClientsMaximum),
|
||||
SysPrefix + "/broker/clients/total": Int64toa(info.ClientsTotal),
|
||||
SysPrefix + "/broker/packets/received": Int64toa(info.PacketsReceived),
|
||||
SysPrefix + "/broker/packets/sent": Int64toa(info.PacketsSent),
|
||||
SysPrefix + "/broker/messages/received": Int64toa(info.MessagesReceived),
|
||||
SysPrefix + "/broker/messages/sent": Int64toa(info.MessagesSent),
|
||||
SysPrefix + "/broker/messages/dropped": Int64toa(info.MessagesDropped),
|
||||
SysPrefix + "/broker/messages/inflight": Int64toa(info.Inflight),
|
||||
SysPrefix + "/broker/retained": Int64toa(info.Retained),
|
||||
SysPrefix + "/broker/subscriptions": Int64toa(info.Subscriptions),
|
||||
SysPrefix + "/broker/system/memory": Int64toa(info.MemoryAlloc),
|
||||
SysPrefix + "/broker/system/threads": Int64toa(info.Threads),
|
||||
}
|
||||
|
||||
for topic, payload := range topics {
|
||||
@@ -1268,17 +1404,18 @@ func (s *Server) publishSysTopics() {
|
||||
s.publishToSubscribers(pk)
|
||||
}
|
||||
|
||||
s.hooks.OnSysInfoTick(s.Info)
|
||||
s.hooks.OnSysInfoTick(info)
|
||||
}
|
||||
|
||||
// Close attempts to gracefully shut down the server, all listeners, clients, and stores.
|
||||
func (s *Server) Close() error {
|
||||
close(s.done)
|
||||
s.Log.Info("gracefully stopping server")
|
||||
s.Listeners.CloseAll(s.closeListenerClients)
|
||||
s.hooks.OnStopped()
|
||||
s.hooks.Stop()
|
||||
|
||||
s.Log.Info().Msg("mochi mqtt server stopped")
|
||||
s.Log.Info("mochi mqtt server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1286,7 +1423,7 @@ func (s *Server) Close() error {
|
||||
func (s *Server) closeListenerClients(listener string) {
|
||||
clients := s.Clients.GetByListener(listener)
|
||||
for _, cl := range clients {
|
||||
s.DisconnectClient(cl, packets.ErrServerShuttingDown)
|
||||
_ = s.DisconnectClient(cl, packets.ErrServerShuttingDown)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1337,9 +1474,7 @@ func (s *Server) readStore() error {
|
||||
return fmt.Errorf("failed to load clients; %w", err)
|
||||
}
|
||||
s.loadClients(clients)
|
||||
s.Log.Debug().
|
||||
Int("len", len(clients)).
|
||||
Msg("loaded clients from store")
|
||||
s.Log.Debug("loaded clients from store", "len", len(clients))
|
||||
}
|
||||
|
||||
if s.hooks.Provides(StoredSubscriptions) {
|
||||
@@ -1348,9 +1483,7 @@ func (s *Server) readStore() error {
|
||||
return fmt.Errorf("load subscriptions; %w", err)
|
||||
}
|
||||
s.loadSubscriptions(subs)
|
||||
s.Log.Debug().
|
||||
Int("len", len(subs)).
|
||||
Msg("loaded subscriptions from store")
|
||||
s.Log.Debug("loaded subscriptions from store", "len", len(subs))
|
||||
}
|
||||
|
||||
if s.hooks.Provides(StoredInflightMessages) {
|
||||
@@ -1359,9 +1492,7 @@ func (s *Server) readStore() error {
|
||||
return fmt.Errorf("load inflight; %w", err)
|
||||
}
|
||||
s.loadInflight(inflight)
|
||||
s.Log.Debug().
|
||||
Int("len", len(inflight)).
|
||||
Msg("loaded inflights from store")
|
||||
s.Log.Debug("loaded inflights from store", "len", len(inflight))
|
||||
}
|
||||
|
||||
if s.hooks.Provides(StoredRetainedMessages) {
|
||||
@@ -1370,9 +1501,7 @@ func (s *Server) readStore() error {
|
||||
return fmt.Errorf("load retained; %w", err)
|
||||
}
|
||||
s.loadRetained(retained)
|
||||
s.Log.Debug().
|
||||
Int("len", len(retained)).
|
||||
Msg("loaded retained messages from store")
|
||||
s.Log.Debug("loaded retained messages from store", "len", len(retained))
|
||||
}
|
||||
|
||||
if s.hooks.Provides(StoredSysInfo) {
|
||||
@@ -1381,8 +1510,7 @@ func (s *Server) readStore() error {
|
||||
return fmt.Errorf("load server info; %w", err)
|
||||
}
|
||||
s.loadServerInfo(sysInfo.Info)
|
||||
s.Log.Debug().
|
||||
Msg("loaded $SYS info from store")
|
||||
s.Log.Debug("loaded $SYS info from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1448,7 +1576,18 @@ func (s *Server) loadClients(v []storage.Client) {
|
||||
MaximumPacketSize: c.Properties.MaximumPacketSize,
|
||||
}
|
||||
cl.Properties.Will = Will(c.Will)
|
||||
s.Clients.Add(cl)
|
||||
|
||||
// cancel the context, update cl.State such as disconnected time and stopCause.
|
||||
cl.Stop(packets.ErrServerShuttingDown)
|
||||
|
||||
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
||||
s.hooks.OnDisconnect(cl, packets.ErrServerShuttingDown, expire)
|
||||
if expire {
|
||||
cl.ClearInflights()
|
||||
s.UnsubscribeClient(cl)
|
||||
} else {
|
||||
s.Clients.Add(cl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1492,7 +1631,14 @@ func (s *Server) clearExpiredClients(dt int64) {
|
||||
// clearExpiredRetainedMessage deletes retained messages from topics if they have expired.
|
||||
func (s *Server) clearExpiredRetainedMessages(now int64) {
|
||||
for filter, pk := range s.Topics.Retained.GetAll() {
|
||||
if (pk.Expiry > 0 && pk.Expiry < now) || pk.Created+s.Options.Capabilities.MaximumMessageExpiryInterval < now {
|
||||
expired := pk.ProtocolVersion == 5 && pk.Expiry > 0 && pk.Expiry < now // [MQTT-3.3.2-5]
|
||||
|
||||
// If the maximum message expiry interval is set (greater than 0), and the message
|
||||
// retention period exceeds the maximum expiry, the message will be forcibly removed.
|
||||
enforced := s.Options.Capabilities.MaximumMessageExpiryInterval > 0 &&
|
||||
now-pk.Created > s.Options.Capabilities.MaximumMessageExpiryInterval
|
||||
|
||||
if expired || enforced {
|
||||
s.Topics.Retained.Delete(filter)
|
||||
s.hooks.OnRetainedExpired(filter)
|
||||
}
|
||||
@@ -1502,7 +1648,7 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
|
||||
// clearExpiredInflights deletes any inflight messages which have expired.
|
||||
func (s *Server) clearExpiredInflights(now int64) {
|
||||
for _, client := range s.Clients.GetAll() {
|
||||
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
|
||||
if deleted := client.ClearExpiredInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
|
||||
for _, id := range deleted {
|
||||
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
|
||||
}
|
||||
@@ -1527,7 +1673,7 @@ func (s *Server) sendDelayedLWT(dt int64) {
|
||||
}
|
||||
}
|
||||
|
||||
// AtomicItoa converts an int64 point to a string.
|
||||
func AtomicItoa(ptr *int64) string {
|
||||
return strconv.FormatInt(atomic.LoadInt64(ptr), 10)
|
||||
// Int64toa converts an int64 to a string.
|
||||
func Int64toa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
1005
server_test.go
1005
server_test.go
File diff suppressed because it is too large
Load Diff
153
topics.go
153
topics.go
@@ -186,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio
|
||||
return m
|
||||
}
|
||||
|
||||
// InlineSubFn is the signature for a callback function which will be called
|
||||
// when an inline client receives a message on a topic it is subscribed to.
|
||||
// The sub argument contains information about the subscription that was matched for any filters.
|
||||
type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
|
||||
|
||||
// InlineSubscriptions represents a map of internal subscriptions keyed on client.
|
||||
type InlineSubscriptions struct {
|
||||
internal map[int]InlineSubscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
|
||||
func NewInlineSubscriptions() *InlineSubscriptions {
|
||||
return &InlineSubscriptions{
|
||||
internal: map[int]InlineSubscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new internal subscription for a client id.
|
||||
func (s *InlineSubscriptions) Add(val InlineSubscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.internal[val.Identifier] = val
|
||||
}
|
||||
|
||||
// GetAll returns all internal subscriptions.
|
||||
func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[int]InlineSubscription{}
|
||||
for k, v := range s.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns an internal subscription for a client id.
|
||||
func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val, ok = s.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the number of internal subscriptions.
|
||||
func (s *InlineSubscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes an internal subscription by the client id.
|
||||
func (s *InlineSubscriptions) Delete(id int) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal, id)
|
||||
}
|
||||
|
||||
// Subscriptions is a map of subscriptions keyed on client.
|
||||
type Subscriptions struct {
|
||||
internal map[string]packets.Subscription
|
||||
@@ -244,11 +303,17 @@ func (s *Subscriptions) Delete(id string) {
|
||||
// ClientSubscriptions is a map of aggregated subscriptions for a client.
|
||||
type ClientSubscriptions map[string]packets.Subscription
|
||||
|
||||
type InlineSubscription struct {
|
||||
packets.Subscription
|
||||
Handler InlineSubFn
|
||||
}
|
||||
|
||||
// Subscribers contains the shared and non-shared subscribers matching a topic.
|
||||
type Subscribers struct {
|
||||
Shared map[string]map[string]packets.Subscription
|
||||
SharedSelected map[string]packets.Subscription
|
||||
Subscriptions map[string]packets.Subscription
|
||||
Shared map[string]map[string]packets.Subscription
|
||||
SharedSelected map[string]packets.Subscription
|
||||
Subscriptions map[string]packets.Subscription
|
||||
InlineSubscriptions map[int]InlineSubscription
|
||||
}
|
||||
|
||||
// SelectShared returns one subscriber for each shared subscription group.
|
||||
@@ -298,6 +363,39 @@ func NewTopicsIndex() *TopicsIndex {
|
||||
}
|
||||
}
|
||||
|
||||
// InlineSubscribe adds a new internal subscription for a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
n := x.set(subscription.Filter, 0)
|
||||
_, existed = n.inlineSubscriptions.Get(subscription.Identifier)
|
||||
n.inlineSubscriptions.Add(subscription)
|
||||
|
||||
return !existed
|
||||
}
|
||||
|
||||
// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
|
||||
// returning true if the subscription existed.
|
||||
func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
particle := x.seek(filter, 0)
|
||||
if particle == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
particle.inlineSubscriptions.Delete(id)
|
||||
|
||||
if particle.inlineSubscriptions.Len() == 0 {
|
||||
x.trim(particle)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Subscribe adds a new subscription for a client to a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
||||
@@ -484,9 +582,10 @@ func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []pack
|
||||
// their subscription ids and highest qos.
|
||||
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
|
||||
return x.scanSubscribers(topic, 0, nil, &Subscribers{
|
||||
Shared: map[string]map[string]packets.Subscription{},
|
||||
SharedSelected: map[string]packets.Subscription{},
|
||||
Subscriptions: map[string]packets.Subscription{},
|
||||
Shared: map[string]map[string]packets.Subscription{},
|
||||
SharedSelected: map[string]packets.Subscription{},
|
||||
Subscriptions: map[string]packets.Subscription{},
|
||||
InlineSubscriptions: map[int]InlineSubscription{},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -508,10 +607,12 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
|
||||
} else {
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
|
||||
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
|
||||
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
||||
x.gatherSharedSubscriptions(wild, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -520,6 +621,7 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
|
||||
if particle := n.particles.get("#"); particle != nil {
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
x.gatherInlineSubscriptions(particle, subs)
|
||||
}
|
||||
|
||||
return subs
|
||||
@@ -562,6 +664,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr
|
||||
}
|
||||
}
|
||||
|
||||
// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
|
||||
func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
|
||||
if subs.InlineSubscriptions == nil {
|
||||
subs.InlineSubscriptions = map[int]InlineSubscription{}
|
||||
}
|
||||
|
||||
for id, inline := range particle.inlineSubscriptions.GetAll() {
|
||||
subs.InlineSubscriptions[id] = inline
|
||||
}
|
||||
}
|
||||
|
||||
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
||||
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
||||
var next, end int
|
||||
@@ -592,7 +705,7 @@ func IsSharedFilter(filter string) bool {
|
||||
|
||||
// IsValidFilter returns true if the filter is valid.
|
||||
func IsValidFilter(filter string, forPublish bool) bool {
|
||||
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
|
||||
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish.
|
||||
return false // [MQTT-4.7.3-1]
|
||||
}
|
||||
|
||||
@@ -633,23 +746,25 @@ func IsValidFilter(filter string, forPublish bool) bool {
|
||||
|
||||
// particle is a child node on the tree.
|
||||
type particle struct {
|
||||
key string // the key of the particle
|
||||
parent *particle // a pointer to the parent of the particle
|
||||
particles particles // a map of child particles
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
retainPath string // path of a retained message
|
||||
sync.Mutex // mutex for when making changes to the particle
|
||||
key string // the key of the particle
|
||||
parent *particle // a pointer to the parent of the particle
|
||||
particles particles // a map of child particles
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
|
||||
retainPath string // path of a retained message
|
||||
sync.Mutex // mutex for when making changes to the particle
|
||||
}
|
||||
|
||||
// newParticle returns a pointer to a new instance of particle.
|
||||
func newParticle(key string, parent *particle) *particle {
|
||||
return &particle{
|
||||
key: key,
|
||||
parent: parent,
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
shared: NewSharedSubscriptions(),
|
||||
key: key,
|
||||
parent: parent,
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
shared: NewSharedSubscriptions(),
|
||||
inlineSubscriptions: NewInlineSubscriptions(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
213
topics_test.go
213
topics_test.go
@@ -5,6 +5,7 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/packets"
|
||||
@@ -853,3 +854,215 @@ func TestNewTopicAliases(t *testing.T) {
|
||||
require.NotNil(t, a.Outbound)
|
||||
require.Equal(t, uint16(5), a.Outbound.maximum)
|
||||
}
|
||||
|
||||
func TestNewInlineSubscriptions(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
require.NotNil(t, subscriptions)
|
||||
require.NotNil(t, subscriptions.internal)
|
||||
require.Equal(t, 0, subscriptions.Len())
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionAdd(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
subscription := InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
Handler: handler,
|
||||
}
|
||||
subscriptions.Add(subscription)
|
||||
|
||||
sub, ok := subscriptions.Get(1)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "a/b/c", sub.Filter)
|
||||
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionGet(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
subscription := InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
Handler: handler,
|
||||
}
|
||||
subscriptions.Add(subscription)
|
||||
|
||||
sub, ok := subscriptions.Get(1)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "a/b/c", sub.Filter)
|
||||
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
|
||||
|
||||
_, ok = subscriptions.Get(999)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionsGetAll(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
})
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
})
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2},
|
||||
})
|
||||
subscriptions.Add(InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3},
|
||||
})
|
||||
|
||||
allSubs := subscriptions.GetAll()
|
||||
require.Len(t, allSubs, 3)
|
||||
require.Contains(t, allSubs, 1)
|
||||
require.Contains(t, allSubs, 2)
|
||||
require.Contains(t, allSubs, 3)
|
||||
}
|
||||
|
||||
func TestInlineSubscriptionDelete(t *testing.T) {
|
||||
subscriptions := NewInlineSubscriptions()
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
subscription := InlineSubscription{
|
||||
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||
Handler: handler,
|
||||
}
|
||||
subscriptions.Add(subscription)
|
||||
|
||||
subscriptions.Delete(1)
|
||||
_, ok := subscriptions.Get(1)
|
||||
require.False(t, ok)
|
||||
require.Empty(t, subscriptions.GetAll())
|
||||
require.Zero(t, subscriptions.Len())
|
||||
}
|
||||
|
||||
func TestInlineSubscribe(t *testing.T) {
|
||||
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
desc string
|
||||
filter string
|
||||
subscription InlineSubscription
|
||||
wasNew bool
|
||||
}{
|
||||
{
|
||||
desc: "subscribe",
|
||||
filter: "a/b/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "subscribe existed",
|
||||
filter: "a/b/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
|
||||
wasNew: false,
|
||||
},
|
||||
{
|
||||
desc: "subscribe different identifier",
|
||||
filter: "a/b/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "subscribe case sensitive didnt exist",
|
||||
filter: "A/B/c",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard+ sub",
|
||||
filter: "d/+",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard# sub",
|
||||
filter: "d/e/#",
|
||||
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}},
|
||||
wasNew: true,
|
||||
},
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
for _, tx := range tt {
|
||||
t.Run(tx.desc, func(t *testing.T) {
|
||||
require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription))
|
||||
})
|
||||
}
|
||||
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
}
|
||||
|
||||
func TestInlineUnsubscribe(t *testing.T) {
|
||||
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||
// handler logic
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
|
||||
sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index = NewTopicsIndex()
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}})
|
||||
sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
ok := index.InlineUnsubscribe(1, "a/b/c/d")
|
||||
require.True(t, ok)
|
||||
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
|
||||
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
|
||||
require.NotNil(t, sub)
|
||||
require.True(t, exists)
|
||||
|
||||
ok = index.InlineUnsubscribe(1, "d/e/f")
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f"))
|
||||
|
||||
ok = index.InlineUnsubscribe(1, "not/exist")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
1
vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
1
vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
@@ -1 +0,0 @@
|
||||
language: go
|
||||
35
vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
35
vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
@@ -1,35 +0,0 @@
|
||||
bbloom.go
|
||||
|
||||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
siphash.go
|
||||
|
||||
// https://github.com/dchest/siphash
|
||||
//
|
||||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
131
vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
131
vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
@@ -1,131 +0,0 @@
|
||||
## bbloom: a bitset Bloom filter for go/golang
|
||||
===
|
||||
|
||||
[](http://travis-ci.org/AndreasBriese/bbloom)
|
||||
|
||||
package implements a fast bloom filter with real 'bitset' and JSONMarshal/JSONUnmarshal to store/reload the Bloom filter.
|
||||
|
||||
NOTE: the package uses unsafe.Pointer to set and read the bits from the bitset. If you're uncomfortable with using the unsafe package, please consider using my bloom filter package at github.com/AndreasBriese/bloom
|
||||
|
||||
===
|
||||
|
||||
changelog 11/2015: new thread safe methods AddTS(), HasTS(), AddIfNotHasTS() following a suggestion from Srdjan Marinovic (github @a-little-srdjan), who used this to code a bloomfilter cache.
|
||||
|
||||
This bloom filter was developed to strengthen a website-log database and was tested and optimized for this log-entry mask: "2014/%02i/%02i %02i:%02i:%02i /info.html".
|
||||
Nonetheless bbloom should work with any other form of entries.
|
||||
|
||||
~~Hash function is a modified Berkeley DB sdbm hash (to optimize for smaller strings). sdbm http://www.cse.yorku.ca/~oz/hash.html~~
|
||||
|
||||
Found sipHash (SipHash-2-4, a fast short-input PRF created by Jean-Philippe Aumasson and Daniel J. Bernstein.) to be about as fast. sipHash had been ported by Dimtry Chestnyk to Go (github.com/dchest/siphash )
|
||||
|
||||
Minimum hashset size is: 512 ([4]uint64; will be set automatically).
|
||||
|
||||
###install
|
||||
|
||||
```sh
|
||||
go get github.com/AndreasBriese/bbloom
|
||||
```
|
||||
|
||||
###test
|
||||
+ change to folder ../bbloom
|
||||
+ create wordlist in file "words.txt" (you might use `python permut.py`)
|
||||
+ run 'go test -bench=.' within the folder
|
||||
|
||||
```go
|
||||
go test -bench=.
|
||||
```
|
||||
|
||||
~~If you've installed the GOCONVEY TDD-framework http://goconvey.co/ you can run the tests automatically.~~
|
||||
|
||||
using go's testing framework now (have in mind that the op timing is related to 65536 operations of Add, Has, AddIfNotHas respectively)
|
||||
|
||||
### usage
|
||||
|
||||
after installation add
|
||||
|
||||
```go
|
||||
import (
|
||||
...
|
||||
"github.com/AndreasBriese/bbloom"
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
at your header. In the program use
|
||||
|
||||
```go
|
||||
// create a bloom filter for 65536 items and 1 % wrong-positive ratio
|
||||
bf := bbloom.New(float64(1<<16), float64(0.01))
|
||||
|
||||
// or
|
||||
// create a bloom filter with 650000 for 65536 items and 7 locs per hash explicitly
|
||||
// bf = bbloom.New(float64(650000), float64(7))
|
||||
// or
|
||||
bf = bbloom.New(650000.0, 7.0)
|
||||
|
||||
// add one item
|
||||
bf.Add([]byte("butter"))
|
||||
|
||||
// Number of elements added is exposed now
|
||||
// Note: ElemNum will not be included in JSON export (for compatability to older version)
|
||||
nOfElementsInFilter := bf.ElemNum
|
||||
|
||||
// check if item is in the filter
|
||||
isIn := bf.Has([]byte("butter")) // should be true
|
||||
isNotIn := bf.Has([]byte("Butter")) // should be false
|
||||
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added := bf.AddIfNotHas([]byte("butter")) // should be false because 'butter' is already in the set
|
||||
added = bf.AddIfNotHas([]byte("buTTer")) // should be true because 'buTTer' is new
|
||||
|
||||
// thread safe versions for concurrent use: AddTS, HasTS, AddIfNotHasTS
|
||||
// add one item
|
||||
bf.AddTS([]byte("peanutbutter"))
|
||||
// check if item is in the filter
|
||||
isIn = bf.HasTS([]byte("peanutbutter")) // should be true
|
||||
isNotIn = bf.HasTS([]byte("peanutButter")) // should be false
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added = bf.AddIfNotHasTS([]byte("butter")) // should be false because 'peanutbutter' is already in the set
|
||||
added = bf.AddIfNotHasTS([]byte("peanutbuTTer")) // should be true because 'penutbuTTer' is new
|
||||
|
||||
// convert to JSON ([]byte)
|
||||
Json := bf.JSONMarshal()
|
||||
|
||||
// bloomfilters Mutex is exposed for external un-/locking
|
||||
// i.e. mutex lock while doing JSON conversion
|
||||
bf.Mtx.Lock()
|
||||
Json = bf.JSONMarshal()
|
||||
bf.Mtx.Unlock()
|
||||
|
||||
// restore a bloom filter from storage
|
||||
bfNew := bbloom.JSONUnmarshal(Json)
|
||||
|
||||
isInNew := bfNew.Has([]byte("butter")) // should be true
|
||||
isNotInNew := bfNew.Has([]byte("Butter")) // should be false
|
||||
|
||||
```
|
||||
|
||||
to work with the bloom filter.
|
||||
|
||||
### why 'fast'?
|
||||
|
||||
It's about 3 times faster than William Fitzgeralds bitset bloom filter https://github.com/willf/bloom . And it is about so fast as my []bool set variant for Boom filters (see https://github.com/AndreasBriese/bloom ) but having a 8times smaller memory footprint:
|
||||
|
||||
|
||||
Bloom filter (filter size 524288, 7 hashlocs)
|
||||
github.com/AndreasBriese/bbloom 'Add' 65536 items (10 repetitions): 6595800 ns (100 ns/op)
|
||||
github.com/AndreasBriese/bbloom 'Has' 65536 items (10 repetitions): 5986600 ns (91 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Add' 65536 items (10 repetitions): 6304684 ns (96 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Has' 65536 items (10 repetitions): 6568663 ns (100 ns/op)
|
||||
|
||||
github.com/willf/bloom 'Add' 65536 items (10 repetitions): 24367224 ns (371 ns/op)
|
||||
github.com/willf/bloom 'Test' 65536 items (10 repetitions): 21881142 ns (333 ns/op)
|
||||
github.com/dataence/bloom/standard 'Add' 65536 items (10 repetitions): 23041644 ns (351 ns/op)
|
||||
github.com/dataence/bloom/standard 'Check' 65536 items (10 repetitions): 19153133 ns (292 ns/op)
|
||||
github.com/cabello/bloom 'Add' 65536 items (10 repetitions): 131921507 ns (2012 ns/op)
|
||||
github.com/cabello/bloom 'Contains' 65536 items (10 repetitions): 131108962 ns (2000 ns/op)
|
||||
|
||||
(on MBPro15 OSX10.8.5 i7 4Core 2.4Ghz)
|
||||
|
||||
|
||||
With 32bit bloom filters (bloom32) using modified sdbm, bloom32 does hashing with only 2 bit shifts, one xor and one substraction per byte. smdb is about as fast as fnv64a but gives less collisions with the dataset (see mask above). bloom.New(float64(10 * 1<<16),float64(7)) populated with 1<<16 random items from the dataset (see above) and tested against the rest results in less than 0.05% collisions.
|
||||
284
vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
284
vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
@@ -1,284 +0,0 @@
|
||||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
// 2019/08/25 code revision to reduce unsafe use
|
||||
// Parts are adopted from the fork at ipfs/bbloom after performance rev by
|
||||
// Steve Allen (https://github.com/Stebalien)
|
||||
// (see https://github.com/ipfs/bbloom/blob/master/bbloom.go)
|
||||
// -> func Has
|
||||
// -> func set
|
||||
// -> func add
|
||||
|
||||
package bbloom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// helper
|
||||
// not needed anymore by Set
|
||||
// var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128}
|
||||
|
||||
func getSize(ui64 uint64) (size uint64, exponent uint64) {
|
||||
if ui64 < uint64(512) {
|
||||
ui64 = uint64(512)
|
||||
}
|
||||
size = uint64(1)
|
||||
for size < ui64 {
|
||||
size <<= 1
|
||||
exponent++
|
||||
}
|
||||
return size, exponent
|
||||
}
|
||||
|
||||
func calcSizeByWrongPositives(numEntries, wrongs float64) (uint64, uint64) {
|
||||
size := -1 * numEntries * math.Log(wrongs) / math.Pow(float64(0.69314718056), 2)
|
||||
locs := math.Ceil(float64(0.69314718056) * size / numEntries)
|
||||
return uint64(size), uint64(locs)
|
||||
}
|
||||
|
||||
// New
|
||||
// returns a new bloomfilter
|
||||
func New(params ...float64) (bloomfilter Bloom) {
|
||||
var entries, locs uint64
|
||||
if len(params) == 2 {
|
||||
if params[1] < 1 {
|
||||
entries, locs = calcSizeByWrongPositives(params[0], params[1])
|
||||
} else {
|
||||
entries, locs = uint64(params[0]), uint64(params[1])
|
||||
}
|
||||
} else {
|
||||
log.Fatal("usage: New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(3)) or New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(0.03))")
|
||||
}
|
||||
size, exponent := getSize(uint64(entries))
|
||||
bloomfilter = Bloom{
|
||||
Mtx: &sync.Mutex{},
|
||||
sizeExp: exponent,
|
||||
size: size - 1,
|
||||
setLocs: locs,
|
||||
shift: 64 - exponent,
|
||||
}
|
||||
bloomfilter.Size(size)
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// NewWithBoolset
|
||||
// takes a []byte slice and number of locs per entry
|
||||
// returns the bloomfilter with a bitset populated according to the input []byte
|
||||
func NewWithBoolset(bs *[]byte, locs uint64) (bloomfilter Bloom) {
|
||||
bloomfilter = New(float64(len(*bs)<<3), float64(locs))
|
||||
for i, b := range *bs {
|
||||
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bloomfilter.bitset[0])) + uintptr(i))) = b
|
||||
}
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// bloomJSONImExport
|
||||
// Im/Export structure used by JSONMarshal / JSONUnmarshal
|
||||
type bloomJSONImExport struct {
|
||||
FilterSet []byte
|
||||
SetLocs uint64
|
||||
}
|
||||
|
||||
// JSONUnmarshal
|
||||
// takes JSON-Object (type bloomJSONImExport) as []bytes
|
||||
// returns Bloom object
|
||||
func JSONUnmarshal(dbData []byte) Bloom {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
json.Unmarshal(dbData, &bloomImEx)
|
||||
buf := bytes.NewBuffer(bloomImEx.FilterSet)
|
||||
bs := buf.Bytes()
|
||||
bf := NewWithBoolset(&bs, bloomImEx.SetLocs)
|
||||
return bf
|
||||
}
|
||||
|
||||
//
|
||||
// Bloom filter
|
||||
type Bloom struct {
|
||||
Mtx *sync.Mutex
|
||||
ElemNum uint64
|
||||
bitset []uint64
|
||||
sizeExp uint64
|
||||
size uint64
|
||||
setLocs uint64
|
||||
shift uint64
|
||||
}
|
||||
|
||||
// <--- http://www.cse.yorku.ca/~oz/hash.html
|
||||
// modified Berkeley DB Hash (32bit)
|
||||
// hash is casted to l, h = 16bit fragments
|
||||
// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.shift >> bl.shift
|
||||
// return l, h
|
||||
// }
|
||||
|
||||
// Update: found sipHash of Jean-Philippe Aumasson & Daniel J. Bernstein to be even faster than absdbm()
|
||||
// https://131002.net/siphash/
|
||||
// siphash was implemented for Go by Dmitry Chestnykh https://github.com/dchest/siphash
|
||||
|
||||
// Add
|
||||
// set the bit(s) for entry; Adds an entry to the Bloom filter
|
||||
func (bl *Bloom) Add(entry []byte) {
|
||||
l, h := bl.sipHash(entry)
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
bl.set((h + i*l) & bl.size)
|
||||
bl.ElemNum++
|
||||
}
|
||||
}
|
||||
|
||||
// AddTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) AddTS(entry []byte) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
bl.Add(entry)
|
||||
}
|
||||
|
||||
// Has
|
||||
// check if bit(s) for entry is/are set
|
||||
// returns true if the entry was added to the Bloom Filter
|
||||
func (bl Bloom) Has(entry []byte) bool {
|
||||
l, h := bl.sipHash(entry)
|
||||
res := true
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
res = res && bl.isSet((h+i*l)&bl.size)
|
||||
// https://github.com/ipfs/bbloom/commit/84e8303a9bfb37b2658b85982921d15bbb0fecff
|
||||
// // Branching here (early escape) is not worth it
|
||||
// // This is my conclusion from benchmarks
|
||||
// // (prevents loop unrolling)
|
||||
// switch bl.IsSet((h + i*l) & bl.size) {
|
||||
// case false:
|
||||
// return false
|
||||
// }
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// HasTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) HasTS(entry []byte) bool {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.Has(entry)
|
||||
}
|
||||
|
||||
// AddIfNotHas
|
||||
// Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl Bloom) AddIfNotHas(entry []byte) (added bool) {
|
||||
if bl.Has(entry) {
|
||||
return added
|
||||
}
|
||||
bl.Add(entry)
|
||||
return true
|
||||
}
|
||||
|
||||
// AddIfNotHasTS
|
||||
// Tread safe: Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl *Bloom) AddIfNotHasTS(entry []byte) (added bool) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.AddIfNotHas(entry)
|
||||
}
|
||||
|
||||
// Size
|
||||
// make Bloom filter with as bitset of size sz
|
||||
func (bl *Bloom) Size(sz uint64) {
|
||||
bl.bitset = make([]uint64, sz>>6)
|
||||
}
|
||||
|
||||
// Clear
|
||||
// resets the Bloom filter
|
||||
func (bl *Bloom) Clear() {
|
||||
bs := bl.bitset
|
||||
for i := range bs {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Set
|
||||
// set the bit[idx] of bitsit
|
||||
func (bl *Bloom) set(idx uint64) {
|
||||
// ommit unsafe
|
||||
// *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3))) |= mask[idx%8]
|
||||
bl.bitset[idx>>6] |= 1 << (idx % 64)
|
||||
}
|
||||
|
||||
// IsSet
|
||||
// check if bit[idx] of bitset is set
|
||||
// returns true/false
|
||||
func (bl *Bloom) isSet(idx uint64) bool {
|
||||
// ommit unsafe
|
||||
// return (((*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)))) >> (idx % 8)) & 1) == 1
|
||||
return bl.bitset[idx>>6]&(1<<(idx%64)) != 0
|
||||
}
|
||||
|
||||
// JSONMarshal
|
||||
// returns JSON-object (type bloomJSONImExport) as []byte
|
||||
func (bl Bloom) JSONMarshal() []byte {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
bloomImEx.SetLocs = uint64(bl.setLocs)
|
||||
bloomImEx.FilterSet = make([]byte, len(bl.bitset)<<3)
|
||||
for i := range bloomImEx.FilterSet {
|
||||
bloomImEx.FilterSet[i] = *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[0])) + uintptr(i)))
|
||||
}
|
||||
data, err := json.Marshal(bloomImEx)
|
||||
if err != nil {
|
||||
log.Fatal("json.Marshal failed: ", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// // alternative hashFn
|
||||
// func (bl Bloom) fnv64a(b *[]byte) (l, h uint64) {
|
||||
// h64 := fnv.New64a()
|
||||
// h64.Write(*b)
|
||||
// hash := h64.Sum64()
|
||||
// h = hash >> 32
|
||||
// l = hash << 32 >> 32
|
||||
// return l, h
|
||||
// }
|
||||
//
|
||||
// // <-- http://partow.net/programming/hashfunctions/index.html
|
||||
// // citation: An algorithm proposed by Donald E. Knuth in The Art Of Computer Programming Volume 3,
|
||||
// // under the topic of sorting and search chapter 6.4.
|
||||
// // modified to fit with boolset-length
|
||||
// func (bl Bloom) DEKHash(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = ((hash << 5) ^ (hash >> bl.shift)) ^ uint64(c)
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.sizeExp >> bl.sizeExp
|
||||
// return l, h
|
||||
// }
|
||||
225
vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
225
vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
@@ -1,225 +0,0 @@
|
||||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
|
||||
package bbloom
|
||||
|
||||
// Hash returns the 64-bit SipHash-2-4 of the given byte slice with two 64-bit
|
||||
// parts of 128-bit key: k0 and k1.
|
||||
func (bl Bloom) sipHash(p []byte) (l, h uint64) {
|
||||
// Initialization.
|
||||
v0 := uint64(8317987320269560794) // k0 ^ 0x736f6d6570736575
|
||||
v1 := uint64(7237128889637516672) // k1 ^ 0x646f72616e646f6d
|
||||
v2 := uint64(7816392314733513934) // k0 ^ 0x6c7967656e657261
|
||||
v3 := uint64(8387220255325274014) // k1 ^ 0x7465646279746573
|
||||
t := uint64(len(p)) << 56
|
||||
|
||||
// Compression.
|
||||
for len(p) >= 8 {
|
||||
|
||||
m := uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 |
|
||||
uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
|
||||
|
||||
v3 ^= m
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= m
|
||||
p = p[8:]
|
||||
}
|
||||
|
||||
// Compress last block.
|
||||
switch len(p) {
|
||||
case 7:
|
||||
t |= uint64(p[6]) << 48
|
||||
fallthrough
|
||||
case 6:
|
||||
t |= uint64(p[5]) << 40
|
||||
fallthrough
|
||||
case 5:
|
||||
t |= uint64(p[4]) << 32
|
||||
fallthrough
|
||||
case 4:
|
||||
t |= uint64(p[3]) << 24
|
||||
fallthrough
|
||||
case 3:
|
||||
t |= uint64(p[2]) << 16
|
||||
fallthrough
|
||||
case 2:
|
||||
t |= uint64(p[1]) << 8
|
||||
fallthrough
|
||||
case 1:
|
||||
t |= uint64(p[0])
|
||||
}
|
||||
|
||||
v3 ^= t
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= t
|
||||
|
||||
// Finalization.
|
||||
v2 ^= 0xff
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 3.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 4.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// return v0 ^ v1 ^ v2 ^ v3
|
||||
|
||||
hash := v0 ^ v1 ^ v2 ^ v3
|
||||
h = hash >> bl.shift
|
||||
l = hash << bl.shift >> bl.shift
|
||||
return l, h
|
||||
|
||||
}
|
||||
140
vendor/github.com/AndreasBriese/bbloom/words.txt
generated
vendored
140
vendor/github.com/AndreasBriese/bbloom/words.txt
generated
vendored
@@ -1,140 +0,0 @@
|
||||
2014/01/01 00:00:00 /info.html
|
||||
2014/01/01 00:00:00 /info.html
|
||||
2014/01/01 00:00:01 /info.html
|
||||
2014/01/01 00:00:02 /info.html
|
||||
2014/01/01 00:00:03 /info.html
|
||||
2014/01/01 00:00:04 /info.html
|
||||
2014/01/01 00:00:05 /info.html
|
||||
2014/01/01 00:00:06 /info.html
|
||||
2014/01/01 00:00:07 /info.html
|
||||
2014/01/01 00:00:08 /info.html
|
||||
2014/01/01 00:00:09 /info.html
|
||||
2014/01/01 00:00:10 /info.html
|
||||
2014/01/01 00:00:11 /info.html
|
||||
2014/01/01 00:00:12 /info.html
|
||||
2014/01/01 00:00:13 /info.html
|
||||
2014/01/01 00:00:14 /info.html
|
||||
2014/01/01 00:00:15 /info.html
|
||||
2014/01/01 00:00:16 /info.html
|
||||
2014/01/01 00:00:17 /info.html
|
||||
2014/01/01 00:00:18 /info.html
|
||||
2014/01/01 00:00:19 /info.html
|
||||
2014/01/01 00:00:20 /info.html
|
||||
2014/01/01 00:00:21 /info.html
|
||||
2014/01/01 00:00:22 /info.html
|
||||
2014/01/01 00:00:23 /info.html
|
||||
2014/01/01 00:00:24 /info.html
|
||||
2014/01/01 00:00:25 /info.html
|
||||
2014/01/01 00:00:26 /info.html
|
||||
2014/01/01 00:00:27 /info.html
|
||||
2014/01/01 00:00:28 /info.html
|
||||
2014/01/01 00:00:29 /info.html
|
||||
2014/01/01 00:00:30 /info.html
|
||||
2014/01/01 00:00:31 /info.html
|
||||
2014/01/01 00:00:32 /info.html
|
||||
2014/01/01 00:00:33 /info.html
|
||||
2014/01/01 00:00:34 /info.html
|
||||
2014/01/01 00:00:35 /info.html
|
||||
2014/01/01 00:00:36 /info.html
|
||||
2014/01/01 00:00:37 /info.html
|
||||
2014/01/01 00:00:38 /info.html
|
||||
2014/01/01 00:00:39 /info.html
|
||||
2014/01/01 00:00:40 /info.html
|
||||
2014/01/01 00:00:41 /info.html
|
||||
2014/01/01 00:00:42 /info.html
|
||||
2014/01/01 00:00:43 /info.html
|
||||
2014/01/01 00:00:44 /info.html
|
||||
2014/01/01 00:00:45 /info.html
|
||||
2014/01/01 00:00:46 /info.html
|
||||
2014/01/01 00:00:47 /info.html
|
||||
2014/01/01 00:00:48 /info.html
|
||||
2014/01/01 00:00:49 /info.html
|
||||
2014/01/01 00:00:50 /info.html
|
||||
2014/01/01 00:00:51 /info.html
|
||||
2014/01/01 00:00:52 /info.html
|
||||
2014/01/01 00:00:53 /info.html
|
||||
2014/01/01 00:00:54 /info.html
|
||||
2014/01/01 00:00:55 /info.html
|
||||
2014/01/01 00:00:56 /info.html
|
||||
2014/01/01 00:00:57 /info.html
|
||||
2014/01/01 00:00:58 /info.html
|
||||
2014/01/01 00:00:59 /info.html
|
||||
2014/01/01 00:01:00 /info.html
|
||||
2014/01/01 00:01:01 /info.html
|
||||
2014/01/01 00:01:02 /info.html
|
||||
2014/01/01 00:01:03 /info.html
|
||||
2014/01/01 00:01:04 /info.html
|
||||
2014/01/01 00:01:05 /info.html
|
||||
2014/01/01 00:01:06 /info.html
|
||||
2014/01/01 00:01:07 /info.html
|
||||
2014/01/01 00:01:08 /info.html
|
||||
2014/01/01 00:01:09 /info.html
|
||||
2014/01/01 00:01:10 /info.html
|
||||
2014/01/01 00:01:11 /info.html
|
||||
2014/01/01 00:01:12 /info.html
|
||||
2014/01/01 00:01:13 /info.html
|
||||
2014/01/01 00:01:14 /info.html
|
||||
2014/01/01 00:01:15 /info.html
|
||||
2014/01/01 00:01:16 /info.html
|
||||
2014/01/01 00:01:17 /info.html
|
||||
2014/01/01 00:01:18 /info.html
|
||||
2014/01/01 00:01:19 /info.html
|
||||
2014/01/01 00:01:20 /info.html
|
||||
2014/01/01 00:01:21 /info.html
|
||||
2014/01/01 00:01:22 /info.html
|
||||
2014/01/01 00:01:23 /info.html
|
||||
2014/01/01 00:01:24 /info.html
|
||||
2014/01/01 00:01:25 /info.html
|
||||
2014/01/01 00:01:26 /info.html
|
||||
2014/01/01 00:01:27 /info.html
|
||||
2014/01/01 00:01:28 /info.html
|
||||
2014/01/01 00:01:29 /info.html
|
||||
2014/01/01 00:01:30 /info.html
|
||||
2014/01/01 00:01:31 /info.html
|
||||
2014/01/01 00:01:32 /info.html
|
||||
2014/01/01 00:01:33 /info.html
|
||||
2014/01/01 00:01:34 /info.html
|
||||
2014/01/01 00:01:35 /info.html
|
||||
2014/01/01 00:01:36 /info.html
|
||||
2014/01/01 00:01:37 /info.html
|
||||
2014/01/01 00:01:38 /info.html
|
||||
2014/01/01 00:01:39 /info.html
|
||||
2014/01/01 00:01:40 /info.html
|
||||
2014/01/01 00:01:41 /info.html
|
||||
2014/01/01 00:01:42 /info.html
|
||||
2014/01/01 00:01:43 /info.html
|
||||
2014/01/01 00:01:44 /info.html
|
||||
2014/01/01 00:01:45 /info.html
|
||||
2014/01/01 00:01:46 /info.html
|
||||
2014/01/01 00:01:47 /info.html
|
||||
2014/01/01 00:01:48 /info.html
|
||||
2014/01/01 00:01:49 /info.html
|
||||
2014/01/01 00:01:50 /info.html
|
||||
2014/01/01 00:01:51 /info.html
|
||||
2014/01/01 00:01:52 /info.html
|
||||
2014/01/01 00:01:53 /info.html
|
||||
2014/01/01 00:01:54 /info.html
|
||||
2014/01/01 00:01:55 /info.html
|
||||
2014/01/01 00:01:56 /info.html
|
||||
2014/01/01 00:01:57 /info.html
|
||||
2014/01/01 00:01:58 /info.html
|
||||
2014/01/01 00:01:59 /info.html
|
||||
2014/01/01 00:02:00 /info.html
|
||||
2014/01/01 00:02:01 /info.html
|
||||
2014/01/01 00:02:02 /info.html
|
||||
2014/01/01 00:02:03 /info.html
|
||||
2014/01/01 00:02:04 /info.html
|
||||
2014/01/01 00:02:05 /info.html
|
||||
2014/01/01 00:02:06 /info.html
|
||||
2014/01/01 00:02:07 /info.html
|
||||
2014/01/01 00:02:08 /info.html
|
||||
2014/01/01 00:02:09 /info.html
|
||||
2014/01/01 00:02:10 /info.html
|
||||
2014/01/01 00:02:11 /info.html
|
||||
2014/01/01 00:02:12 /info.html
|
||||
2014/01/01 00:02:13 /info.html
|
||||
2014/01/01 00:02:14 /info.html
|
||||
2014/01/01 00:02:15 /info.html
|
||||
2014/01/01 00:02:16 /info.html
|
||||
2014/01/01 00:02:17 /info.html
|
||||
2014/01/01 00:02:18 /info.html
|
||||
24
vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
24
vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
@@ -1,24 +0,0 @@
|
||||
This is free and unencumbered software released into the public domain.
|
||||
|
||||
Anyone is free to copy, modify, publish, use, compile, sell, or
|
||||
distribute this software, either in source code form or as a compiled
|
||||
binary, for any purpose, commercial or non-commercial, and by any
|
||||
means.
|
||||
|
||||
In jurisdictions that recognize copyright laws, the author or authors
|
||||
of this software dedicate any and all copyright interest in the
|
||||
software to the public domain. We make this dedication for the benefit
|
||||
of the public at large and to the detriment of our heirs and
|
||||
successors. We intend this dedication to be an overt act of
|
||||
relinquishment in perpetuity of all present and future rights to this
|
||||
software under copyright law.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
For more information, please refer to <http://unlicense.org/>
|
||||
7
vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
7
vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
@@ -1,7 +0,0 @@
|
||||
# gopher-json [](https://godoc.org/layeh.com/gopher-json)
|
||||
|
||||
Package json is a simple JSON encoder/decoder for [gopher-lua](https://github.com/yuin/gopher-lua).
|
||||
|
||||
## License
|
||||
|
||||
Public domain
|
||||
33
vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
33
vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
@@ -1,33 +0,0 @@
|
||||
// Package json is a simple JSON encoder/decoder for gopher-lua.
|
||||
//
|
||||
// Documentation
|
||||
//
|
||||
// The following functions are exposed by the library:
|
||||
// decode(string): Decodes a JSON string. Returns nil and an error string if
|
||||
// the string could not be decoded.
|
||||
// encode(value): Encodes a value into a JSON string. Returns nil and an error
|
||||
// string if the value could not be encoded.
|
||||
//
|
||||
// The following types are supported:
|
||||
//
|
||||
// Lua | JSON
|
||||
// ---------+-----
|
||||
// nil | null
|
||||
// number | number
|
||||
// string | string
|
||||
// table | object: when table is non-empty and has only string keys
|
||||
// | array: when table is empty, or has only sequential numeric keys
|
||||
// | starting from 1
|
||||
//
|
||||
// Attempting to encode any other Lua type will result in an error.
|
||||
//
|
||||
// Example
|
||||
//
|
||||
// Below is an example usage of the library:
|
||||
// import (
|
||||
// luajson "layeh.com/gopher-json"
|
||||
// )
|
||||
//
|
||||
// L := lua.NewState()
|
||||
// luajson.Preload(s)
|
||||
package json
|
||||
189
vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
189
vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
@@ -1,189 +0,0 @@
|
||||
package json
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// Preload adds json to the given Lua state's package.preload table. After it
|
||||
// has been preloaded, it can be loaded using require:
|
||||
//
|
||||
// local json = require("json")
|
||||
func Preload(L *lua.LState) {
|
||||
L.PreloadModule("json", Loader)
|
||||
}
|
||||
|
||||
// Loader is the module loader function.
|
||||
func Loader(L *lua.LState) int {
|
||||
t := L.NewTable()
|
||||
L.SetFuncs(t, api)
|
||||
L.Push(t)
|
||||
return 1
|
||||
}
|
||||
|
||||
var api = map[string]lua.LGFunction{
|
||||
"decode": apiDecode,
|
||||
"encode": apiEncode,
|
||||
}
|
||||
|
||||
func apiDecode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to decode"), 1)
|
||||
return 0
|
||||
}
|
||||
str := L.CheckString(1)
|
||||
|
||||
value, err := Decode(L, []byte(str))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(value)
|
||||
return 1
|
||||
}
|
||||
|
||||
func apiEncode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to encode"), 1)
|
||||
return 0
|
||||
}
|
||||
value := L.CheckAny(1)
|
||||
|
||||
data, err := Encode(value)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(lua.LString(string(data)))
|
||||
return 1
|
||||
}
|
||||
|
||||
var (
|
||||
errNested = errors.New("cannot encode recursively nested tables to JSON")
|
||||
errSparseArray = errors.New("cannot encode sparse array")
|
||||
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
|
||||
)
|
||||
|
||||
type invalidTypeError lua.LValueType
|
||||
|
||||
func (i invalidTypeError) Error() string {
|
||||
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
|
||||
}
|
||||
|
||||
// Encode returns the JSON encoding of value.
|
||||
func Encode(value lua.LValue) ([]byte, error) {
|
||||
return json.Marshal(jsonValue{
|
||||
LValue: value,
|
||||
visited: make(map[*lua.LTable]bool),
|
||||
})
|
||||
}
|
||||
|
||||
type jsonValue struct {
|
||||
lua.LValue
|
||||
visited map[*lua.LTable]bool
|
||||
}
|
||||
|
||||
func (j jsonValue) MarshalJSON() (data []byte, err error) {
|
||||
switch converted := j.LValue.(type) {
|
||||
case lua.LBool:
|
||||
data, err = json.Marshal(bool(converted))
|
||||
case lua.LNumber:
|
||||
data, err = json.Marshal(float64(converted))
|
||||
case *lua.LNilType:
|
||||
data = []byte(`null`)
|
||||
case lua.LString:
|
||||
data, err = json.Marshal(string(converted))
|
||||
case *lua.LTable:
|
||||
if j.visited[converted] {
|
||||
return nil, errNested
|
||||
}
|
||||
j.visited[converted] = true
|
||||
|
||||
key, value := converted.Next(lua.LNil)
|
||||
|
||||
switch key.Type() {
|
||||
case lua.LTNil: // empty table
|
||||
data = []byte(`[]`)
|
||||
case lua.LTNumber:
|
||||
arr := make([]jsonValue, 0, converted.Len())
|
||||
expectedKey := lua.LNumber(1)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTNumber {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
if expectedKey != key {
|
||||
err = errSparseArray
|
||||
return
|
||||
}
|
||||
arr = append(arr, jsonValue{value, j.visited})
|
||||
expectedKey++
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(arr)
|
||||
case lua.LTString:
|
||||
obj := make(map[string]jsonValue)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTString {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
obj[key.String()] = jsonValue{value, j.visited}
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(obj)
|
||||
default:
|
||||
err = errInvalidKeys
|
||||
}
|
||||
default:
|
||||
err = invalidTypeError(j.LValue.Type())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Decode converts the JSON encoded data to Lua values.
|
||||
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
|
||||
var value interface{}
|
||||
err := json.Unmarshal(data, &value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeValue(L, value), nil
|
||||
}
|
||||
|
||||
// DecodeValue converts the value to a Lua value.
|
||||
//
|
||||
// This function only converts values that the encoding/json package decodes to.
|
||||
// All other values will return lua.LNil.
|
||||
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
|
||||
switch converted := value.(type) {
|
||||
case bool:
|
||||
return lua.LBool(converted)
|
||||
case float64:
|
||||
return lua.LNumber(converted)
|
||||
case string:
|
||||
return lua.LString(converted)
|
||||
case json.Number:
|
||||
return lua.LString(converted)
|
||||
case []interface{}:
|
||||
arr := L.CreateTable(len(converted), 0)
|
||||
for _, item := range converted {
|
||||
arr.Append(DecodeValue(L, item))
|
||||
}
|
||||
return arr
|
||||
case map[string]interface{}:
|
||||
tbl := L.CreateTable(0, len(converted))
|
||||
for key, item := range converted {
|
||||
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
|
||||
}
|
||||
return tbl
|
||||
case nil:
|
||||
return lua.LNil
|
||||
}
|
||||
|
||||
return lua.LNil
|
||||
}
|
||||
6
vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
6
vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
@@ -1,6 +0,0 @@
|
||||
/integration/redis_src/
|
||||
/integration/dump.rdb
|
||||
*.swp
|
||||
/integration/nodes.conf
|
||||
.idea/
|
||||
miniredis.iml
|
||||
225
vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
225
vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
@@ -1,225 +0,0 @@
|
||||
## Changelog
|
||||
|
||||
|
||||
### v2.23.0
|
||||
|
||||
- basic INFO support (thanks @kirill-a-belov)
|
||||
- support COUNT in SSCAN (thanks @Abdi-dd)
|
||||
- test and support Go 1.19
|
||||
- support LPOS (thanks @ianstarz)
|
||||
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
|
||||
|
||||
|
||||
### v2.22.0
|
||||
|
||||
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
|
||||
- fix invalid resposne of COMMAND (thanks @zsh1995)
|
||||
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
|
||||
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
|
||||
|
||||
|
||||
### v2.21.0
|
||||
|
||||
- support for GETEX (thanks @dntj)
|
||||
- support for GT and LT in ZADD (thanks @lsgndln)
|
||||
- support for XAUTOCLAIM (thanks @randall-fulton)
|
||||
|
||||
|
||||
### v2.20.0
|
||||
|
||||
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
|
||||
|
||||
|
||||
### v2.19.0
|
||||
|
||||
- support for TYPE in SCAN (thanks @0xDiddi)
|
||||
- update BITPOS (thanks @dirkm)
|
||||
- fix a lua redis.call() return value (thanks @mpetronic)
|
||||
- update ZRANGE (thanks @valdemarpereira)
|
||||
|
||||
|
||||
### v2.18.0
|
||||
|
||||
- support for ZUNION (thanks @propan)
|
||||
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
|
||||
- support for LMOVE (thanks @btwear)
|
||||
|
||||
|
||||
### v2.17.0
|
||||
|
||||
- added miniredis.RunT(t)
|
||||
|
||||
|
||||
### v2.16.1
|
||||
|
||||
- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang)
|
||||
- fix exclusive ranges in XRANGE (thanks @joseotoro)
|
||||
|
||||
|
||||
### v2.16.0
|
||||
|
||||
- simplify some code (thanks @zonque)
|
||||
- support for EXAT/PXAT in SET
|
||||
- support for XTRIM (thanks @joseotoro)
|
||||
- support for ZRANDMEMBER
|
||||
- support for redis.log() in lua (thanks @dirkm)
|
||||
|
||||
|
||||
### v2.15.2
|
||||
|
||||
- Fix race condition in blocking code (thanks @zonque and @robx)
|
||||
- XREAD accepts '$' as ID (thanks @bradengroom)
|
||||
|
||||
|
||||
### v2.15.1
|
||||
|
||||
- EVAL should cache the script (thanks @guoshimin)
|
||||
|
||||
|
||||
### v2.15.0
|
||||
|
||||
- target redis 6.2 and added new args to various commands
|
||||
- support for all hyperlog commands (thanks @ilbaktin)
|
||||
- support for GETDEL (thanks @wszaranski)
|
||||
|
||||
|
||||
### v2.14.5
|
||||
|
||||
- added XPENDING
|
||||
- support for BLOCK option in XREAD and XREADGROUP
|
||||
|
||||
|
||||
### v2.14.4
|
||||
|
||||
- fix BITPOS error (thanks @xiaoyuzdy)
|
||||
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
|
||||
- fix empty EXEC return type (thanks @ashanbrown)
|
||||
- fix XDEL (thanks @svakili and @yvesf)
|
||||
- fix FLUSHALL for streams (thanks @svakili)
|
||||
|
||||
|
||||
### v2.14.3
|
||||
|
||||
- fix problem where Lua code didn't set the selected DB
|
||||
- update to redis 6.0.10 (thanks @lazappa)
|
||||
|
||||
|
||||
### v2.14.2
|
||||
|
||||
- update LUA dependency
|
||||
- deal with (p)unsubscribe when there are no channels
|
||||
|
||||
|
||||
### v2.14.1
|
||||
|
||||
- mod tidy
|
||||
|
||||
|
||||
### v2.14.0
|
||||
|
||||
- support for HELLO and the RESP3 protocol
|
||||
- KEEPTTL in SET (thanks @johnpena)
|
||||
|
||||
|
||||
### v2.13.3
|
||||
|
||||
- support Go 1.14 and 1.15
|
||||
- update the `Check...()` methods
|
||||
- support for XREAD (thanks @pieterlexis)
|
||||
|
||||
|
||||
### v2.13.2
|
||||
|
||||
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
|
||||
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
|
||||
- changed unit and integration tests to compare raw payloads, not parsed payloads
|
||||
- remove "redigo" dependency
|
||||
|
||||
|
||||
### v2.13.1
|
||||
|
||||
- added HSTRLEN
|
||||
- minimal support for ACL users in AUTH
|
||||
|
||||
|
||||
### v2.13.0
|
||||
|
||||
- added RunTLS(...)
|
||||
- added SetError(...)
|
||||
|
||||
|
||||
### v2.12.0
|
||||
|
||||
- redis 6
|
||||
- Lua json update (thanks @gsmith85)
|
||||
- CLUSTER commands (thanks @kratisto)
|
||||
- fix TOUCH
|
||||
- fix a shutdown race condition
|
||||
|
||||
|
||||
### v2.11.4
|
||||
|
||||
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
|
||||
|
||||
|
||||
### v2.11.3
|
||||
|
||||
- support for TOUCH (thanks @cleroux)
|
||||
- support for cluster and stream commands (thanks @kak-tus)
|
||||
|
||||
|
||||
### v2.11.2
|
||||
|
||||
- make sure Lua code is executed concurrently
|
||||
- add command GEORADIUSBYMEMBER (thanks @kyeett)
|
||||
|
||||
|
||||
### v2.11.1
|
||||
|
||||
- globals protection for Lua code (thanks @vk-outreach)
|
||||
- HSET update (thanks @carlgreen)
|
||||
- fix BLPOP block on shutdown (thanks @Asalle)
|
||||
|
||||
|
||||
### v2.11.0
|
||||
|
||||
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
|
||||
- added GEODIST
|
||||
- improved precision for geohashes, closer to what real redis does
|
||||
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
|
||||
|
||||
|
||||
### v2.10.1
|
||||
|
||||
- added m.Server()
|
||||
|
||||
|
||||
### v2.10.0
|
||||
|
||||
- added UNLINK
|
||||
- fix DEL zero-argument case
|
||||
- cleanup some direct access commands
|
||||
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
|
||||
|
||||
|
||||
### v2.9.1
|
||||
|
||||
- fix issue with ZRANGEBYLEX
|
||||
- fix issue with BRPOPLPUSH and direct access
|
||||
|
||||
|
||||
### v2.9.0
|
||||
|
||||
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
|
||||
- fix messages generated by PSUBSCRIBE
|
||||
- optional internal seed (thanks @zikaeroh)
|
||||
|
||||
|
||||
### v2.8.0
|
||||
|
||||
Proper `v2` in go.mod.
|
||||
|
||||
|
||||
### older
|
||||
|
||||
See https://github.com/alicebob/miniredis/releases for the full changelog
|
||||
21
vendor/github.com/alicebob/miniredis/v2/LICENSE
generated
vendored
21
vendor/github.com/alicebob/miniredis/v2/LICENSE
generated
vendored
@@ -1,21 +0,0 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Harmen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
12
vendor/github.com/alicebob/miniredis/v2/Makefile
generated
vendored
12
vendor/github.com/alicebob/miniredis/v2/Makefile
generated
vendored
@@ -1,12 +0,0 @@
|
||||
.PHONY: all test testrace int
|
||||
|
||||
all: test
|
||||
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
testrace:
|
||||
go test -race ./...
|
||||
|
||||
int:
|
||||
${MAKE} -C integration all
|
||||
333
vendor/github.com/alicebob/miniredis/v2/README.md
generated
vendored
333
vendor/github.com/alicebob/miniredis/v2/README.md
generated
vendored
@@ -1,333 +0,0 @@
|
||||
# Miniredis
|
||||
|
||||
Pure Go Redis test server, used in Go unittests.
|
||||
|
||||
|
||||
##
|
||||
|
||||
Sometimes you want to test code which uses Redis, without making it a full-blown
|
||||
integration test.
|
||||
Miniredis implements (parts of) the Redis server, to be used in unittests. It
|
||||
enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`.
|
||||
|
||||
It saves you from using mock code, and since the redis server lives in the
|
||||
test process you can query for values directly, without going through the server
|
||||
stack.
|
||||
|
||||
There are no dependencies on external binaries, so you can easily integrate it in automated build processes.
|
||||
|
||||
Be sure to import v2:
|
||||
```
|
||||
import "github.com/alicebob/miniredis/v2"
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
Implemented commands:
|
||||
|
||||
- Connection (complete)
|
||||
- AUTH -- see RequireAuth()
|
||||
- ECHO
|
||||
- HELLO -- see RequireUserAuth()
|
||||
- PING
|
||||
- SELECT
|
||||
- SWAPDB
|
||||
- QUIT
|
||||
- Key
|
||||
- COPY
|
||||
- DEL
|
||||
- EXISTS
|
||||
- EXPIRE
|
||||
- EXPIREAT
|
||||
- KEYS
|
||||
- MOVE
|
||||
- PERSIST
|
||||
- PEXPIRE
|
||||
- PEXPIREAT
|
||||
- PTTL
|
||||
- RENAME
|
||||
- RENAMENX
|
||||
- RANDOMKEY -- see m.Seed(...)
|
||||
- SCAN
|
||||
- TOUCH
|
||||
- TTL
|
||||
- TYPE
|
||||
- UNLINK
|
||||
- Transactions (complete)
|
||||
- DISCARD
|
||||
- EXEC
|
||||
- MULTI
|
||||
- UNWATCH
|
||||
- WATCH
|
||||
- Server
|
||||
- DBSIZE
|
||||
- FLUSHALL
|
||||
- FLUSHDB
|
||||
- TIME -- returns time.Now() or value set by SetTime()
|
||||
- COMMAND -- partly
|
||||
- INFO -- partly, returns only "clients" section with one field "connected_clients"
|
||||
- String keys (complete)
|
||||
- APPEND
|
||||
- BITCOUNT
|
||||
- BITOP
|
||||
- BITPOS
|
||||
- DECR
|
||||
- DECRBY
|
||||
- GET
|
||||
- GETBIT
|
||||
- GETRANGE
|
||||
- GETSET
|
||||
- GETDEL
|
||||
- GETEX
|
||||
- INCR
|
||||
- INCRBY
|
||||
- INCRBYFLOAT
|
||||
- MGET
|
||||
- MSET
|
||||
- MSETNX
|
||||
- PSETEX
|
||||
- SET
|
||||
- SETBIT
|
||||
- SETEX
|
||||
- SETNX
|
||||
- SETRANGE
|
||||
- STRLEN
|
||||
- Hash keys (complete)
|
||||
- HDEL
|
||||
- HEXISTS
|
||||
- HGET
|
||||
- HGETALL
|
||||
- HINCRBY
|
||||
- HINCRBYFLOAT
|
||||
- HKEYS
|
||||
- HLEN
|
||||
- HMGET
|
||||
- HMSET
|
||||
- HSET
|
||||
- HSETNX
|
||||
- HSTRLEN
|
||||
- HVALS
|
||||
- HSCAN
|
||||
- List keys (complete)
|
||||
- BLPOP
|
||||
- BRPOP
|
||||
- BRPOPLPUSH
|
||||
- LINDEX
|
||||
- LINSERT
|
||||
- LLEN
|
||||
- LPOP
|
||||
- LPUSH
|
||||
- LPUSHX
|
||||
- LRANGE
|
||||
- LREM
|
||||
- LSET
|
||||
- LTRIM
|
||||
- RPOP
|
||||
- RPOPLPUSH
|
||||
- RPUSH
|
||||
- RPUSHX
|
||||
- LMOVE
|
||||
- Pub/Sub (complete)
|
||||
- PSUBSCRIBE
|
||||
- PUBLISH
|
||||
- PUBSUB
|
||||
- PUNSUBSCRIBE
|
||||
- SUBSCRIBE
|
||||
- UNSUBSCRIBE
|
||||
- Set keys (complete)
|
||||
- SADD
|
||||
- SCARD
|
||||
- SDIFF
|
||||
- SDIFFSTORE
|
||||
- SINTER
|
||||
- SINTERSTORE
|
||||
- SISMEMBER
|
||||
- SMEMBERS
|
||||
- SMOVE
|
||||
- SPOP -- see m.Seed(...)
|
||||
- SRANDMEMBER -- see m.Seed(...)
|
||||
- SREM
|
||||
- SUNION
|
||||
- SUNIONSTORE
|
||||
- SSCAN
|
||||
- Sorted Set keys (complete)
|
||||
- ZADD
|
||||
- ZCARD
|
||||
- ZCOUNT
|
||||
- ZINCRBY
|
||||
- ZINTERSTORE
|
||||
- ZLEXCOUNT
|
||||
- ZPOPMIN
|
||||
- ZPOPMAX
|
||||
- ZRANDMEMBER
|
||||
- ZRANGE
|
||||
- ZRANGEBYLEX
|
||||
- ZRANGEBYSCORE
|
||||
- ZRANK
|
||||
- ZREM
|
||||
- ZREMRANGEBYLEX
|
||||
- ZREMRANGEBYRANK
|
||||
- ZREMRANGEBYSCORE
|
||||
- ZREVRANGE
|
||||
- ZREVRANGEBYLEX
|
||||
- ZREVRANGEBYSCORE
|
||||
- ZREVRANK
|
||||
- ZSCORE
|
||||
- ZUNION
|
||||
- ZUNIONSTORE
|
||||
- ZSCAN
|
||||
- Stream keys
|
||||
- XACK
|
||||
- XADD
|
||||
- XAUTOCLAIM
|
||||
- XCLAIM
|
||||
- XDEL
|
||||
- XGROUP CREATE
|
||||
- XGROUP CREATECONSUMER
|
||||
- XGROUP DESTROY
|
||||
- XGROUP DELCONSUMER
|
||||
- XINFO STREAM -- partly
|
||||
- XINFO GROUPS
|
||||
- XINFO CONSUMERS -- partly
|
||||
- XLEN
|
||||
- XRANGE
|
||||
- XREAD
|
||||
- XREADGROUP
|
||||
- XREVRANGE
|
||||
- XPENDING
|
||||
- XTRIM
|
||||
- Scripting
|
||||
- EVAL
|
||||
- EVALSHA
|
||||
- SCRIPT LOAD
|
||||
- SCRIPT EXISTS
|
||||
- SCRIPT FLUSH
|
||||
- GEO
|
||||
- GEOADD
|
||||
- GEODIST
|
||||
- ~~GEOHASH~~
|
||||
- GEOPOS
|
||||
- GEORADIUS
|
||||
- GEORADIUS_RO
|
||||
- GEORADIUSBYMEMBER
|
||||
- GEORADIUSBYMEMBER_RO
|
||||
- Cluster
|
||||
- CLUSTER SLOTS
|
||||
- CLUSTER KEYSLOT
|
||||
- CLUSTER NODES
|
||||
- HyperLogLog (complete)
|
||||
- PFADD
|
||||
- PFCOUNT
|
||||
- PFMERGE
|
||||
|
||||
|
||||
## TTLs, key expiration, and time
|
||||
|
||||
Since miniredis is intended to be used in unittests TTLs don't decrease
|
||||
automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a
|
||||
key. It will return 0 when no TTL is set.
|
||||
|
||||
`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <=
|
||||
0 will be removed.
|
||||
|
||||
EXPIREAT and PEXPIREAT values will be
|
||||
converted to a duration. For that you can either set m.SetTime(t) to use that
|
||||
time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in
|
||||
which case time.Now() will be used.
|
||||
|
||||
SetTime() also sets the value returned by TIME, which defaults to time.Now().
|
||||
It is not updated by FastForward, only by SetTime.
|
||||
|
||||
## Randomness and Seed()
|
||||
|
||||
Miniredis will use `math/rand`'s global RNG for randomness unless a seed is
|
||||
provided by calling `m.Seed(...)`. If a seed is provided, then miniredis will
|
||||
use its own RNG based on that seed.
|
||||
|
||||
Commands which use randomness are: RANDOMKEY, SPOP, and SRANDMEMBER.
|
||||
|
||||
## Example
|
||||
|
||||
``` Go
|
||||
|
||||
import (
|
||||
...
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
...
|
||||
)
|
||||
|
||||
func TestSomething(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
|
||||
// Optionally set some keys your code expects:
|
||||
s.Set("foo", "bar")
|
||||
s.HSet("some", "other", "key")
|
||||
|
||||
// Run your code and see if it behaves.
|
||||
// An example using the redigo library from "github.com/gomodule/redigo/redis":
|
||||
c, err := redis.Dial("tcp", s.Addr())
|
||||
_, err = c.Do("SET", "foo", "bar")
|
||||
|
||||
// Optionally check values in redis...
|
||||
if got, err := s.Get("foo"); err != nil || got != "bar" {
|
||||
t.Error("'foo' has the wrong value")
|
||||
}
|
||||
// ... or use a helper for that:
|
||||
s.CheckGet(t, "foo", "bar")
|
||||
|
||||
// TTL and expiration:
|
||||
s.Set("foo", "bar")
|
||||
s.SetTTL("foo", 10*time.Second)
|
||||
s.FastForward(11 * time.Second)
|
||||
if s.Exists("foo") {
|
||||
t.Fatal("'foo' should not have existed anymore")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Not supported
|
||||
|
||||
Commands which will probably not be implemented:
|
||||
|
||||
- CLUSTER (all)
|
||||
- ~~CLUSTER *~~
|
||||
- ~~READONLY~~
|
||||
- ~~READWRITE~~
|
||||
- Key
|
||||
- ~~DUMP~~
|
||||
- ~~MIGRATE~~
|
||||
- ~~OBJECT~~
|
||||
- ~~RESTORE~~
|
||||
- ~~WAIT~~
|
||||
- Scripting
|
||||
- ~~SCRIPT DEBUG~~
|
||||
- ~~SCRIPT KILL~~
|
||||
- Server
|
||||
- ~~BGSAVE~~
|
||||
- ~~BGWRITEAOF~~
|
||||
- ~~CLIENT *~~
|
||||
- ~~CONFIG *~~
|
||||
- ~~DEBUG *~~
|
||||
- ~~LASTSAVE~~
|
||||
- ~~MONITOR~~
|
||||
- ~~ROLE~~
|
||||
- ~~SAVE~~
|
||||
- ~~SHUTDOWN~~
|
||||
- ~~SLAVEOF~~
|
||||
- ~~SLOWLOG~~
|
||||
- ~~SYNC~~
|
||||
|
||||
|
||||
## &c.
|
||||
|
||||
Integration tests are run against Redis 6.2.6. The [./integration](./integration/) subdir
|
||||
compares miniredis against a real redis instance.
|
||||
|
||||
The Redis 6 RESP3 protocol is supported. If there are problems, please open
|
||||
an issue.
|
||||
|
||||
If you want to test Redis Sentinel have a look at [minisentinel](https://github.com/Bose/minisentinel).
|
||||
|
||||
A changelog is kept at [CHANGELOG.md](https://github.com/alicebob/miniredis/blob/master/CHANGELOG.md).
|
||||
|
||||
[](https://pkg.go.dev/github.com/alicebob/miniredis/v2)
|
||||
63
vendor/github.com/alicebob/miniredis/v2/check.go
generated
vendored
63
vendor/github.com/alicebob/miniredis/v2/check.go
generated
vendored
@@ -1,63 +0,0 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// T is implemented by Testing.T
|
||||
type T interface {
|
||||
Helper()
|
||||
Errorf(string, ...interface{})
|
||||
}
|
||||
|
||||
// CheckGet does not call Errorf() iff there is a string key with the
|
||||
// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`.
|
||||
func (m *Miniredis) CheckGet(t T, key, expected string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Get(key)
|
||||
if err != nil {
|
||||
t.Errorf("GET error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if found != expected {
|
||||
t.Errorf("GET error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckList does not call Errorf() iff there is a list key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`.
|
||||
func (m *Miniredis) CheckList(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.List(key)
|
||||
if err != nil {
|
||||
t.Errorf("List error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("List error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckSet does not call Errorf() iff there is a set key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`.
|
||||
func (m *Miniredis) CheckSet(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Members(key)
|
||||
if err != nil {
|
||||
t.Errorf("Set error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
sort.Strings(expected)
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("Set error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
67
vendor/github.com/alicebob/miniredis/v2/cmd_cluster.go
generated
vendored
67
vendor/github.com/alicebob/miniredis/v2/cmd_cluster.go
generated
vendored
@@ -1,67 +0,0 @@
|
||||
// Commands from https://redis.io/commands#cluster
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsCluster handles some cluster operations.
|
||||
func commandsCluster(m *Miniredis) {
|
||||
m.srv.Register("CLUSTER", m.cmdCluster)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdCluster(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "SLOTS":
|
||||
m.cmdClusterSlots(c, cmd, args)
|
||||
case "KEYSLOT":
|
||||
m.cmdClusterKeySlot(c, cmd, args)
|
||||
case "NODES":
|
||||
m.cmdClusterNodes(c, cmd, args)
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("ERR 'CLUSTER %s' not supported", strings.Join(args, " ")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CLUSTER SLOTS
|
||||
func (m *Miniredis) cmdClusterSlots(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteLen(1)
|
||||
c.WriteLen(3)
|
||||
c.WriteInt(0)
|
||||
c.WriteInt(16383)
|
||||
c.WriteLen(3)
|
||||
c.WriteBulk(m.srv.Addr().IP.String())
|
||||
c.WriteInt(m.srv.Addr().Port)
|
||||
c.WriteBulk("09dbe9720cda62f7865eabc5fd8857c5d2678366")
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER KEYSLOT
|
||||
func (m *Miniredis) cmdClusterKeySlot(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteInt(163)
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER NODES
|
||||
func (m *Miniredis) cmdClusterNodes(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk("e7d1eecce10fd6bb5eb35b9f99a514335d9ba9ca 127.0.0.1:7000@7000 myself,master - 0 0 1 connected 0-16383")
|
||||
})
|
||||
}
|
||||
2045
vendor/github.com/alicebob/miniredis/v2/cmd_command.go
generated
vendored
2045
vendor/github.com/alicebob/miniredis/v2/cmd_command.go
generated
vendored
File diff suppressed because it is too large
Load Diff
284
vendor/github.com/alicebob/miniredis/v2/cmd_connection.go
generated
vendored
284
vendor/github.com/alicebob/miniredis/v2/cmd_connection.go
generated
vendored
@@ -1,284 +0,0 @@
|
||||
// Commands from https://redis.io/commands#connection
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
func commandsConnection(m *Miniredis) {
|
||||
m.srv.Register("AUTH", m.cmdAuth)
|
||||
m.srv.Register("ECHO", m.cmdEcho)
|
||||
m.srv.Register("HELLO", m.cmdHello)
|
||||
m.srv.Register("PING", m.cmdPing)
|
||||
m.srv.Register("QUIT", m.cmdQuit)
|
||||
m.srv.Register("SELECT", m.cmdSelect)
|
||||
m.srv.Register("SWAPDB", m.cmdSwapdb)
|
||||
}
|
||||
|
||||
// PING
|
||||
func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
payload := ""
|
||||
if len(args) > 0 {
|
||||
payload = args[0]
|
||||
}
|
||||
|
||||
// PING is allowed in subscribed state
|
||||
if sub := getCtx(c).subscriber; sub != nil {
|
||||
c.Block(func(c *server.Writer) {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("pong")
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if payload == "" {
|
||||
c.WriteInline("PONG")
|
||||
return
|
||||
}
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
}
|
||||
|
||||
// AUTH
|
||||
func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 2 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
var opts = struct {
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
username: "default",
|
||||
password: args[0],
|
||||
}
|
||||
if len(args) == 2 {
|
||||
opts.username, opts.password = args[0], args[1]
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
c.WriteError("ERR AUTH <password> called without any password configured for the default user. Are you sure your configuration is correct?")
|
||||
return
|
||||
}
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.authenticated = true
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// HELLO
|
||||
func (m *Miniredis) cmdHello(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
version int
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.version, "ERR Protocol version is not an integer or out of range"); !ok {
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
switch opts.version {
|
||||
case 2, 3:
|
||||
default:
|
||||
c.WriteError("NOPROTO unsupported protocol version")
|
||||
return
|
||||
}
|
||||
|
||||
var checkAuth bool
|
||||
for len(args) > 0 {
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "AUTH":
|
||||
if len(args) < 3 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
opts.username, opts.password, args = args[1], args[2], args[3:]
|
||||
checkAuth = true
|
||||
case "SETNAME":
|
||||
if len(args) < 2 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
_, args = args[1], args[2:]
|
||||
default:
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
// redis ignores legacy "AUTH" if it's not enabled.
|
||||
checkAuth = false
|
||||
}
|
||||
if checkAuth {
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
getCtx(c).authenticated = true
|
||||
}
|
||||
|
||||
c.Resp3 = opts.version == 3
|
||||
|
||||
c.WriteMapLen(7)
|
||||
c.WriteBulk("server")
|
||||
c.WriteBulk("miniredis")
|
||||
c.WriteBulk("version")
|
||||
c.WriteBulk("6.0.5")
|
||||
c.WriteBulk("proto")
|
||||
c.WriteInt(opts.version)
|
||||
c.WriteBulk("id")
|
||||
c.WriteInt(42)
|
||||
c.WriteBulk("mode")
|
||||
c.WriteBulk("standalone")
|
||||
c.WriteBulk("role")
|
||||
c.WriteBulk("master")
|
||||
c.WriteBulk("modules")
|
||||
c.WriteLen(0)
|
||||
}
|
||||
|
||||
// ECHO
|
||||
func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
msg := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk(msg)
|
||||
})
|
||||
}
|
||||
|
||||
// SELECT
|
||||
func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.isValidCMD(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id int
|
||||
}
|
||||
if ok := optInt(c, args[0], &opts.id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.selectedDB = opts.id
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// SWAPDB
|
||||
func (m *Miniredis) cmdSwapdb(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id1 int
|
||||
id2 int
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.id1, "ERR invalid first DB index"); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optIntErr(c, args[1], &opts.id2, "ERR invalid second DB index"); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id1 < 0 || opts.id2 < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
m.swapDB(opts.id1, opts.id2)
|
||||
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// QUIT
|
||||
func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) {
|
||||
// QUIT isn't transactionfied and accepts any arguments.
|
||||
c.WriteOK()
|
||||
c.Close()
|
||||
}
|
||||
669
vendor/github.com/alicebob/miniredis/v2/cmd_generic.go
generated
vendored
669
vendor/github.com/alicebob/miniredis/v2/cmd_generic.go
generated
vendored
@@ -1,669 +0,0 @@
|
||||
// Commands from https://redis.io/commands#generic
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsGeneric handles EXPIRE, TTL, PERSIST, &c.
|
||||
func commandsGeneric(m *Miniredis) {
|
||||
m.srv.Register("COPY", m.cmdCopy)
|
||||
m.srv.Register("DEL", m.cmdDel)
|
||||
// DUMP
|
||||
m.srv.Register("EXISTS", m.cmdExists)
|
||||
m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second))
|
||||
m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second))
|
||||
m.srv.Register("KEYS", m.cmdKeys)
|
||||
// MIGRATE
|
||||
m.srv.Register("MOVE", m.cmdMove)
|
||||
// OBJECT
|
||||
m.srv.Register("PERSIST", m.cmdPersist)
|
||||
m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond))
|
||||
m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond))
|
||||
m.srv.Register("PTTL", m.cmdPTTL)
|
||||
m.srv.Register("RANDOMKEY", m.cmdRandomkey)
|
||||
m.srv.Register("RENAME", m.cmdRename)
|
||||
m.srv.Register("RENAMENX", m.cmdRenamenx)
|
||||
// RESTORE
|
||||
m.srv.Register("TOUCH", m.cmdTouch)
|
||||
m.srv.Register("TTL", m.cmdTTL)
|
||||
m.srv.Register("TYPE", m.cmdType)
|
||||
m.srv.Register("SCAN", m.cmdScan)
|
||||
// SORT
|
||||
m.srv.Register("UNLINK", m.cmdDel)
|
||||
}
|
||||
|
||||
// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
|
||||
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
|
||||
// converted to a duration.
|
||||
func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) {
|
||||
return func(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
value int
|
||||
}
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.value); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
// Key must be present.
|
||||
if _, ok := db.keys[opts.key]; !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if unix {
|
||||
db.ttl[opts.key] = m.at(opts.value, d)
|
||||
} else {
|
||||
db.ttl[opts.key] = time.Duration(opts.value) * d
|
||||
}
|
||||
db.keyVersion[opts.key]++
|
||||
db.checkTTL(opts.key)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TOUCH
|
||||
func (m *Miniredis) cmdTouch(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TTL
|
||||
func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// No such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Seconds()))
|
||||
})
|
||||
}
|
||||
|
||||
// PTTL
|
||||
func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Nanoseconds() / 1000000))
|
||||
})
|
||||
}
|
||||
|
||||
// PERSIST
|
||||
func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.ttl[key]; !ok {
|
||||
// no expire value
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
delete(db.ttl, key)
|
||||
db.keyVersion[key]++
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// DEL and UNLINK
|
||||
func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
db.del(key, true) // delete expire
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TYPE
|
||||
func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError("usage error")
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteInline("none")
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInline(t)
|
||||
})
|
||||
}
|
||||
|
||||
// EXISTS
|
||||
func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
found := 0
|
||||
for _, k := range args {
|
||||
if db.exists(k) {
|
||||
found++
|
||||
}
|
||||
}
|
||||
c.WriteInt(found)
|
||||
})
|
||||
}
|
||||
|
||||
// MOVE
|
||||
func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
targetDB int
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
opts.targetDB, _ = strconv.Atoi(args[1])
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if ctx.selectedDB == opts.targetDB {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
db := m.db(ctx.selectedDB)
|
||||
targetDB := m.db(opts.targetDB)
|
||||
|
||||
if !db.move(opts.key, targetDB) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// KEYS
|
||||
func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
keys, _ := matchKeys(db.allKeys(), key)
|
||||
c.WriteLen(len(keys))
|
||||
for _, s := range keys {
|
||||
c.WriteBulk(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RANDOMKEY
|
||||
func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if len(db.keys) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
nr := m.randIntn(len(db.keys))
|
||||
for k := range db.keys {
|
||||
if nr == 0 {
|
||||
c.WriteBulk(k)
|
||||
return
|
||||
}
|
||||
nr--
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RENAME
|
||||
func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// RENAMENX
|
||||
func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if db.exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// SCAN
|
||||
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
cursor int
|
||||
withMatch bool
|
||||
match string
|
||||
withType bool
|
||||
_type string
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
// MATCH, COUNT and TYPE options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
// we do nothing with count
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if _, err := strconv.Atoi(args[1]); err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match, args = args[1], args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "type" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withType = true
|
||||
opts._type, args = strings.ToLower(args[1]), args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// We return _all_ (matched) keys every time.
|
||||
|
||||
if opts.cursor != 0 {
|
||||
// Invalid cursor.
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
|
||||
var keys []string
|
||||
|
||||
if opts.withType {
|
||||
keys = make([]string, 0)
|
||||
for k, t := range db.keys {
|
||||
// type must be given exactly; no pattern matching is performed
|
||||
if t == opts._type {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys) // To make things deterministic.
|
||||
} else {
|
||||
keys = db.allKeys()
|
||||
}
|
||||
|
||||
if opts.withMatch {
|
||||
keys, _ = matchKeys(keys, opts.match)
|
||||
}
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(len(keys))
|
||||
for _, k := range keys {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// COPY
|
||||
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts = struct {
|
||||
from string
|
||||
to string
|
||||
destinationDB int
|
||||
replace bool
|
||||
}{
|
||||
destinationDB: -1,
|
||||
}
|
||||
|
||||
opts.from, opts.to, args = args[0], args[1], args[2:]
|
||||
for len(args) > 0 {
|
||||
switch strings.ToLower(args[0]) {
|
||||
case "db":
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
db, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if db < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
return
|
||||
}
|
||||
opts.destinationDB = db
|
||||
args = args[2:]
|
||||
case "replace":
|
||||
opts.replace = true
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
fromDB, toDB := ctx.selectedDB, opts.destinationDB
|
||||
if toDB == -1 {
|
||||
toDB = fromDB
|
||||
}
|
||||
|
||||
if fromDB == toDB && opts.from == opts.to {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
|
||||
if !m.db(fromDB).exists(opts.from) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if !opts.replace {
|
||||
if m.db(toDB).exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.copy(m.db(fromDB), opts.from, m.db(toDB), opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
609
vendor/github.com/alicebob/miniredis/v2/cmd_geo.go
generated
vendored
609
vendor/github.com/alicebob/miniredis/v2/cmd_geo.go
generated
vendored
@@ -1,609 +0,0 @@
|
||||
// Commands from https://redis.io/commands#geo
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsGeo handles GEOADD, GEORADIUS etc.
|
||||
func commandsGeo(m *Miniredis) {
|
||||
m.srv.Register("GEOADD", m.cmdGeoadd)
|
||||
m.srv.Register("GEODIST", m.cmdGeodist)
|
||||
m.srv.Register("GEOPOS", m.cmdGeopos)
|
||||
m.srv.Register("GEORADIUS", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUS_RO", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUSBYMEMBER", m.cmdGeoradiusbymember)
|
||||
m.srv.Register("GEORADIUSBYMEMBER_RO", m.cmdGeoradiusbymember)
|
||||
}
|
||||
|
||||
// GEOADD
|
||||
func (m *Miniredis) cmdGeoadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 || len(args[1:])%3 != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
toSet := map[string]float64{}
|
||||
for len(args) > 2 {
|
||||
rawLong, rawLat, name := args[0], args[1], args[2]
|
||||
args = args[3:]
|
||||
longitude, err := strconv.ParseFloat(rawLong, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(rawLat, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
|
||||
if latitude < -85.05112878 ||
|
||||
latitude > 85.05112878 ||
|
||||
longitude < -180 ||
|
||||
longitude > 180 {
|
||||
c.WriteError(fmt.Sprintf("ERR invalid longitude,latitude pair %.6f,%.6f", longitude, latitude))
|
||||
return
|
||||
}
|
||||
|
||||
toSet[name] = float64(toGeohash(longitude, latitude))
|
||||
}
|
||||
|
||||
set := 0
|
||||
for name, score := range toSet {
|
||||
if db.ssetAdd(key, score, name) {
|
||||
set++
|
||||
}
|
||||
}
|
||||
c.WriteInt(set)
|
||||
})
|
||||
}
|
||||
|
||||
// GEODIST
|
||||
func (m *Miniredis) cmdGeodist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, from, to, args := args[0], args[1], args[2], args[3:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
unit := "m"
|
||||
if len(args) > 0 {
|
||||
unit, args = args[0], args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
toMeter := parseUnit(unit)
|
||||
if toMeter == 0 {
|
||||
c.WriteError(msgUnsupportedUnit)
|
||||
return
|
||||
}
|
||||
|
||||
members := db.sortedsetKeys[key]
|
||||
fromD, okFrom := members.get(from)
|
||||
toD, okTo := members.get(to)
|
||||
if !okFrom || !okTo {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
fromLo, fromLat := fromGeohash(uint64(fromD))
|
||||
toLo, toLat := fromGeohash(uint64(toD))
|
||||
|
||||
dist := distance(fromLat, fromLo, toLat, toLo) / toMeter
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", dist))
|
||||
})
|
||||
}
|
||||
|
||||
// GEOPOS
|
||||
func (m *Miniredis) cmdGeopos(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(args))
|
||||
for _, l := range args {
|
||||
if !db.ssetExists(key, l) {
|
||||
c.WriteLen(-1)
|
||||
continue
|
||||
}
|
||||
score := db.ssetScore(key, l)
|
||||
c.WriteLen(2)
|
||||
long, lat := fromGeohash(uint64(score))
|
||||
c.WriteBulk(fmt.Sprintf("%f", long))
|
||||
c.WriteBulk(fmt.Sprintf("%f", lat))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type geoDistance struct {
|
||||
Name string
|
||||
Score float64
|
||||
Distance float64
|
||||
Longitude float64
|
||||
Latitude float64
|
||||
}
|
||||
|
||||
// GEORADIUS and GEORADIUS_RO
|
||||
func (m *Miniredis) cmdGeoradius(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 5 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
longitude, err := strconv.ParseFloat(args[1], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
radius, err := strconv.ParseFloat(args[3], 64)
|
||||
if err != nil || radius < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
toMeter := parseUnit(args[4])
|
||||
if toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[5:]
|
||||
|
||||
var opts struct {
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUS_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
members := db.ssetElements(key)
|
||||
|
||||
matches := withinRadius(members, longitude, latitude, radius*toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GEORADIUSBYMEMBER and GEORADIUSBYMEMBER_RO
|
||||
func (m *Miniredis) cmdGeoradiusbymember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
member string
|
||||
radius float64
|
||||
toMeter float64
|
||||
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}{
|
||||
key: args[0],
|
||||
member: args[1],
|
||||
}
|
||||
|
||||
r, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil || r < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
opts.radius = r
|
||||
|
||||
opts.toMeter = parseUnit(args[3])
|
||||
if opts.toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[4:]
|
||||
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUSBYMEMBER_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(opts.key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// get position of member
|
||||
if !db.ssetExists(opts.key, opts.member) {
|
||||
c.WriteError("ERR could not decode requested zset member")
|
||||
return
|
||||
}
|
||||
score := db.ssetScore(opts.key, opts.member)
|
||||
longitude, latitude := fromGeohash(uint64(score))
|
||||
|
||||
members := db.ssetElements(opts.key)
|
||||
matches := withinRadius(members, longitude, latitude, opts.radius*opts.toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/opts.toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/opts.toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func withinRadius(members []ssElem, longitude, latitude, radius float64) []geoDistance {
|
||||
matches := []geoDistance{}
|
||||
for _, el := range members {
|
||||
elLo, elLat := fromGeohash(uint64(el.score))
|
||||
distanceInMeter := distance(latitude, longitude, elLat, elLo)
|
||||
|
||||
if distanceInMeter <= radius {
|
||||
matches = append(matches, geoDistance{
|
||||
Name: el.member,
|
||||
Score: el.score,
|
||||
Distance: distanceInMeter,
|
||||
Longitude: elLo,
|
||||
Latitude: elLat,
|
||||
})
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
func parseUnit(u string) float64 {
|
||||
switch u {
|
||||
case "m":
|
||||
return 1
|
||||
case "km":
|
||||
return 1000
|
||||
case "mi":
|
||||
return 1609.34
|
||||
case "ft":
|
||||
return 0.3048
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
683
vendor/github.com/alicebob/miniredis/v2/cmd_hash.go
generated
vendored
683
vendor/github.com/alicebob/miniredis/v2/cmd_hash.go
generated
vendored
@@ -1,683 +0,0 @@
|
||||
// Commands from https://redis.io/commands#hash
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsHash handles all hash value operations.
|
||||
func commandsHash(m *Miniredis) {
|
||||
m.srv.Register("HDEL", m.cmdHdel)
|
||||
m.srv.Register("HEXISTS", m.cmdHexists)
|
||||
m.srv.Register("HGET", m.cmdHget)
|
||||
m.srv.Register("HGETALL", m.cmdHgetall)
|
||||
m.srv.Register("HINCRBY", m.cmdHincrby)
|
||||
m.srv.Register("HINCRBYFLOAT", m.cmdHincrbyfloat)
|
||||
m.srv.Register("HKEYS", m.cmdHkeys)
|
||||
m.srv.Register("HLEN", m.cmdHlen)
|
||||
m.srv.Register("HMGET", m.cmdHmget)
|
||||
m.srv.Register("HMSET", m.cmdHmset)
|
||||
m.srv.Register("HSET", m.cmdHset)
|
||||
m.srv.Register("HSETNX", m.cmdHsetnx)
|
||||
m.srv.Register("HSTRLEN", m.cmdHstrlen)
|
||||
m.srv.Register("HVALS", m.cmdHvals)
|
||||
m.srv.Register("HSCAN", m.cmdHscan)
|
||||
}
|
||||
|
||||
// HSET
|
||||
func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, pairs := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if len(pairs)%2 == 1 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
new := db.hashSet(key, pairs...)
|
||||
c.WriteInt(new)
|
||||
})
|
||||
}
|
||||
|
||||
// HSETNX
|
||||
func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
value string
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
value: args[2],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.hashKeys[opts.key]; !ok {
|
||||
db.hashKeys[opts.key] = map[string]string{}
|
||||
db.keys[opts.key] = "hash"
|
||||
}
|
||||
_, ok := db.hashKeys[opts.key][opts.field]
|
||||
if ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
db.hashKeys[opts.key][opts.field] = opts.value
|
||||
db.keyVersion[opts.key]++
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// HMSET
|
||||
func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, args := args[0], args[1:]
|
||||
if len(args)%2 != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
for len(args) > 0 {
|
||||
field, value := args[0], args[1]
|
||||
args = args[2:]
|
||||
db.hashSet(key, field, value)
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// HGET
|
||||
func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, field := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
value, ok := db.hashKeys[key][field]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteBulk(value)
|
||||
})
|
||||
}
|
||||
|
||||
// HDEL
|
||||
func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
fields []string
|
||||
}{
|
||||
key: args[0],
|
||||
fields: args[1:],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
// No key is zero deleted
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
deleted := 0
|
||||
for _, f := range opts.fields {
|
||||
_, ok := db.hashKeys[opts.key][f]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
delete(db.hashKeys[opts.key], f)
|
||||
deleted++
|
||||
}
|
||||
c.WriteInt(deleted)
|
||||
|
||||
// Nothing left. Remove the whole key.
|
||||
if len(db.hashKeys[opts.key]) == 0 {
|
||||
db.del(opts.key, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HEXISTS
|
||||
func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.hashKeys[opts.key][opts.field]; !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// HGETALL
|
||||
func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteMapLen(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteMapLen(len(db.hashKeys[key]))
|
||||
for _, k := range db.hashFields(key) {
|
||||
c.WriteBulk(k)
|
||||
c.WriteBulk(db.hashGet(key, k))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HKEYS
|
||||
func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
if db.t(key) != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
fields := db.hashFields(key)
|
||||
c.WriteLen(len(fields))
|
||||
for _, f := range fields {
|
||||
c.WriteBulk(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HSTRLEN
|
||||
func (m *Miniredis) cmdHstrlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
hash, key := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[hash]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
keys := db.hashKeys[hash]
|
||||
c.WriteInt(len(keys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// HVALS
|
||||
func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
vals := db.hashValues(key)
|
||||
c.WriteLen(len(vals))
|
||||
for _, v := range vals {
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HLEN
|
||||
func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(len(db.hashKeys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// HMGET
|
||||
func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
f, ok := db.hashKeys[key]
|
||||
if !ok {
|
||||
f = map[string]string{}
|
||||
}
|
||||
|
||||
c.WriteLen(len(args) - 1)
|
||||
for _, k := range args[1:] {
|
||||
v, ok := f[k]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
continue
|
||||
}
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HINCRBY
|
||||
func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
delta int
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
if ok := optInt(c, args[2], &opts.delta); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := db.hashIncr(opts.key, opts.field, opts.delta)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteInt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// HINCRBYFLOAT
|
||||
func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
delta *big.Float
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
delta, _, err := big.ParseFloat(args[2], 10, 128, 0)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidFloat)
|
||||
return
|
||||
}
|
||||
opts.delta = delta
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := db.hashIncrfloat(opts.key, opts.field, opts.delta)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteBulk(formatBig(v))
|
||||
})
|
||||
}
|
||||
|
||||
// HSCAN
|
||||
func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
cursor int
|
||||
withMatch bool
|
||||
match string
|
||||
}{
|
||||
key: args[0],
|
||||
}
|
||||
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
|
||||
// MATCH and COUNT options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
// we do nothing with count
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
_, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match, args = args[1], args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// return _all_ (matched) keys every time
|
||||
|
||||
if opts.cursor != 0 {
|
||||
// Invalid cursor.
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
if db.exists(opts.key) && db.t(opts.key) != "hash" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.hashFields(opts.key)
|
||||
if opts.withMatch {
|
||||
members, _ = matchKeys(members, opts.match)
|
||||
}
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
// HSCAN gives key, values.
|
||||
c.WriteLen(len(members) * 2)
|
||||
for _, k := range members {
|
||||
c.WriteBulk(k)
|
||||
c.WriteBulk(db.hashGet(opts.key, k))
|
||||
}
|
||||
})
|
||||
}
|
||||
95
vendor/github.com/alicebob/miniredis/v2/cmd_hll.go
generated
vendored
95
vendor/github.com/alicebob/miniredis/v2/cmd_hll.go
generated
vendored
@@ -1,95 +0,0 @@
|
||||
package miniredis
|
||||
|
||||
import "github.com/alicebob/miniredis/v2/server"
|
||||
|
||||
// commandsHll handles all hll related operations.
|
||||
func commandsHll(m *Miniredis) {
|
||||
m.srv.Register("PFADD", m.cmdPfadd)
|
||||
m.srv.Register("PFCOUNT", m.cmdPfcount)
|
||||
m.srv.Register("PFMERGE", m.cmdPfmerge)
|
||||
}
|
||||
|
||||
// PFADD
|
||||
func (m *Miniredis) cmdPfadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, items := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "hll" {
|
||||
c.WriteError(ErrNotValidHllValue.Error())
|
||||
return
|
||||
}
|
||||
|
||||
altered := db.hllAdd(key, items...)
|
||||
c.WriteInt(altered)
|
||||
})
|
||||
}
|
||||
|
||||
// PFCOUNT
|
||||
func (m *Miniredis) cmdPfcount(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count, err := db.hllCount(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// PFMERGE
|
||||
func (m *Miniredis) cmdPfmerge(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if err := db.hllMerge(keys); err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
40
vendor/github.com/alicebob/miniredis/v2/cmd_info.go
generated
vendored
40
vendor/github.com/alicebob/miniredis/v2/cmd_info.go
generated
vendored
@@ -1,40 +0,0 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// Command 'INFO' from https://redis.io/commands/info/
|
||||
func (m *Miniredis) cmdInfo(c *server.Peer, cmd string, args []string) {
|
||||
if !m.isValidCMD(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
const (
|
||||
clientsSectionName = "clients"
|
||||
clientsSectionContent = "# Clients\nconnected_clients:%d\r\n"
|
||||
)
|
||||
|
||||
var result string
|
||||
|
||||
for _, key := range args {
|
||||
if key != clientsSectionName {
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("section (%s) is not supported", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
result = fmt.Sprintf(clientsSectionContent, m.Server().ClientsLen())
|
||||
|
||||
c.WriteBulk(result)
|
||||
})
|
||||
}
|
||||
986
vendor/github.com/alicebob/miniredis/v2/cmd_list.go
generated
vendored
986
vendor/github.com/alicebob/miniredis/v2/cmd_list.go
generated
vendored
@@ -1,986 +0,0 @@
|
||||
// Commands from https://redis.io/commands#list
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
type leftright int
|
||||
|
||||
const (
|
||||
left leftright = iota
|
||||
right
|
||||
)
|
||||
|
||||
// commandsList handles list commands (mostly L*)
|
||||
func commandsList(m *Miniredis) {
|
||||
m.srv.Register("BLPOP", m.cmdBlpop)
|
||||
m.srv.Register("BRPOP", m.cmdBrpop)
|
||||
m.srv.Register("BRPOPLPUSH", m.cmdBrpoplpush)
|
||||
m.srv.Register("LINDEX", m.cmdLindex)
|
||||
m.srv.Register("LPOS", m.cmdLpos)
|
||||
m.srv.Register("LINSERT", m.cmdLinsert)
|
||||
m.srv.Register("LLEN", m.cmdLlen)
|
||||
m.srv.Register("LPOP", m.cmdLpop)
|
||||
m.srv.Register("LPUSH", m.cmdLpush)
|
||||
m.srv.Register("LPUSHX", m.cmdLpushx)
|
||||
m.srv.Register("LRANGE", m.cmdLrange)
|
||||
m.srv.Register("LREM", m.cmdLrem)
|
||||
m.srv.Register("LSET", m.cmdLset)
|
||||
m.srv.Register("LTRIM", m.cmdLtrim)
|
||||
m.srv.Register("RPOP", m.cmdRpop)
|
||||
m.srv.Register("RPOPLPUSH", m.cmdRpoplpush)
|
||||
m.srv.Register("RPUSH", m.cmdRpush)
|
||||
m.srv.Register("RPUSHX", m.cmdRpushx)
|
||||
m.srv.Register("LMOVE", m.cmdLmove)
|
||||
}
|
||||
|
||||
// BLPOP
|
||||
func (m *Miniredis) cmdBlpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdBXpop(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// BRPOP
|
||||
func (m *Miniredis) cmdBrpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdBXpop(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdBXpop(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
timeoutS := args[len(args)-1]
|
||||
keys := args[:len(args)-1]
|
||||
|
||||
timeout, err := strconv.Atoi(timeoutS)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidTimeout)
|
||||
return
|
||||
}
|
||||
if timeout < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgNegTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
blocking(
|
||||
m,
|
||||
c,
|
||||
time.Duration(timeout)*time.Second,
|
||||
func(c *server.Peer, ctx *connCtx) bool {
|
||||
db := m.db(ctx.selectedDB)
|
||||
for _, key := range keys {
|
||||
if !db.exists(key) {
|
||||
continue
|
||||
}
|
||||
if db.t(key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(db.listKeys[key]) == 0 {
|
||||
continue
|
||||
}
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(key)
|
||||
var v string
|
||||
switch lr {
|
||||
case left:
|
||||
v = db.listLpop(key)
|
||||
case right:
|
||||
v = db.listPop(key)
|
||||
}
|
||||
c.WriteBulk(v)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
func(c *server.Peer) {
|
||||
// timeout
|
||||
c.WriteLen(-1)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// LINDEX
|
||||
func (m *Miniredis) cmdLindex(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, offsets := args[0], args[1]
|
||||
|
||||
offset, err := strconv.Atoi(offsets)
|
||||
if err != nil || offsets == "-0" {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[key]
|
||||
if offset < 0 {
|
||||
offset = len(l) + offset
|
||||
}
|
||||
if offset < 0 || offset > len(l)-1 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteBulk(l[offset])
|
||||
})
|
||||
}
|
||||
|
||||
// LPOS key element [RANK rank] [COUNT num-matches] [MAXLEN len]
|
||||
func (m *Miniredis) cmdLpos(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
// Extract options from arguments if present.
|
||||
//
|
||||
// Redis allows duplicate options and uses the last specified.
|
||||
// `LPOS key term RANK 1 RANK 2` is effectively the same as
|
||||
// `LPOS key term RANK 2`
|
||||
if len(args)%2 == 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
rank, count := 1, 1 // Default values
|
||||
var maxlen int // Default value is the list length (see below)
|
||||
var countSpecified, maxlenSpecified bool
|
||||
if len(args) > 2 {
|
||||
for i := 2; i < len(args); i++ {
|
||||
if i%2 == 0 {
|
||||
val := args[i+1]
|
||||
var err error
|
||||
switch strings.ToLower(args[i]) {
|
||||
case "rank":
|
||||
if rank, err = strconv.Atoi(val); err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if rank == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgRankIsZero)
|
||||
return
|
||||
}
|
||||
case "count":
|
||||
countSpecified = true
|
||||
if count, err = strconv.Atoi(val); err != nil || count < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgCountIsNegative)
|
||||
return
|
||||
}
|
||||
case "maxlen":
|
||||
maxlenSpecified = true
|
||||
if maxlen, err = strconv.Atoi(val); err != nil || maxlen < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgMaxLengthIsNegative)
|
||||
return
|
||||
}
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
key, element := args[0], args[1]
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
l := db.listKeys[key]
|
||||
|
||||
// RANK cannot be zero (see above).
|
||||
// If RANK is positive search forward (left to right).
|
||||
// If RANK is negative search backward (right to left).
|
||||
// Iterator returns true to continue iterating.
|
||||
iterate := func(iterator func(i int, e string) bool) {
|
||||
comparisons := len(l)
|
||||
// Only use max length if specified, not zero, and less than total length.
|
||||
// When max length is specified, but is zero, this means "unlimited".
|
||||
if maxlenSpecified && maxlen != 0 && maxlen < len(l) {
|
||||
comparisons = maxlen
|
||||
}
|
||||
if rank > 0 {
|
||||
for i := 0; i < comparisons; i++ {
|
||||
if resume := iterator(i, l[i]); !resume {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if rank < 0 {
|
||||
start := len(l) - 1
|
||||
end := len(l) - comparisons
|
||||
for i := start; i >= end; i-- {
|
||||
if resume := iterator(i, l[i]); !resume {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var currentRank, currentCount int
|
||||
vals := make([]int, 0, count)
|
||||
iterate(func(i int, e string) bool {
|
||||
if e == element {
|
||||
currentRank++
|
||||
// Only collect values only after surpassing the absolute value of rank.
|
||||
if rank > 0 && currentRank < rank {
|
||||
return true
|
||||
}
|
||||
if rank < 0 && currentRank < -rank {
|
||||
return true
|
||||
}
|
||||
vals = append(vals, i)
|
||||
currentCount++
|
||||
if currentCount == count {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !countSpecified && len(vals) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if !countSpecified && len(vals) == 1 {
|
||||
c.WriteInt(vals[0])
|
||||
return
|
||||
}
|
||||
c.WriteLen(len(vals))
|
||||
for _, val := range vals {
|
||||
c.WriteInt(val)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// LINSERT
|
||||
func (m *Miniredis) cmdLinsert(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
where := 0
|
||||
switch strings.ToLower(args[1]) {
|
||||
case "before":
|
||||
where = -1
|
||||
case "after":
|
||||
where = +1
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
pivot := args[2]
|
||||
value := args[3]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[key]
|
||||
for i, el := range l {
|
||||
if el != pivot {
|
||||
continue
|
||||
}
|
||||
|
||||
if where < 0 {
|
||||
l = append(l[:i], append(listKey{value}, l[i:]...)...)
|
||||
} else {
|
||||
if i == len(l)-1 {
|
||||
l = append(l, value)
|
||||
} else {
|
||||
l = append(l[:i+1], append(listKey{value}, l[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
db.listKeys[key] = l
|
||||
db.keyVersion[key]++
|
||||
c.WriteInt(len(l))
|
||||
return
|
||||
}
|
||||
c.WriteInt(-1)
|
||||
})
|
||||
}
|
||||
|
||||
// LLEN
|
||||
func (m *Miniredis) cmdLlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key. That's zero length.
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(len(db.listKeys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// LPOP
|
||||
func (m *Miniredis) cmdLpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpop(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// RPOP
|
||||
func (m *Miniredis) cmdRpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpop(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdXpop(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
withCount bool
|
||||
count int
|
||||
}
|
||||
|
||||
opts.key, args = args[0], args[1:]
|
||||
if len(args) > 0 {
|
||||
if ok := optInt(c, args[0], &opts.count); !ok {
|
||||
return
|
||||
}
|
||||
if opts.count < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgOutOfRange)
|
||||
return
|
||||
}
|
||||
opts.withCount = true
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
// non-existing key is fine
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(opts.key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if opts.withCount {
|
||||
var popped []string
|
||||
for opts.count > 0 && len(db.listKeys[opts.key]) > 0 {
|
||||
switch lr {
|
||||
case left:
|
||||
popped = append(popped, db.listLpop(opts.key))
|
||||
case right:
|
||||
popped = append(popped, db.listPop(opts.key))
|
||||
}
|
||||
opts.count -= 1
|
||||
}
|
||||
if len(popped) == 0 {
|
||||
c.WriteLen(-1)
|
||||
} else {
|
||||
c.WriteStrings(popped)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var elem string
|
||||
switch lr {
|
||||
case left:
|
||||
elem = db.listLpop(opts.key)
|
||||
case right:
|
||||
elem = db.listPop(opts.key)
|
||||
}
|
||||
c.WriteBulk(elem)
|
||||
})
|
||||
}
|
||||
|
||||
// LPUSH
|
||||
func (m *Miniredis) cmdLpush(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpush(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// RPUSH
|
||||
func (m *Miniredis) cmdRpush(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpush(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdXpush(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
var newLen int
|
||||
for _, value := range args {
|
||||
switch lr {
|
||||
case left:
|
||||
newLen = db.listLpush(key, value)
|
||||
case right:
|
||||
newLen = db.listPush(key, value)
|
||||
}
|
||||
}
|
||||
c.WriteInt(newLen)
|
||||
})
|
||||
}
|
||||
|
||||
// LPUSHX
|
||||
func (m *Miniredis) cmdLpushx(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpushx(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// RPUSHX
|
||||
func (m *Miniredis) cmdRpushx(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpushx(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdXpushx(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if db.t(key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
var newLen int
|
||||
for _, value := range args {
|
||||
switch lr {
|
||||
case left:
|
||||
newLen = db.listLpush(key, value)
|
||||
case right:
|
||||
newLen = db.listPush(key, value)
|
||||
}
|
||||
}
|
||||
c.WriteInt(newLen)
|
||||
})
|
||||
}
|
||||
|
||||
// LRANGE
|
||||
func (m *Miniredis) cmdLrange(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
start int
|
||||
end int
|
||||
}{
|
||||
key: args[0],
|
||||
}
|
||||
if ok := optInt(c, args[1], &opts.start); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optInt(c, args[2], &opts.end); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
if len(l) == 0 {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
|
||||
rs, re := redisRange(len(l), opts.start, opts.end, false)
|
||||
c.WriteLen(re - rs)
|
||||
for _, el := range l[rs:re] {
|
||||
c.WriteBulk(el)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// LREM
|
||||
func (m *Miniredis) cmdLrem(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
count int
|
||||
value string
|
||||
}
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.count); !ok {
|
||||
return
|
||||
}
|
||||
opts.value = args[2]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if db.t(opts.key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
if opts.count < 0 {
|
||||
reverseSlice(l)
|
||||
}
|
||||
deleted := 0
|
||||
newL := []string{}
|
||||
toDelete := len(l)
|
||||
if opts.count < 0 {
|
||||
toDelete = -opts.count
|
||||
}
|
||||
if opts.count > 0 {
|
||||
toDelete = opts.count
|
||||
}
|
||||
for _, el := range l {
|
||||
if el == opts.value {
|
||||
if toDelete > 0 {
|
||||
deleted++
|
||||
toDelete--
|
||||
continue
|
||||
}
|
||||
}
|
||||
newL = append(newL, el)
|
||||
}
|
||||
if opts.count < 0 {
|
||||
reverseSlice(newL)
|
||||
}
|
||||
if len(newL) == 0 {
|
||||
db.del(opts.key, true)
|
||||
} else {
|
||||
db.listKeys[opts.key] = newL
|
||||
db.keyVersion[opts.key]++
|
||||
}
|
||||
|
||||
c.WriteInt(deleted)
|
||||
})
|
||||
}
|
||||
|
||||
// LSET
|
||||
func (m *Miniredis) cmdLset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
index int
|
||||
value string
|
||||
}
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.index); !ok {
|
||||
return
|
||||
}
|
||||
opts.value = args[2]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
if db.t(opts.key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
index := opts.index
|
||||
if index < 0 {
|
||||
index = len(l) + index
|
||||
}
|
||||
if index < 0 || index > len(l)-1 {
|
||||
c.WriteError(msgOutOfRange)
|
||||
return
|
||||
}
|
||||
l[index] = opts.value
|
||||
db.keyVersion[opts.key]++
|
||||
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// LTRIM
|
||||
func (m *Miniredis) cmdLtrim(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.start); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optInt(c, args[2], &opts.end); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
c.WriteOK()
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
rs, re := redisRange(len(l), opts.start, opts.end, false)
|
||||
l = l[rs:re]
|
||||
if len(l) == 0 {
|
||||
db.del(opts.key, true)
|
||||
} else {
|
||||
db.listKeys[opts.key] = l
|
||||
db.keyVersion[opts.key]++
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// RPOPLPUSH
|
||||
func (m *Miniredis) cmdRpoplpush(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
src, dst := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(src) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(src) != "list" || (db.exists(dst) && db.t(dst) != "list") {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
elem := db.listPop(src)
|
||||
db.listLpush(dst, elem)
|
||||
c.WriteBulk(elem)
|
||||
})
|
||||
}
|
||||
|
||||
// BRPOPLPUSH
|
||||
func (m *Miniredis) cmdBrpoplpush(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
src string
|
||||
dst string
|
||||
timeout int
|
||||
}
|
||||
opts.src = args[0]
|
||||
opts.dst = args[1]
|
||||
if ok := optIntErr(c, args[2], &opts.timeout, msgInvalidTimeout); !ok {
|
||||
return
|
||||
}
|
||||
if opts.timeout < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgNegTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
blocking(
|
||||
m,
|
||||
c,
|
||||
time.Duration(opts.timeout)*time.Second,
|
||||
func(c *server.Peer, ctx *connCtx) bool {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.src) {
|
||||
return false
|
||||
}
|
||||
if db.t(opts.src) != "list" || (db.exists(opts.dst) && db.t(opts.dst) != "list") {
|
||||
c.WriteError(msgWrongType)
|
||||
return true
|
||||
}
|
||||
if len(db.listKeys[opts.src]) == 0 {
|
||||
return false
|
||||
}
|
||||
elem := db.listPop(opts.src)
|
||||
db.listLpush(opts.dst, elem)
|
||||
c.WriteBulk(elem)
|
||||
return true
|
||||
},
|
||||
func(c *server.Peer) {
|
||||
// timeout
|
||||
c.WriteLen(-1)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// LMOVE
|
||||
func (m *Miniredis) cmdLmove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
src string
|
||||
dst string
|
||||
srcDir string
|
||||
dstDir string
|
||||
}{
|
||||
src: args[0],
|
||||
dst: args[1],
|
||||
srcDir: strings.ToLower(args[2]),
|
||||
dstDir: strings.ToLower(args[3]),
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.src) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(opts.src) != "list" || (db.exists(opts.dst) && db.t(opts.dst) != "list") {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
var elem string
|
||||
switch opts.srcDir {
|
||||
case "left":
|
||||
elem = db.listLpop(opts.src)
|
||||
case "right":
|
||||
elem = db.listPop(opts.src)
|
||||
default:
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
switch opts.dstDir {
|
||||
case "left":
|
||||
db.listLpush(opts.dst, elem)
|
||||
case "right":
|
||||
db.listPush(opts.dst, elem)
|
||||
default:
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
c.WriteBulk(elem)
|
||||
})
|
||||
}
|
||||
256
vendor/github.com/alicebob/miniredis/v2/cmd_pubsub.go
generated
vendored
256
vendor/github.com/alicebob/miniredis/v2/cmd_pubsub.go
generated
vendored
@@ -1,256 +0,0 @@
|
||||
// Commands from https://redis.io/commands#pubsub
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsPubsub handles all PUB/SUB operations.
|
||||
func commandsPubsub(m *Miniredis) {
|
||||
m.srv.Register("SUBSCRIBE", m.cmdSubscribe)
|
||||
m.srv.Register("UNSUBSCRIBE", m.cmdUnsubscribe)
|
||||
m.srv.Register("PSUBSCRIBE", m.cmdPsubscribe)
|
||||
m.srv.Register("PUNSUBSCRIBE", m.cmdPunsubscribe)
|
||||
m.srv.Register("PUBLISH", m.cmdPublish)
|
||||
m.srv.Register("PUBSUB", m.cmdPubSub)
|
||||
}
|
||||
|
||||
// SUBSCRIBE
|
||||
func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
for _, channel := range args {
|
||||
n := sub.Subscribe(channel)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("subscribe")
|
||||
w.WriteBulk(channel)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// UNSUBSCRIBE
|
||||
func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
channels := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
|
||||
if len(channels) == 0 {
|
||||
channels = sub.Channels()
|
||||
}
|
||||
|
||||
// there is no de-duplication
|
||||
for _, channel := range channels {
|
||||
n := sub.Unsubscribe(channel)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("unsubscribe")
|
||||
w.WriteBulk(channel)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
if len(channels) == 0 {
|
||||
// special case: there is always a reply
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("unsubscribe")
|
||||
w.WriteNull()
|
||||
w.WriteInt(0)
|
||||
})
|
||||
}
|
||||
|
||||
if sub.Count() == 0 {
|
||||
endSubscriber(m, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PSUBSCRIBE
|
||||
func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
for _, pat := range args {
|
||||
n := sub.Psubscribe(pat)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("psubscribe")
|
||||
w.WriteBulk(pat)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PUNSUBSCRIBE
|
||||
func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
patterns := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
|
||||
if len(patterns) == 0 {
|
||||
patterns = sub.Patterns()
|
||||
}
|
||||
|
||||
// there is no de-duplication
|
||||
for _, pat := range patterns {
|
||||
n := sub.Punsubscribe(pat)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("punsubscribe")
|
||||
w.WriteBulk(pat)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
if len(patterns) == 0 {
|
||||
// special case: there is always a reply
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("punsubscribe")
|
||||
w.WriteNull()
|
||||
w.WriteInt(0)
|
||||
})
|
||||
}
|
||||
|
||||
if sub.Count() == 0 {
|
||||
endSubscriber(m, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PUBLISH
|
||||
func (m *Miniredis) cmdPublish(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, mesg := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteInt(m.publish(channel, mesg))
|
||||
})
|
||||
}
|
||||
|
||||
// PUBSUB
|
||||
func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
subcommand := strings.ToUpper(args[0])
|
||||
subargs := args[1:]
|
||||
var argsOk bool
|
||||
|
||||
switch subcommand {
|
||||
case "CHANNELS":
|
||||
argsOk = len(subargs) < 2
|
||||
case "NUMSUB":
|
||||
argsOk = true
|
||||
case "NUMPAT":
|
||||
argsOk = len(subargs) == 0
|
||||
default:
|
||||
argsOk = false
|
||||
}
|
||||
|
||||
if !argsOk {
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf(msgFPubsubUsage, subcommand))
|
||||
return
|
||||
}
|
||||
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
switch subcommand {
|
||||
case "CHANNELS":
|
||||
pat := ""
|
||||
if len(subargs) == 1 {
|
||||
pat = subargs[0]
|
||||
}
|
||||
|
||||
allsubs := m.allSubscribers()
|
||||
channels := activeChannels(allsubs, pat)
|
||||
|
||||
c.WriteLen(len(channels))
|
||||
for _, channel := range channels {
|
||||
c.WriteBulk(channel)
|
||||
}
|
||||
|
||||
case "NUMSUB":
|
||||
subs := m.allSubscribers()
|
||||
c.WriteLen(len(subargs) * 2)
|
||||
for _, channel := range subargs {
|
||||
c.WriteBulk(channel)
|
||||
c.WriteInt(countSubs(subs, channel))
|
||||
}
|
||||
|
||||
case "NUMPAT":
|
||||
c.WriteInt(countPsubs(m.allSubscribers()))
|
||||
}
|
||||
})
|
||||
}
|
||||
281
vendor/github.com/alicebob/miniredis/v2/cmd_scripting.go
generated
vendored
281
vendor/github.com/alicebob/miniredis/v2/cmd_scripting.go
generated
vendored
@@ -1,281 +0,0 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
luajson "github.com/alicebob/gopher-json"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
"github.com/yuin/gopher-lua/parse"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
func commandsScripting(m *Miniredis) {
|
||||
m.srv.Register("EVAL", m.cmdEval)
|
||||
m.srv.Register("EVALSHA", m.cmdEvalsha)
|
||||
m.srv.Register("SCRIPT", m.cmdScript)
|
||||
}
|
||||
|
||||
// Execute lua. Needs to run m.Lock()ed, from within withTx().
|
||||
// Returns true if the lua was OK (and hence should be cached).
|
||||
func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) bool {
|
||||
l := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||
defer l.Close()
|
||||
|
||||
// Taken from the go-lua manual
|
||||
for _, pair := range []struct {
|
||||
n string
|
||||
f lua.LGFunction
|
||||
}{
|
||||
{lua.LoadLibName, lua.OpenPackage},
|
||||
{lua.BaseLibName, lua.OpenBase},
|
||||
{lua.CoroutineLibName, lua.OpenCoroutine},
|
||||
{lua.TabLibName, lua.OpenTable},
|
||||
{lua.StringLibName, lua.OpenString},
|
||||
{lua.MathLibName, lua.OpenMath},
|
||||
{lua.DebugLibName, lua.OpenDebug},
|
||||
} {
|
||||
if err := l.CallByParam(lua.P{
|
||||
Fn: l.NewFunction(pair.f),
|
||||
NRet: 0,
|
||||
Protect: true,
|
||||
}, lua.LString(pair.n)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
luajson.Preload(l)
|
||||
requireGlobal(l, "cjson", "json")
|
||||
|
||||
// set global variable KEYS
|
||||
keysTable := l.NewTable()
|
||||
keysS, args := args[0], args[1:]
|
||||
keysLen, err := strconv.Atoi(keysS)
|
||||
if err != nil {
|
||||
c.WriteError(msgInvalidInt)
|
||||
return false
|
||||
}
|
||||
if keysLen < 0 {
|
||||
c.WriteError(msgNegativeKeysNumber)
|
||||
return false
|
||||
}
|
||||
if keysLen > len(args) {
|
||||
c.WriteError(msgInvalidKeysNumber)
|
||||
return false
|
||||
}
|
||||
keys, args := args[:keysLen], args[keysLen:]
|
||||
for i, k := range keys {
|
||||
l.RawSet(keysTable, lua.LNumber(i+1), lua.LString(k))
|
||||
}
|
||||
l.SetGlobal("KEYS", keysTable)
|
||||
|
||||
argvTable := l.NewTable()
|
||||
for i, a := range args {
|
||||
l.RawSet(argvTable, lua.LNumber(i+1), lua.LString(a))
|
||||
}
|
||||
l.SetGlobal("ARGV", argvTable)
|
||||
|
||||
redisFuncs, redisConstants := mkLua(m.srv, c)
|
||||
// Register command handlers
|
||||
l.Push(l.NewFunction(func(l *lua.LState) int {
|
||||
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
|
||||
for k, v := range redisConstants {
|
||||
mod.RawSetString(k, v)
|
||||
}
|
||||
l.Push(mod)
|
||||
return 1
|
||||
}))
|
||||
|
||||
l.DoString(protectGlobals)
|
||||
|
||||
l.Push(lua.LString("redis"))
|
||||
l.Call(1, 0)
|
||||
|
||||
if err := l.DoString(script); err != nil {
|
||||
c.WriteError(errLuaParseError(err))
|
||||
return false
|
||||
}
|
||||
|
||||
luaToRedis(l, c, l.Get(1))
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
script, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
ok := m.runLuaScript(c, script, args)
|
||||
if ok {
|
||||
sha := sha1Hex(script)
|
||||
m.scripts[sha] = script
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
sha, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
script, ok := m.scripts[sha]
|
||||
if !ok {
|
||||
c.WriteError(msgNoScriptFound)
|
||||
return
|
||||
}
|
||||
|
||||
m.runLuaScript(c, script, args)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
subcmd, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
switch strings.ToLower(subcmd) {
|
||||
case "load":
|
||||
if len(args) != 1 {
|
||||
c.WriteError(fmt.Sprintf(msgFScriptUsage, "LOAD"))
|
||||
return
|
||||
}
|
||||
script := args[0]
|
||||
|
||||
if _, err := parse.Parse(strings.NewReader(script), "user_script"); err != nil {
|
||||
c.WriteError(errLuaParseError(err))
|
||||
return
|
||||
}
|
||||
sha := sha1Hex(script)
|
||||
m.scripts[sha] = script
|
||||
c.WriteBulk(sha)
|
||||
|
||||
case "exists":
|
||||
c.WriteLen(len(args))
|
||||
for _, arg := range args {
|
||||
if _, ok := m.scripts[arg]; ok {
|
||||
c.WriteInt(1)
|
||||
} else {
|
||||
c.WriteInt(0)
|
||||
}
|
||||
}
|
||||
|
||||
case "flush":
|
||||
if len(args) == 1 {
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "SYNC", "ASYNC":
|
||||
args = args[1:]
|
||||
default:
|
||||
}
|
||||
}
|
||||
if len(args) != 0 {
|
||||
c.WriteError(msgScriptFlush)
|
||||
return
|
||||
}
|
||||
|
||||
m.scripts = map[string]string{}
|
||||
c.WriteOK()
|
||||
|
||||
default:
|
||||
c.WriteError(fmt.Sprintf(msgFScriptUsage, strings.ToUpper(subcmd)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func sha1Hex(s string) string {
|
||||
h := sha1.New()
|
||||
io.WriteString(h, s)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// requireGlobal imports module modName into the global namespace with the
|
||||
// identifier id. panics if an error results from the function execution
|
||||
func requireGlobal(l *lua.LState, id, modName string) {
|
||||
if err := l.CallByParam(lua.P{
|
||||
Fn: l.GetGlobal("require"),
|
||||
NRet: 1,
|
||||
Protect: true,
|
||||
}, lua.LString(modName)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mod := l.Get(-1)
|
||||
l.Pop(1)
|
||||
|
||||
l.SetGlobal(id, mod)
|
||||
}
|
||||
|
||||
// the following script protects globals
|
||||
// it is based on: http://metalua.luaforge.net/src/lib/strict.lua.html
|
||||
var protectGlobals = `
|
||||
local dbg=debug
|
||||
local mt = {}
|
||||
setmetatable(_G, mt)
|
||||
mt.__newindex = function (t, n, v)
|
||||
if dbg.getinfo(2) then
|
||||
local w = dbg.getinfo(2, "S").what
|
||||
if w ~= "C" then
|
||||
error("Script attempted to create global variable '"..tostring(n).."'", 2)
|
||||
end
|
||||
end
|
||||
rawset(t, n, v)
|
||||
end
|
||||
mt.__index = function (t, n)
|
||||
if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then
|
||||
error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2)
|
||||
end
|
||||
return rawget(t, n)
|
||||
end
|
||||
debug = nil
|
||||
|
||||
`
|
||||
112
vendor/github.com/alicebob/miniredis/v2/cmd_server.go
generated
vendored
112
vendor/github.com/alicebob/miniredis/v2/cmd_server.go
generated
vendored
@@ -1,112 +0,0 @@
|
||||
// Commands from https://redis.io/commands#server
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
func commandsServer(m *Miniredis) {
|
||||
m.srv.Register("COMMAND", m.cmdCommand)
|
||||
m.srv.Register("DBSIZE", m.cmdDbsize)
|
||||
m.srv.Register("FLUSHALL", m.cmdFlushall)
|
||||
m.srv.Register("FLUSHDB", m.cmdFlushdb)
|
||||
m.srv.Register("INFO", m.cmdInfo)
|
||||
m.srv.Register("TIME", m.cmdTime)
|
||||
}
|
||||
|
||||
// DBSIZE
|
||||
func (m *Miniredis) cmdDbsize(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
c.WriteInt(len(db.keys))
|
||||
})
|
||||
}
|
||||
|
||||
// FLUSHALL
|
||||
func (m *Miniredis) cmdFlushall(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
m.flushAll()
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// FLUSHDB
|
||||
func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
m.db(ctx.selectedDB).flush()
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// TIME
|
||||
func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
now := m.effectiveNow()
|
||||
nanos := now.UnixNano()
|
||||
seconds := nanos / 1_000_000_000
|
||||
microseconds := (nanos / 1_000) % 1_000_000
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(strconv.FormatInt(seconds, 10))
|
||||
c.WriteBulk(strconv.FormatInt(microseconds, 10))
|
||||
})
|
||||
}
|
||||
704
vendor/github.com/alicebob/miniredis/v2/cmd_set.go
generated
vendored
704
vendor/github.com/alicebob/miniredis/v2/cmd_set.go
generated
vendored
@@ -1,704 +0,0 @@
|
||||
// Commands from https://redis.io/commands#set
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsSet handles all set value operations.
|
||||
func commandsSet(m *Miniredis) {
|
||||
m.srv.Register("SADD", m.cmdSadd)
|
||||
m.srv.Register("SCARD", m.cmdScard)
|
||||
m.srv.Register("SDIFF", m.cmdSdiff)
|
||||
m.srv.Register("SDIFFSTORE", m.cmdSdiffstore)
|
||||
m.srv.Register("SINTER", m.cmdSinter)
|
||||
m.srv.Register("SINTERSTORE", m.cmdSinterstore)
|
||||
m.srv.Register("SISMEMBER", m.cmdSismember)
|
||||
m.srv.Register("SMEMBERS", m.cmdSmembers)
|
||||
m.srv.Register("SMOVE", m.cmdSmove)
|
||||
m.srv.Register("SPOP", m.cmdSpop)
|
||||
m.srv.Register("SRANDMEMBER", m.cmdSrandmember)
|
||||
m.srv.Register("SREM", m.cmdSrem)
|
||||
m.srv.Register("SUNION", m.cmdSunion)
|
||||
m.srv.Register("SUNIONSTORE", m.cmdSunionstore)
|
||||
m.srv.Register("SSCAN", m.cmdSscan)
|
||||
}
|
||||
|
||||
// SADD
|
||||
func (m *Miniredis) cmdSadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, elems := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
added := db.setAdd(key, elems...)
|
||||
c.WriteInt(added)
|
||||
})
|
||||
}
|
||||
|
||||
// SCARD
|
||||
func (m *Miniredis) cmdScard(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.setMembers(key)
|
||||
c.WriteInt(len(members))
|
||||
})
|
||||
}
|
||||
|
||||
// SDIFF
|
||||
func (m *Miniredis) cmdSdiff(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setDiff(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteSetLen(len(set))
|
||||
for k := range set {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SDIFFSTORE
|
||||
func (m *Miniredis) cmdSdiffstore(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
dest, keys := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setDiff(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db.del(dest, true)
|
||||
db.setSet(dest, set)
|
||||
c.WriteInt(len(set))
|
||||
})
|
||||
}
|
||||
|
||||
// SINTER
|
||||
func (m *Miniredis) cmdSinter(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setInter(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(set))
|
||||
for k := range set {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SINTERSTORE
|
||||
func (m *Miniredis) cmdSinterstore(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
dest, keys := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setInter(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db.del(dest, true)
|
||||
db.setSet(dest, set)
|
||||
c.WriteInt(len(set))
|
||||
})
|
||||
}
|
||||
|
||||
// SISMEMBER
|
||||
func (m *Miniredis) cmdSismember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, value := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if db.setIsMember(key, value) {
|
||||
c.WriteInt(1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(0)
|
||||
})
|
||||
}
|
||||
|
||||
// SMEMBERS
|
||||
func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteSetLen(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.setMembers(key)
|
||||
|
||||
c.WriteSetLen(len(members))
|
||||
for _, elem := range members {
|
||||
c.WriteBulk(elem)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SMOVE
|
||||
func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
src, dst, member := args[0], args[1], args[2]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(src) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(src) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if db.exists(dst) && db.t(dst) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !db.setIsMember(src, member) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
db.setRem(src, member)
|
||||
db.setAdd(dst, member)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// SPOP
|
||||
func (m *Miniredis) cmdSpop(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
withCount bool
|
||||
count int
|
||||
}{
|
||||
count: 1,
|
||||
}
|
||||
opts.key, args = args[0], args[1:]
|
||||
|
||||
if len(args) > 0 {
|
||||
v, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if v < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgOutOfRange)
|
||||
return
|
||||
}
|
||||
opts.count = v
|
||||
opts.withCount = true
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
if !opts.withCount {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(opts.key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var deleted []string
|
||||
for i := 0; i < opts.count; i++ {
|
||||
members := db.setMembers(opts.key)
|
||||
if len(members) == 0 {
|
||||
break
|
||||
}
|
||||
member := members[m.randIntn(len(members))]
|
||||
db.setRem(opts.key, member)
|
||||
deleted = append(deleted, member)
|
||||
}
|
||||
// without `count` return a single value
|
||||
if !opts.withCount {
|
||||
if len(deleted) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteBulk(deleted[0])
|
||||
return
|
||||
}
|
||||
// with `count` return a list
|
||||
c.WriteLen(len(deleted))
|
||||
for _, v := range deleted {
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SRANDMEMBER
|
||||
func (m *Miniredis) cmdSrandmember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if len(args) > 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
count := 0
|
||||
withCount := false
|
||||
if len(args) == 2 {
|
||||
var err error
|
||||
count, err = strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
withCount = true
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.setMembers(key)
|
||||
if count < 0 {
|
||||
// Non-unique elements is allowed with negative count.
|
||||
c.WriteLen(-count)
|
||||
for count != 0 {
|
||||
member := members[m.randIntn(len(members))]
|
||||
c.WriteBulk(member)
|
||||
count++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Must be unique elements.
|
||||
m.shuffle(members)
|
||||
if count > len(members) {
|
||||
count = len(members)
|
||||
}
|
||||
if !withCount {
|
||||
c.WriteBulk(members[0])
|
||||
return
|
||||
}
|
||||
c.WriteLen(count)
|
||||
for i := range make([]struct{}, count) {
|
||||
c.WriteBulk(members[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SREM
|
||||
func (m *Miniredis) cmdSrem(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, fields := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(db.setRem(key, fields...))
|
||||
})
|
||||
}
|
||||
|
||||
// SUNION
|
||||
func (m *Miniredis) cmdSunion(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setUnion(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(set))
|
||||
for k := range set {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SUNIONSTORE
|
||||
func (m *Miniredis) cmdSunionstore(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
dest, keys := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setUnion(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db.del(dest, true)
|
||||
db.setSet(dest, set)
|
||||
c.WriteInt(len(set))
|
||||
})
|
||||
}
|
||||
|
||||
// SSCAN
|
||||
func (m *Miniredis) cmdSscan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
value int
|
||||
cursor int
|
||||
count int
|
||||
withMatch bool
|
||||
match string
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
|
||||
// MATCH and COUNT options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
count, err := strconv.Atoi(args[1])
|
||||
if err != nil || count < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if count == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.count = count
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match = args[1]
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// return _all_ (matched) keys every time
|
||||
if db.exists(opts.key) && db.t(opts.key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
members := db.setMembers(opts.key)
|
||||
if opts.withMatch {
|
||||
members, _ = matchKeys(members, opts.match)
|
||||
}
|
||||
low := opts.cursor
|
||||
high := low + opts.count
|
||||
// validate high is correct
|
||||
if high > len(members) || high == 0 {
|
||||
high = len(members)
|
||||
}
|
||||
if opts.cursor > high {
|
||||
// invalid cursor
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
cursorValue := low + opts.count
|
||||
if cursorValue > len(members) {
|
||||
cursorValue = 0 // no next cursor
|
||||
}
|
||||
members = members[low:high]
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%d", cursorValue))
|
||||
c.WriteLen(len(members))
|
||||
for _, k := range members {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
1880
vendor/github.com/alicebob/miniredis/v2/cmd_sorted_set.go
generated
vendored
1880
vendor/github.com/alicebob/miniredis/v2/cmd_sorted_set.go
generated
vendored
File diff suppressed because it is too large
Load Diff
1704
vendor/github.com/alicebob/miniredis/v2/cmd_stream.go
generated
vendored
1704
vendor/github.com/alicebob/miniredis/v2/cmd_stream.go
generated
vendored
File diff suppressed because it is too large
Load Diff
1350
vendor/github.com/alicebob/miniredis/v2/cmd_string.go
generated
vendored
1350
vendor/github.com/alicebob/miniredis/v2/cmd_string.go
generated
vendored
File diff suppressed because it is too large
Load Diff
179
vendor/github.com/alicebob/miniredis/v2/cmd_transactions.go
generated
vendored
179
vendor/github.com/alicebob/miniredis/v2/cmd_transactions.go
generated
vendored
@@ -1,179 +0,0 @@
|
||||
// Commands from https://redis.io/commands#transactions
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsTransaction handles MULTI &c.
|
||||
func commandsTransaction(m *Miniredis) {
|
||||
m.srv.Register("DISCARD", m.cmdDiscard)
|
||||
m.srv.Register("EXEC", m.cmdExec)
|
||||
m.srv.Register("MULTI", m.cmdMulti)
|
||||
m.srv.Register("UNWATCH", m.cmdUnwatch)
|
||||
m.srv.Register("WATCH", m.cmdWatch)
|
||||
}
|
||||
|
||||
// MULTI
|
||||
func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
if inTx(ctx) {
|
||||
c.WriteError("ERR MULTI calls can not be nested")
|
||||
return
|
||||
}
|
||||
|
||||
startTx(ctx)
|
||||
|
||||
c.WriteOK()
|
||||
}
|
||||
|
||||
// EXEC
|
||||
func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
if !inTx(ctx) {
|
||||
c.WriteError("ERR EXEC without MULTI")
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.dirtyTransaction {
|
||||
c.WriteError("EXECABORT Transaction discarded because of previous errors.")
|
||||
// a failed EXEC finishes the tx
|
||||
stopTx(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Check WATCHed keys.
|
||||
for t, version := range ctx.watch {
|
||||
if m.db(t.db).keyVersion[t.key] > version {
|
||||
// Abort! Abort!
|
||||
stopTx(ctx)
|
||||
c.WriteLen(-1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.WriteLen(len(ctx.transaction))
|
||||
for _, cb := range ctx.transaction {
|
||||
cb(c, ctx)
|
||||
}
|
||||
// wake up anyone who waits on anything.
|
||||
m.signal.Broadcast()
|
||||
|
||||
stopTx(ctx)
|
||||
}
|
||||
|
||||
// DISCARD
|
||||
func (m *Miniredis) cmdDiscard(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if !inTx(ctx) {
|
||||
c.WriteError("ERR DISCARD without MULTI")
|
||||
return
|
||||
}
|
||||
|
||||
stopTx(ctx)
|
||||
c.WriteOK()
|
||||
}
|
||||
|
||||
// WATCH
|
||||
func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
if inTx(ctx) {
|
||||
c.WriteError("ERR WATCH in MULTI")
|
||||
return
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
for _, key := range args {
|
||||
watch(db, ctx, key)
|
||||
}
|
||||
c.WriteOK()
|
||||
}
|
||||
|
||||
// UNWATCH
|
||||
func (m *Miniredis) cmdUnwatch(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
// Doesn't matter if UNWATCH is in a TX or not. Looks like a Redis bug to me.
|
||||
unwatch(getCtx(c))
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
// Do nothing if it's called in a transaction.
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user