mirror of
https://github.com/langhuihui/monibuca.git
synced 2025-09-26 18:51:13 +08:00
321 lines
8.6 KiB
Go
321 lines
8.6 KiB
Go
package util
|
||
|
||
import (
|
||
"crypto/sha256"
|
||
"crypto/subtle"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net"
|
||
"net/http"
|
||
"reflect"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
myip "github.com/husanpao/ip"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
func FetchValue[T any](t T) func() T {
|
||
return func() T {
|
||
return t
|
||
}
|
||
}
|
||
|
||
const (
|
||
APIErrorNone = 0
|
||
APIErrorDecode = iota + 4000
|
||
APIErrorQueryParse
|
||
APIErrorNoBody
|
||
)
|
||
|
||
const (
|
||
APIErrorNotFound = iota + 4040
|
||
APIErrorNoStream
|
||
APIErrorNoConfig
|
||
APIErrorNoPusher
|
||
APIErrorNoSubscriber
|
||
APIErrorNoSEI
|
||
)
|
||
|
||
const (
|
||
APIErrorInternal = iota + 5000
|
||
APIErrorJSONEncode
|
||
APIErrorPublish
|
||
APIErrorSave
|
||
APIErrorOpen
|
||
)
|
||
|
||
type APIError struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"msg"`
|
||
}
|
||
|
||
type APIResult struct {
|
||
Code int `json:"code"`
|
||
Data any `json:"data"`
|
||
Message string `json:"msg"`
|
||
}
|
||
|
||
func ReturnValue(v any, rw http.ResponseWriter, r *http.Request) {
|
||
ReturnFetchValue(FetchValue(v), rw, r)
|
||
}
|
||
|
||
func ReturnOK(rw http.ResponseWriter, r *http.Request) {
|
||
ReturnError(0, "ok", rw, r)
|
||
}
|
||
|
||
func ReturnError(code int, msg string, rw http.ResponseWriter, r *http.Request) {
|
||
query := r.URL.Query()
|
||
isJson := query.Get("format") == "json"
|
||
if isJson {
|
||
if err := json.NewEncoder(rw).Encode(APIError{code, msg}); err != nil {
|
||
json.NewEncoder(rw).Encode(APIError{
|
||
Code: APIErrorJSONEncode,
|
||
Message: err.Error(),
|
||
})
|
||
}
|
||
} else {
|
||
switch true {
|
||
case code == 0:
|
||
http.Error(rw, msg, http.StatusOK)
|
||
case code/10 == 404:
|
||
http.Error(rw, msg, http.StatusNotFound)
|
||
case code > 5000:
|
||
http.Error(rw, msg, http.StatusInternalServerError)
|
||
default:
|
||
http.Error(rw, msg, http.StatusBadRequest)
|
||
}
|
||
}
|
||
}
|
||
func ReturnFetchList[T any](fetch func() []T, rw http.ResponseWriter, r *http.Request) {
|
||
query := r.URL.Query()
|
||
isYaml := query.Get("format") == "yaml"
|
||
isJson := query.Get("format") == "json"
|
||
pageSize := query.Get("pageSize")
|
||
pageNum := query.Get("pageNum")
|
||
data := fetch()
|
||
var output any
|
||
output = data
|
||
if pageSize != "" && pageNum != "" {
|
||
pageSizeInt, _ := strconv.Atoi(pageSize)
|
||
pageNumInt, _ := strconv.Atoi(pageNum)
|
||
if pageSizeInt > 0 && pageNumInt > 0 {
|
||
start := (pageNumInt - 1) * pageSizeInt
|
||
end := pageNumInt * pageSizeInt
|
||
if start > len(data) {
|
||
start = len(data)
|
||
}
|
||
if end > len(data) {
|
||
end = len(data)
|
||
}
|
||
output = map[string]any{
|
||
"total": len(data),
|
||
"list": data[start:end],
|
||
"pageSize": pageSizeInt,
|
||
"pageNum": pageNumInt,
|
||
}
|
||
}
|
||
}
|
||
rw.Header().Set("Content-Type", Conditional(isYaml, "text/yaml", "application/json"))
|
||
if isYaml {
|
||
if err := yaml.NewEncoder(rw).Encode(output); err != nil {
|
||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||
}
|
||
} else if isJson {
|
||
if err := json.NewEncoder(rw).Encode(APIResult{
|
||
Code: 0,
|
||
Data: output,
|
||
Message: "ok",
|
||
}); err != nil {
|
||
json.NewEncoder(rw).Encode(APIError{
|
||
Code: APIErrorJSONEncode,
|
||
Message: err.Error(),
|
||
})
|
||
}
|
||
} else {
|
||
if err := json.NewEncoder(rw).Encode(output); err != nil {
|
||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||
}
|
||
}
|
||
}
|
||
func ReturnFetchValue[T any](fetch func() T, rw http.ResponseWriter, r *http.Request) {
|
||
query := r.URL.Query()
|
||
isYaml := query.Get("format") == "yaml"
|
||
isJson := query.Get("format") == "json"
|
||
tickDur, err := time.ParseDuration(query.Get("interval"))
|
||
if err != nil {
|
||
tickDur = time.Second
|
||
}
|
||
if r.Header.Get("Accept") == "text/event-stream" {
|
||
NewSSE(rw, r.Context(), func(sse *SSE) {
|
||
tick := time.NewTicker(tickDur)
|
||
defer tick.Stop()
|
||
writer := Conditional(isYaml, sse.WriteYAML, sse.WriteJSON)
|
||
err := writer(fetch())
|
||
for range tick.C {
|
||
if err = writer(fetch()); err != nil {
|
||
fmt.Println(err)
|
||
return
|
||
}
|
||
}
|
||
})
|
||
} else {
|
||
data := fetch()
|
||
rw.Header().Set("Content-Type", Conditional(isYaml, "text/yaml", "application/json"))
|
||
if isYaml {
|
||
if err := yaml.NewEncoder(rw).Encode(data); err != nil {
|
||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||
}
|
||
} else if isJson {
|
||
if err := json.NewEncoder(rw).Encode(APIResult{
|
||
Code: 0,
|
||
Data: data,
|
||
Message: "ok",
|
||
}); err != nil {
|
||
json.NewEncoder(rw).Encode(APIError{
|
||
Code: APIErrorJSONEncode,
|
||
Message: err.Error(),
|
||
})
|
||
}
|
||
} else {
|
||
t := reflect.TypeOf(data)
|
||
switch t.Kind() {
|
||
case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
|
||
rw.Header().Set("Content-Type", "text/plain")
|
||
fmt.Fprint(rw, data)
|
||
default:
|
||
if err := json.NewEncoder(rw).Encode(data); err != nil {
|
||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) {
|
||
addr, err := net.ResolveUDPAddr("udp", address)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
conn, err := net.ListenUDP("udp", addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err = conn.SetReadBuffer(networkBuffer); err != nil {
|
||
return nil, err
|
||
}
|
||
if err = conn.SetWriteBuffer(networkBuffer); err != nil {
|
||
return nil, err
|
||
}
|
||
return conn, err
|
||
}
|
||
|
||
// CORS 加入跨域策略头包含CORP
|
||
func CORS(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
header := w.Header()
|
||
header.Set("Access-Control-Allow-Credentials", "true")
|
||
header.Set("Cross-Origin-Resource-Policy", "cross-origin")
|
||
header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token,Authorization")
|
||
header.Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
|
||
header.Set("Access-Control-Allow-Private-Network", "true")
|
||
origin := r.Header["Origin"]
|
||
if len(origin) == 0 {
|
||
header.Set("Access-Control-Allow-Origin", "*")
|
||
} else {
|
||
header.Set("Access-Control-Allow-Origin", origin[0])
|
||
}
|
||
if next != nil && r.Method != "OPTIONS" {
|
||
next.ServeHTTP(w, r)
|
||
}
|
||
})
|
||
}
|
||
|
||
func BasicAuth(u, p string, next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// Extract the username and password from the request
|
||
// Authorization header. If no Authentication header is present
|
||
// or the header value is invalid, then the 'ok' return value
|
||
// will be false.
|
||
username, password, ok := r.BasicAuth()
|
||
if ok {
|
||
// Calculate SHA-256 hashes for the provided and expected
|
||
// usernames and passwords.
|
||
usernameHash := sha256.Sum256([]byte(username))
|
||
passwordHash := sha256.Sum256([]byte(password))
|
||
expectedUsernameHash := sha256.Sum256([]byte(u))
|
||
expectedPasswordHash := sha256.Sum256([]byte(p))
|
||
|
||
// 使用 subtle.ConstantTimeCompare() 进行校验
|
||
// the provided username and password hashes equal the
|
||
// expected username and password hashes. ConstantTimeCompare
|
||
// 如果值相等,则返回1,否则返回0。
|
||
// Importantly, we should to do the work to evaluate both the
|
||
// username and password before checking the return values to
|
||
// 避免泄露信息。
|
||
usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1)
|
||
passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1)
|
||
|
||
// If the username and password are correct, then call
|
||
// the next handler in the chain. Make sure to return
|
||
// afterwards, so that none of the code below is run.
|
||
if usernameMatch && passwordMatch {
|
||
if next != nil {
|
||
next.ServeHTTP(w, r)
|
||
}
|
||
return
|
||
}
|
||
}
|
||
|
||
// If the Authentication header is not present, is invalid, or the
|
||
// username or password is wrong, then set a WWW-Authenticate
|
||
// header to inform the client that we expect them to use basic
|
||
// authentication and send a 401 Unauthorized response.
|
||
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||
})
|
||
}
|
||
|
||
var ipReg = regexp.MustCompile(`^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`)
|
||
var privateIPReg = regexp.MustCompile(`^((10|192\.168|172\.(1[6-9]|2[0-9]|3[0-1]))\.){3}(10|192\.168|172\.(1[6-9]|2[0-9]|3[0-1]))$`)
|
||
|
||
var routes map[string]string
|
||
var PublicIP string
|
||
|
||
func IsPrivateIP(ip string) bool {
|
||
return privateIPReg.MatchString(ip)
|
||
}
|
||
|
||
func initRoutes() {
|
||
PublicIP = myip.ExternalIP()
|
||
for k, v := range myip.LocalAndInternalIPs() {
|
||
routes[k] = v
|
||
if lastdot := strings.LastIndex(k, "."); lastdot >= 0 {
|
||
routes[k[0:lastdot]] = k
|
||
}
|
||
}
|
||
initRoutesWait.Done()
|
||
}
|
||
|
||
var initRoutesWait sync.WaitGroup
|
||
|
||
func init() {
|
||
routes = make(map[string]string)
|
||
initRoutesWait.Add(1)
|
||
go initRoutes()
|
||
}
|
||
|
||
func GetPublicIP(ip string) string {
|
||
initRoutesWait.Wait()
|
||
if ip == "" {
|
||
return PublicIP
|
||
}
|
||
if publicIP, ok := routes[ip]; ok {
|
||
return publicIP
|
||
}
|
||
return ip
|
||
}
|