init commit

This commit is contained in:
notch
2020-12-10 08:53:42 +08:00
parent 031a0531bd
commit b47b0cd6c2
61 changed files with 31801 additions and 12 deletions

40
.gitignore vendored
View File

@@ -1,15 +1,31 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
*.dylib
# Test binary, built with `go test -c`
# Folders
_obj
_test
logs
bin
# file
tomatox
routetable.json
users.json
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
*.prof

35
.vscode/launch.json vendored Executable file
View File

@@ -0,0 +1,35 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "tomatox",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceRoot}",
"env": {},
"args": []
},
{
"name": "tomatox -log-tofile",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceRoot}",
"env": {},
"args": ["-log-tofile"]
},
{
"name": "tomatox -h",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceRoot}",
"env": {},
"args": ["-h"]
}
]
}

37
config/config.go Executable file
View File

@@ -0,0 +1,37 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package config
import (
"flag"
)
// config 服务配置
type config struct {
ListenAddr string `json:"listen"` // 服务侦听地址和端口
Auth bool `json:"auth"` // 启用安全验证
CacheGop bool `json:"cache_gop"` // 缓存图像组,以便提高播放端打开速度,但内存需求大
HlsPath string `json:"hlspath,omitempty"` // Hls临时缓存目录
Profile bool `json:"profile"` // 是否启动Profile
TLS *TLSConfig `json:"tls,omitempty"` // https安全端口交互
Routetable *ProviderConfig `json:"routetable,omitempty"` // 路由表
Users *ProviderConfig `json:"users,omitempty"` // 用户
Log LogConfig `json:"log"` // 日志配置
}
func (c *config) initFlags() {
// 服务的端口
flag.StringVar(&c.ListenAddr, "listen", ":554", "Set server listen address")
flag.BoolVar(&c.Auth, "auth", false,
"Determines if requires permission verification to access stream media")
flag.BoolVar(&c.CacheGop, "cachegop", false,
"Determines if Gop should be cached to memory")
flag.StringVar(&c.HlsPath, "hlspath", "", "Set HLS live dir")
flag.BoolVar(&c.Profile, "pprof", false,
"Determines if profile enabled")
// 初始化日志配置
c.Log.initFlags()
}

229
config/global.go Executable file
View File

@@ -0,0 +1,229 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package config
import (
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
cfg "github.com/cnotch/loader"
"github.com/cnotch/tomatox/provider/auth"
"github.com/cnotch/tomatox/utils"
"github.com/cnotch/xlog"
)
// 服务名
const (
Vendor = "CAOHONGJU"
Name = "tomatox"
Version = "V1.0.0"
)
var (
globalC *config
consoleAppDir string
demosAppDir string
)
// InitConfig 初始化 Config
func InitConfig() {
exe, err := os.Executable()
if err != nil {
xlog.Panic(err.Error())
}
configPath := filepath.Join(filepath.Dir(exe), Name+".conf")
consoleAppDir = filepath.Join(filepath.Dir(exe), "console")
demosAppDir = filepath.Join(filepath.Dir(exe), "demos")
globalC = new(config)
globalC.initFlags()
// 创建或加载配置文件
if err := cfg.Load(globalC,
&cfg.JSONLoader{Path: configPath, CreatedIfNonExsit: true},
&cfg.EnvLoader{Prefix: strings.ToUpper(Name)},
&cfg.FlagLoader{}); err != nil {
// 异常,直接退出
xlog.Panic(err.Error())
}
if globalC.HlsPath != "" {
if !filepath.IsAbs(globalC.HlsPath) {
globalC.HlsPath = filepath.Join(filepath.Dir(exe), globalC.HlsPath)
}
_, err = os.Stat(globalC.HlsPath)
if err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(globalC.HlsPath, os.ModePerm); err != nil {
panic(err)
}
} else {
panic(err)
}
}
}
// 初始化日志
globalC.Log.initLogger()
}
// Addr Listen addr
func Addr() string {
if globalC == nil {
return ":554"
}
return globalC.ListenAddr
}
// Auth 是否启用验证
func Auth() bool {
if globalC == nil {
return false
}
return globalC.Auth
}
// CacheGop 是否Cache Gop
func CacheGop() bool {
if globalC == nil {
return false
}
return globalC.CacheGop
}
// Profile 是否启动 Http Profile
func Profile() bool {
if globalC == nil {
return false
}
return globalC.Profile
}
// GetTLSConfig 获取TLSConfig
func GetTLSConfig() *TLSConfig {
if globalC == nil {
return nil
}
return globalC.TLS
}
// ConsoleAppDir 管理员控制台应用的目录
func ConsoleAppDir() (string, bool) {
if consoleAppDir == "" {
return "", false
}
finfo, err := os.Stat(consoleAppDir)
if err != nil || !finfo.IsDir() {
return "", false
}
return consoleAppDir, true
}
// DemosAppDir 例子应用目录
func DemosAppDir() (string, bool) {
if demosAppDir == "" {
return "", false
}
finfo, err := os.Stat(demosAppDir)
if err != nil || !finfo.IsDir() {
return "", false
}
return demosAppDir, true
}
// NetTimeout 返回网络超时设置
func NetTimeout() time.Duration {
return time.Second * 45
}
// NetHeartbeatInterval 返回网络心跳间隔
func NetHeartbeatInterval() time.Duration {
return time.Second * 30
}
// NetBufferSize 网络通讯时的BufferSize
func NetBufferSize() int {
return 128 * 1024
}
// NetFlushRate 网络刷新频率
func NetFlushRate() int {
return 30
}
// RtspAuthMode rtsp 认证模式
func RtspAuthMode() auth.Mode {
if globalC == nil || !globalC.Auth {
return auth.NoneAuth
}
return auth.DigestAuth
}
// MulticastTTL 组播TTL值
func MulticastTTL() int {
return 127
}
// ChunkSize Rtmp ChunkSize
func ChunkSize(ip net.IP) uint32 {
if utils.IsLocalhostIP(ip) {
return 48 * 1024
}
return 16 * 1024
}
// HlsEnable 是否启动 Hls
func HlsEnable() bool {
return true
}
// HlsFragment TS片段时长s
func HlsFragment() int {
return 3
}
// HlsPath hls 存储目录
func HlsPath() string {
if globalC == nil {
return ""
}
return globalC.HlsPath
}
// LoadRoutetableProvider 加载路由表提供者
func LoadRoutetableProvider(providers ...Provider) Provider {
if globalC == nil {
return LoadProvider(nil, providers...)
}
return LoadProvider(globalC.Routetable, providers...)
}
// LoadUsersProvider 加载用户提供者
func LoadUsersProvider(providers ...Provider) Provider {
if globalC == nil {
return LoadProvider(nil, providers...)
}
return LoadProvider(globalC.Users, providers...)
}
// DetectFfmpeg 判断ffmpeg命令行是否存在
func DetectFfmpeg(l *xlog.Logger) bool {
out, err := exec.Command("ffmpeg", "-version").Output()
if err != nil {
return false
}
i := strings.Index(string(out), "Copyright")
if i > 0 {
l.Infof("detect %s", out[:i])
}
return true
}

80
config/log.go Executable file
View File

@@ -0,0 +1,80 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package config
import (
"flag"
"os"
"github.com/cnotch/xlog"
lumberjack "gopkg.in/natefinch/lumberjack.v2"
)
// LogConfig 日志配置
type LogConfig struct {
// Level 是否启动记录调试日志
Level xlog.Level `json:"level"`
// ToFile 是否将日志记录到文件
ToFile bool `json:"tofile"`
// Filename 日志文件名称
Filename string `json:"filename"`
// MaxSize 日志文件的最大尺寸,以兆为单位
MaxSize int `json:"maxsize"`
// MaxDays 旧日志最多保存多少天
MaxDays int `json:"maxdays"`
// MaxBackups 旧日志最多保持数量。
// 注意:旧日志保存的条件包括 <=MaxAge && <=MaxBackups
MaxBackups int `json:"maxbackups"`
// Compress 是否用 gzip 压缩
Compress bool `json:"compress"`
}
func (c *LogConfig) initFlags() {
// 日志配置的 Flag
flag.Var(&c.Level, "log-level",
"Set the log level to output")
flag.BoolVar(&c.ToFile, "log-tofile", false,
"Determines if logs should be saved to file")
flag.StringVar(&c.Filename, "log-filename",
"./logs/"+Name+".log", "Set the file to write logs to")
flag.IntVar(&c.MaxSize, "log-maxsize", 20,
"Set the maximum size in megabytes of the log file before it gets rotated")
flag.IntVar(&c.MaxDays, "log-maxdays", 7,
"Set the maximum days of old log files to retain")
flag.IntVar(&c.MaxBackups, "log-maxbackups", 14,
"Set the maximum number of old log files to retain")
flag.BoolVar(&c.Compress, "log-compress", false,
"Determines if the log files should be compressed")
}
// 初始化跟日志
func (c *LogConfig) initLogger() {
if c.ToFile {
// 文件输出
fileWriter := &lumberjack.Logger{
Filename: c.Filename, // 日志文件路径
MaxSize: c.MaxSize, // 每个日志文件保存的最大尺寸 单位M
MaxBackups: c.MaxBackups, // 日志文件最多保存多少个备份
MaxAge: c.MaxDays, // 文件最多保存多少天
LocalTime: true, // 使用本地时间
Compress: c.Compress, // 日志压缩
}
xlog.ReplaceGlobal(
xlog.New(xlog.NewTee(xlog.NewCore(xlog.NewConsoleEncoder(xlog.LstdFlags|xlog.Lmicroseconds|xlog.Llongfile), xlog.Lock(os.Stderr), c.Level),
xlog.NewCore(xlog.NewJSONEncoder(xlog.Llongfile), fileWriter, c.Level)),
xlog.AddCaller()))
} else {
xlog.ReplaceGlobal(
xlog.New(xlog.NewCore(xlog.NewConsoleEncoder(xlog.LstdFlags|xlog.Lmicroseconds|xlog.Llongfile), xlog.Lock(os.Stderr), c.Level),
xlog.AddCaller()))
}
}

60
config/provider.go Executable file
View File

@@ -0,0 +1,60 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package config
import (
"errors"
"strings"
)
// Provider 提供者接口
type Provider interface {
Name() string
Configure(config map[string]interface{}) error
}
// ProviderConfig 可扩展提供者配置
type ProviderConfig struct {
Provider string `json:"provider"` // 提供者类型
Config map[string]interface{} `json:"config,omitempty"` // 提供者配置
}
// Load 加载Provider
func (c *ProviderConfig) Load(builtins ...Provider) (Provider, error) {
for _, builtin := range builtins {
if strings.ToLower(builtin.Name()) == strings.ToLower(c.Provider) {
if err := builtin.Configure(c.Config); err != nil {
return nil, errors.New("The provider '" + c.Provider + "' could not be loaded. " + err.Error())
}
return builtin, nil
}
}
// TODO: load a plugin provider
return nil, errors.New("The provider '" + c.Provider + "' could not be loaded. ")
}
// LoadOrPanic 加载 Provider 如果失败直接 panics.
func (c *ProviderConfig) LoadOrPanic(builtins ...Provider) Provider {
provider, err := c.Load(builtins...)
if err != nil {
panic(err)
}
return provider
}
// LoadProvider 加载Provider或Panic默认值为第一个provider
func LoadProvider(config *ProviderConfig, providers ...Provider) Provider {
if config == nil || config.Provider == "" {
config = &ProviderConfig{
Provider: providers[0].Name(),
}
}
// Load the provider according to the configuration
return config.LoadOrPanic(providers...)
}

19
config/rtmp.go Executable file
View File

@@ -0,0 +1,19 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package config
import (
"flag"
)
// RtmpConfig rtsp 配置
type RtmpConfig struct {
ChunkSize int `json:"chunksize"`
}
func (c *RtmpConfig) initFlags() {
// RTSP 组播
flag.IntVar(&c.ChunkSize, "rtmp-chunksize", 16*1024, "Set RTMP ChunkSize")
}

21
config/rtsp.go Executable file
View File

@@ -0,0 +1,21 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package config
import (
"flag"
"github.com/cnotch/tomatox/provider/auth"
)
// RtspConfig rtsp 配置
type RtspConfig struct {
AuthMode auth.Mode `json:"authmode"`
}
func (c *RtspConfig) initFlags() {
// RTSP 组播
flag.Var(&c.AuthMode, "rtsp-auth", "Set RTSP auth mode")
}

58
config/tls.go Executable file
View File

@@ -0,0 +1,58 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package config
import (
"crypto/tls"
"errors"
"io/ioutil"
"os"
"path/filepath"
"strings"
)
// TLSConfig TLS listen 配置.
type TLSConfig struct {
ListenAddr string `json:"listen"`
Certificate string `json:"cert"`
PrivateKey string `json:"key"`
}
// Load loads the certificates from the cache or the configuration.
func (c *TLSConfig) Load() (*tls.Config, error) {
if c.PrivateKey == "" || c.Certificate == "" {
return &tls.Config{}, errors.New("No certificate or private key configured")
}
// If the certificate provided is in plain text, write to file so we can read it.
if strings.HasPrefix(c.Certificate, "---") {
if err := ioutil.WriteFile("broker.crt", []byte(c.Certificate), os.ModePerm); err == nil {
c.Certificate = Name+".crt"
}
}
// If the private key provided is in plain text, write to file so we can read it.
if strings.HasPrefix(c.PrivateKey, "---") {
if err := ioutil.WriteFile("broker.key", []byte(c.PrivateKey), os.ModePerm); err == nil {
c.PrivateKey = Name+".key"
}
}
// Make sure the paths are absolute, otherwise we won't be able to read the files.
c.Certificate = resolvePath(c.Certificate)
c.PrivateKey = resolvePath(c.PrivateKey)
// Load the certificate from the cert/key files.
cer, err := tls.LoadX509KeyPair(c.Certificate, c.PrivateKey)
return &tls.Config{
Certificates: []tls.Certificate{cer},
}, err
}
func resolvePath(path string) string {
// Make sure the path is absolute
path, _ = filepath.Abs(path)
return path
}

1
console/readme.md Normal file
View File

@@ -0,0 +1 @@
在此目录加入管理员控制台web项目

108
demos/flv/demo.css Executable file
View File

@@ -0,0 +1,108 @@
.mainContainer {
display: block;
width: 100%;
margin-left: auto;
margin-right: auto;
}
@media screen and (min-width: 1152px) {
.mainContainer {
display: block;
width: 1152px;
margin-left: auto;
margin-right: auto;
}
}
.video-container {
position: relative;
margin-top: 8px;
}
.video-container:before {
display: block;
content: "";
width: 100%;
padding-bottom: 56.25%;
}
.video-container > div {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
}
.video-container video {
width: 100%;
height: 100%;
}
.urlInput {
display: block;
width: 100%;
margin-left: auto;
margin-right: auto;
margin-top: 8px;
margin-bottom: 8px;
}
.centeredVideo {
display: block;
width: 100%;
height: 100%;
margin-left: auto;
margin-right: auto;
margin-bottom: auto;
}
.controls {
display: block;
width: 100%;
text-align: left;
margin-left: auto;
margin-right: auto;
margin-top: 8px;
margin-bottom: 10px;
}
.logcatBox {
border-color: #CCCCCC;
font-size: 11px;
font-family: Menlo, Consolas, monospace;
display: block;
width: 100%;
text-align: left;
margin-left: auto;
margin-right: auto;
}
.url-input , .options {
font-size: 13px;
}
.url-input {
display: flex;
}
.url-input label {
flex: initial;
}
.url-input input {
flex: auto;
margin-left: 8px;
}
.url-input button {
flex: initial;
margin-left: 8px;
}
.options {
margin-top: 5px;
}
.hidden {
display: none;
}

