Files
monibuca/pkg/util/net.go
2025-09-16 14:03:22 +08:00

321 lines
8.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package util
import (
"crypto/sha256"
"crypto/subtle"
"encoding/json"
"fmt"
"net"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"time"
myip "github.com/husanpao/ip"
"gopkg.in/yaml.v3"
)
func FetchValue[T any](t T) func() T {
return func() T {
return t
}
}
const (
APIErrorNone = 0
APIErrorDecode = iota + 4000
APIErrorQueryParse
APIErrorNoBody
)
const (
APIErrorNotFound = iota + 4040
APIErrorNoStream
APIErrorNoConfig
APIErrorNoPusher
APIErrorNoSubscriber
APIErrorNoSEI
)
const (
APIErrorInternal = iota + 5000
APIErrorJSONEncode
APIErrorPublish
APIErrorSave
APIErrorOpen
)
type APIError struct {
Code int `json:"code"`
Message string `json:"msg"`
}
type APIResult struct {
Code int `json:"code"`
Data any `json:"data"`
Message string `json:"msg"`
}
func ReturnValue(v any, rw http.ResponseWriter, r *http.Request) {
ReturnFetchValue(FetchValue(v), rw, r)
}
func ReturnOK(rw http.ResponseWriter, r *http.Request) {
ReturnError(0, "ok", rw, r)
}
func ReturnError(code int, msg string, rw http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
isJson := query.Get("format") == "json"
if isJson {
if err := json.NewEncoder(rw).Encode(APIError{code, msg}); err != nil {
json.NewEncoder(rw).Encode(APIError{
Code: APIErrorJSONEncode,
Message: err.Error(),
})
}
} else {
switch true {
case code == 0:
http.Error(rw, msg, http.StatusOK)
case code/10 == 404:
http.Error(rw, msg, http.StatusNotFound)
case code > 5000:
http.Error(rw, msg, http.StatusInternalServerError)
default:
http.Error(rw, msg, http.StatusBadRequest)
}
}
}
func ReturnFetchList[T any](fetch func() []T, rw http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
isYaml := query.Get("format") == "yaml"
isJson := query.Get("format") == "json"
pageSize := query.Get("pageSize")
pageNum := query.Get("pageNum")
data := fetch()
var output any
output = data
if pageSize != "" && pageNum != "" {
pageSizeInt, _ := strconv.Atoi(pageSize)
pageNumInt, _ := strconv.Atoi(pageNum)
if pageSizeInt > 0 && pageNumInt > 0 {
start := (pageNumInt - 1) * pageSizeInt
end := pageNumInt * pageSizeInt
if start > len(data) {
start = len(data)
}
if end > len(data) {
end = len(data)
}
output = map[string]any{
"total": len(data),
"list": data[start:end],
"pageSize": pageSizeInt,
"pageNum": pageNumInt,
}
}
}
rw.Header().Set("Content-Type", Conditional(isYaml, "text/yaml", "application/json"))
if isYaml {
if err := yaml.NewEncoder(rw).Encode(output); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
} else if isJson {
if err := json.NewEncoder(rw).Encode(APIResult{
Code: 0,
Data: output,
Message: "ok",
}); err != nil {
json.NewEncoder(rw).Encode(APIError{
Code: APIErrorJSONEncode,
Message: err.Error(),
})
}
} else {
if err := json.NewEncoder(rw).Encode(output); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
}
func ReturnFetchValue[T any](fetch func() T, rw http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
isYaml := query.Get("format") == "yaml"
isJson := query.Get("format") == "json"
tickDur, err := time.ParseDuration(query.Get("interval"))
if err != nil {
tickDur = time.Second
}
if r.Header.Get("Accept") == "text/event-stream" {
NewSSE(rw, r.Context(), func(sse *SSE) {
tick := time.NewTicker(tickDur)
defer tick.Stop()
writer := Conditional(isYaml, sse.WriteYAML, sse.WriteJSON)
err := writer(fetch())
for range tick.C {
if err = writer(fetch()); err != nil {
fmt.Println(err)
return
}
}
})
} else {
data := fetch()
rw.Header().Set("Content-Type", Conditional(isYaml, "text/yaml", "application/json"))
if isYaml {
if err := yaml.NewEncoder(rw).Encode(data); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
} else if isJson {
if err := json.NewEncoder(rw).Encode(APIResult{
Code: 0,
Data: data,
Message: "ok",
}); err != nil {
json.NewEncoder(rw).Encode(APIError{
Code: APIErrorJSONEncode,
Message: err.Error(),
})
}
} else {
t := reflect.TypeOf(data)
switch t.Kind() {
case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
rw.Header().Set("Content-Type", "text/plain")
fmt.Fprint(rw, data)
default:
if err := json.NewEncoder(rw).Encode(data); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
}
}
}
func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) {
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
if err = conn.SetReadBuffer(networkBuffer); err != nil {
return nil, err
}
if err = conn.SetWriteBuffer(networkBuffer); err != nil {
return nil, err
}
return conn, err
}
// CORS 加入跨域策略头包含CORP
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := w.Header()
header.Set("Access-Control-Allow-Credentials", "true")
header.Set("Cross-Origin-Resource-Policy", "cross-origin")
header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token,Authorization")
header.Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
header.Set("Access-Control-Allow-Private-Network", "true")
origin := r.Header["Origin"]
if len(origin) == 0 {
header.Set("Access-Control-Allow-Origin", "*")
} else {
header.Set("Access-Control-Allow-Origin", origin[0])
}
if next != nil && r.Method != "OPTIONS" {
next.ServeHTTP(w, r)
}
})
}
func BasicAuth(u, p string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract the username and password from the request
// Authorization header. If no Authentication header is present
// or the header value is invalid, then the 'ok' return value
// will be false.
username, password, ok := r.BasicAuth()
if ok {
// Calculate SHA-256 hashes for the provided and expected
// usernames and passwords.
usernameHash := sha256.Sum256([]byte(username))
passwordHash := sha256.Sum256([]byte(password))
expectedUsernameHash := sha256.Sum256([]byte(u))
expectedPasswordHash := sha256.Sum256([]byte(p))
// 使用 subtle.ConstantTimeCompare() 进行校验
// the provided username and password hashes equal the
// expected username and password hashes. ConstantTimeCompare
// 如果值相等则返回1否则返回0。
// Importantly, we should to do the work to evaluate both the
// username and password before checking the return values to
// 避免泄露信息。
usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1)
passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1)
// If the username and password are correct, then call
// the next handler in the chain. Make sure to return
// afterwards, so that none of the code below is run.
if usernameMatch && passwordMatch {
if next != nil {
next.ServeHTTP(w, r)
}
return
}
}
// If the Authentication header is not present, is invalid, or the
// username or password is wrong, then set a WWW-Authenticate
// header to inform the client that we expect them to use basic
// authentication and send a 401 Unauthorized response.
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
}
var ipReg = regexp.MustCompile(`^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`)
var privateIPReg = regexp.MustCompile(`^((10|192\.168|172\.(1[6-9]|2[0-9]|3[0-1]))\.){3}(10|192\.168|172\.(1[6-9]|2[0-9]|3[0-1]))$`)
var routes map[string]string
var PublicIP string
func IsPrivateIP(ip string) bool {
return privateIPReg.MatchString(ip)
}
func initRoutes() {
PublicIP = myip.ExternalIP()
for k, v := range myip.LocalAndInternalIPs() {
routes[k] = v
if lastdot := strings.LastIndex(k, "."); lastdot >= 0 {
routes[k[0:lastdot]] = k
}
}
initRoutesWait.Done()
}
var initRoutesWait sync.WaitGroup
func init() {
routes = make(map[string]string)
initRoutesWait.Add(1)
go initRoutes()
}
func GetPublicIP(ip string) string {
initRoutesWait.Wait()
if ip == "" {
return PublicIP
}
if publicIP, ok := routes[ip]; ok {
return publicIP
}
return ip
}