diff --git a/engine/engine.go b/engine/engine.go index afd6d91..84ec49a 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -2,6 +2,7 @@ package engine import ( "errors" + "sync" "github.com/xjasonlyu/tun2socks/v2/component/dialer" "github.com/xjasonlyu/tun2socks/v2/core" @@ -16,6 +17,8 @@ import ( ) var ( + _engineMu sync.Mutex + // _defaultKey holds the default key for the engine. _defaultKey *Key @@ -45,10 +48,13 @@ func Stop() { // Insert loads *Key to the default engine. func Insert(k *Key) { + _engineMu.Lock() _defaultKey = k + _engineMu.Unlock() } func start() error { + _engineMu.Lock() if _defaultKey == nil { return errors.New("empty key") } @@ -62,10 +68,12 @@ func start() error { return err } } + _engineMu.Unlock() return nil } func stop() (err error) { + _engineMu.Lock() if _defaultDevice != nil { err = _defaultDevice.Close() } @@ -73,6 +81,7 @@ func stop() (err error) { _defaultStack.Close() _defaultStack.Wait() } + _engineMu.Unlock() return err } @@ -107,6 +116,17 @@ func restAPI(k *Key) error { } host, token := u.Host, u.User.String() + restapi.SetStatsFunc(func() tcpip.Stats { + _engineMu.Lock() + defer _engineMu.Unlock() + + // default stack is not initialized. + if _defaultStack == nil { + return tcpip.Stats{} + } + return _defaultStack.Stats() + }) + go func() { if err := restapi.Start(host, token); err != nil { log.Warnf("[RESTAPI] failed to start: %v", err) diff --git a/restapi/errors.go b/restapi/errors.go index 8f65081..1dfaca6 100644 --- a/restapi/errors.go +++ b/restapi/errors.go @@ -1,8 +1,9 @@ package restapi var ( - ErrUnauthorized = newError("Unauthorized") - ErrBadRequest = newError("Body invalid") + ErrBadRequest = newError("Body invalid") + ErrUnauthorized = newError("Unauthorized") + ErrUninitialized = newError("Uninitialized") ) var _ error = (*HTTPError)(nil) diff --git a/restapi/netstats.go b/restapi/netstats.go index aad3dc1..d0dc8c2 100644 --- a/restapi/netstats.go +++ b/restapi/netstats.go @@ -1,4 +1,84 @@ package restapi -// TODO: Network statistic support. -func init() {} +import ( + "bytes" + "encoding/json" + "net/http" + "reflect" + "time" + + "github.com/go-chi/render" + "github.com/gorilla/websocket" + "gvisor.dev/gvisor/pkg/tcpip" +) + +var _statsFunc func() tcpip.Stats + +func SetStatsFunc(s func() tcpip.Stats) { + _statsFunc = s +} + +func getNetStats(w http.ResponseWriter, r *http.Request) { + if _statsFunc == nil { + render.Status(r, http.StatusInternalServerError) + render.JSON(w, r, ErrUninitialized) + return + } + + snapshot := func() any { + s := _statsFunc() + return dump(reflect.ValueOf(&s).Elem()) + } + + if !websocket.IsWebSocketUpgrade(r) { + render.JSON(w, r, snapshot()) + return + } + + conn, err := _upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + tick := time.NewTicker(time.Second) + defer tick.Stop() + + buf := &bytes.Buffer{} + for range tick.C { + buf.Reset() + + if err = json.NewEncoder(buf).Encode(snapshot()); err != nil { + break + } + + if err = conn.WriteMessage(websocket.TextMessage, buf.Bytes()); err != nil { + break + } + } +} + +func dump(value reflect.Value) map[string]any { + numField := value.NumField() + structure := make(map[string]any, numField) + + for i := 0; i < numField; i++ { + field := value.Type().Field(i) + value := value.Field(i) + + switch v := value.Addr().Interface().(type) { + case **tcpip.StatCounter: + structure[field.Name] = (*v).Value() + case **tcpip.IntegralStatCounterMap: + counterMap := make(map[uint64]uint64) + for _, k := range (*v).Keys() { + if counter, ok := (*v).Get(k); ok { + counterMap[k] = counter.Value() + } + } + structure[field.Name] = counterMap + default: + structure[field.Name] = dump(value) + } + } + return structure +} diff --git a/restapi/server.go b/restapi/server.go index a8e1180..57d26e3 100644 --- a/restapi/server.go +++ b/restapi/server.go @@ -43,6 +43,7 @@ func Start(addr, token string) error { r.Get("/logs", getLogs) r.Get("/traffic", traffic) r.Get("/version", version) + r.Get("/netstats", getNetStats) r.Mount("/connections", connectionRouter()) })