Files
pg/disco/udp/stun.go
2025-03-08 13:05:37 +08:00

101 lines
2.4 KiB
Go

package udp
import (
"context"
"fmt"
"log/slog"
"net"
"os"
"sync"
"time"
"github.com/sigcn/pg/stun"
)
type stunRoundTripper struct {
stunResponseMapMutex sync.RWMutex
doInit sync.Once
stunResponseMap map[string]chan stunResponse // key is stun txid
}
func (rt *stunRoundTripper) init() {
rt.doInit.Do(func() {
rt.stunResponseMap = make(map[string]chan stunResponse)
})
}
func (rt *stunRoundTripper) roundTrip(ctx context.Context, udpConn *net.UDPConn, stunServer string) (*net.UDPAddr, error) {
rt.init()
txID := stun.NewTxID()
ch := make(chan stunResponse)
defer close(ch)
rt.stunResponseMapMutex.Lock()
rt.stunResponseMap[string(txID[:])] = ch
rt.stunResponseMapMutex.Unlock()
uaddr, err := rt.resolveUDPAddr(ctx, stunServer)
if err != nil {
return nil, fmt.Errorf("resolve stun addr: %w", err)
}
_, err = udpConn.WriteToUDP(stun.Request(txID), uaddr)
if err != nil {
return nil, fmt.Errorf("write udp: %w", err)
}
timeout := time.NewTimer(3 * time.Second)
defer timeout.Stop()
select {
case r := <-ch:
return r.addr, nil
case <-timeout.C:
return nil, os.ErrDeadlineExceeded
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (c *stunRoundTripper) resolveUDPAddr(ctx context.Context, addr string) (udpAddr *net.UDPAddr, err error) {
var closeOnce sync.Once
ch := make(chan struct{})
defer closeOnce.Do(func() { close(ch) })
go func() {
udpAddr, err = net.ResolveUDPAddr("udp", addr)
closeOnce.Do(func() { close(ch) })
}()
select {
case <-ch:
return
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (c *stunRoundTripper) recvResponse(b []byte, peerAddr net.Addr) {
c.init()
txid, saddr, err := stun.ParseResponse(b)
if err != nil {
slog.Error("[STUN] ParseResponse", "stun", peerAddr, "err", fmt.Errorf("parse: %w", err))
return
}
c.stunResponseMapMutex.RLock()
if r, ok := c.stunResponseMap[string(txid[:])]; ok {
c.stunResponseMapMutex.RUnlock()
addr, err := net.ResolveUDPAddr("udp", saddr.String())
if err != nil {
slog.Error("[STUN] ParseResponse", "stun", peerAddr, "err", fmt.Errorf("resolve udp addr: %w", err))
return
}
resp := stunResponse{txid: string(txid[:]), addr: addr}
defer func() { recover() }()
select {
case r <- resp:
default:
}
slog.Log(context.Background(), -2, "[STUN] RecvResponse", "from", peerAddr, "pub_addr", addr)
return
}
c.stunResponseMapMutex.RUnlock()
slog.Log(context.Background(), -2, "[STUN] RecvResponse", "from", peerAddr)
}