12047
demos/flv/flv.js Executable file

File diff suppressed because it is too large Load Diff

229
demos/flv/index.html Executable file
View File

@@ -0,0 +1,229 @@
<!DOCTYPE html>
<html>
<head>
<meta content="text/html; charset=utf-8" http-equiv="Content-Type">
<title>flv.js demo</title>
<link rel="stylesheet" type="text/css" href="demo.css" />
</head>
<body>
<div class="mainContainer">
<div>
<div id="streamURL">
<div class="url-input">
<label for="sURL">Stream URL:</label>
<input id="sURL" type="text" value="http://127.0.0.1:1554/streams/test/live3.flv" />
<!-- <button onclick="switch_mds()">Switch to MediaDataSource</button> -->
</div>
<div class="options">
<input type="checkbox" id="isLive" onchange="saveSettings()" />
<label for="isLive">isLive</label>
<input type="checkbox" id="withCredentials" onchange="saveSettings()" />
<label for="withCredentials">withCredentials</label>
<input type="checkbox" id="hasAudio" onchange="saveSettings()" checked />
<label for="hasAudio">hasAudio</label>
<input type="checkbox" id="hasVideo" onchange="saveSettings()" checked />
<label for="hasVideo">hasVideo</label>
</div>
</div>
<div id="mediaSourceURL" class="hidden">
<div class="url-input">
<label for="msURL">MediaDataSource JsonURL:</label>
<input id="msURL" type="text" value="http://127.0.0.1/flv/7182741.json" />
<button onclick="switch_url()">Switch to URL</button>
</div>
</div>
</div>
<div class="video-container">
<div>
<video name="videoElement" class="centeredVideo" controls autoplay>
Your browser is too old which doesn't support HTML5 video.
</video>
</div>
</div>
<div class="controls">
<button onclick="flv_load()">Load</button>
<button onclick="flv_start()">Start</button>
<button onclick="flv_pause()">Pause</button>
<button onclick="flv_destroy()">Destroy</button>
<input style="width:100px" type="text" name="seekpoint"/>
<button onclick="flv_seekto()">SeekTo</button>
</div>
<textarea name="logcatbox" class="logcatBox" rows="10" readonly></textarea>
</div>
<script src="flv.js"></script>
<script>
var checkBoxFields = ['isLive', 'withCredentials', 'hasAudio', 'hasVideo'];
var streamURL, mediaSourceURL;
function flv_load() {
console.log('isSupported: ' + flvjs.isSupported());
if (mediaSourceURL.className === '') {
var url = document.getElementById('msURL').value;
var xhr = new XMLHttpRequest();
xhr.open('GET', url, true);
xhr.onload = function (e) {
var mediaDataSource = JSON.parse(xhr.response);
flv_load_mds(mediaDataSource);
}
xhr.send();
} else {
var i;
var mediaDataSource = {
type: 'flv'
};
for (i = 0; i < checkBoxFields.length; i++) {
var field = checkBoxFields[i];
/** @type {HTMLInputElement} */
var checkbox = document.getElementById(field);
mediaDataSource[field] = checkbox.checked;
}
mediaDataSource['url'] = document.getElementById('sURL').value;
console.log('MediaDataSource', mediaDataSource);
flv_load_mds(mediaDataSource);
}
}
function flv_load_mds(mediaDataSource) {
var element = document.getElementsByName('videoElement')[0];
if (typeof player !== "undefined") {
if (player != null) {
player.unload();
player.detachMediaElement();
player.destroy();
player = null;
}
}
player = flvjs.createPlayer(mediaDataSource, {
enableWorker: false,
lazyLoadMaxDuration: 3 * 60,
seekType: 'range',
// my change config
fixAudioTimestampGap: false,
enableWorker: true,
enableStashBuffer: false,
stashInitialSize: 128,// 减少首桢显示等待时长
});
player.attachMediaElement(element);
player.load();
}
function flv_start() {
player.play();
}
function flv_pause() {
player.pause();
}
function flv_destroy() {
player.pause();
player.unload();
player.detachMediaElement();
player.destroy();
player = null;
}
function flv_seekto() {
var input = document.getElementsByName('seekpoint')[0];
player.currentTime = parseFloat(input.value);
}
function switch_url() {
streamURL.className = '';
mediaSourceURL.className = 'hidden';
saveSettings();
}
function switch_mds() {
streamURL.className = 'hidden';
mediaSourceURL.className = '';
saveSettings();
}
function ls_get(key, def) {
try {
var ret = localStorage.getItem('flvjs_demo.' + key);
if (ret === null) {
ret = def;
}
return ret;
} catch (e) {}
return def;
}
function ls_set(key, value) {
try {
localStorage.setItem('flvjs_demo.' + key, value);
} catch (e) {}
}
function saveSettings() {
if (mediaSourceURL.className === '') {
ls_set('inputMode', 'MediaDataSource');
} else {
ls_set('inputMode', 'StreamURL');
}
var i;
for (i = 0; i < checkBoxFields.length; i++) {
var field = checkBoxFields[i];
/** @type {HTMLInputElement} */
var checkbox = document.getElementById(field);
ls_set(field, checkbox.checked ? '1' : '0');
}
var msURL = document.getElementById('msURL');
var sURL = document.getElementById('sURL');
ls_set('msURL', msURL.value);
ls_set('sURL', sURL.value);
console.log('save');
}
function loadSettings() {
var i;
for (i = 0; i < checkBoxFields.length; i++) {
var field = checkBoxFields[i];
/** @type {HTMLInputElement} */
var checkbox = document.getElementById(field);
var c = ls_get(field, checkbox.checked ? '1' : '0');
checkbox.checked = c === '1' ? true : false;
}
var msURL = document.getElementById('msURL');
var sURL = document.getElementById('sURL');
msURL.value = ls_get('msURL', msURL.value);
sURL.value = ls_get('sURL', sURL.value);
if (ls_get('inputMode', 'StreamURL') === 'StreamURL') {
switch_url();
} else {
switch_mds();
}
}
function showVersion() {
var version = flvjs.version;
document.title = document.title + " (v" + version + ")";
}
var logcatbox = document.getElementsByName('logcatbox')[0];
flvjs.LoggingControl.addLogListener(function(type, str) {
logcatbox.value = logcatbox.value + str + '\n';
logcatbox.scrollTop = logcatbox.scrollHeight;
});
document.addEventListener('DOMContentLoaded', function () {
streamURL = document.getElementById('streamURL');
mediaSourceURL = document.getElementById('mediaSourceURL');
loadSettings();
showVersion();
// flv_load();
});
</script>
</body>
</html>

155
demos/rtsp/index.html Executable file
View File

@@ -0,0 +1,155 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>RTSP player example(based rtsp websockcet client)</title>
<link rel="stylesheet" href="style.css">
</head>
<body>
<div id="sourcesNode"></div>
<div>
<input id="stream_url" value= "ws://localhost:1554/ws/test/live1" size="36">
<button id="load_url">load</button>
<button id="unload_url">unload</button>
</div>
<div>
<p style="color:#808080">Enter your ws link to the stream, for example: "ws://localhost:1554/ws/test/live1"</p>
</div>
<video id="test_video" controls autoplay>
<source src="rtsp://placehold" type="application/x-rtsp">
</video>
<div class="controls form">
<div>
<button id="to_end" class="btn btn-success">live(to end)</button>
</div>
</div>
<p>View HTML5 RTSP video player log</p>
<div id="pllogs" class="logs"></div>
<button class="btn btn-success" onclick="cleanLog(pllogs)">clear</button>
<button class="btn btn-success" onclick="scrollset(pllogs, true)">scroll up</button>
<button class="btn btn-success" onclick="scrollset(pllogs, false)">scroll down</button>
<button id="scrollSetPl" class="btn btn-success" onclick="scrollswitch(pllogs)">Scroll off</button>
</br></br>
<!-- <script src="jquery-1.9.1.js"></script> -->
<script>
var scrollStatPl = true;
var scrollStatWs = true;
var pllogs = document.getElementById("pllogs");
var wslogs = document.getElementById("wslogs");
// define a new console
var console=(function(oldConsole){
return {
log: function(){
oldConsole.log(newConsole(arguments, "black", "#A9F5A9"));
},
info: function () {
oldConsole.info(newConsole(arguments, "black", "#A9F5A9"));
},
warn: function () {
oldConsole.warn(newConsole(arguments, "black", "#F3F781"));
},
error: function () {
oldConsole.error(newConsole(arguments, "black", "#F5A9A9"));
}
};
}(window.console));
function newConsole(args, textColor, backColor){
let text = '';
let node = document.createElement("div");
for (let arg in args){
text +=' ' + args[arg];
}
node.appendChild(document.createTextNode(text));
node.style.color = textColor;
node.style.backgroundColor = backColor;
pllogs.appendChild(node);
autoscroll(pllogs);
return text;
}
//Then redefine the old console
window.console = console;
function cleanLog(element){
while (element.firstChild) {
element.removeChild(element.firstChild);
}
}
function autoscroll(element){
if(scrollStatus(element)){
element.scrollTop = element.scrollHeight;
}
if(element.childElementCount > 1000){
element.removeChild(element.firstChild);
}
}
function scrollset(element, state){
if(state){
element.scrollTop = 0;
scrollChange(element, false);
} else {
element.scrollTop = element.scrollHeight;
scrollChange(element, true);
}
}
function scrollswitch(element){
if(scrollStatus(element)){
scrollChange(element, false);
} else {
scrollChange(element, true);
}
}
function scrollChange(element, status){
if(scrollStatus(element)){
scrollStatPl = false;
document.getElementById("scrollSetPl").innerText = "Scroll on";
} else {
scrollStatPl = true;
document.getElementById("scrollSetPl").innerText = "Scroll off";
}
}
function scrollStatus(element){
if(element.id === "pllogs"){
return scrollStatPl;
} else {
return scrollStatWs;
}
}
</script>
<script src="rtsp.dev.js" ></script>
<script>
let videoElement = document.getElementById('test_video');
let loadButton = document.getElementById("load_url");
let unloadButton = document.getElementById("unload_url");
let urlEdit = document.getElementById("stream_url");
let rtspPlayer = rtspjs.createPlayer('test_video');
loadButton.onclick = ()=> {
rtspPlayer.load(urlEdit.value)
};
unloadButton.onclick = ()=> {
rtspPlayer.unload()
};
// 设置到实时视频
var set_live = document.getElementById('to_end');
set_live.addEventListener('click', function () {
videoElement.playbackRate = 1;
videoElement.currentTime = videoElement.buffered.end(0);//videoElement.seekable.end(videoElement.seekable.length - 1);
});
</script>
</body>
</html>

4879
demos/rtsp/rtsp.dev.js Executable file

File diff suppressed because it is too large Load Diff

28
demos/rtsp/style.css Executable file
View File

@@ -0,0 +1,28 @@
body {
max-width: 720px;
margin: 50px auto;
}
#test_video {
width: 720px;
}
.controls {
display: flex;
justify-content: space-around;
align-items: center;
}
input.input, .form-inline .input-group>.form-control {
width: 300px;
}
.logs {
overflow: auto;
width: 720px;
height: 150px;
padding: 5px;
border-top: solid 1px gray;
border-bottom: solid 1px gray;
}
button {
margin: 5px
}

9466
demos/wsp/free.player.1.8.js Normal file

File diff suppressed because it is too large Load Diff

251
demos/wsp/index.html Executable file
View File

@@ -0,0 +1,251 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>RTSP player example(based streamedian)</title>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css">
<link rel="stylesheet" href="style.css">
</head>
<body>
<div id="sourcesNode"></div>
<div>
<input id="stream_url" size="80" value="rtsp://localhost:1554/test/live1">
<button id="set_new_url">Set</button>
</div>
<div>
<p style="color:#808080">Enter your rtsp link to the stream, for example: "rtsp://localhost:1554/test/live1"</p>
<p style="color:#808080">If need token,for example: "rtsp://localhost:1554/test/live1?token=4df8f5d5d680385cb07c2e354dd0f3f3"</p>
</div>
<div>
<input id="buffer_duration" type="range" min="10" max="200" style="width:40%;">
<span id="buffer_value">120sec.</span>
</div>
<div>
<p style="color:#808080">Change buffer duration</p>
</div>
<video id="test_video" controls autoplay>
<!--<source src="rtsp://192.168.10.205:554/ch01.264" type="application/x-rtsp">-->
<!--<source src="rtsp://wowzaec2demo.streamlock.net/vod/mp4:BigBuckBunny_115k.mov" type="application/x-rtsp">-->
</video>
<div class="controls form">
<div>
Playback rate:&nbsp;
<input id="rate" class="input" type="range" min="0.5" max="5.0" value="1.0" step="0.5">
<output for="rate" id="rate_res">live</output>
</div>
<div>
<button id="to_end" class="btn btn-success">live</button>
</div>
</div>
<p>View HTML5 RTSP video player log</p>
<div id="pllogs" class="logs"></div>
<button class="btn btn-success" onclick="cleanLog(pllogs)">clear</button>
<button class="btn btn-success" onclick="scrollset(pllogs, true)">scroll up</button>
<button class="btn btn-success" onclick="scrollset(pllogs, false)">scroll down</button>
<button id="scrollSetPl" class="btn btn-success" onclick="scrollswitch(pllogs)">Scroll off</button>
<br/><br/>
<script src="free.player.1.8.js"></script> <!-- Path to player js-->
<script>
var scrollStatPl = true;
var scrollStatWs = true;
var pllogs = document.getElementById("pllogs");
var wslogs = document.getElementById("wslogs");
// define a new console
var console=(function(oldConsole){
return {
log: function(){
oldConsole.log(newConsole(arguments, "black", "#A9F5A9"));
},
info: function () {
oldConsole.info(newConsole(arguments, "black", "#A9F5A9"));
},
warn: function () {
oldConsole.warn(newConsole(arguments, "black", "#F3F781"));
},
error: function () {
oldConsole.error(newConsole(arguments, "black", "#F5A9A9"));
}
};
}(window.console));
function newConsole(args, textColor, backColor){
let text = '';
let node = document.createElement("div");
for (let arg in args){
text +=' ' + args[arg];
}
node.appendChild(document.createTextNode(text));
node.style.color = textColor;
node.style.backgroundColor = backColor;
pllogs.appendChild(node);
autoscroll(pllogs);
return text;
}
//Then redefine the old console
window.console = console;
function cleanLog(element){
while (element.firstChild) {
element.removeChild(element.firstChild);
}
}
function autoscroll(element){
if(scrollStatus(element)){
element.scrollTop = element.scrollHeight;
}
if(element.childElementCount > 1000){
element.removeChild(element.firstChild);
}
}
function scrollset(element, state){
if(state){
element.scrollTop = 0;
scrollChange(element, false);
} else {
element.scrollTop = element.scrollHeight;
scrollChange(element, true);
}
}
function scrollswitch(element){
if(scrollStatus(element)){
scrollChange(element, false);
} else {
scrollChange(element, true);
}
}
function scrollChange(element, status){
if(scrollStatus(element)){
scrollStatPl = false;
document.getElementById("scrollSetPl").innerText = "Scroll on";
} else {
scrollStatPl = true;
document.getElementById("scrollSetPl").innerText = "Scroll off";
}
}
function scrollStatus(element){
if(element.id === "pllogs"){
return scrollStatPl;
} else {
return scrollStatWs;
}
}
</script>
<script>
if (window.Streamedian) {
let errHandler = function(err){
alert(err.message);
};
let infHandler = function(inf) {
let sourcesNode = document.getElementById("sourcesNode");
let clients = inf.clients;
sourcesNode.innerHTML = "";
for (let client in clients) {
clients[client].forEach((sources) => {
let nodeButton = document.createElement("button");
nodeButton.setAttribute('data', sources.url + ' ' + client);
nodeButton.appendChild(document.createTextNode(sources.description));
nodeButton.onclick = (event)=> {
setPlayerSource(event.target.getAttribute('data'));
};
sourcesNode.appendChild(nodeButton);
});
}
};
var playerOptions = {
socket: "ws://localhost:1554/ws/test/live1", redirectNativeMediaErrors : true,
bufferDuration: 30,
errorHandler: errHandler,
infoHandler: infHandler
};
var html5Player = document.getElementById("test_video");
var urlButton = document.getElementById("set_new_url");
var urlEdit = document.getElementById("stream_url");
var bufferRange = document.getElementById("buffer_duration");
var bufferValue = document.getElementById("buffer_value");
var player = Streamedian.player('test_video', playerOptions);
var nativePlayer = document.getElementById('test_video');
var range = document.getElementById('rate');
var set_live = document.getElementById('to_end');
var range_out = document.getElementById('rate_res');
range.addEventListener('input', function () {
nativePlayer.playbackRate = range.value;
range_out.innerHTML = `x${range.value}`;
});
set_live.addEventListener('click', function () {
range.value = 1.0;
range_out.innerHTML = `live`;
nativePlayer.playbackRate = 1;
nativePlayer.currentTime = nativePlayer.buffered.end(0);
});
var updateRangeControls = function(){
bufferRange.value = player.bufferDuration;
bufferValue.innerHTML = bufferRange.value + "sec.";
};
bufferRange.addEventListener('input', function(){
var iValue = parseInt(this.value, 10);
player.bufferDuration = iValue;
bufferValue.innerHTML = this.value + "sec.";
});
bufferRange.innerHTML = player.bufferDuration + "sec.";
updateRangeControls();
urlButton.onclick = ()=> {
setPlayerSource(urlEdit.value);
};
function setPlayerSource(newSource) {
player.destroy();
player = null;
html5Player.src = newSource;
// 修改原例子 begn =======>
// 我们直接使用ws来决定播放的路径
// 比如ws://192.168.1.100:1554/ws/test/live1
// 表示要播放服务器上路径为/test/live1
// 如果播放失败,可能是以下情况(1和2发生在升级websocket阶段3发生在rtsp通讯阶段)
// 1. 如果服务器rtsp的验证模式不为NONE则需要登录后获取到token才能访问像这样ws://.../test/live1?token=...
// 2. 可能没有权限,需要联系人对你登录的用户授权
// 3. 可能找不到流媒体
//
// 如果不是使用例子,我们可以这样
// html5Player.src = "rtsp://placehold"
// playerOptions.socket ="ws://localhost:1554/ws/test/live1"
// 知道ws主机后实际上只需要提供流媒体path即可这也更好
//
let rtspUrl = new URL(newSource)
rtspUrl.protocol = "ws"
rtspUrl.pathname = "/ws"+rtspUrl.pathname
playerOptions.socket = rtspUrl.href
// <========= end
player = Streamedian.player("test_video", playerOptions);
updateRangeControls();
}
}
</script>
</body>
</html>

28
demos/wsp/style.css Executable file
View File

@@ -0,0 +1,28 @@
body {
max-width: 720px;
margin: 50px auto;
}
#test_video {
width: 720px;
}
.controls {
display: flex;
justify-content: space-around;
align-items: center;
}
input.input, .form-inline .input-group>.form-control {
width: 300px;
}
.logs {
overflow: auto;
width: 720px;
height: 150px;
padding: 5px;
border-top: solid 1px gray;
border-bottom: solid 1px gray;
}
button {
margin: 5px
}

17
go.mod Normal file
View File

@@ -0,0 +1,17 @@
module github.com/cnotch/tomatox
go 1.14
require (
github.com/BurntSushi/toml v0.3.1 // indirect
github.com/cnotch/loader v0.0.0-20200405015128-d9d964d09439
github.com/cnotch/scheduler v0.0.0-20200522024700-1d2da93eefc5
github.com/cnotch/xlog v0.0.0-20201208005456-cfda439cd3a0
github.com/emitter-io/address v1.0.0
github.com/gorilla/websocket v1.4.2
github.com/kelindar/process v0.0.0-20170730150328-69a29e249ec3
github.com/kelindar/rate v1.0.0
github.com/stretchr/testify v1.6.1
golang.org/x/crypto v0.0.0-20201208171446-5f87f3452ae9
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)

63
go.sum Normal file
View File

@@ -0,0 +1,63 @@
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/cnotch/loader v0.0.0-20200405015128-d9d964d09439 h1:iNWyllf6zuby+nDNC6zKEkM7aUFbp4RccfWVdQ3HFfQ=
github.com/cnotch/loader v0.0.0-20200405015128-d9d964d09439/go.mod h1:oWpDagHB6p+Kqqq7RoRZKyC4XAXft50hR8pbTxdbYYs=
github.com/cnotch/scheduler v0.0.0-20200522024700-1d2da93eefc5 h1:m9Wx/d4iPXFmE0f2zJ6iQ8tXZ52kOZO9qs/kMevEHxk=
github.com/cnotch/scheduler v0.0.0-20200522024700-1d2da93eefc5/go.mod h1:F4GE3SZkJZ8an1Y0ZCqvSM3jeozNuKzoC67erG1PhIo=
github.com/cnotch/xlog v0.0.0-20201208005456-cfda439cd3a0 h1:YXATGJEn/ymZjZOGCFfE5248ABcLbfwpd/dQGfByxGQ=
github.com/cnotch/xlog v0.0.0-20201208005456-cfda439cd3a0/go.mod h1:RW9oHsR79ffl3sR3yMGgxYupMn2btzdtJUwoxFPUE5E=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emitter-io/address v1.0.0 h1:j8mAEIV2TipN2TOf/sTNveJjf8nTBq2ov7/qBG/19vg=
github.com/emitter-io/address v1.0.0/go.mod h1:GfZb5+S/o8694B1GMGK2imUYQyn2skszMvGNA5D84Ug=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kelindar/process v0.0.0-20170730150328-69a29e249ec3 h1:6If+E1dikQbdT7DlhZqLplfGkEt6dSoz7+MK+TFC7+U=
github.com/kelindar/process v0.0.0-20170730150328-69a29e249ec3/go.mod h1:+lTCLnZFXOkqwD8sLPl6u4erAc0cP8wFegQHfipz7KE=
github.com/kelindar/rate v1.0.0 h1:JNZdufLjtDzr/E/rCtWkqo2OVU4yJSScZngJ8LuZ7kU=
github.com/kelindar/rate v1.0.0/go.mod h1:AjT4G+hTItNwt30lucEGZIz8y7Uk5zPho6vurIZ+1Es=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20201208171446-5f87f3452ae9 h1:sYNJzB4J8toYPQTM6pAkcmBRgw9SnQKP9oXCHfgy604=
golang.org/x/crypto v0.0.0-20201208171446-5f87f3452ae9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

40
main.go Executable file
View File

@@ -0,0 +1,40 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package main
import (
"github.com/cnotch/scheduler"
"github.com/cnotch/tomatox/config"
"github.com/cnotch/tomatox/provider/auth"
"github.com/cnotch/tomatox/provider/route"
"github.com/cnotch/xlog"
)
func main() {
// 初始化配置
config.InitConfig()
// 初始化全局计划任务
scheduler.SetPanicHandler(func(job *scheduler.ManagedJob, r interface{}) {
xlog.Errorf("scheduler task panic. tag: %v, recover: %v", job.Tag, r)
})
// 初始化各类提供者
// 路由表提供者
routetableProvider := config.LoadRoutetableProvider(route.JSON)
route.Reset(routetableProvider.(route.Provider))
// 用户提供者
userProvider := config.LoadUsersProvider(auth.JSON)
auth.Reset(userProvider.(auth.UserProvider))
// // Start new service
// svc, err := service.NewService(context.Background(), xlog.L())
// if err != nil {
// xlog.L().Panic(err.Error())
// }
// // Listen and serve
// svc.Listen()
}

52
makefile Executable file
View File

@@ -0,0 +1,52 @@
# Go parameters
GOCMD=go
GOBUILD=$(GOCMD) build
GOCLEAN=$(GOCMD) clean
GOTEST=$(GOCMD) test
GOGET=$(GOCMD) get
ENABLED_CGO=0
BINARY_NAME=tomatox
BINARY_DIR= bin/v1.0.0
build:
CGO_ENABLED=$(ENABLED_CGO) $(GOBUILD) -o bin/$(BINARY_NAME) .
cp -r demos bin/
cp -r docs bin/
# linux compilation
build-linux-amd64:
CGO_ENABLED=$(ENABLED_CGO) GOOS=linux GOARCH=amd64 $(GOBUILD) -o $(BINARY_DIR)/linux/amd64/$(BINARY_NAME)$(VERSION) .
cp -r demos $(BINARY_DIR)/linux/amd64/
cp -r docs $(BINARY_DIR)/linux/amd64/
build-linux-386:
CGO_ENABLED=$(ENABLED_CGO) GOOS=linux GOARCH=386 $(GOBUILD) -o $(BINARY_DIR)/linux/386/$(BINARY_NAME)$(VERSION) .
build-linux-arm:
CGO_ENABLED=$(ENABLED_CGO) GOOS=linux GOARCH=arm $(GOBUILD) -o $(BINARY_DIR)/linux/arm/$(BINARY_NAME)$(VERSION) .
# window compilation
build-windows-amd64:
CGO_ENABLED=$(ENABLED_CGO) GOOS=windows GOARCH=amd64 $(GOBUILD) -o $(BINARY_DIR)/windows/amd64/$(BINARY_NAME)$(VERSION).exe .
cp -r demos $(BINARY_DIR)/windows/amd64/
cp -r docs $(BINARY_DIR)/windows/amd64/
build-windows-386:
CGO_ENABLED=$(ENABLED_CGO) GOOS=windows GOARCH=386 $(GOBUILD) -o $(BINARY_DIR)/windows/386/$(BINARY_NAME)$(VERSION).exe .
# darwin compilation
build-darwin-amd64:
CGO_ENABLED=$(ENABLED_CGO) GOOS=darwin GOARCH=amd64 $(GOBUILD) -o $(BINARY_DIR)/darwin/amd64/$(BINARY_NAME)$(VERSION) .
cp -r demos $(BINARY_DIR)/darwin/amd64/
cp -r docs $(BINARY_DIR)/darwin/amd64/
build-darwin-386:
CGO_ENABLED=$(ENABLED_CGO) GOOS=darwin GOARCH=386 $(GOBUILD) -o $(BINARY_DIR)/darwin/386/$(BINARY_NAME)$(VERSION) .
# amd64 all platform compilation
build-amd64: build-linux-amd64 build-windows-amd64 build-darwin-amd64
# all
build-all: build-linux-amd64 build-windows-amd64 build-darwin-amd64 build-linux-386 build-windows-386 build-darwin-386 build-linux-arm
test:
$(GOTEST) -v ./...
clean:
$(GOCLEAN)
rm -f bin/$(BINARY_NAME)
rm -rf $(BINARY_DIR)

View File

@@ -0,0 +1,202 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package buffered
import (
"bufio"
"bytes"
"net"
"time"
"github.com/kelindar/rate"
)
const (
defaultRate = 50
defaultBufferSize = 64 * 1024
minBufferSize = 8 * 1024
)
// Conn wraps a net.Conn and provides buffered ability.
type Conn struct {
socket net.Conn // The underlying network connection.
reader *bufio.Reader // The buffered reader
writer *bytes.Buffer // The buffered write queue.
limit *rate.Limiter // The write rate limiter.
bufferSize int // The read and write max buffer size
}
// NewConn creates a new sniffed connection.
func NewConn(c net.Conn, options ...Option) *Conn {
conn, ok := c.(*Conn)
if !ok {
conn = &Conn{
socket: c,
}
}
for _, option := range options {
option.apply(conn)
}
// 设置默认值刷新频率
if conn.limit == nil {
conn.limit = rate.New(defaultRate, time.Second)
}
if conn.bufferSize <= 0 {
conn.bufferSize = defaultBufferSize
}
// 设置IO缓冲对象
conn.reader = bufio.NewReaderSize(conn.socket, conn.bufferSize)
conn.writer = bytes.NewBuffer(make([]byte, 0, conn.bufferSize))
return conn
}
// Buffered returns the pending buffer size.
func (m *Conn) Buffered() (n int) {
return m.writer.Len()
}
// Reader 返回内部的 bufio.Reader
func (m *Conn) Reader() *bufio.Reader {
return m.reader
}
// Flush flushes the underlying buffer by writing into the underlying connection.
func (m *Conn) Flush() (n int, err error) {
if m.Buffered() == 0 {
return 0, nil
}
// Flush everything and reset the buffer
n, err = m.writeFull(m.writer.Bytes())
m.writer.Reset()
return
}
// Read reads the block of data from the underlying buffer.
func (m *Conn) Read(p []byte) (int, error) {
return m.reader.Read(p)
}
// Write writes the block of data into the underlying buffer.
func (m *Conn) Write(p []byte) (nn int, err error) {
var n int
// 没有足够的空间容纳 p
for len(p) > m.bufferSize-m.Buffered() && err == nil {
if m.Buffered() == 0 {
// Large write, empty buffer.
// Write directly from p to avoid copy.
n, err = m.socket.Write(p)
} else {
// write buffer to full stateand flush
n, err = m.writer.Write(p[:m.bufferSize-m.writer.Len()])
_, err = m.Flush()
}
nn += n
p = p[n:]
}
if err != nil {
return nn, err
}
// 未到达时间频率的间隔,直接写到缓存
if m.limit.Limit() {
n, err = m.writer.Write(p)
return nn + n, err
}
// 缓存中有数据flush
if m.Buffered() > 0 {
n, err = m.writer.Write(p)
_, err = m.Flush()
return nn + n, err
}
// 缓存中无数据,直接写避免内存拷贝
n, err = m.writeFull(p)
return nn + n, err
}
func (m *Conn) writeFull(p []byte) (nn int, err error) {
var n int
for len(p) > 0 && err == nil {
n, err = m.socket.Write(p)
nn += n
p = p[n:]
}
return nn, err
}
// Close closes the connection. Any blocked Read or Write operations will be unblocked
// and return errors.
func (m *Conn) Close() error {
return m.socket.Close()
}
// LocalAddr returns the local network address.
func (m *Conn) LocalAddr() net.Addr {
return m.socket.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (m *Conn) RemoteAddr() net.Addr {
return m.socket.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
func (m *Conn) SetDeadline(t time.Time) error {
return m.socket.SetDeadline(t)
}
// SetReadDeadline sets the deadline for future Read calls
// and any currently-blocked Read call.
func (m *Conn) SetReadDeadline(t time.Time) error {
return m.socket.SetReadDeadline(t)
}
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
func (m *Conn) SetWriteDeadline(t time.Time) error {
return m.socket.SetWriteDeadline(t)
}
// Option 配置 Conn 的选项接口
type Option interface {
apply(*Conn)
}
// OptionFunc 包装函数以便它满足 Option 接口
type optionFunc func(*Conn)
func (f optionFunc) apply(c *Conn) {
f(c)
}
// FlushRate Conn 写操作的每秒刷新频率
func FlushRate(r int) Option {
return optionFunc(func(c *Conn) {
if r < 1 { // 如果不合规,设置成默认值
r = defaultRate
}
c.limit = rate.New(r, time.Second)
})
}
// BufferSize Conn 缓冲大小
func BufferSize(bufferSize int) Option {
return optionFunc(func(c *Conn) {
if bufferSize < minBufferSize { // 如果不合规,设置成最小值
bufferSize = minBufferSize
}
c.bufferSize = bufferSize
})
}

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package buffered
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/kelindar/rate"
)
func TestConn(t *testing.T) {
conn := NewConn(new(fakeConn))
defer conn.Close()
assert.Equal(t, 0, conn.Buffered())
assert.Nil(t, conn.LocalAddr())
assert.Nil(t, conn.RemoteAddr())
assert.Nil(t, conn.SetDeadline(time.Now()))
assert.Nil(t, conn.SetReadDeadline(time.Now()))
assert.Nil(t, conn.SetWriteDeadline(time.Now()))
conn.limit = rate.New(1, time.Millisecond)
for i := 0; i < 100; i++ {
_, err := conn.Write([]byte{1, 2, 3})
assert.NoError(t, err)
}
time.Sleep(10 * time.Millisecond)
_, err := conn.Write([]byte{1, 2, 3})
assert.NoError(t, err)
conn.Write(make([]byte, 122*1024))
assert.Equal(t, defaultBufferSize, conn.writer.Cap(), "buffer can't extend")
}
// ------------------------------------------------------------------------------------
type fakeConn struct{}
func (m *fakeConn) Read(p []byte) (int, error) {
return 0, nil
}
func (m *fakeConn) Write(p []byte) (int, error) {
if len(p) > minBufferSize {
return minBufferSize, nil
}
return len(p), nil
}
func (m *fakeConn) Close() error {
return nil
}
func (m *fakeConn) LocalAddr() net.Addr {
return nil
}
func (m *fakeConn) RemoteAddr() net.Addr {
return nil
}
func (m *fakeConn) SetDeadline(t time.Time) error {
return nil
}
func (m *fakeConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *fakeConn) SetWriteDeadline(t time.Time) error {
return nil
}

View File

@@ -0,0 +1,330 @@
/**********************************************************************************
* Copyright (c) 2009-2017 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
************************************************************************************/
package listener
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"time"
)
// Server represents a server which can serve requests.
type Server interface {
Serve(listener net.Listener)
}
// Matcher matches a connection based on its content.
type Matcher func(io.Reader) bool
// SettingsHandler 处理连接使用前的设置
type SettingsHandler func(net.Conn)
// ErrorHandler handles an error and notifies the listener on whether
// it should continue serving.
type ErrorHandler func(error) bool
var _ net.Error = ErrNotMatched{}
// ErrNotMatched is returned whenever a connection is not matched by any of
// the matchers registered in the multiplexer.
type ErrNotMatched struct {
c net.Conn
}
func (e ErrNotMatched) Error() string {
return fmt.Sprintf("Unable to match connection %v", e.c.RemoteAddr())
}
// Temporary implements the net.Error interface.
func (e ErrNotMatched) Temporary() bool { return true }
// Timeout implements the net.Error interface.
func (e ErrNotMatched) Timeout() bool { return false }
type errListenerClosed string
func (e errListenerClosed) Error() string { return string(e) }
func (e errListenerClosed) Temporary() bool { return false }
func (e errListenerClosed) Timeout() bool { return false }
// ErrListenerClosed is returned from muxListener.Accept when the underlying
// listener is closed.
var ErrListenerClosed = errListenerClosed("mux: listener closed")
// for readability of readTimeout
var noTimeout time.Duration
// New announces on the local network address laddr. The syntax of laddr is
// "host:port", like "127.0.0.1:8080". If host is omitted, as in ":8080",
// New listens on all available interfaces instead of just the interface
// with the given host address. Listening on a hostname is not recommended
// because this creates a socket for at most one of its IP addresses.
func New(address string, config *tls.Config) (*Listener, error) {
l, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
// If we have a TLS configuration provided, wrap the listener in TLS
if config != nil {
l = tls.NewListener(l, config)
}
return &Listener{
root: l,
bufferSize: 1024,
errorHandler: func(_ error) bool { return true },
closing: make(chan struct{}),
readTimeout: noTimeout,
settingsHandler: func(_ net.Conn) {},
}, nil
}
type processor struct {
matchers []Matcher
listen muxListener
}
// Listener represents a listener used for multiplexing protocols.
type Listener struct {
root net.Listener
bufferSize int
errorHandler ErrorHandler
closing chan struct{}
matchers []processor
readTimeout time.Duration
settingsHandler SettingsHandler
}
// Accept waits for and returns the next connection to the listener.
func (m *Listener) Accept() (net.Conn, error) {
return m.root.Accept()
}
// ServeAsync adds a protocol based on the matcher and serves it.
func (m *Listener) ServeAsync(matcher Matcher, serve func(l net.Listener) error) {
l := m.Match(matcher)
go serve(l)
}
// Match returns a net.Listener that sees (i.e., accepts) only
// the connections matched by at least one of the matcher.
func (m *Listener) Match(matchers ...Matcher) net.Listener {
ml := muxListener{
Listener: m.root,
connections: make(chan net.Conn, m.bufferSize),
}
m.matchers = append(m.matchers, processor{matchers: matchers, listen: ml})
return ml
}
// SetReadTimeout sets a timeout for the read of matchers.
func (m *Listener) SetReadTimeout(t time.Duration) {
m.readTimeout = t
}
// Serve starts multiplexing the listener.
func (m *Listener) Serve() error {
var wg sync.WaitGroup
defer func() {
close(m.closing)
wg.Wait()
for _, sl := range m.matchers {
close(sl.listen.connections)
// Drain the connections enqueued for the listener.
for c := range sl.listen.connections {
_ = c.Close()
}
}
}()
for {
c, err := m.root.Accept()
if err != nil {
if !m.handleErr(err) {
return err
}
continue
}
wg.Add(1)
go m.serve(c, m.closing, &wg)
}
}
func (m *Listener) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
m.settingsHandler(c)
muc := newConn(c)
if m.readTimeout > noTimeout {
_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
}
for _, sl := range m.matchers {
for _, processor := range sl.matchers {
matched := processor(muc.startSniffing())
if matched {
muc.doneSniffing()
if m.readTimeout > noTimeout {
_ = c.SetReadDeadline(time.Time{})
}
select {
case sl.listen.connections <- muc:
case <-donec:
_ = c.Close()
}
return
}
}
}
_ = c.Close()
err := ErrNotMatched{c: c}
if !m.handleErr(err) {
_ = m.root.Close()
}
}
// HandleSettings 处理连接设置的函数,给予调用者一个干预系统级设置的机会
func (m *Listener) HandleSettings(h SettingsHandler) {
if h != nil {
m.settingsHandler = h
}
}
// HandleError registers an error handler that handles listener errors.
func (m *Listener) HandleError(h ErrorHandler) {
m.errorHandler = h
}
func (m *Listener) handleErr(err error) bool {
if !m.errorHandler(err) {
return false
}
if ne, ok := err.(net.Error); ok {
return ne.Temporary()
}
return false
}
// Close closes the listener
func (m *Listener) Close() error {
return m.root.Close()
}
// Addr returns the listener's network address.
func (m *Listener) Addr() net.Addr {
return m.root.Addr()
}
// ------------------------------------------------------------------------------------
type muxListener struct {
net.Listener
connections chan net.Conn
}
func (l muxListener) Accept() (net.Conn, error) {
c, ok := <-l.connections
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
// ------------------------------------------------------------------------------------
// Conn wraps a net.Conn and provides transparent sniffing of connection data.
type Conn struct {
net.Conn
sniffer sniffer
reader io.Reader
}
// NewConn creates a new sniffed connection.
func newConn(c net.Conn) *Conn {
m := &Conn{
Conn: c,
sniffer: sniffer{source: c},
}
m.sniffer.conn = m
m.reader = &m.sniffer
return m
}
// Read reads the block of data from the underlying buffer.
func (m *Conn) Read(p []byte) (int, error) {
return m.reader.Read(p)
}
func (m *Conn) startSniffing() io.Reader {
m.sniffer.reset(true)
return &m.sniffer
}
func (m *Conn) doneSniffing() {
m.sniffer.reset(false)
}
// ------------------------------------------------------------------------------------
// Sniffer represents a io.Reader which can peek incoming bytes and reset back to normal.
type sniffer struct {
conn *Conn
source io.Reader
buffer bytes.Buffer
bufferRead int
bufferSize int
sniffing bool
lastErr error
}
// Read reads data from the buffer.
func (s *sniffer) Read(p []byte) (int, error) {
if s.bufferSize > s.bufferRead {
bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize])
s.bufferRead += bn
return bn, s.lastErr
} else if !s.sniffing && s.buffer.Cap() != 0 {
s.buffer = bytes.Buffer{}
s.conn.reader = s.conn.Conn // 重置到直接从Conn读取减少判断
}
sn, sErr := s.source.Read(p)
if sn > 0 && s.sniffing {
s.lastErr = sErr
if wn, wErr := s.buffer.Write(p[:sn]); wErr != nil {
return wn, wErr
}
}
return sn, sErr
}
// Reset resets the buffer.
func (s *sniffer) reset(snif bool) {
s.sniffing = snif
s.bufferRead = 0
s.bufferSize = s.buffer.Len()
}

View File

@@ -0,0 +1,329 @@
/**********************************************************************************
* Copyright (c) 2009-2017 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
************************************************************************************/
package listener
import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/rpc"
"runtime"
"sort"
"strings"
"sync"
"testing"
"time"
)
const (
testHTTP1Resp = "http1"
rpcVal = 1234
)
func safeServe(errCh chan<- error, muxl *Listener) {
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed") {
errCh <- err
}
}
func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
c, err := rpc.Dial(addr.Network(), addr.String())
if err != nil {
t.Fatal(err)
}
return c, func() {
if err := c.Close(); err != nil {
t.Fatal(err)
}
}
}
func testListener(t *testing.T) (*Listener, func()) {
l, err := New(":0", nil)
if err != nil {
t.Fatal(err)
}
var once sync.Once
return l, func() {
once.Do(func() {
if err := l.Close(); err != nil {
t.Fatal(err)
}
})
}
}
type testHTTP1Handler struct{}
func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, testHTTP1Resp)
}
func runTestHTTPServer(errCh chan<- error, l net.Listener) {
var mu sync.Mutex
conns := make(map[net.Conn]struct{})
defer func() {
mu.Lock()
for c := range conns {
if err := c.Close(); err != nil {
errCh <- err
}
}
mu.Unlock()
}()
s := &http.Server{
Handler: &testHTTP1Handler{},
ConnState: func(c net.Conn, state http.ConnState) {
mu.Lock()
switch state {
case http.StateNew:
conns[c] = struct{}{}
case http.StateClosed:
delete(conns, c)
}
mu.Unlock()
},
}
if err := s.Serve(l); err != ErrListenerClosed {
errCh <- err
}
}
func runTestHTTP1Client(t *testing.T, addr net.Addr) {
r, err := http.Get("http://" + addr.String())
if err != nil {
t.Fatal(err)
}
defer func() {
if err = r.Body.Close(); err != nil {
t.Fatal(err)
}
}()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != testHTTP1Resp {
t.Fatalf("invalid response: want=%s got=%s", testHTTP1Resp, b)
}
}
type TestRPCRcvr struct{}
func (r TestRPCRcvr) Test(i int, j *int) error {
*j = i
return nil
}
func runTestRPCServer(errCh chan<- error, l net.Listener) {
s := rpc.NewServer()
if err := s.Register(TestRPCRcvr{}); err != nil {
errCh <- err
}
for {
c, err := l.Accept()
if err != nil {
if err != ErrListenerClosed {
errCh <- err
}
return
}
go s.ServeConn(c)
}
}
func runTestRPCClient(t *testing.T, addr net.Addr) {
c, cleanup := safeDial(t, addr)
defer cleanup()
var num int
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil {
t.Fatal(err)
}
if num != rpcVal {
t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num)
}
}
const (
handleHTTP1Close = 1
handleHTTP1Request = 2
handleAnyClose = 3
handleAnyRequest = 4
)
func TestTimeout(t *testing.T) {
defer leakCheck(t)()
m, Close := testListener(t)
defer Close()
result := make(chan int, 5)
testDuration := time.Millisecond * 100
m.SetReadTimeout(testDuration)
http1 := m.Match(MatchHTTP())
any := m.Match(MatchAny())
go func() {
_ = m.Serve()
}()
go func() {
con, err := http1.Accept()
if err != nil {
result <- handleHTTP1Close
} else {
_, _ = con.Write([]byte("http1"))
_ = con.Close()
result <- handleHTTP1Request
}
}()
go func() {
con, err := any.Accept()
if err != nil {
result <- handleAnyClose
} else {
_, _ = con.Write([]byte("any"))
_ = con.Close()
result <- handleAnyRequest
}
}()
time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners
client, err := net.Dial("tcp", m.Addr().String())
if err != nil {
log.Fatal("testTimeout client failed: ", err)
}
defer func() {
_ = client.Close()
}()
time.Sleep(testDuration / 2)
if len(result) != 0 {
log.Print("tcp ")
t.Fatal("testTimeout failed: accepted to fast: ", len(result))
}
_ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
buffer := make([]byte, 10)
rl, err := client.Read(buffer)
if err != nil {
t.Fatal("testTimeout failed: client error: ", err, rl)
}
Close()
if rl != 3 {
log.Print("testTimeout failed: response from wrong service ", rl)
}
if string(buffer[0:3]) != "any" {
log.Print("testTimeout failed: response from wrong service ")
}
time.Sleep(testDuration * 2)
if len(result) != 2 {
t.Fatal("testTimeout failed: accepted to less: ", len(result))
}
if a := <-result; a != handleAnyRequest {
t.Fatal("testTimeout failed: any rule did not match")
}
if a := <-result; a != handleHTTP1Close {
t.Fatal("testTimeout failed: no close an http rule")
}
}
func TestAny(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
muxl, cleanup := testListener(t)
defer cleanup()
httpl := muxl.Match(MatchAny())
go runTestHTTPServer(errCh, httpl)
go safeServe(errCh, muxl)
runTestHTTP1Client(t, muxl.Addr())
}
// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines() (gs []string) {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
for _, g := range strings.Split(string(buf), "\n\n") {
sl := strings.SplitN(g, "\n", 2)
if len(sl) != 2 {
continue
}
stack := strings.TrimSpace(sl[1])
if strings.HasPrefix(stack, "testing.RunTests") {
continue
}
if stack == "" ||
strings.Contains(stack, "main.main()") ||
strings.Contains(stack, "testing.Main(") ||
strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") ||
strings.Contains(stack, "interestingGoroutines") ||
strings.Contains(stack, "runtime.MHeap_Scavenger") {
continue
}
gs = append(gs, g)
}
sort.Strings(gs)
return
}
// leakCheck snapshots the currently-running goroutines and returns a
// function to be run at the end of tests to see whether any
// goroutines leaked.
func leakCheck(t testing.TB) func() {
orig := map[string]bool{}
for _, g := range interestingGoroutines() {
orig[g] = true
}
return func() {
// Loop, waiting for goroutines to shut down.
// Wait up to 5 seconds, but finish as quickly as possible.
deadline := time.Now().Add(5 * time.Second)
for {
var leaked []string
for _, g := range interestingGoroutines() {
if !orig[g] {
leaked = append(leaked, g)
}
}
if len(leaked) == 0 {
return
}
if time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
continue
}
for _, g := range leaked {
t.Errorf("Leaked goroutine: %v", g)
}
return
}
}
}

View File

@@ -0,0 +1,221 @@
/**********************************************************************************
* Copyright (c) 2009-2017 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
*
* This file was originally developed by The CMux Authors and released under Apache
* License, Version 2.0 in 2016.
************************************************************************************/
package listener
import (
"bytes"
"io"
)
var defaultHTTPMethods = []string{
"OPTIONS",
"GET",
"HEAD",
"POST",
"PATCH",
"PUT",
"DELETE",
"TRACE",
"CONNECT",
}
// ------------------------------------------------------------------------------------
// MatchAny matches any connection.
func MatchAny() Matcher {
return func(r io.Reader) bool { return true }
}
// MatchPrefix returns a matcher that matches a connection if it
// starts with any of the strings in strs.
func MatchPrefix(strs ...string) Matcher {
pt := newPatriciaTreeString(strs...)
return pt.matchPrefix
}
// MatchHTTP only matches the methods in the HTTP request.
func MatchHTTP(extMethods ...string) Matcher {
return MatchPrefix(append(defaultHTTPMethods, extMethods...)...)
}
// MatchPrefixBytes 匹配前缀字节数组
func MatchPrefixBytes(bs ...[]byte) Matcher {
pt := newPatriciaTree(bs...)
return pt.matchPrefix
}
// ------------------------------------------------------------------------------------
// patriciaTree is a simple patricia tree that handles []byte instead of string
// and cannot be changed after instantiation.
type patriciaTree struct {
root *ptNode
maxDepth int // max depth of the tree.
}
func newPatriciaTree(bs ...[]byte) *patriciaTree {
max := 0
for _, b := range bs {
if max < len(b) {
max = len(b)
}
}
return &patriciaTree{
root: newNode(bs),
maxDepth: max + 1,
}
}
func newPatriciaTreeString(strs ...string) *patriciaTree {
b := make([][]byte, len(strs))
for i, s := range strs {
b[i] = []byte(s)
}
return newPatriciaTree(b...)
}
func (t *patriciaTree) matchPrefix(r io.Reader) bool {
buf := make([]byte, t.maxDepth)
n, _ := io.ReadFull(r, buf)
return t.root.match(buf[:n], true)
}
func (t *patriciaTree) match(r io.Reader) bool {
buf := make([]byte, t.maxDepth)
n, _ := io.ReadFull(r, buf)
return t.root.match(buf[:n], false)
}
type ptNode struct {
prefix []byte
next map[byte]*ptNode
terminal bool
}
func newNode(strs [][]byte) *ptNode {
if len(strs) == 0 {
return &ptNode{
prefix: []byte{},
terminal: true,
}
}
if len(strs) == 1 {
return &ptNode{
prefix: strs[0],
terminal: true,
}
}
p, strs := splitPrefix(strs)
n := &ptNode{
prefix: p,
}
nexts := make(map[byte][][]byte)
for _, s := range strs {
if len(s) == 0 {
n.terminal = true
continue
}
nexts[s[0]] = append(nexts[s[0]], s[1:])
}
n.next = make(map[byte]*ptNode)
for first, rests := range nexts {
n.next[first] = newNode(rests)
}
return n
}
func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) {
if len(bss) == 0 || len(bss[0]) == 0 {
return prefix, bss
}
if len(bss) == 1 {
return bss[0], [][]byte{{}}
}
for i := 0; ; i++ {
var cur byte
eq := true
for j, b := range bss {
if len(b) <= i {
eq = false
break
}
if j == 0 {
cur = b[i]
continue
}
if cur != b[i] {
eq = false
break
}
}
if !eq {
break
}
prefix = append(prefix, cur)
}
rest = make([][]byte, 0, len(bss))
for _, b := range bss {
rest = append(rest, b[len(prefix):])
}
return prefix, rest
}
func (n *ptNode) match(b []byte, prefix bool) bool {
l := len(n.prefix)
if l > 0 {
if l > len(b) {
l = len(b)
}
if !bytes.Equal(b[:l], n.prefix) {
return false
}
}
if n.terminal && (prefix || len(n.prefix) == len(b)) {
return true
}
if l >= len(b) {
return false
}
nextN, ok := n.next[b[l]]
if !ok {
return false
}
if l == len(b) {
b = b[l:l]
} else {
b = b[l+1:]
}
return nextN.match(b, prefix)
}

View File

@@ -0,0 +1,59 @@
/**********************************************************************************
* Copyright (c) 2009-2017 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
*
* This file was originally developed by The CMux Authors and released under Apache
* License, Version 2.0 in 2016.
************************************************************************************/
package listener
import (
"strings"
"testing"
)
func testPTree(t *testing.T, strs ...string) {
pt := newPatriciaTreeString(strs...)
for _, s := range strs {
if !pt.match(strings.NewReader(s)) {
t.Errorf("%s is not matched by %s", s, s)
}
if !pt.matchPrefix(strings.NewReader(s + s)) {
t.Errorf("%s is not matched as a prefix by %s", s+s, s)
}
if pt.match(strings.NewReader(s + s)) {
t.Errorf("%s matches %s", s+s, s)
}
// The following tests are just to catch index out of
// range and off-by-one errors and not the functionality.
pt.matchPrefix(strings.NewReader(s[:len(s)-1]))
pt.match(strings.NewReader(s[:len(s)-1]))
pt.matchPrefix(strings.NewReader(s + "$"))
pt.match(strings.NewReader(s + "$"))
}
}
func TestPatriciaOnePrefix(t *testing.T) {
testPTree(t, "prefix")
}
func TestPatriciaNonOverlapping(t *testing.T) {
testPTree(t, "foo", "bar", "dummy")
}
func TestPatriciaOverlapping(t *testing.T) {
testPTree(t, "foo", "far", "farther", "boo", "ba", "bar")
}

239
network/websocket/websocket.go Executable file
View File

@@ -0,0 +1,239 @@
/**********************************************************************************
* Copyright (c) 2009-2017 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
************************************************************************************/
//
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package websocket
import (
"io"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Conn websocket连接
type Conn interface {
net.Conn
Subprotocol() string // 获取子协议
TextTransport() Conn // 获取文本传输通道
Path() string // 接入时的ws后的路径
Username() string // 接入是http验证后的用户名称
}
type websocketConn interface {
NextReader() (messageType int, r io.Reader, err error)
NextWriter(messageType int) (io.WriteCloser, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
Subprotocol() string
}
// websocketConn represents a websocket connection.
type websocketTransport struct {
sync.Mutex
socket websocketConn
reader io.Reader
closing chan bool
path string
username string
}
const (
writeWait = 10 * time.Second // Time allowed to write a message to the peer.
pongWait = 60 * time.Second // Time allowed to read the next pong message from the peer.
pingPeriod = (pongWait * 9) / 10 // Send pings to peer with this period. Must be less than pongWait.
closeGracePeriod = 10 * time.Second // Time to wait before force close on connection.
)
// The default upgrader to use
var upgrader = &websocket.Upgrader{
Subprotocols: []string{"rtsp", "control", "data"},
CheckOrigin: func(r *http.Request) bool { return true },
// ReadBufferSize: 64 * 1024, WriteBufferSize: 64 * 1024,
}
// TryUpgrade attempts to upgrade an HTTP request to rtsp/wsp over websocket.
func TryUpgrade(w http.ResponseWriter, r *http.Request, path, username string) (Conn, bool) {
if w == nil || r == nil {
return nil, false
}
if ws, err := upgrader.Upgrade(w, r, nil); err == nil {
return newConn(ws, path, username), true
}
return nil, false
}
// newConn creates a new transport from websocket.
func newConn(ws websocketConn, path, username string) Conn {
conn := &websocketTransport{
socket: ws,
closing: make(chan bool),
path: path,
username: username,
}
/*ws.SetReadLimit(maxMessageSize)
ws.SetReadDeadline(time.Now().Add(pongWait))
ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil })
ws.SetCloseHandler(func(code int, text string) error {
return conn.Close()
})
utils.Repeat(func() {
log.Println("ping")
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
log.Println("ping:", err)
}
}, pingPeriod, conn.closing)*/
return conn
}
// Read reads data from the connection. It is possible to allow reader to time
// out and return a Error with Timeout() == true after a fixed time limit by
// using SetDeadline and SetReadDeadline on the websocket.
func (c *websocketTransport) Read(b []byte) (n int, err error) {
var opCode int
if c.reader == nil {
// New message
var r io.Reader
for {
if opCode, r, err = c.socket.NextReader(); err != nil {
return
}
if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
continue
}
c.reader = r
break
}
}
// Read from the reader
n, err = c.reader.Read(b)
if err != nil {
if err == io.EOF {
c.reader = nil
err = nil
}
}
return
}
// Write writes data to the connection. It is possible to allow writer to time
// out and return a Error with Timeout() == true after a fixed time limit by
// using SetDeadline and SetWriteDeadline on the websocket.
func (c *websocketTransport) Write(b []byte) (n int, err error) {
// Serialize write to avoid concurrent write
c.Lock()
defer c.Unlock()
var w io.WriteCloser
if w, err = c.socket.NextWriter(websocket.BinaryMessage); err == nil {
if n, err = w.Write(b); err == nil {
err = w.Close()
}
}
return
}
// Close terminates the connection.
func (c *websocketTransport) Close() error {
return c.socket.Close()
}
// LocalAddr returns the local network address.
func (c *websocketTransport) LocalAddr() net.Addr {
return c.socket.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *websocketTransport) RemoteAddr() net.Addr {
return c.socket.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
func (c *websocketTransport) SetDeadline(t time.Time) (err error) {
if err = c.socket.SetReadDeadline(t); err == nil {
err = c.socket.SetWriteDeadline(t)
}
return
}
// SetReadDeadline sets the deadline for future Read calls
// and any currently-blocked Read call.
func (c *websocketTransport) SetReadDeadline(t time.Time) error {
return c.socket.SetReadDeadline(t)
}
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
func (c *websocketTransport) SetWriteDeadline(t time.Time) error {
return c.socket.SetWriteDeadline(t)
}
// Subprotocol 获取子协议名称
func (c *websocketTransport) Subprotocol() string {
return c.socket.Subprotocol()
}
// TextTransport 获取文本传输Conn
func (c *websocketTransport) TextTransport() Conn {
return &websocketTextTransport{c}
}
func (c *websocketTransport) Path() string {
return c.path
}
func (c *websocketTransport) Username() string {
return c.username
}
type websocketTextTransport struct {
*websocketTransport
}
// Write writes data to the connection. It is possible to allow writer to time
// out and return a Error with Timeout() == true after a fixed time limit by
// using SetDeadline and SetWriteDeadline on the websocket.
func (c *websocketTextTransport) Write(b []byte) (n int, err error) {
// Serialize write to avoid concurrent write
c.Lock()
defer c.Unlock()
var w io.WriteCloser
if w, err = c.socket.NextWriter(websocket.TextMessage); err == nil {
if n, err = w.Write(b); err == nil {
err = w.Close()
}
}
return
}

View File

@@ -0,0 +1,135 @@
package websocket
import (
"bytes"
"io"
"net"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)
type writer bytes.Buffer
func (w *writer) Close() error { return nil }
func (w *writer) Write(data []byte) (n int, err error) { return ((*bytes.Buffer)(w)).Write(data) }
type conn struct {
read []byte
write *writer
}
func (c *conn) NextReader() (messageType int, r io.Reader, err error) {
messageType = websocket.BinaryMessage
r = bytes.NewBuffer(c.read)
if c.read == nil {
err = io.EOF
}
return
}
func (c *conn) NextWriter(messageType int) (w io.WriteCloser, err error) {
w = c.write
if c.write == nil {
err = io.EOF
}
return
}
func (c *conn) Close() error { return nil }
func (c *conn) LocalAddr() net.Addr { return &net.IPAddr{} }
func (c *conn) RemoteAddr() net.Addr { return &net.IPAddr{} }
func (c *conn) SetReadDeadline(t time.Time) error { return nil }
func (c *conn) SetWriteDeadline(t time.Time) error { return nil }
func (c *conn) Subprotocol() string { return "" }
func TestTryUpgradeNil(t *testing.T) {
_, ok := TryUpgrade(nil, nil, "", "")
assert.Equal(t, false, ok)
}
func TestTryUpgrade(t *testing.T) {
//httptest.NewServer(handler)
r := httptest.NewRequest("GET", "http://127.0.0.1/", bytes.NewBuffer([]byte{}))
r.Header.Set("Connection", "upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; client_max_window_bits")
r.Header.Set("Sec-WebSocket-Key", "D1icfJz+khA9kj5/14dRXQ==")
r.Header.Set("Sec-WebSocket-Protocol", "mqttv3.1")
r.Header.Set("Sec-WebSocket-Version", "13")
w := httptest.NewRecorder()
assert.NotPanics(t, func() {
TryUpgrade(w, r, "", "")
})
// TODO: need to have a hijackable response writer to test properly
//ws, ok := TryUpgrade(w, r)
//assert.NotNil(t, ws)
//assert.True(t, ok)
}
func TestRead_EOF(t *testing.T) {
c := newConn(new(conn), "", "")
_, err := c.Read([]byte{})
assert.Error(t, io.EOF, err)
}
func TestRead(t *testing.T) {
message := []byte("hello world")
c := &websocketTransport{
socket: &conn{
read: message,
},
closing: make(chan bool),
}
buffer := make([]byte, 64)
n, err := c.Read(buffer)
assert.NoError(t, err)
assert.Equal(t, message, buffer[:n])
}
func TestWrite(t *testing.T) {
message := []byte("hello world")
buffer := new(bytes.Buffer)
c := &websocketTransport{
socket: &conn{
write: (*writer)(buffer),
},
closing: make(chan bool),
}
_, err := c.Write(message)
assert.NoError(t, err)
assert.Equal(t, message, buffer.Bytes())
}
func TestMisc(t *testing.T) {
c := &websocketTransport{
socket: &conn{},
closing: make(chan bool),
}
err := c.Close()
assert.NoError(t, err)
err = c.SetDeadline(time.Now())
assert.NoError(t, err)
err = c.SetReadDeadline(time.Now())
assert.NoError(t, err)
err = c.SetWriteDeadline(time.Now())
assert.NoError(t, err)
addr1 := c.LocalAddr()
assert.Equal(t, "", addr1.String())
addr2 := c.RemoteAddr()
assert.Equal(t, "", addr2.String())
}

77
provider/auth/json.go Executable file
View File

@@ -0,0 +1,77 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"github.com/cnotch/tomatox/utils"
)
// JSON json 提供者
var JSON = &jsonProvider{}
type jsonProvider struct {
filePath string
}
func (p *jsonProvider) Name() string {
return "json"
}
func (p *jsonProvider) Configure(config map[string]interface{}) error {
path, ok := config["file"]
if ok {
switch v := path.(type) {
case string:
p.filePath = v
default:
return fmt.Errorf("invalid user config, file attr: %v", path)
}
} else {
p.filePath = "users.json"
}
if !filepath.IsAbs(p.filePath) {
exe, err := os.Executable()
if err != nil {
return err
}
p.filePath = filepath.Join(filepath.Dir(exe), p.filePath)
}
return nil
}
func (p *jsonProvider) LoadAll() ([]*User, error) {
path := p.filePath
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
// 从文件读
b, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
var users []*User
if err := json.Unmarshal(b, &users); err != nil {
return nil, err
}
return users, nil
}
func (p *jsonProvider) Flush(full []*User, saves []*User, removes []*User) error {
return utils.EncodeJSONFile(p.filePath, full)
}

205
provider/auth/manager.go Executable file
View File

@@ -0,0 +1,205 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import (
"strings"
"sync"
"github.com/cnotch/xlog"
)
var globalM = &manager{
m: make(map[string]*User),
}
func init() {
// 默认为内存提供者,避免没有初始化全局函数调用问题
globalM.Reset(&memProvider{})
}
// Reset 重置用户提供者
func Reset(provider UserProvider) {
globalM.Reset(provider)
}
// All 获取所有的用户
func All() []*User {
return globalM.All()
}
// Get 获取取指定名称的用户
func Get(userName string) *User {
return globalM.Get(userName)
}
// Del 删除指定名称的用户
func Del(userName string) error {
return globalM.Del(userName)
}
// Save 保存用户
func Save(src *User, updatePassword bool) error {
return globalM.Save(src, updatePassword)
}
// Flush 刷新用户
func Flush() error {
return globalM.Flush()
}
type manager struct {
lock sync.RWMutex
m map[string]*User // 用户map
l []*User // 用户list
saves []*User // 自上次Flush后新的保存和删除的用户
removes []*User
provider UserProvider
}
func (m *manager) Reset(provider UserProvider) {
m.lock.Lock()
defer m.lock.Unlock()
m.m = make(map[string]*User)
m.l = m.l[:0]
m.saves = m.saves[:0]
m.removes = m.removes[:0]
m.provider = provider
users, err := provider.LoadAll()
if err != nil {
panic("Load user fail")
}
if cap(m.l) < len(users) {
m.l = make([]*User, 0, len(users))
}
// 加入缓存
for _, u := range users {
if err := u.init(); err != nil {
xlog.Warnf("user table init failed: `%v`", err)
continue // 忽略错误的配置
}
m.m[u.Name] = u
m.l = append(m.l, u)
}
}
func (m *manager) Get(userName string) *User {
m.lock.RLock()
defer m.lock.RUnlock()
userName = strings.ToLower(userName)
u, ok := m.m[userName]
if ok {
return u
}
return nil
}
func (m *manager) Del(userName string) error {
m.lock.Lock()
defer m.lock.Unlock()
userName = strings.ToLower(userName)
u, ok := m.m[userName]
if ok {
delete(m.m, userName)
// 从完整列表中删除
for i, u2 := range m.l {
if u.Name == u2.Name {
m.l = append(m.l[:i], m.l[i+1:]...)
break
}
}
// 从保存列表中删除
for i, u2 := range m.saves {
if u.Name == u2.Name {
m.saves = append(m.saves[:i], m.saves[i+1:]...)
break
}
}
m.removes = append(m.removes, u)
}
return nil
}
func (m *manager) Save(newu *User, updatePassword bool) error {
m.lock.Lock()
defer m.lock.Unlock()
err := newu.init()
if err != nil {
return err
}
u, ok := m.m[newu.Name]
if ok { // 更新
u.CopyFrom(newu, updatePassword)
save := true
// 如果保存列表存在,不新增
for _, u2 := range m.saves {
if u.Name == u2.Name {
save = false
break
}
}
if save {
m.saves = append(m.saves, u)
}
} else { // 新增
u = newu
m.m[u.Name] = u
m.l = append(m.l, u)
m.saves = append(m.saves, u)
for i, u2 := range m.removes {
if u.Name == u2.Name {
m.removes = append(m.removes[:i], m.removes[i+1:]...)
break
}
}
}
return nil
}
func (m *manager) Flush() error {
m.lock.Lock()
defer m.lock.Unlock()
if len(m.saves)+len(m.removes) == 0 {
return nil
}
err := m.provider.Flush(m.l, m.saves, m.removes)
if err != nil {
return err
}
m.saves = m.saves[:0]
m.removes = m.removes[:0]
return nil
}
func (m *manager) All() []*User {
m.lock.RLock()
defer m.lock.RUnlock()
users := make([]*User, len(m.l))
copy(users, m.l)
return users
}

28
provider/auth/memory.go Executable file
View File

@@ -0,0 +1,28 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
type memProvider struct {
}
func (p *memProvider) Name() string {
return "memory"
}
func (p *memProvider) Configure(config map[string]interface{}) error {
return nil
}
func (p *memProvider) LoadAll() ([]*User, error) {
return []*User{{
Name: "admin",
Password: "admin",
Admin: true,
}}, nil
}
func (p *memProvider) Flush(full []*User, saves []*User, removes []*User) error {
return nil
}

78
provider/auth/mode.go Executable file
View File

@@ -0,0 +1,78 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import (
"bytes"
"errors"
"fmt"
)
// Mode 认证模式
type Mode int
// 认证模式常量
const (
NoneAuth Mode = iota
BasicAuth
DigestAuth
)
var errUnmarshalNilMode = errors.New("can't unmarshal a nil *Mode")
// String 返回认证模式字串
func (m Mode) String() string {
switch m {
case NoneAuth:
return "NONE"
case BasicAuth:
return "BASIC"
case DigestAuth:
return "DIGEST"
default:
return fmt.Sprintf("AuthMode(%d)", m)
}
}
// MarshalText 编入认证模式到文本
func (m Mode) MarshalText() ([]byte, error) {
return []byte(m.String()), nil
}
// UnmarshalText 从文本编出认证模式
// 典型的用于 YAML、TOML、JSON等文件编出
func (m *Mode) UnmarshalText(text []byte) error {
if m == nil {
return errUnmarshalNilMode
}
if !m.unmarshalText(text) && !m.unmarshalText(bytes.ToLower(text)) {
return fmt.Errorf("unrecognized Mode: %q", text)
}
return nil
}
func (m *Mode) unmarshalText(text []byte) bool {
switch string(text) {
case "none", "NONE", "": // make the zero value useful
*m = NoneAuth
case "basic", "BASIC":
*m = BasicAuth
case "digest", "DIGEST":
*m = DigestAuth
default:
return false
}
return true
}
// Set flag.Value 接口实现.
func (m *Mode) Set(s string) error {
return m.UnmarshalText([]byte(s))
}
// Get flag.Getter 接口实现
func (m *Mode) Get() interface{} {
return *m
}

91
provider/auth/path_matcher.go Executable file
View File

@@ -0,0 +1,91 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import (
"strings"
"unicode"
"github.com/cnotch/tomatox/utils/scan"
)
const (
sectionWildcard = "+" // 单段通配符
endWildcard = "*" // 0-n段通配符必须位于结尾
)
// 行分割
var pathScanner = scan.NewScanner('/', unicode.IsSpace)
// PathMatcher 路径匹配接口
type PathMatcher interface {
Match(path string) bool
}
// NewPathMatcher 创建匹配器
func NewPathMatcher(pathMask string) PathMatcher {
if strings.TrimSpace(pathMask) == endWildcard {
return alwaysMatcher{}
}
parts := strings.Split(strings.ToLower(strings.Trim(pathMask, "/")), "/")
wildcard := parts[len(parts)-1] == endWildcard
if wildcard {
parts = parts[0 : len(parts)-1]
}
return &pathMacher{parts: parts, wildcardEnd: wildcard}
}
type alwaysMatcher struct {
}
func (m alwaysMatcher) Match(path string) bool {
return true
}
type pathMacher struct {
parts []string
wildcardEnd bool
}
func (m *pathMacher) Match(path string) bool {
path = strings.ToLower(strings.Trim(path, "/"))
count := partCount(path) + 1
if count < len(m.parts) {
return false
}
if count > len(m.parts) && !m.wildcardEnd {
return false
}
ok := true
advance := path
token := ""
for i := 0; i < len(m.parts) && ok; i++ {
advance, token, ok = pathScanner.Scan(advance)
if sectionWildcard == m.parts[i] {
continue // 跳过
}
if token != m.parts[i] {
return false
}
}
return true
}
func partCount(s string) int {
n := 0
for {
i := strings.IndexByte(s, '/')
if i == -1 {
return n
}
n++
s = s[i+1:]
}
}

View File

@@ -0,0 +1,55 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import "testing"
func Test_alwaysMatcher_Match(t *testing.T) {
tests := []struct {
name string
path string
want bool
}{
{"always", "/a/b", true},
{"always", "/a", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := NewPathMatcher(endWildcard)
if got := m.Match(tt.path); got != tt.want {
t.Errorf("alwaysMatcher.Match() = %v, want %v", got, tt.want)
}
})
}
}
func Test_pathMacher_Match(t *testing.T) {
tests := []struct {
name string
pathMask string
path string
want bool
}{
{"g1", "/a", "/a", true},
{"g2", "/a", "/a/b", false},
{"e1", "/a/*", "/a", true},
{"e2", "/a/*", "/a/b", true},
{"e3", "/a/*", "/a/b/c", true},
{"e4", "/a/*", "/b", false},
{"c1", "/a/+/c/*", "/a/b/c", true},
{"c2", "/a/+/c/*", "/a/d/c", true},
{"c3", "/a/+/c/*", "/a/b/c/d", true},
{"c4", "/a/+/c/*", "/a/b/c/d/e", true},
{"c5", "/a/+/c/*", "/a/c/d/e", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := NewPathMatcher(tt.pathMask)
if got := m.Match(tt.path); got != tt.want {
t.Errorf("pathMacher.Match() = %v, want %v", got, tt.want)
}
})
}
}

