mirror of
https://github.com/bhpike65/go-stun.git
synced 2025-09-26 19:11:12 +08:00
266 lines
6.9 KiB
Go
266 lines
6.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/hex"
|
|
"flag"
|
|
"fmt"
|
|
"github.com/bhpike65/go-stun/stun"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
typePP = iota // primaryAddr:primaryPort
|
|
typePA // primaryAddr:alterAddr
|
|
typeAP // alterAddr:primaryPort
|
|
typeAA // alterAddr:alterAddr
|
|
typeMax
|
|
)
|
|
|
|
var roleSet [typeMax]*net.UDPConn
|
|
|
|
var logger *log.Logger
|
|
|
|
// ./stunserver --primaryAddr 1.1.1.1 --alternativeAddr 2.2.2.2 --primaryPort 3478 --alternativePort 3479
|
|
// ./stunserver --slaveserver 2.2.2.2:12345 --primaryAddr 1.1.1.1 --primaryPort 3478 --alternativePort 3479
|
|
// ./stunserver --slave --slaveserver 2.2.2.2:12345 --primaryPort 3478 --alternativePort 3479
|
|
|
|
var primaryAddr = flag.String("primary-addr", "", "STUN server primary address")
|
|
var alterAddr = flag.String("alt-addr", "", "STUN server alternative address")
|
|
var primaryPort = flag.Int("primary-port", 3478, "primary port")
|
|
var alterPort = flag.Int("alt-port", 3479, "alternative port")
|
|
var slaveServer = flag.String("slaveserver", "", "slave STUN server which has alternative Ip")
|
|
|
|
var isSlave = flag.Bool("slave", false, "this is a slave stun server")
|
|
var public = flag.Bool("public", true, "primaryAddr and alternativeAddr must be public ip address")
|
|
|
|
var slaveChan chan *string
|
|
|
|
var lanNets = []*net.IPNet{
|
|
{net.IPv4(10, 0, 0, 0), net.CIDRMask(8, 32)},
|
|
{net.IPv4(172, 16, 0, 0), net.CIDRMask(12, 32)},
|
|
{net.IPv4(192, 168, 0, 0), net.CIDRMask(16, 32)},
|
|
{net.ParseIP("fc00"), net.CIDRMask(7, 128)},
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
logFile, err := os.OpenFile("./slave.log", os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
|
|
if err != nil {
|
|
fmt.Println("failed to create slave.log: ", err.Error())
|
|
os.Exit(-1)
|
|
}
|
|
logger = log.New(logFile, "", log.Llongfile|log.LstdFlags)
|
|
|
|
if *primaryAddr == "" || *alterAddr == "" {
|
|
addrs, err := net.InterfaceAddrs()
|
|
if err != nil {
|
|
logger.Fatal("get interface addrs error: ", err.Error())
|
|
}
|
|
for _, a := range addrs {
|
|
if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
|
|
if *public {
|
|
for _, lan := range lanNets {
|
|
if ipnet.IP.To4() != nil && !lan.Contains(ipnet.IP) {
|
|
if *primaryAddr == "" {
|
|
*primaryAddr = ipnet.IP.String()
|
|
} else if *alterAddr == "" && *primaryAddr != ipnet.IP.String() {
|
|
*alterAddr = ipnet.IP.String()
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
if ipnet.IP.To4() != nil {
|
|
if *primaryAddr == "" {
|
|
*primaryAddr = ipnet.IP.String()
|
|
} else if *alterAddr == "" {
|
|
*alterAddr = ipnet.IP.String()
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
roleSet[typePP], err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(*primaryAddr), Port: *primaryPort})
|
|
if err != nil {
|
|
logger.Fatal("listen on PP failed")
|
|
}
|
|
roleSet[typePA], err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(*primaryAddr), Port: *alterPort})
|
|
if err != nil {
|
|
logger.Fatal("listen on PA failed")
|
|
}
|
|
|
|
if *isSlave && *alterAddr != "" {
|
|
*alterAddr = ""
|
|
}
|
|
var aaAddr *net.UDPAddr
|
|
if *alterAddr == "" {
|
|
if *isSlave == false {
|
|
if *slaveServer != "" {
|
|
slaveAddr, err := net.ResolveTCPAddr("tcp", *slaveServer)
|
|
if err != nil {
|
|
logger.Fatal("slave server resolve failed")
|
|
}
|
|
slaveChan = make(chan *string, 128)
|
|
go slaveClientWorker(slaveAddr)
|
|
aaAddr = &net.UDPAddr{IP: slaveAddr.IP, Port: *alterPort}
|
|
}
|
|
} else if *slaveServer != "" {
|
|
slaveAddr, err := net.ResolveTCPAddr("tcp", *slaveServer)
|
|
if err != nil {
|
|
logger.Fatal("slave server resolve failed")
|
|
}
|
|
go slaveWorker(slaveAddr)
|
|
}
|
|
} else {
|
|
aaAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", *alterAddr, *alterPort))
|
|
if err != nil {
|
|
logger.Fatalf("alterAddr %s:%d resolve failed", *alterAddr, alterPort)
|
|
}
|
|
roleSet[typeAP], err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(*alterAddr), Port: *primaryPort})
|
|
if err != nil {
|
|
logger.Fatal("listen on PP failed")
|
|
}
|
|
roleSet[typeAA], err = net.ListenUDP("udp", aaAddr)
|
|
if err != nil {
|
|
logger.Fatal("listen on PA failed")
|
|
}
|
|
|
|
go startStunServer(typeAP, roleSet[typeAP], nil)
|
|
go startStunServer(typeAA, roleSet[typeAA], nil)
|
|
}
|
|
|
|
go startStunServer(typePA, roleSet[typePA], nil)
|
|
startStunServer(typePP, roleSet[typePP], aaAddr)
|
|
}
|
|
|
|
func startStunServer(role int, conn *net.UDPConn, other *net.UDPAddr) {
|
|
buf := make([]byte, 1500)
|
|
for {
|
|
n, remote, err := conn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
logger.Println("receive Error: ", err)
|
|
}
|
|
var req stun.StunMessageReq
|
|
if err = req.Unmarshal(buf[:n]); err != nil {
|
|
logger.Println("receive error req: ", err.Error())
|
|
continue
|
|
}
|
|
otherRole := role
|
|
if req.ChangeIp {
|
|
otherRole ^= 0x02
|
|
}
|
|
if req.ChangePort {
|
|
otherRole ^= 0x01
|
|
}
|
|
if otherRole != role {
|
|
if slaveChan != nil {
|
|
info := fmt.Sprintf("%s|%x\n", remote.String(), req.TransacrtonId)
|
|
go sendToSlave(&info)
|
|
continue
|
|
} else if *alterAddr != "" && roleSet[otherRole] != nil {
|
|
if err = req.RespondTo(roleSet[otherRole], remote, nil); err != nil {
|
|
logger.Printf("respond to %s failed %s", remote, err.Error())
|
|
}
|
|
} else {
|
|
//ignore
|
|
continue
|
|
}
|
|
} else {
|
|
if err = req.RespondTo(conn, remote, other); err != nil {
|
|
logger.Printf("respond to %s failed %s", remote, err.Error())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func sendToSlave(info *string) {
|
|
//ip:port|transactionId\n
|
|
slaveChan <- info
|
|
}
|
|
|
|
func slaveClientWorker(slaveServer *net.TCPAddr) {
|
|
|
|
for {
|
|
conn, err := net.DialTCP("tcp", nil, slaveServer)
|
|
if err != nil {
|
|
logger.Fatal("Dial slave server failed:", err.Error())
|
|
}
|
|
|
|
for {
|
|
data := <-slaveChan
|
|
conn.SetNoDelay(true)
|
|
_, err = conn.Write([]byte(*data))
|
|
if err != nil {
|
|
fmt.Println("Write to slave server failed:", err.Error())
|
|
conn.Close()
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func slaveWorker(slaveServer *net.TCPAddr) {
|
|
l, err := net.ListenTCP("tcp", slaveServer)
|
|
if err != nil {
|
|
logger.Fatal("slave tcp listen error: ", err.Error())
|
|
os.Exit(-1)
|
|
}
|
|
defer l.Close()
|
|
|
|
for {
|
|
conn, err := l.AcceptTCP()
|
|
if err != nil {
|
|
logger.Fatal("slave accept error:", err.Error())
|
|
continue
|
|
}
|
|
go slaveProcessRequest(conn)
|
|
}
|
|
}
|
|
|
|
func slaveProcessRequest(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
reader := bufio.NewReaderSize(conn, 128)
|
|
for {
|
|
data, err := reader.ReadString('\n')
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
logger.Println("read from tcp socket failed:", err.Error())
|
|
break
|
|
}
|
|
data = strings.TrimRight(data, "\n")
|
|
logger.Println("slave get: ", data)
|
|
infos := strings.Split(data, "|")
|
|
if len(infos) != 2 {
|
|
logger.Print("receive error slave data: ", data)
|
|
continue
|
|
}
|
|
addr := infos[0]
|
|
tid, err := hex.DecodeString(infos[1])
|
|
if err != nil || len(tid) != 12 {
|
|
logger.Print("receive error slave data: ", data)
|
|
continue
|
|
}
|
|
remote, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
logger.Print("receive error slave data: ", data)
|
|
continue
|
|
}
|
|
req := stun.NewBindRequest(tid)
|
|
req.RespondTo(roleSet[typePP], remote, nil)
|
|
}
|
|
}
|