diff --git a/config/types.go b/config/types.go index b463781..01bbc59 100755 --- a/config/types.go +++ b/config/types.go @@ -2,7 +2,9 @@ package config import ( "context" + "fmt" "net/http" + "regexp" "strings" "time" @@ -75,16 +77,59 @@ func (c *Subscribe) GetSubscribeConfig() *Subscribe { } type Pull struct { - RePull int // 断开后自动重拉,0 表示不自动重拉,-1 表示无限重拉,高于0 的数代表最大重拉次数 - PullOnStart map[string]string // 启动时拉流的列表 - PullOnSub map[string]string // 订阅时自动拉流的列表 - Proxy string // 代理地址 + RePull int // 断开后自动重拉,0 表示不自动重拉,-1 表示无限重拉,高于0 的数代表最大重拉次数 + EnableRegexp bool // 是否启用正则表达式 + PullOnStart map[string]string // 启动时拉流的列表 + PullOnSub map[string]string // 订阅时自动拉流的列表 + Proxy string // 代理地址 } func (p *Pull) GetPullConfig() *Pull { return p } +func (p *Pull) CheckPullOnStart(streamPath string) string { + if p.PullOnStart == nil { + return "" + } + url, ok := p.PullOnStart[streamPath] + if !ok && p.EnableRegexp { + for k, url := range p.PullOnStart { + if r, err := regexp.Compile(k); err != nil { + if group := r.FindStringSubmatch(streamPath); group != nil { + for i, value := range group { + url = strings.Replace(url, fmt.Sprintf("$%d", i), value, -1) + } + return url + } + } + return "" + } + } + return url +} + +func (p *Pull) CheckPullOnSub(streamPath string) string { + if p.PullOnSub == nil { + return "" + } + url, ok := p.PullOnSub[streamPath] + if !ok && p.EnableRegexp { + for k, url := range p.PullOnSub { + if r, err := regexp.Compile(k); err == nil { + if group := r.FindStringSubmatch(streamPath); group != nil { + for i, value := range group { + url = strings.Replace(url, fmt.Sprintf("$%d", i), value, -1) + } + return url + } + } + return "" + } + } + return url +} + func (p *Pull) AddPullOnStart(streamPath string, url string) { if p.PullOnStart == nil { p.PullOnStart = make(map[string]string) diff --git a/util/socket.go b/util/socket.go index 5552fb5..79461fc 100644 --- a/util/socket.go +++ b/util/socket.go @@ -6,6 +6,7 @@ import ( "encoding/json" "net" "net/http" + "strconv" "time" "gopkg.in/yaml.v3" @@ -83,7 +84,57 @@ func ReturnError(code int, msg string, rw http.ResponseWriter, r *http.Request) } } } - +func ReturnFetchList[T any](fetch func() []T, rw http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + isYaml := query.Get("format") == "yaml" + isJson := query.Get("format") == "json" + pageSize := query.Get("pageSize") + pageNum := query.Get("pageNum") + data := fetch() + var output any + output = data + if pageSize != "" && pageNum != "" { + pageSizeInt, _ := strconv.Atoi(pageSize) + pageNumInt, _ := strconv.Atoi(pageNum) + if pageSizeInt > 0 && pageNumInt > 0 { + start := (pageNumInt - 1) * pageSizeInt + end := pageNumInt * pageSizeInt + if start > len(data) { + start = len(data) + } + if end > len(data) { + end = len(data) + } + output = map[string]any{ + "total": len(data), + "list": data[start:end], + "pageSize": pageSizeInt, + "pageNum": pageNumInt, + } + } + } + rw.Header().Set("Content-Type", Conditoinal(isYaml, "text/yaml", "application/json")) + if isYaml { + if err := yaml.NewEncoder(rw).Encode(output); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } + } else if isJson { + if err := json.NewEncoder(rw).Encode(APIResult{ + Code: 0, + Data: output, + Message: "ok", + }); err != nil { + json.NewEncoder(rw).Encode(APIError{ + Code: APIErrorJSONEncode, + Message: err.Error(), + }) + } + } else { + if err := json.NewEncoder(rw).Encode(output); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } + } +} func ReturnFetchValue[T any](fetch func() T, rw http.ResponseWriter, r *http.Request) { query := r.URL.Query() isYaml := query.Get("format") == "yaml"