参数改为 netip.Addr

This commit is contained in:
xmdhs
2024-01-24 23:58:25 +08:00
parent b16821e2d2
commit 84bd5482dd
3 changed files with 39 additions and 34 deletions

22
main.go
View File

@@ -57,8 +57,10 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
laddrPort := netip.AddrPortFrom(netip.MustParseAddr(localAddr), uint16(portu))
for { for {
err := openPort(ctx, target, localAddr, uint16(portu), stun, func(s netip.AddrPort) { err := openPort(ctx, target, laddrPort, stun, func(s netip.AddrPort) {
fmt.Println(s) fmt.Println(s)
if comm != "" { if comm != "" {
c := exec.CommandContext(ctx, comm, localAddr, port, s.Addr().String(), strconv.Itoa(int(s.Port()))) c := exec.CommandContext(ctx, comm, localAddr, port, s.Addr().String(), strconv.Itoa(int(s.Port())))
@@ -78,19 +80,19 @@ func main() {
} }
} }
func openPort(ctx context.Context, target, localAddr string, portu uint16, func openPort(ctx context.Context, target string, laddr netip.AddrPort,
stun string, finish func(netip.AddrPort), udp bool, testserver bool) error { stun string, finish func(netip.AddrPort), udp bool, testserver bool) error {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
if target != "" { if target != "" {
var forward func(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) var forward func(ctx context.Context, laddr netip.AddrPort, target string, log func(string)) (io.Closer, error)
if udp { if udp {
forward = natmap.ForwardUdp forward = natmap.ForwardUdp
} else { } else {
forward = natmap.Forward forward = natmap.Forward
} }
l, err := forward(ctx, portu, target, func(s string) { l, err := forward(ctx, laddr, target, func(s string) {
log.Println(s) log.Println(s)
}) })
if err != nil { if err != nil {
@@ -99,21 +101,21 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16,
defer l.Close() defer l.Close()
} }
if testserver { if testserver {
l, err := testServer(ctx, portu) l, err := testServer(ctx, laddr)
if err != nil { if err != nil {
return fmt.Errorf("openPort: %w", err) return fmt.Errorf("openPort: %w", err)
} }
defer l.Close() defer l.Close()
} }
errCh := make(chan error, 1) errCh := make(chan error, 1)
var nmap func(ctx context.Context, stunAddr string, host string, port uint16, log func(error)) (*natmap.Map, netip.AddrPort, error) var nmap func(ctx context.Context, stunAddr string, laddr netip.AddrPort, log func(error)) (*natmap.Map, netip.AddrPort, error)
if udp { if udp {
nmap = natmap.NatMapUdp nmap = natmap.NatMapUdp
} else { } else {
nmap = natmap.NatMap nmap = natmap.NatMap
} }
m, s, err := nmap(ctx, stun, localAddr, uint16(portu), func(s error) { m, s, err := nmap(ctx, stun, laddr, func(s error) {
cancel() cancel()
select { select {
case errCh <- s: case errCh <- s:
@@ -134,16 +136,16 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16,
return nil return nil
} }
func testServer(ctx context.Context, port uint16) (net.Listener, error) { func testServer(ctx context.Context, laddr netip.AddrPort) (net.Listener, error) {
s := http.Server{ s := http.Server{
ReadTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second,
Addr: "0.0.0.0:" + strconv.FormatUint(uint64(port), 10), Addr: laddr.String(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok")) w.Write([]byte("ok"))
}), }),
} }
l, err := reuse.Listen(ctx, "tcp", "0.0.0.0:"+strconv.FormatUint(uint64(port), 10)) l, err := reuse.Listen(ctx, "tcp", laddr.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("testServer: %w", err) return nil, fmt.Errorf("testServer: %w", err)
} }

View File

@@ -8,7 +8,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
"strconv"
"time" "time"
"github.com/xmdhs/natupnp/reuse" "github.com/xmdhs/natupnp/reuse"
@@ -20,21 +19,20 @@ type Map struct {
cancel func() cancel func()
} }
func getPubulicPort(ctx context.Context, stunAddr string, host string, port uint16, isTcp bool) (netip.AddrPort, error) { func getPubulicPort(ctx context.Context, stunAddr string, laddr netip.AddrPort, isTcp bool) (netip.AddrPort, error) {
var ( var (
upnpP = "TCP" upnpP = "TCP"
dialP = "tcp4" dialP = "tcp"
) )
if !isTcp { if !isTcp {
upnpP = "UDP" upnpP = "UDP"
dialP = "udp4" dialP = "udp"
} }
err := upnp.AddPortMapping(ctx, "", laddr.Port(), upnpP, laddr.Port(), laddr.Addr().String(), true, "github.com/xmdhs/natupnp", 0)
err := upnp.AddPortMapping(ctx, "", port, upnpP, port, host, true, "github.com/xmdhs/natupnp", 0)
if err != nil { if err != nil {
return netip.AddrPort{}, fmt.Errorf("getPubulicPort: %w", err) return netip.AddrPort{}, fmt.Errorf("getPubulicPort: %w", err)
} }
stunConn, err := reuse.DialContext(ctx, dialP, "0.0.0.0:"+strconv.Itoa(int(port)), stunAddr) stunConn, err := reuse.DialContext(ctx, dialP, laddr.String(), stunAddr)
if err != nil { if err != nil {
return netip.AddrPort{}, fmt.Errorf("getPubulicPort: %w", err) return netip.AddrPort{}, fmt.Errorf("getPubulicPort: %w", err)
} }
@@ -47,16 +45,16 @@ func getPubulicPort(ctx context.Context, stunAddr string, host string, port uint
return netip.AddrPortFrom(addr, uint16(mapAddr.Port)), nil return netip.AddrPortFrom(addr, uint16(mapAddr.Port)), nil
} }
func NatMap(ctx context.Context, stunAddr string, host string, port uint16, log func(error)) (*Map, netip.AddrPort, error) { func NatMap(ctx context.Context, stunAddr string, laddr netip.AddrPort, log func(error)) (*Map, netip.AddrPort, error) {
m := Map{} m := Map{}
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel m.cancel = cancel
mapAddr, err := getPubulicPort(ctx, stunAddr, host, port, true) mapAddr, err := getPubulicPort(ctx, stunAddr, laddr, true)
if err != nil { if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("NatMap: %w", err) return nil, netip.AddrPort{}, fmt.Errorf("NatMap: %w", err)
} }
go keepalive(ctx, port, log) go keepalive(ctx, laddr, log)
return &m, mapAddr, nil return &m, mapAddr, nil
} }
@@ -65,17 +63,19 @@ func (m Map) Close() error {
return nil return nil
} }
func keepalive(ctx context.Context, port uint16, log func(error)) { func keepalive(ctx context.Context, laddr netip.AddrPort, log func(error)) {
tr := http.DefaultTransport.(*http.Transport).Clone() tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return reuse.DialContext(ctx, "tcp", "0.0.0.0:"+strconv.Itoa(int(port)), addr) return reuse.DialContext(ctx, "tcp", laddr.String(), addr)
} }
tr.Proxy = nil
c := http.Client{Transport: tr, Timeout: 5 * time.Second} c := http.Client{Transport: tr, Timeout: 5 * time.Second}
for { for {
func() { func() {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
reqs, err := http.NewRequestWithContext(ctx, "GET", "http://connect.rom.miui.com/generate_204", nil) reqs, err := http.NewRequestWithContext(ctx, "HEAD", "http://www.gstatic.com/generate_204", nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -100,14 +100,14 @@ func keepalive(ctx context.Context, port uint16, log func(error)) {
func GetLocalAddr() (net.Addr, error) { func GetLocalAddr() (net.Addr, error) {
l, err := net.Dial("udp4", "223.5.5.5:53") l, err := net.Dial("udp4", "223.5.5.5:53")
if err != nil { if err != nil {
return nil, fmt.Errorf("getLocal: %w", err) return nil, fmt.Errorf("GetLocalAddr: %w", err)
} }
defer l.Close() defer l.Close()
return l.LocalAddr(), nil return l.LocalAddr(), nil
} }
func Forward(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) { func Forward(ctx context.Context, laddr netip.AddrPort, target string, log func(string)) (io.Closer, error) {
l, err := reuse.Listen(ctx, "tcp", "0.0.0.0:"+strconv.FormatUint(uint64(port), 10)) l, err := reuse.Listen(ctx, "tcp", laddr.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("Forward: %w", err) return nil, fmt.Errorf("Forward: %w", err)
} }

View File

@@ -6,32 +6,35 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"time" "time"
"github.com/xmdhs/natupnp/reuse" "github.com/xmdhs/natupnp/reuse"
) )
func NatMapUdp(ctx context.Context, stunAddr string, host string, port uint16, log func(error)) (*Map, netip.AddrPort, error) { func NatMapUdp(ctx context.Context, stunAddr string, laddr netip.AddrPort, log func(error)) (*Map, netip.AddrPort, error) {
m := Map{} m := Map{}
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel m.cancel = cancel
mapAddr, err := getPubulicPort(ctx, stunAddr, host, port, false) mapAddr, err := getPubulicPort(ctx, stunAddr, laddr, false)
if err != nil { if err != nil {
return nil, netip.AddrPort{}, fmt.Errorf("NatMap: %w", err) return nil, netip.AddrPort{}, fmt.Errorf("NatMap: %w", err)
} }
go keepaliveUDP(ctx, port, log) go keepaliveUDP(ctx, laddr, log)
return &m, mapAddr, nil return &m, mapAddr, nil
} }
func keepaliveUDP(ctx context.Context, port uint16, log func(error)) { func keepaliveUDP(ctx context.Context, laddr netip.AddrPort, log func(error)) {
raddr := "223.5.5.5:53"
if laddr.Addr().Is6() {
raddr = "[2400:3200::1]:53"
}
r := net.Resolver{ r := net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(context context.Context, network, address string) (net.Conn, error) { Dial: func(context context.Context, network, address string) (net.Conn, error) {
conn, err := reuse.DialContext(context, "udp", "0.0.0.0:"+strconv.Itoa(int(port)), "223.5.5.5:53") conn, err := reuse.DialContext(context, "udp", laddr.String(), raddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -67,8 +70,8 @@ func (l logger) Println(v ...any) {
l.log(build.String()) l.log(build.String())
} }
func ForwardUdp(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) { func ForwardUdp(ctx context.Context, laddr netip.AddrPort, target string, log func(string)) (io.Closer, error) {
lc, err := reuse.ListenPacket(ctx, "udp", "0.0.0.0:"+strconv.FormatUint(uint64(port), 10)) lc, err := reuse.ListenPacket(ctx, "udp", laddr.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }