From 8e585855fef45c8e006e27654c48c3baadeb185d Mon Sep 17 00:00:00 2001 From: spiritlhl <103393591+spiritLHLS@users.noreply.github.com> Date: Tue, 22 Jul 2025 14:46:41 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=9B=9E=E9=80=80=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/ci.yaml | 2 +- cmd/main.go | 3 +- model/model.go | 7 +- stuncheck/checktype.go | 4 +- stuncheck/stuncheck.go | 162 ++++++++++++++------------------------ 5 files changed, 66 insertions(+), 112 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a1197da..80921f8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -24,7 +24,7 @@ jobs: run: | git config --global user.name 'github-actions' git config --global user.email 'github-actions@github.com' - TAG="v0.0.6-$(date +'%Y%m%d%H%M%S')" + TAG="v0.0.5-$(date +'%Y%m%d%H%M%S')" git tag $TAG git push origin $TAG echo "TAG=$TAG" >> $GITHUB_ENV diff --git a/cmd/main.go b/cmd/main.go index fb31b86..1475544 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -83,7 +83,6 @@ func main() { gostunFlag.StringVar(&model.AddrStr, "server", "stun.voipgate.com:3478", "Specify STUN server address") gostunFlag.BoolVar(&model.EnableLoger, "e", true, "Enable logging functionality") gostunFlag.StringVar(&model.IPVersion, "type", "ipv4", "Specify ip test version: ipv4, ipv6 or both") - gostunFlag.StringVar(&model.TransmissionProtocol, "protocol", "udp", "Specify transmission protocol: udp, tcp, or tls") gostunFlag.Parse(os.Args[1:]) if help { fmt.Printf("Usage: %s [options]\n", os.Args[0]) @@ -164,4 +163,4 @@ func main() { model.IPVersion = originalIPVersion res := stuncheck.CheckType() fmt.Printf("NAT Type: %s\n", res) -} +} \ No newline at end of file diff --git a/model/model.go b/model/model.go index 9c3f35b..8cad0ec 100644 --- a/model/model.go +++ b/model/model.go @@ -2,18 +2,17 @@ package model import "github.com/pion/logging" -const GoStunVersion = "v0.0.6" +const GoStunVersion = "v0.0.5" var ( AddrStr = "stun.voipgate.com:3478" - Timeout = 5 + Timeout = 3 Verbose = 0 Log logging.LeveledLogger NatMappingBehavior string NatFilteringBehavior string EnableLoger = true IPVersion = "ipv4" - TransmissionProtocol = "udp" ) func GetDefaultServers(IPVersion string) []string { @@ -57,4 +56,4 @@ func GetDefaultServers(IPVersion string) []string { "stun.f.haeder.net:3478", } } -} +} \ No newline at end of file diff --git a/stuncheck/checktype.go b/stuncheck/checktype.go index 964f1dd..62fbd08 100644 --- a/stuncheck/checktype.go +++ b/stuncheck/checktype.go @@ -6,6 +6,8 @@ import ( "github.com/oneclickvirt/gostun/model" ) +// CheckType +// Summarize the NAT type func CheckType() string { var result string if model.NatMappingBehavior != "" && model.NatFilteringBehavior != "" { @@ -28,4 +30,4 @@ func CheckType() string { result = "Inconclusive" } return result -} +} \ No newline at end of file diff --git a/stuncheck/stuncheck.go b/stuncheck/stuncheck.go index 0d0ee6b..996a8c7 100644 --- a/stuncheck/stuncheck.go +++ b/stuncheck/stuncheck.go @@ -1,7 +1,6 @@ package stuncheck import ( - "crypto/tls" "errors" "net" "time" @@ -10,13 +9,14 @@ import ( "github.com/pion/stun/v2" ) +// From https://github.com/pion/stun/blob/master/cmd/stun-nat-behaviour/main.go + type stunServerConn struct { - conn net.Conn + conn net.PacketConn LocalAddr net.Addr RemoteAddr *net.UDPAddr OtherAddr *net.UDPAddr messageChan chan *stun.Message - protocol string } func (c *stunServerConn) Close() error { @@ -45,30 +45,15 @@ func isIPv6Address(addr string) bool { func getNetworkType(addrStr string) string { switch model.IPVersion { case "ipv6": - if model.TransmissionProtocol == "tcp" { - return "tcp6" - } return "udp6" case "ipv4": - if model.TransmissionProtocol == "tcp" { - return "tcp4" - } return "udp4" case "both": if isIPv6Address(addrStr) { - if model.TransmissionProtocol == "tcp" { - return "tcp6" - } return "udp6" } - if model.TransmissionProtocol == "tcp" { - return "tcp4" - } return "udp4" } - if model.TransmissionProtocol == "tcp" { - return "tcp4" - } return "udp4" } @@ -83,7 +68,8 @@ func getCurrentProtocol(addrStr string) string { return "ipv4" } -func MappingTests(addrStr string) error { +// RFC 5780 implementation (current) +func MappingTests(addrStr string) error { //nolint:cyclop currentProtocol := getCurrentProtocol(addrStr) mapTestConn, err := connect(addrStr) if err != nil { @@ -171,7 +157,8 @@ func MappingTests(addrStr string) error { return mapTestConn.Close() } -func FilteringTests(addrStr string) error { +// RFC 5780 implementation (current) +func FilteringTests(addrStr string) error { //nolint:cyclop currentProtocol := getCurrentProtocol(addrStr) mapTestConn, err := connect(addrStr) if err != nil { @@ -241,6 +228,7 @@ func FilteringTests(addrStr string) error { return mapTestConn.Close() } +// RFC 5389/8489 implementation - basic STUN binding request func MappingTestsRFC5389(addrStr string) error { currentProtocol := getCurrentProtocol(addrStr) mapTestConn, err := connect(addrStr) @@ -269,10 +257,12 @@ func MappingTestsRFC5389(addrStr string) error { if model.EnableLoger { model.Log.Infof("[%s] RFC5389: Received XOR-MAPPED-ADDRESS: %v", currentProtocol, resps.xorAddr) } + // Simple classification based on whether we're behind NAT if resps.xorAddr.String() == mapTestConn.LocalAddr.String() { model.NatMappingBehavior = "endpoint independent (no NAT)" model.NatFilteringBehavior = "endpoint independent" } else { + // Can't determine exact type with RFC5389, so use conservative estimate model.NatMappingBehavior = "address and port dependent" model.NatFilteringBehavior = "address and port dependent" } @@ -283,6 +273,7 @@ func MappingTestsRFC5389(addrStr string) error { return nil } +// RFC 3489 implementation - classic STUN func MappingTestsRFC3489(addrStr string) error { currentProtocol := getCurrentProtocol(addrStr) mapTestConn, err := connect(addrStr) @@ -296,6 +287,7 @@ func MappingTestsRFC3489(addrStr string) error { if model.EnableLoger { model.Log.Infof("[%s] RFC3489: Test I - Basic binding request", currentProtocol) } + // Test I: Basic binding request request := stun.MustBuild(stun.TransactionID, stun.BindingRequest) resp, err := mapTestConn.roundTrip(request, mapTestConn.RemoteAddr) if err != nil { @@ -303,6 +295,7 @@ func MappingTestsRFC3489(addrStr string) error { } resps1 := parse(resp) var mappedAddr *net.UDPAddr + // Try XOR-MAPPED-ADDRESS first, then MAPPED-ADDRESS if resps1.xorAddr != nil { mappedAddr, _ = net.ResolveUDPAddr("udp", resps1.xorAddr.String()) } else if resps1.mappedAddr != nil { @@ -317,35 +310,26 @@ func MappingTestsRFC3489(addrStr string) error { if model.EnableLoger { model.Log.Infof("[%s] RFC3489: Received mapped address: %v", currentProtocol, mappedAddr) } - localAddr := mapTestConn.LocalAddr - if model.TransmissionProtocol == "tcp" || model.TransmissionProtocol == "tls" { - localTCP := localAddr.(*net.TCPAddr) - if mappedAddr.IP.Equal(localTCP.IP) && mappedAddr.Port == localTCP.Port { - model.NatMappingBehavior = "endpoint independent (no NAT)" - model.NatFilteringBehavior = "endpoint independent" - if model.EnableLoger { - model.Log.Warnf("[%s] RFC3489: No NAT detected", currentProtocol) - } - return nil - } - } else { - localUDP := localAddr.(*net.UDPAddr) - if mappedAddr.IP.Equal(localUDP.IP) && mappedAddr.Port == localUDP.Port { - model.NatMappingBehavior = "endpoint independent (no NAT)" - model.NatFilteringBehavior = "endpoint independent" - if model.EnableLoger { - model.Log.Warnf("[%s] RFC3489: No NAT detected", currentProtocol) - } - return nil + // Check if we're behind NAT + localUDP, _ := mapTestConn.LocalAddr.(*net.UDPAddr) + if mappedAddr.IP.Equal(localUDP.IP) && mappedAddr.Port == localUDP.Port { + // No NAT + model.NatMappingBehavior = "endpoint independent (no NAT)" + model.NatFilteringBehavior = "endpoint independent" + if model.EnableLoger { + model.Log.Warnf("[%s] RFC3489: No NAT detected", currentProtocol) } + return nil } + // Test II: Binding request with change IP and Port if model.EnableLoger { model.Log.Infof("[%s] RFC3489: Test II - Request with change IP and Port", currentProtocol) } request2 := stun.MustBuild(stun.TransactionID, stun.BindingRequest) - request2.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x06}) + request2.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x06}) // Change both IP and port resp2, err2 := mapTestConn.roundTrip(request2, mapTestConn.RemoteAddr) if err2 == nil && resp2 != nil { + // Full cone NAT model.NatMappingBehavior = "endpoint independent" model.NatFilteringBehavior = "endpoint independent" if model.EnableLoger { @@ -353,13 +337,15 @@ func MappingTestsRFC3489(addrStr string) error { } return nil } + // Test III: Binding request with change port only if model.EnableLoger { model.Log.Infof("[%s] RFC3489: Test III - Request with change Port only", currentProtocol) } request3 := stun.MustBuild(stun.TransactionID, stun.BindingRequest) - request3.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x02}) + request3.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x02}) // Change port only resp3, err3 := mapTestConn.roundTrip(request3, mapTestConn.RemoteAddr) if err3 == nil && resp3 != nil { + // Restricted cone NAT model.NatMappingBehavior = "endpoint independent" model.NatFilteringBehavior = "address dependent" if model.EnableLoger { @@ -367,6 +353,8 @@ func MappingTestsRFC3489(addrStr string) error { } return nil } + // If we get here, we need to do additional tests for symmetric vs port restricted + // For simplicity in RFC3489, we'll classify remaining as Port Restricted or Symmetric model.NatMappingBehavior = "address and port dependent" model.NatFilteringBehavior = "address and port dependent" if model.EnableLoger { @@ -413,7 +401,12 @@ func parse(msg *stun.Message) (ret struct { } for _, attr := range msg.Attributes { switch attr.Type { - case stun.AttrXORMappedAddress, stun.AttrOtherAddress, stun.AttrResponseOrigin, stun.AttrMappedAddress, stun.AttrSoftware: + case + stun.AttrXORMappedAddress, + stun.AttrOtherAddress, + stun.AttrResponseOrigin, + stun.AttrMappedAddress, + stun.AttrSoftware: break //nolint:staticcheck default: if model.EnableLoger { @@ -430,50 +423,27 @@ func connect(addrStr string) (*stunServerConn, error) { model.Log.Infof("[%s] Connecting to STUN server: %s", currentProtocol, addrStr) } networkType := getNetworkType(addrStr) - var conn net.Conn - var localAddr net.Addr - var err error - switch model.TransmissionProtocol { - case "tcp": - conn, err = net.Dial(networkType, addrStr) - if err != nil { - return nil, err + addr, err := net.ResolveUDPAddr(networkType, addrStr) + if err != nil { + if model.EnableLoger { + model.Log.Warnf("[%s] Error resolving address: %s", currentProtocol, err) } - localAddr = conn.LocalAddr() - case "tls": - config := &tls.Config{InsecureSkipVerify: true} - conn, err = tls.Dial(networkType[:3], addrStr, config) - if err != nil { - return nil, err - } - localAddr = conn.LocalAddr() - default: - _, err := net.ResolveUDPAddr(networkType, addrStr) - if err != nil { - if model.EnableLoger { - model.Log.Warnf("[%s] Error resolving address: %s", currentProtocol, err) - } - return nil, err - } - udpConn, err := net.ListenUDP(networkType, nil) - if err != nil { - return nil, err - } - conn = udpConn - localAddr = udpConn.LocalAddr() + return nil, err + } + c, err := net.ListenUDP(networkType, nil) + if err != nil { + return nil, err } if model.EnableLoger { - model.Log.Infof("[%s] Local address: %s", currentProtocol, localAddr.String()) - model.Log.Infof("[%s] Remote address: %s", currentProtocol, addrStr) + model.Log.Infof("[%s] Local address: %s", currentProtocol, c.LocalAddr()) + model.Log.Infof("[%s] Remote address: %s", currentProtocol, addr.String()) } - remoteAddr, _ := net.ResolveUDPAddr("udp", addrStr) - mChan := listen(conn) + mChan := listen(c) return &stunServerConn{ - conn: conn, - LocalAddr: localAddr, - RemoteAddr: remoteAddr, + conn: c, + LocalAddr: c.LocalAddr(), + RemoteAddr: addr, messageChan: mChan, - protocol: model.TransmissionProtocol, }, nil } @@ -486,17 +456,7 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess model.Log.Debugf("\t%v (l=%v)", attr, attr.Length) } } - var err error - switch c.protocol { - case "tcp", "tls": - _, err = c.conn.Write(msg.Raw) - default: - if udpConn, ok := c.conn.(*net.UDPConn); ok { - _, err = udpConn.WriteTo(msg.Raw, addr) - } else { - _, err = c.conn.Write(msg.Raw) - } - } + _, err := c.conn.WriteTo(msg.Raw, addr) if err != nil { if model.EnableLoger { model.Log.Warnf("Error sending request to %v", addr) @@ -517,22 +477,15 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess } } -func listen(conn net.Conn) (messages chan *stun.Message) { +// taken from https://github.com/pion/stun/blob/master/cmd/stun-traversal/main.go +func listen(conn *net.UDPConn) (messages chan *stun.Message) { messages = make(chan *stun.Message) go func() { - defer close(messages) for { buf := make([]byte, 1024) - var n int - var addr net.Addr - var err error - if udpConn, ok := conn.(*net.UDPConn); ok { - n, addr, err = udpConn.ReadFromUDP(buf) - } else { - n, err = conn.Read(buf) - addr = conn.RemoteAddr() - } + n, addr, err := conn.ReadFromUDP(buf) if err != nil { + close(messages) return } if model.EnableLoger { @@ -546,10 +499,11 @@ func listen(conn net.Conn) (messages chan *stun.Message) { if model.EnableLoger { model.Log.Infof("Error decoding message: %v", err) } + close(messages) return } messages <- m } }() return -} +} \ No newline at end of file