mirror of
https://github.com/sigcn/pg.git
synced 2025-10-04 17:22:39 +08:00
101 lines
2.4 KiB
Go
101 lines
2.4 KiB
Go
package udp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sigcn/pg/stun"
|
|
)
|
|
|
|
type stunRoundTripper struct {
|
|
stunResponseMapMutex sync.RWMutex
|
|
doInit sync.Once
|
|
|
|
stunResponseMap map[string]chan stunResponse // key is stun txid
|
|
}
|
|
|
|
func (rt *stunRoundTripper) init() {
|
|
rt.doInit.Do(func() {
|
|
rt.stunResponseMap = make(map[string]chan stunResponse)
|
|
})
|
|
}
|
|
|
|
func (rt *stunRoundTripper) roundTrip(ctx context.Context, udpConn *net.UDPConn, stunServer string) (*net.UDPAddr, error) {
|
|
rt.init()
|
|
txID := stun.NewTxID()
|
|
ch := make(chan stunResponse)
|
|
defer close(ch)
|
|
rt.stunResponseMapMutex.Lock()
|
|
rt.stunResponseMap[string(txID[:])] = ch
|
|
rt.stunResponseMapMutex.Unlock()
|
|
|
|
uaddr, err := rt.resolveUDPAddr(ctx, stunServer)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve stun addr: %w", err)
|
|
}
|
|
_, err = udpConn.WriteToUDP(stun.Request(txID), uaddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("write udp: %w", err)
|
|
}
|
|
|
|
timeout := time.NewTimer(3 * time.Second)
|
|
defer timeout.Stop()
|
|
select {
|
|
case r := <-ch:
|
|
return r.addr, nil
|
|
case <-timeout.C:
|
|
return nil, os.ErrDeadlineExceeded
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (c *stunRoundTripper) resolveUDPAddr(ctx context.Context, addr string) (udpAddr *net.UDPAddr, err error) {
|
|
var closeOnce sync.Once
|
|
ch := make(chan struct{})
|
|
defer closeOnce.Do(func() { close(ch) })
|
|
go func() {
|
|
udpAddr, err = net.ResolveUDPAddr("udp", addr)
|
|
closeOnce.Do(func() { close(ch) })
|
|
}()
|
|
select {
|
|
case <-ch:
|
|
return
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (c *stunRoundTripper) recvResponse(b []byte, peerAddr net.Addr) {
|
|
c.init()
|
|
txid, saddr, err := stun.ParseResponse(b)
|
|
if err != nil {
|
|
slog.Error("[STUN] ParseResponse", "stun", peerAddr, "err", fmt.Errorf("parse: %w", err))
|
|
return
|
|
}
|
|
c.stunResponseMapMutex.RLock()
|
|
if r, ok := c.stunResponseMap[string(txid[:])]; ok {
|
|
c.stunResponseMapMutex.RUnlock()
|
|
addr, err := net.ResolveUDPAddr("udp", saddr.String())
|
|
if err != nil {
|
|
slog.Error("[STUN] ParseResponse", "stun", peerAddr, "err", fmt.Errorf("resolve udp addr: %w", err))
|
|
return
|
|
}
|
|
resp := stunResponse{txid: string(txid[:]), addr: addr}
|
|
defer func() { recover() }()
|
|
select {
|
|
case r <- resp:
|
|
default:
|
|
}
|
|
slog.Log(context.Background(), -2, "[STUN] RecvResponse", "from", peerAddr, "pub_addr", addr)
|
|
return
|
|
}
|
|
c.stunResponseMapMutex.RUnlock()
|
|
slog.Log(context.Background(), -2, "[STUN] RecvResponse", "from", peerAddr)
|
|
}
|