87
provider/auth/token.go Executable file
View File

@@ -0,0 +1,87 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import (
"sync"
"time"
"github.com/cnotch/tomatox/provider/security"
)
// Token 用户登录后的Token
type Token struct {
Username string `json:"-"`
AToken string `json:"access_token"`
AExp int64 `json:"-"`
RToken string `json:"refresh_token"`
RExp int64 `json:"-"`
}
// TokenManager token管理
type TokenManager struct {
tokens sync.Map // token->Token
}
// NewToken 给用户新建Token
func (tm *TokenManager) NewToken(username string) *Token {
token := &Token{
Username: username,
AToken: security.NewID().MD5(),
AExp: time.Now().Add(time.Hour * time.Duration(2)).Unix(),
RToken: security.NewID().MD5(),
RExp: time.Now().Add(time.Hour * time.Duration(7*24)).Unix(),
}
tm.tokens.Store(token.AToken, token)
tm.tokens.Store(token.RToken, token)
return token
}
// Refresh 刷新指定的Token
func (tm *TokenManager) Refresh(rtoken string) *Token {
ti, ok := tm.tokens.Load(rtoken)
if ok {
oldToken := ti.(*Token)
username := oldToken.Username
if rtoken == oldToken.RToken { // 是refresh token
tm.tokens.Delete(oldToken.AToken)
tm.tokens.Delete(oldToken.RToken)
if oldToken.RExp > time.Now().Unix() {
return tm.NewToken(username)
}
}
}
return nil
}
// AccessCheck 访问检测
func (tm *TokenManager) AccessCheck(atoken string) string {
ti, ok := tm.tokens.Load(atoken)
if ok {
token := ti.(*Token)
if token.AToken == atoken { // 访问token
if token.AExp > time.Now().Unix() {
return token.Username
}
tm.tokens.Delete(token.AToken)
}
}
return ""
}
// ExpCheck 过期检测
func (tm *TokenManager) ExpCheck() {
tm.tokens.Range(func(k, v interface{}) bool {
token := v.(*Token)
if time.Now().Unix() > token.AExp {
tm.tokens.Delete(token.AToken)
}
if time.Now().Unix() > token.RExp {
tm.tokens.Delete(token.RToken)
}
return true
})
}

141
provider/auth/user.go Executable file
View File

@@ -0,0 +1,141 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import (
"crypto/md5"
"encoding/hex"
"errors"
"strings"
"github.com/cnotch/tomatox/utils/scan"
)
// AccessRight 访问权限类型
type AccessRight int
// 权限常量
const (
PullRight AccessRight = 1 << iota // 拉流权限
PushRight // 推流权限
)
// UserProvider 用户提供者
type UserProvider interface {
LoadAll() ([]*User, error)
Flush(full []*User, saves []*User, removes []*User) error
}
// User 用户
type User struct {
Name string `json:"name"`
Password string `json:"password,omitempty"`
Admin bool `json:"admin,omitempty"`
PushAccess string `json:"push,omitempty"`
PullAccess string `json:"pull,omitempty"`
pushMatchers []PathMatcher
pullMatchers []PathMatcher
}
func initMatchers(access string, destMatcher *[]PathMatcher) {
advance := access
pathMask := ""
continueScan := true
for continueScan {
advance, pathMask, continueScan = scan.Semicolon.Scan(advance)
if len(pathMask) == 0 {
continue
}
*destMatcher = append(*destMatcher, NewPathMatcher(pathMask))
}
}
func (u *User) init() error {
u.Name = strings.ToLower(u.Name)
if u.Admin {
if len(u.PullAccess) == 0 {
u.PullAccess = "*"
}
if len(u.PushAccess) == 0 {
u.PushAccess = "*"
}
}
initMatchers(u.PushAccess, &u.pushMatchers)
initMatchers(u.PullAccess, &u.pullMatchers)
return nil
}
// PasswordMD5 返回口令的MD5字串
func (u *User) PasswordMD5() string {
if passwordNeedMD5(u.Password) {
pw := md5.Sum([]byte(u.Password))
return hex.EncodeToString(pw[:])
}
return u.Password
}
// ValidatePassword 验证密码
func (u *User) ValidatePassword(password string) error {
if passwordNeedMD5(password) {
pw := md5.Sum([]byte(password))
password = hex.EncodeToString(pw[:])
}
if strings.EqualFold(u.PasswordMD5(), password) {
return nil
}
return errors.New("password error")
}
// ValidatePermission 验证权限
func (u *User) ValidatePermission(path string, right AccessRight) bool {
var matchers []PathMatcher
switch right {
case PushRight:
matchers = u.pushMatchers
case PullRight:
matchers = u.pullMatchers
}
if matchers == nil {
return false
}
path = strings.TrimSpace(path)
for _, matcher := range matchers {
if matcher.Match(path) {
return true
}
}
return false
}
// CopyFrom 从源属性并初始化
func (u *User) CopyFrom(src *User, withPassword bool) {
if withPassword {
u.Password = src.Password
}
u.Admin = src.Admin
u.PushAccess = src.PushAccess
u.PullAccess = src.PullAccess
u.init()
}
// 密码是否需要进行md5处理如果已经是md5则不处理
func passwordNeedMD5(password string) bool {
if len(password) != 32 {
return true
}
_, err := hex.DecodeString(password)
if err != nil {
return true
}
return false
}

37
provider/auth/user_test.go Executable file
View File

@@ -0,0 +1,37 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package auth
import "testing"
func TestUser_ValidePermission(t *testing.T) {
u := User{
Name: "cao",
Password: "ok",
PushAccess: "/a/+/c",
PullAccess: "/a/*",
}
u.init()
tests := []struct {
name string
path string
right AccessRight
want bool
}{
{"2", "/a/b/c", PushRight, true},
{"3", "/a/c", PushRight, false},
{"4", "/a", PullRight, true},
{"5", "/a/c", PullRight, true},
{"6", "/a/c/d", PullRight, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := u.ValidatePermission(tt.path, tt.right); got != tt.want {
t.Errorf("User.ValidePermission() = %v, want %v", got, tt.want)
}
})
}
}

77
provider/route/json.go Executable file
View File

@@ -0,0 +1,77 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package route
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"github.com/cnotch/tomatox/utils"
)
// JSON json 提供者
var JSON = &jsonProvider{}
type jsonProvider struct {
filePath string
}
func (p *jsonProvider) Name() string {
return "json"
}
func (p *jsonProvider) Configure(config map[string]interface{}) error {
path, ok := config["file"]
if ok {
switch v := path.(type) {
case string:
p.filePath = v
default:
return fmt.Errorf("invalid route table config, file attr: %v", path)
}
} else {
p.filePath = "routetable.json"
}
if !filepath.IsAbs(p.filePath) {
exe, err := os.Executable()
if err != nil {
return err
}
p.filePath = filepath.Join(filepath.Dir(exe), p.filePath)
}
return nil
}
func (p *jsonProvider) LoadAll() ([]*Route, error) {
path := p.filePath
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
// 从文件读
b, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
var routes []*Route
if err := json.Unmarshal(b, &routes); err != nil {
return nil, err
}
return routes, nil
}
func (p *jsonProvider) Flush(full []*Route, saves []*Route, removes []*Route) error {
return utils.EncodeJSONFile(p.filePath, full)
}

24
provider/route/memory.go Executable file
View File

@@ -0,0 +1,24 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package route
type memProvider struct {
}
func (p *memProvider) Name() string {
return "memory"
}
func (p *memProvider) Configure(config map[string]interface{}) error {
return nil
}
func (p *memProvider) LoadAll() ([]*Route, error) {
return nil, nil
}
func (p *memProvider) Flush(full []*Route, saves []*Route, removes []*Route) error {
return nil
}

39
provider/route/route.go Executable file
View File

@@ -0,0 +1,39 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package route
import (
"net/url"
"github.com/cnotch/tomatox/utils"
)
// Route 路由
type Route struct {
Pattern string `json:"pattern"` // 路由模式字串
URL string `json:"url"` // 目标url
KeepAlive bool `json:"keepalive,omitempty"` // 是否一直保持连接,直到对方断开;默认 false会在没有人使用时关闭
}
func (r *Route) init() error {
r.Pattern = utils.CanonicalPath(r.Pattern)
_, err := url.Parse(r.URL)
if err != nil {
return err
}
return nil
}
// CopyFrom 从源拷贝
func (r *Route) CopyFrom(src *Route) {
r.URL = src.URL
r.KeepAlive = src.KeepAlive
}
// Provider 路由提供者
type Provider interface {
LoadAll() ([]*Route, error)
Flush(full []*Route, saves []*Route, removes []*Route) error
}

261
provider/route/routetable.go Executable file
View File

@@ -0,0 +1,261 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package route
import (
"sync"
"github.com/cnotch/tomatox/utils"
"github.com/cnotch/xlog"
)
var globalT = &routetable{
m: make(map[string]*Route),
}
func init() {
// 默认为内存提供者,避免没有初始化全局函数调用问题
globalT.Reset(&memProvider{})
}
// Reset 重置路由表提供者
func Reset(provider Provider) {
globalT.Reset(provider)
}
// Match 从路由表中获取和路径匹配的路由实例
func Match(path string) *Route {
return globalT.Match(path)
}
// All 获取所有的路由
func All() []*Route {
return globalT.All()
}
// Get 获取取指定模式的路由
func Get(pattern string) *Route {
return globalT.Get(pattern)
}
// Del 删除指定模式的路由
func Del(pattern string) error {
return globalT.Del(pattern)
}
// Save 保存路由
func Save(src *Route) error {
return globalT.Save(src)
}
// Flush 刷新路由
func Flush() error {
return globalT.Flush()
}
type routetable struct {
lock sync.RWMutex
m map[string]*Route // 路由map
l []*Route // 路由list
saves []*Route // 自上次Flush后新的保存和删除的路由
removes []*Route
provider Provider
}
func (t *routetable) Reset(provider Provider) {
t.lock.Lock()
defer t.lock.Unlock()
t.m = make(map[string]*Route)
t.l = t.l[:0]
t.saves = t.saves[:0]
t.removes = t.removes[:0]
t.provider = provider
routes, err := provider.LoadAll()
if err != nil {
panic("Load route table fail")
}
if cap(t.l) < len(routes) {
t.l = make([]*Route, 0, len(routes))
}
// 加入缓存
for _, r := range routes {
if err := r.init(); err != nil {
xlog.Warnf("route table init failed: `%v`", err)
continue // 忽略错误的配置
}
t.m[r.Pattern] = r
t.l = append(t.l, r)
}
}
func (t *routetable) Match(path string) *Route {
t.lock.RLock()
defer t.lock.RUnlock()
path = utils.CanonicalPath(path)
if path[len(path)-1] == '/' { // 必须有具体的子路径
return nil
}
r, ok := t.m[path]
if ok { // 精确匹配
ret := *r
return &ret
}
// 获取最长有效匹配的路由
var n = 0
for k, v := range t.m {
if !pathMatch(k, path) {
continue
}
if r == nil || len(k) > n {
n = len(k)
r = v
}
}
if r != nil {
ret := *r
r = &ret
if r.URL[len(r.URL)-1] == '/' {
r.URL = r.URL + path[len(r.Pattern):]
} else {
r.URL = r.URL + path[len(r.Pattern)-1:]
}
r.Pattern = path
}
return r
}
func (t *routetable) Get(pattern string) *Route {
t.lock.RLock()
defer t.lock.RUnlock()
pattern = utils.CanonicalPath(pattern)
r, _ := t.m[pattern]
return r
}
func (t *routetable) Del(pattern string) error {
t.lock.Lock()
defer t.lock.Unlock()
pattern = utils.CanonicalPath(pattern)
r, ok := t.m[pattern]
if ok {
delete(t.m, pattern)
// 从完整列表中删除
for i, r2 := range t.l {
if r.Pattern == r2.Pattern {
t.l = append(t.l[:i], t.l[i+1:]...)
break
}
}
// 从保存列表中删除
for i, r2 := range t.saves {
if r.Pattern == r2.Pattern {
t.saves = append(t.saves[:i], t.saves[i+1:]...)
break
}
}
t.removes = append(t.removes, r)
}
return nil
}
func (t *routetable) Save(newr *Route) error {
t.lock.Lock()
defer t.lock.Unlock()
err := newr.init()
if err != nil {
return err
}
r, ok := t.m[newr.Pattern]
if ok { // 更新
r.CopyFrom(newr)
save := true
// 如果保存列表存在,不新增
for _, r2 := range t.saves {
if r.Pattern == r2.Pattern {
save = false
break
}
}
if save {
t.saves = append(t.saves, r)
}
} else { // 新增
r = newr
t.m[r.Pattern] = r
t.l = append(t.l, r)
t.saves = append(t.saves, r)
for i, r2 := range t.removes {
if r.Pattern == r2.Pattern {
t.removes = append(t.removes[:i], t.removes[i+1:]...)
break
}
}
}
return nil
}
func (t *routetable) Flush() error {
t.lock.Lock()
defer t.lock.Unlock()
if len(t.saves)+len(t.removes) == 0 {
return nil
}
err := t.provider.Flush(t.l, t.saves, t.removes)
if err != nil {
return err
}
t.saves = t.saves[:0]
t.removes = t.removes[:0]
return nil
}
func (t *routetable) All() []*Route {
t.lock.RLock()
defer t.lock.RUnlock()
routes := make([]*Route, len(t.l))
copy(routes, t.l)
return routes
}
// Does path match pattern?
func pathMatch(pattern, path string) bool {
if len(pattern) == 0 {
// should not happen
return false
}
n := len(pattern)
if pattern[n-1] != '/' {
return pattern == path
}
return len(path) >= n && path[0:n] == pattern
}

View File

@@ -0,0 +1,36 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package route
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_routetable(t *testing.T) {
t.Run("routetable", func(t *testing.T) {
Save(&Route{"/test/live1", "rtsp://localhost:5540/live1", false})
assert.Equal(t, 1, len(globalT.l))
r := Get("/test/live1")
assert.NotNil(t, r)
Save(&Route{"/easy/", "rtsp://localhost:5540/test", false})
assert.Equal(t, 2, len(globalT.l))
r = Match("/easy/live4")
assert.NotNil(t, r)
assert.Equal(t, "rtsp://localhost:5540/test/live4", r.URL)
Del("/test/live1")
Save(&Route{"/test/live1", "rtsp://localhost:5540/live1", false})
Save(&Route{"/test/live1", "rtsp://localhost:5540/live1", false})
Del("/test/live1")
Save(&Route{"/test/live1", "rtsp://localhost:5540/live1", false})
assert.Equal(t, 2, len(globalT.saves))
assert.Equal(t, 0, len(globalT.removes))
Flush()
assert.Equal(t, 0, len(globalT.saves))
})
}

85
provider/security/id.go Executable file
View File

@@ -0,0 +1,85 @@
/**********************************************************************************
* Copyright (c) 2009-2017 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
************************************************************************************/
//
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package security
import (
"crypto/md5"
"crypto/sha1"
"encoding/base32"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"strconv"
"strings"
"sync/atomic"
"time"
"golang.org/x/crypto/pbkdf2"
)
// ID represents a process-wide unique ID.
type ID uint64
// next is the next identifier. We seed it with the time in seconds
// to avoid collisions of ids between process restarts.
var next = uint64(
time.Now().Sub(time.Date(2017, 9, 17, 0, 0, 0, 0, time.UTC)).Seconds(),
)
// NewID generates a new, process-wide unique ID.
func NewID() ID {
return ID(atomic.AddUint64(&next, 1))
}
// Unique generates unique id based on the current id with a prefix and salt.
func (id ID) Unique(prefix uint64, salt string) string {
buffer := [16]byte{}
binary.BigEndian.PutUint64(buffer[:8], prefix)
binary.BigEndian.PutUint64(buffer[8:], uint64(id))
enc := pbkdf2.Key(buffer[:], []byte(salt), 4096, 16, sha1.New)
return strings.Trim(base32.StdEncoding.EncodeToString(enc), "=")
}
// String converts the ID to a string representation.
func (id ID) String() string {
return strconv.FormatUint(uint64(id), 10)
}
// Base64 Base64格式
func (id ID) Base64() string {
buf := [10]byte{}
l := binary.PutUvarint(buf[:], uint64(id))
return base64.RawURLEncoding.EncodeToString(buf[:l])
}
// Hex 二进制格式
func (id ID) Hex() string {
buf := [10]byte{}
l := binary.PutUvarint(buf[:], uint64(id))
return strings.ToUpper(hex.EncodeToString(buf[:l]))
}
// MD5 获取ID的MD5值
func (id ID) MD5() string {
buf := [10]byte{}
l := binary.PutUvarint(buf[:], uint64(id))
md5Digest := md5.Sum(buf[:l])
return hex.EncodeToString(md5Digest[:])
}

59
stats/conns.go Executable file
View File

@@ -0,0 +1,59 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package stats
import (
"sync/atomic"
)
// 全局变量
var (
RtspConns = NewConns() // RTSP连接统计
RtmpConns = NewConns() // RTMP连接统计
WspConns = NewConns() // WSP连接统计
FlvConns = NewConns() // flv连接统计
)
// ConnsSample 连接计数采样
type ConnsSample struct {
Total int64 `json:"total"`
Active int64 `json:"active"`
}
// Conns 连接统计
type Conns interface {
Add() int64
Release() int64
GetSample() ConnsSample
}
func (s *ConnsSample) clone() ConnsSample {
return ConnsSample{
Total: atomic.LoadInt64(&s.Total),
Active: atomic.LoadInt64(&s.Active),
}
}
type conns struct {
sample ConnsSample
}
// NewConns 新建连接计数
func NewConns() Conns {
return &conns{}
}
func (c *conns) Add() int64 {
atomic.AddInt64(&c.sample.Total, 1)
return atomic.AddInt64(&c.sample.Active, 1)
}
func (c *conns) Release() int64 {
return atomic.AddInt64(&c.sample.Active, -1)
}
func (c *conns) GetSample() ConnsSample {
return c.sample.clone()
}

82
stats/flow.go Executable file
View File

@@ -0,0 +1,82 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package stats
import (
"sync/atomic"
)
// FlowSample 流统计采样
type FlowSample struct {
InBytes int64 `json:"inbytes"`
OutBytes int64 `json:"outbytes"`
}
// Flow 流统计接口
type Flow interface {
AddIn(size int64) // 增加输入
AddOut(size int64) // 增加输出
GetSample() FlowSample // 获取当前时点采样
}
func (fs *FlowSample) clone() FlowSample {
return FlowSample{
InBytes: atomic.LoadInt64(&fs.InBytes),
OutBytes: atomic.LoadInt64(&fs.OutBytes),
}
}
// Add 采样累加
func (fs *FlowSample) Add(f FlowSample) {
fs.InBytes = fs.InBytes + f.InBytes
fs.OutBytes = fs.OutBytes + f.OutBytes
}
type flow struct {
sample FlowSample
}
// NewFlow 创建流量统计
func NewFlow() Flow {
return &flow{}
}
func (r *flow) AddIn(size int64) {
atomic.AddInt64(&r.sample.InBytes, size)
}
func (r *flow) AddOut(size int64) {
atomic.AddInt64(&r.sample.OutBytes, size)
}
func (r *flow) GetSample() FlowSample {
return r.sample.clone()
}
type childFlow struct {
parent Flow
sample FlowSample
}
// NewChildFlow 创建子流量计数它会把自己的计数Add到parent上
func NewChildFlow(parent Flow) Flow {
return &childFlow{
parent: parent,
}
}
func (r *childFlow) AddIn(size int64) {
atomic.AddInt64(&r.sample.InBytes, size)
r.parent.AddIn(size)
}
func (r *childFlow) AddOut(size int64) {
atomic.AddInt64(&r.sample.OutBytes, size)
r.parent.AddOut(size)
}
func (r *childFlow) GetSample() FlowSample {
return r.sample.clone()
}

29
stats/flow_test.go Executable file
View File

@@ -0,0 +1,29 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package stats
import (
"testing"
)
func TestFlow(t *testing.T) {
totalFlow := NewFlow()
sub1 := NewChildFlow(totalFlow)
sub2 := NewChildFlow(totalFlow)
t.Run("", func(t *testing.T) {
sub1.AddIn(100)
sample := sub1.GetSample()
if sample.InBytes != 100 {
t.Error("InBytes not is 100")
}
sub2.AddIn(200)
sample = totalFlow.GetSample()
if sample.InBytes != 300 {
t.Error("InBytes not is 300")
}
})
}

131
stats/runtime.go Executable file
View File

@@ -0,0 +1,131 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package stats
import (
"runtime"
"time"
"github.com/kelindar/process"
)
// 创建时间
var (
StartingTime = time.Now()
)
// Runtime 运行时统计
type Runtime struct {
Heap Heap `json:"heap"`
MCache Memory `json:"mcache"` // MemStats.MCacheInuse/MCacheSys
MSpan Memory `json:"mspan"` // MemStats.MSpanInuse/MSpanSys
Stack Memory `json:"stack"` // MemStats.StackInuse/StackSys
GC GC `json:"gc"`
Go Go `json:"go"`
}
// Proc 进程信息统计
type Proc struct {
CPU float64 `json:"cpu"` // cpu使用情况
Priv int32 `json:"priv"` // 私有内存 KB
Virt int32 `json:"virt"` // 虚拟内存 KB
Uptime int32 `json:"uptime"` // 运行时间 S
}
// Heap 运行是堆信息
type Heap struct {
Inuse int32 `json:"inuse"` // KB MemStats.HeapInuse
Sys int32 `json:"sys"` // KB MemStats.HeapSys
Alloc int32 `json:"alloc"` // KB MemStats.HeapAlloc
Idle int32 `json:"idle"` // KB MemStats.HeapIdle
Released int32 `json:"released"` // KB MemStats.HeapReleased
Objects int32 `json:"objects"` // = MemStats.HeapObjects
}
// Memory 通用内存信息
type Memory struct {
Inuse int32 `json:"inuse"` // KB
Sys int32 `json:"sys"` // KB
}
// GC 垃圾回收信息
type GC struct {
CPU float64 `json:"cpu"` // cpu使用情况
Sys int32 `json:"sys"` // KB MemStats.GCSys
}
// Go Go运行时 goroutines 、threads 和 total memory
type Go struct {
Count int32 `json:"count"` // runtime.NumGoroutine()
Procs int32 `json:"procs"` //runtime.NumCPU()
Sys int32 `json:"sys"` // KB MemStats.Sys
Alloc int32 `json:"alloc"` // KB MemStats.TotalAlloc
}
// MeasureRuntime 获取运行时信息。
func MeasureRuntime() Proc {
defer recover()
var memoryPriv, memoryVirtual int64
var cpu float64
process.ProcUsage(&cpu, &memoryPriv, &memoryVirtual)
return Proc{
CPU: cpu,
Priv: toKB(uint64(memoryPriv)),
Virt: toKB(uint64(memoryVirtual)),
Uptime: int32(time.Now().Sub(StartingTime).Seconds()),
}
}
// MeasureFullRuntime 获取运行时信息。
func MeasureFullRuntime() *Runtime {
defer recover()
// Collect stats
var memory runtime.MemStats
runtime.ReadMemStats(&memory)
return &Runtime{
// Measure heap information
Heap: Heap{
Alloc: toKB(memory.HeapAlloc),
Idle: toKB(memory.HeapIdle),
Inuse: toKB(memory.HeapInuse),
Objects: int32(memory.HeapObjects),
Released: toKB(memory.HeapReleased),
Sys: toKB(memory.HeapSys),
},
// Measure off heap memory
MCache: Memory{
Inuse: toKB(memory.MCacheInuse),
Sys: toKB(memory.MCacheSys),
},
MSpan: Memory{
Inuse: toKB(memory.MSpanInuse),
Sys: toKB(memory.MSpanSys),
},
// Measure memory
Stack: Memory{
Inuse: toKB(memory.StackInuse),
Sys: toKB(memory.StackSys),
},
// Measure GC
GC: GC{
CPU: memory.GCCPUFraction,
Sys: toKB(memory.GCSys),
},
// Measure goroutines and threads and total memory
Go: Go{
Count: int32(runtime.NumGoroutine()),
Procs: int32(runtime.NumCPU()),
Sys: toKB(memory.Sys),
Alloc: toKB(memory.TotalAlloc),
},
}
}
// Converts the memory in bytes to KBs, otherwise it would overflow our int32
func toKB(v uint64) int32 {
return int32(v / 1024)
}

15
tomatox.conf Executable file
View File

@@ -0,0 +1,15 @@
{
"listen": ":1554",
"auth": false,
"cache_gop": true,
"profile": true,
"log": {
"level": "INFO",
"tofile": false,
"filename": "./logs/tomatox.log",
"maxsize": 20,
"maxdays": 7,
"maxbackups": 14,
"compress": false
}
}

54
utils/addr.go Executable file
View File

@@ -0,0 +1,54 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package utils
import (
"fmt"
"net"
"strings"
"github.com/emitter-io/address"
)
// GetIP 获取IP信息
func GetIP(addr net.Addr) string {
s := addr.String()
i := strings.LastIndex(s, ":")
return s[:i]
}
// IsLocalhostIP 判断是否为本机IP
func IsLocalhostIP(ip net.IP) bool {
for _, localhost := range loopbackBlocks {
if localhost.Contains(ip) {
return true
}
}
privs, err := address.GetPrivate()
if err != nil {
return false
}
for _, priv := range privs {
if priv.IP.Equal(ip) {
return true
}
}
return false
}
var loopbackBlocks = []*net.IPNet{
parseCIDR("0.0.0.0/8"), // RFC 1918 IPv4 loopback address
parseCIDR("127.0.0.0/8"), // RFC 1122 IPv4 loopback address
parseCIDR("::1/128"), // RFC 1884 IPv6 loopback address
}
func parseCIDR(s string) *net.IPNet {
_, block, err := net.ParseCIDR(s)
if err != nil {
panic(fmt.Sprintf("Bad CIDR %s: %s", s, err))
}
return block
}

40
utils/io.go Executable file
View File

@@ -0,0 +1,40 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package utils
import (
"bytes"
"encoding/json"
"os"
)
// EncodeJSONFile 编码 JSON 文件
func EncodeJSONFile(path string, obj interface{}) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm)
if err != nil {
return err
}
defer f.Close()
var formatted bytes.Buffer
body, err := json.Marshal(obj)
if err != nil {
return err
}
if err := json.Indent(&formatted, body, "", "\t"); err != nil {
return err
}
if _, err := f.Write(formatted.Bytes()); err != nil {
return err
}
if err := f.Sync(); err != nil {
return err
}
return nil
}

55
utils/multicast.go Executable file
View File

@@ -0,0 +1,55 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package utils
import (
"encoding/binary"
"net"
"sync"
)
// MulticastIPS 全局组播池
var (
Multicast = &multicast{
ipseed: minIP,
portseed: minPort,
}
minIP = binary.BigEndian.Uint32([]byte{235, 0, 0, 0})
maxIP = binary.BigEndian.Uint32([]byte{235, 255, 255, 255})
minPort uint16 = 16666
maxPort uint16 = 39999
)
// multicast 组播IP地址池
type multicast struct {
ipseed uint32
portseed uint16
l sync.Mutex
}
// NextIP 获取组播地址
func (p *multicast) NextIP() string {
p.l.Lock()
defer p.l.Unlock()
var ipbytes [4]byte
binary.BigEndian.PutUint32(ipbytes[:], p.ipseed)
ip := net.IP(ipbytes[:]).String()
p.ipseed++
if p.ipseed > maxIP {
p.ipseed = minIP
}
return ip
}
func (p *multicast) NextPort() int {
p.l.Lock()
defer p.l.Unlock()
port := p.portseed
p.portseed++
if p.portseed > maxPort {
p.portseed = minPort
}
return int(port)
}

102
utils/murmur/murmur.go Normal file
View File

@@ -0,0 +1,102 @@
/**********************************************************************************
* Copyright (c) 2009-2019 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
************************************************************************************/
// Package murmur 是murmur算法的实现方网站https://sites.google.com/site/murmurhash/
//
// MurmurHash算法高运算性能低碰撞率由Austin Appleby创建于2008年
// 现已应用到Hadoop、libstdc++、nginx、libmemcached等开源系统。
// 2011年Appleby被Google雇佣随后Google推出其变种的CityHash算法。
//
// 当key的长度大于10字节的时候MurmurHash的运算速度才快于DJB。
// “从计算速度上来看MurmurHash只适用于已知长度的、长度比较长的字符”。
package murmur
import (
"reflect"
"unsafe"
)
const (
c1_32 uint32 = 0xcc9e2d51
c2_32 uint32 = 0x1b873593
)
// OfString returns a murmur32 hash for the string
func OfString(value string) uint32 {
return Of(stringToBinary(value))
}
// Of returns a murmur32 hash for the data slice.
func Of(data []byte) uint32 {
// Seed is set to 37, same as C# version of emitter
var h1 uint32 = 37
nblocks := len(data) / 4
var p uintptr
if len(data) > 0 {
p = uintptr(unsafe.Pointer(&data[0]))
}
p1 := p + uintptr(4*nblocks)
for ; p < p1; p += 4 {
k1 := *(*uint32)(unsafe.Pointer(p))
k1 *= c1_32
k1 = (k1 << 15) | (k1 >> 17) // rotl32(k1, 15)
k1 *= c2_32
h1 ^= k1
h1 = (h1 << 13) | (h1 >> 19) // rotl32(h1, 13)
h1 = h1*5 + 0xe6546b64
}
tail := data[nblocks*4:]
var k1 uint32
switch len(tail) & 3 {
case 3:
k1 ^= uint32(tail[2]) << 16
fallthrough
case 2:
k1 ^= uint32(tail[1]) << 8
fallthrough
case 1:
k1 ^= uint32(tail[0])
k1 *= c1_32
k1 = (k1 << 15) | (k1 >> 17) // rotl32(k1, 15)
k1 *= c2_32
h1 ^= k1
}
h1 ^= uint32(len(data))
h1 ^= h1 >> 16
h1 *= 0x85ebca6b
h1 ^= h1 >> 13
h1 *= 0xc2b2ae35
h1 ^= h1 >> 16
return (h1 << 24) | (((h1 >> 8) << 16) & 0xFF0000) | (((h1 >> 16) << 8) & 0xFF00) | (h1 >> 24)
}
func stringToBinary(v string) (b []byte) {
strHeader := (*reflect.StringHeader)(unsafe.Pointer(&v))
byteHeader := (*reflect.SliceHeader)(unsafe.Pointer(&b))
byteHeader.Data = strHeader.Data
l := len(v)
byteHeader.Len = l
byteHeader.Cap = l
return
}

View File

@@ -0,0 +1,72 @@
/**********************************************************************************
* Copyright (c) 2009-2019 Misakai Ltd.
* This program is free software: you can redistribute it and/or modify it under the
* terms of the GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or(at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along
* with this program. If not, see<http://www.gnu.org/licenses/>.
************************************************************************************/
package murmur
import (
"testing"
"github.com/stretchr/testify/assert"
)
// BenchmarkOf-8 100000000 14.5 ns/op 0 B/op 0 allocs/op
func BenchmarkOf(b *testing.B) {
v := []byte("a/b/c/d/e/f/g/h/this/is/tomatox")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = Of(v)
}
}
// BenchmarkOfString-8 100000000 18.4 ns/op 0 B/op 0 allocs/op
func BenchmarkOfString(b *testing.B) {
v := "a/b/c/d/e/f/g/h/this/is/tomatox"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = OfString(v)
}
}
func TestMeHash(t *testing.T) {
h := OfString("me")
assert.Equal(t, uint32(2539734036), h)
}
func TestShareHash(t *testing.T) {
h := Of([]byte("$share"))
assert.Equal(t, uint32(1480642916), h)
}
func TestLinkHash(t *testing.T) {
h := Of([]byte("link"))
assert.Equal(t, uint32(2667034312), h)
}
func TestGetHash(t *testing.T) {
h := Of([]byte("+"))
if h != 1815237614 {
t.Errorf("Hash %d is not equal to %d", h, 1815237614)
}
}
func TestGetHash2(t *testing.T) {
h := Of([]byte("hello world"))
if h != 4008393376 {
t.Errorf("Hash %d is not equal to %d", h, 1815237614)
}
}

37
utils/path.go Executable file
View File

@@ -0,0 +1,37 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package utils
import (
"path"
"strings"
)
// CanonicalPath 获取合法的path
func CanonicalPath(p string) string {
p = strings.ToLower(strings.TrimSpace(p))
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
// Fast path for common case of p being the string we want:
if len(p) == len(np)+1 && strings.HasPrefix(p, np) {
np = p
} else {
np += "/"
}
}
return np
}

61
utils/scan/pair.go Normal file
View File

@@ -0,0 +1,61 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package scan
import (
"strings"
"unicode"
"unicode/utf8"
)
// 预定义Pair扫描对象
var (
// EqualPair 扫描 K=V这类形式的Pair字串
EqualPair = NewPair('=',
func(r rune) bool {
return unicode.IsSpace(r) || r == '"'
})
// ColonPair 扫描 K:V 这类形式的Pair字串
ColonPair = NewPair(':',
func(r rune) bool {
return unicode.IsSpace(r) || r == '"'
})
)
// Pair 从字串扫描Key Value 值
type Pair struct {
delim rune // Key Value 间的分割
delimLen int // 分割符长度
trimFunc func(r rune) bool // 返回前 Trim使用的函数
}
// NewPair 新建 Pair 扫描器
func NewPair(delim rune, trimFunc func(r rune) bool) Pair {
pair := Pair{
delim: delim,
trimFunc: trimFunc,
}
pair.delimLen = utf8.RuneLen(delim)
if trimFunc == nil {
pair.trimFunc = func(r rune) bool { return false }
}
return pair
}
// Scan 提取 K V
func (p Pair) Scan(s string) (key, value string, found bool) {
if p.delim == 0 {
return s, "", false
}
i := strings.IndexRune(s, p.delim)
if i < 0 {
return s, "", false
}
return strings.TrimFunc(s[:i], p.trimFunc),
strings.TrimFunc(s[i+p.delimLen:], p.trimFunc), true
}

111
utils/scan/pair_test.go Normal file
View File

@@ -0,0 +1,111 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package scan
import (
"testing"
"unicode"
)
func TestPair_Scan(t *testing.T) {
tests := []struct {
name string
args string
wantKey string
wantValue string
wantOk bool
}{
{
"不带引号",
"a=chj",
"a",
"chj",
true,
},
{
"带引号",
"a=\"chj\"",
"a",
"chj",
true,
},
{
"带空个",
" \ta= \"chj\"\t",
"a",
"chj",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotKey, gotValue, gotOk := EqualPair.Scan(tt.args)
if gotKey != tt.wantKey {
t.Errorf("Pair.Scan() gotKey = %v, want %v", gotKey, tt.wantKey)
}
if gotValue != tt.wantValue {
t.Errorf("Pair.Scan() gotValue = %v, want %v", gotValue, tt.wantValue)
}
if gotOk != tt.wantOk {
t.Errorf("Pair.Scan() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestPair_ScanMultiRune(t *testing.T) {
chinesePair := NewPair('是', unicode.IsSpace)
tests := []struct {
name string
args string
wantKey string
wantValue string
wantOk bool
}{
{
"不带空格",
"a是chj",
"a",
"chj",
true,
},
{
"不带空格",
"a是 chj\t",
"a",
"chj",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotKey, gotValue, gotOk := chinesePair.Scan(tt.args)
if gotKey != tt.wantKey {
t.Errorf("Pair.Scan() gotKey = %v, want %v", gotKey, tt.wantKey)
}
if gotValue != tt.wantValue {
t.Errorf("Pair.Scan() gotValue = %v, want %v", gotValue, tt.wantValue)
}
if gotOk != tt.wantOk {
t.Errorf("Pair.Scan() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func Benchmark_Pair_Scan(b *testing.B) {
s := `realm="Another Streaming Media"`
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
key, value, ok := EqualPair.Scan(s)
_ = key
_ = value
_ = ok
}
})
}

53
utils/scan/scanner.go Normal file
View File

@@ -0,0 +1,53 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package scan
import (
"strings"
"unicode"
"unicode/utf8"
)
// 扫描器
var (
// 逗号分割
Comma = NewScanner(',', unicode.IsSpace)
// 分号分割
Semicolon = NewScanner(';', unicode.IsSpace)
// 空格分割
Space = NewScanner(' ', nil)
// 行分割
Line = NewScanner('\n', unicode.IsSpace)
)
// Scanner 扫描器
type Scanner struct {
delim rune
delimLen int
trimFunc func(r rune) bool
}
// NewScanner 创建扫描器
func NewScanner(delim rune, trimFunc func(r rune) bool) Scanner {
scanner := Scanner{
delim: delim,
trimFunc: trimFunc,
}
scanner.delimLen = utf8.RuneLen(delim)
if trimFunc == nil {
scanner.trimFunc = func(r rune) bool { return false }
}
return scanner
}
// Scan 扫描字串
func (s Scanner) Scan(str string) (advance, token string, continueScan bool) {
i := strings.IndexRune(str, s.delim)
if i < 0 {
return "", strings.TrimFunc(str, s.trimFunc), false
}
return strings.TrimFunc(str[i+s.delimLen:], s.trimFunc), strings.TrimFunc(str[:i], s.trimFunc), true
}

View File

@@ -0,0 +1,57 @@
// Copyright (c) 2019,CAOHONGJU All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package scan
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestScanner_Scan(t *testing.T) {
raw := "cao,hong,ju,ok"
t.Run("Scan", func(t *testing.T) {
advance, token, ok := Comma.Scan(raw)
assert.True(t, ok)
assert.Equal(t, "cao", token)
assert.Equal(t, "hong,ju,ok", advance)
i := 0
for ok {
advance, token, ok = Comma.Scan(advance)
if ok {
i++
}
}
assert.Equal(t, 2, i)
assert.Equal(t, "ok", token)
})
}
func Benchmark_Scanner_Scan(b *testing.B) {
s := `realm="Another Streaming Media", nonce="60a76a995a0cb012f1707abc188f60cb"`
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
realm := ""
nonce := ""
ok := true
advance := s
token := ""
for ok {
advance, token, ok = Comma.Scan(advance)
k, v, _ := EqualPair.Scan(token)
switch k {
case "realm":
realm = v
case "nonce":
nonce = v
}
}
_ = realm
_ = nonce
}
})
}