diff --git a/cmd/main.go b/cmd/main.go index 7abe26c..6d6d7c5 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,6 +1,5 @@ package main -import "C" import ( "flag" "fmt" diff --git a/component/stats/session/session.go b/component/stats/session/session.go index 02d7b6d..4320a8d 100644 --- a/component/stats/session/session.go +++ b/component/stats/session/session.go @@ -10,7 +10,6 @@ import ( "sync/atomic" "github.com/xjasonlyu/tun2socks/common/queue" - "github.com/xjasonlyu/tun2socks/component/stats" C "github.com/xjasonlyu/tun2socks/constant" "github.com/xjasonlyu/tun2socks/log" ) @@ -28,7 +27,7 @@ type simpleSessionStater struct { completedSessionQueue *queue.Queue } -func NewSimpleSessionStater() stats.SessionStater { +func NewSimpleSessionStater() *simpleSessionStater { return &simpleSessionStater{ completedSessionQueue: queue.New(maxCompletedSessions), } @@ -36,21 +35,21 @@ func NewSimpleSessionStater() stats.SessionStater { func (s *simpleSessionStater) sessionStatsHandler(resp http.ResponseWriter, req *http.Request) { // Slice of active sessions - var activeSessions []*stats.Session + var activeSessions []*C.Session s.activeSessionMap.Range(func(key, value interface{}) bool { - activeSessions = append(activeSessions, value.(*stats.Session)) + activeSessions = append(activeSessions, value.(*C.Session)) return true }) // Slice of completed sessions - var completedSessions []*stats.Session + var completedSessions []*C.Session for _, item := range s.completedSessionQueue.Copy() { - if sess, ok := item.(*stats.Session); ok { + if sess, ok := item.(*C.Session); ok { completedSessions = append(completedSessions, sess) } } - tablePrint := func(w io.Writer, sessions []*stats.Session) { + tablePrint := func(w io.Writer, sessions []*C.Session) { // Sort by session start time. sort.Slice(sessions, func(i, j int) bool { return sessions[i].SessionStart.Sub(sessions[j].SessionStart) < 0 @@ -114,13 +113,13 @@ func (s *simpleSessionStater) Stop() error { return s.server.Close() } -func (s *simpleSessionStater) AddSession(key interface{}, session *stats.Session) { +func (s *simpleSessionStater) AddSession(key interface{}, session *C.Session) { s.activeSessionMap.Store(key, session) } -func (s *simpleSessionStater) GetSession(key interface{}) *stats.Session { +func (s *simpleSessionStater) GetSession(key interface{}) *C.Session { if sess, ok := s.activeSessionMap.Load(key); ok { - return sess.(*stats.Session) + return sess.(*C.Session) } return nil } diff --git a/component/stats/stats.go b/component/stats/stats.go index a56ffa1..095798d 100644 --- a/component/stats/stats.go +++ b/component/stats/stats.go @@ -1,102 +1,13 @@ package stats import ( - "net" - "sync" - "sync/atomic" - "time" + C "github.com/xjasonlyu/tun2socks/constant" ) type SessionStater interface { Start() error Stop() error - AddSession(key interface{}, session *Session) - GetSession(key interface{}) *Session + AddSession(key interface{}, session *C.Session) + GetSession(key interface{}) *C.Session RemoveSession(key interface{}) } - -type Session struct { - Process string - Network string - DialerAddr string - ClientAddr string - TargetAddr string - UploadBytes int64 - DownloadBytes int64 - SessionStart time.Time - SessionClose time.Time -} - -// Track SessionConn -type SessionConn struct { - net.Conn - once sync.Once - session *Session -} - -func NewSessionConn(conn net.Conn, session *Session) net.Conn { - return &SessionConn{ - Conn: conn, - session: session, - } -} - -func (c *SessionConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if n > 0 { - atomic.AddInt64(&c.session.DownloadBytes, int64(n)) - } - return -} - -func (c *SessionConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - if n > 0 { - atomic.AddInt64(&c.session.UploadBytes, int64(n)) - } - return -} - -func (c *SessionConn) Close() error { - c.once.Do(func() { - c.session.SessionClose = time.Now() - }) - return c.Conn.Close() -} - -// Track SessionPacketConn -type SessionPacketConn struct { - net.PacketConn - once sync.Once - session *Session -} - -func NewSessionPacketConn(conn net.PacketConn, session *Session) net.PacketConn { - return &SessionPacketConn{ - PacketConn: conn, - session: session, - } -} - -func (c *SessionPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, addr, err = c.PacketConn.ReadFrom(b) - if n > 0 { - atomic.AddInt64(&c.session.DownloadBytes, int64(n)) - } - return -} - -func (c *SessionPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - n, err = c.PacketConn.WriteTo(b, addr) - if n > 0 { - atomic.AddInt64(&c.session.UploadBytes, int64(n)) - } - return -} - -func (c *SessionPacketConn) Close() error { - c.once.Do(func() { - c.session.SessionClose = time.Now() - }) - return c.PacketConn.Close() -} diff --git a/constant/session.go b/constant/session.go new file mode 100644 index 0000000..1562f3a --- /dev/null +++ b/constant/session.go @@ -0,0 +1,94 @@ +package constant + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +type Session struct { + Process string + Network string + DialerAddr string + ClientAddr string + TargetAddr string + UploadBytes int64 + DownloadBytes int64 + SessionStart time.Time + SessionClose time.Time +} + +// Track SessionConn +type SessionConn struct { + net.Conn + once sync.Once + session *Session +} + +func NewSessionConn(conn net.Conn, session *Session) net.Conn { + return &SessionConn{ + Conn: conn, + session: session, + } +} + +func (c *SessionConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if n > 0 { + atomic.AddInt64(&c.session.DownloadBytes, int64(n)) + } + return +} + +func (c *SessionConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if n > 0 { + atomic.AddInt64(&c.session.UploadBytes, int64(n)) + } + return +} + +func (c *SessionConn) Close() error { + c.once.Do(func() { + c.session.SessionClose = time.Now() + }) + return c.Conn.Close() +} + +// Track SessionPacketConn +type SessionPacketConn struct { + net.PacketConn + once sync.Once + session *Session +} + +func NewSessionPacketConn(conn net.PacketConn, session *Session) net.PacketConn { + return &SessionPacketConn{ + PacketConn: conn, + session: session, + } +} + +func (c *SessionPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(b) + if n > 0 { + atomic.AddInt64(&c.session.DownloadBytes, int64(n)) + } + return +} + +func (c *SessionPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + n, err = c.PacketConn.WriteTo(b, addr) + if n > 0 { + atomic.AddInt64(&c.session.UploadBytes, int64(n)) + } + return +} + +func (c *SessionPacketConn) Close() error { + c.once.Do(func() { + c.session.SessionClose = time.Now() + }) + return c.PacketConn.Close() +} diff --git a/proxy/tcp.go b/proxy/tcp.go index d818bd7..42400d2 100644 --- a/proxy/tcp.go +++ b/proxy/tcp.go @@ -11,6 +11,7 @@ import ( "github.com/xjasonlyu/tun2socks/common/pool" "github.com/xjasonlyu/tun2socks/component/dns" "github.com/xjasonlyu/tun2socks/component/stats" + C "github.com/xjasonlyu/tun2socks/constant" "github.com/xjasonlyu/tun2socks/core" "github.com/xjasonlyu/tun2socks/log" ) @@ -104,7 +105,7 @@ func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error { // Get name of the process. var process = lsof.GetProcessName(localConn.LocalAddr()) if h.sessionStater != nil { - sess := &stats.Session{ + sess := &C.Session{ Process: process, Network: localConn.LocalAddr().Network(), DialerAddr: remoteConn.LocalAddr().String(), @@ -116,7 +117,7 @@ func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error { } h.sessionStater.AddSession(localConn, sess) - remoteConn = stats.NewSessionConn(remoteConn, sess) + remoteConn = C.NewSessionConn(remoteConn, sess) } // Set keepalive diff --git a/proxy/udp.go b/proxy/udp.go index ffe387e..3530f92 100644 --- a/proxy/udp.go +++ b/proxy/udp.go @@ -11,6 +11,7 @@ import ( "github.com/xjasonlyu/tun2socks/common/pool" "github.com/xjasonlyu/tun2socks/component/dns" "github.com/xjasonlyu/tun2socks/component/stats" + C "github.com/xjasonlyu/tun2socks/constant" "github.com/xjasonlyu/tun2socks/core" "github.com/xjasonlyu/tun2socks/log" ) @@ -82,7 +83,7 @@ func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { // Get name of the process. var process = lsof.GetProcessName(conn.LocalAddr()) if h.sessionStater != nil { - sess := &stats.Session{ + sess := &C.Session{ Process: process, Network: conn.LocalAddr().Network(), DialerAddr: remoteConn.LocalAddr().String(), @@ -94,7 +95,7 @@ func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { } h.sessionStater.AddSession(conn, sess) - remoteConn = stats.NewSessionPacketConn(remoteConn, sess) + remoteConn = C.NewSessionPacketConn(remoteConn, sess) } h.remoteAddrMap.Store(conn, remoteAddr)