mirror of
https://github.com/mcfx/grass.git
synced 2025-12-24 13:18:18 +08:00
add files to git
This commit is contained in:
39
README.md
Normal file
39
README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Grass
|
||||

|
||||
|
||||
Grass is a Peer-to-Peer VPN.
|
||||
|
||||
## Usage
|
||||
Clone this repo, run `go build main.go` in the bin directory.
|
||||
|
||||
Sample config:
|
||||
|
||||
```json
|
||||
{
|
||||
"Id": 114514, // some random number
|
||||
"Key": "some_key111", // pre-shared key
|
||||
"IPv4": "10.56.0.2", // ipv4 for this client, ipv6 support will be added soon
|
||||
"IPv4Gateway": "10.56.0.1", // ipv4 gateway, in fact useless
|
||||
"IPv4Mask": "255.255.255.0",
|
||||
"IPv4DNS": "8.8.8.8,8.8.4.4", // seems only useful on windows
|
||||
"Listen": [
|
||||
{
|
||||
"LocalAddr": "0.0.0.0",
|
||||
"NetworkAddr": "192.168.254.1", // your ip on external network
|
||||
"Port": 14514
|
||||
},
|
||||
{
|
||||
"LocalAddr": "[::]",
|
||||
"NetworkAddr": "[fe80:114:514:1919:810::233]",
|
||||
"Port": 14515
|
||||
}
|
||||
],
|
||||
"BootstrapNodes": ["11.4.5.14:1919"],
|
||||
"CheckClientsInterval": 1,
|
||||
"PingInterval": 20,
|
||||
"CheckConnInterval": 5,
|
||||
"Debug": false,
|
||||
"ConnectNewPeer": true, // if false, the client will not discover new clients
|
||||
"StartTun": true // if false, the client will not start tunnel interface
|
||||
}
|
||||
```
|
||||
44
bin/main.go
Normal file
44
bin/main.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mcfx0/grass"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var configFile string
|
||||
flag.StringVar(&configFile, "config", "config.json", "config file name")
|
||||
flag.Parse()
|
||||
data, err := ioutil.ReadFile(configFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
config, err := grass.UnmarshalConfig(data)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
client := grass.Client{Config: config}
|
||||
|
||||
err = client.Start()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch,
|
||||
syscall.SIGHUP,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
syscall.SIGQUIT)
|
||||
s := <-ch
|
||||
switch s {
|
||||
default:
|
||||
client.Stop()
|
||||
}
|
||||
}
|
||||
577
client.go
Normal file
577
client.go
Normal file
@@ -0,0 +1,577 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mcfx0/grass/network"
|
||||
"github.com/mcfx0/grass/network/tun"
|
||||
)
|
||||
|
||||
const MAX_PEER_CONNS = 8
|
||||
|
||||
type Conn struct {
|
||||
Conn net.Conn
|
||||
Reader *bufio.Reader
|
||||
Writer Writer
|
||||
Latency []uint32
|
||||
tmpKeys *TmpKeySet
|
||||
peerTmpKeys *TmpKeySet
|
||||
WriteCh chan QueuedPacket
|
||||
WriterStopCh chan bool
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
Id uint32
|
||||
Conns []*Conn
|
||||
ConnInfo []string
|
||||
ConnsMutex *sync.RWMutex
|
||||
LastSendClients OtherClientsInfo
|
||||
}
|
||||
|
||||
type ListenInfo struct {
|
||||
LocalAddr string
|
||||
NetworkAddr string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
Id uint32
|
||||
Key []byte
|
||||
IPv4 uint32
|
||||
IPv4Gateway uint32
|
||||
IPv4Mask uint32
|
||||
IPv4DNS []uint32
|
||||
Listen []ListenInfo
|
||||
BootstrapNodes []string
|
||||
CheckClientsInterval int
|
||||
PingInterval int
|
||||
CheckConnInterval int
|
||||
Debug bool
|
||||
ConnectNewPeer bool
|
||||
StartTun bool
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
Id uint32
|
||||
IPv4 uint32
|
||||
Key []byte
|
||||
Peers map[uint32]*Peer
|
||||
PeerMutex *sync.RWMutex
|
||||
OClients OtherClients
|
||||
LastCheckClients OtherClientsInfo
|
||||
Tun *network.TunHandler
|
||||
Config ClientConfig
|
||||
ConnInfo string
|
||||
KnownConnInfo map[string]uint32
|
||||
KnownConnInfoMutex *sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
func (peer *Peer) AvgLatency() uint32 {
|
||||
peer.ConnsMutex.RLock()
|
||||
tot := 0
|
||||
fail := 0
|
||||
sum := 0
|
||||
for i := 0; i < len(peer.Conns); i++ {
|
||||
tmp := peer.Conns[i].Latency
|
||||
for j := 0; j < len(tmp); j++ {
|
||||
if tmp[j] == 0 {
|
||||
fail++
|
||||
} else {
|
||||
sum += int(tmp[j])
|
||||
}
|
||||
tot++
|
||||
}
|
||||
}
|
||||
peer.ConnsMutex.RUnlock()
|
||||
if tot == fail {
|
||||
return 100000000
|
||||
}
|
||||
res := float64(sum) / float64(tot-fail) / float64(tot-fail) * float64(tot)
|
||||
resn := int(res)
|
||||
if resn < 2 {
|
||||
return 2
|
||||
}
|
||||
if resn > 100000000 {
|
||||
return 100000000
|
||||
}
|
||||
return uint32(resn)
|
||||
}
|
||||
|
||||
func (peer *Peer) SendPacket(tp uint8, data []byte, client *Client) error {
|
||||
// a naive implementation, need to be improved
|
||||
peer.ConnsMutex.RLock()
|
||||
defer peer.ConnsMutex.RUnlock()
|
||||
if len(peer.Conns) == 0 {
|
||||
return fmt.Errorf("no usable connections")
|
||||
}
|
||||
//return WriteRawPacket(tp, data, peer.Conns[rand.Intn(len(peer.Conns))], client)
|
||||
peer.Conns[rand.Intn(len(peer.Conns))].WriteCh <- QueuedPacket{Type: tp, Data: data}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendPacketV2(tp uint8, data []byte, client *Client, maxQlen int) bool {
|
||||
peer.ConnsMutex.RLock()
|
||||
defer peer.ConnsMutex.RUnlock()
|
||||
if len(peer.Conns) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(peer.Conns); i++ {
|
||||
//fmt.Printf("%d ", len(peer.Conns[i].WriteCh))
|
||||
if len(peer.Conns[i].WriteCh) <= maxQlen {
|
||||
peer.Conns[i].WriteCh <- QueuedPacket{Type: tp, Data: data}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (client *Client) HandleConn(rawConn net.Conn, connInfo string) error {
|
||||
var conn Conn
|
||||
conn.Conn = rawConn
|
||||
conn.Reader = bufio.NewReader(rawConn)
|
||||
conn.Writer = NewWriter(rawConn)
|
||||
conn.tmpKeys = NewTmpKeySet(120)
|
||||
conn.peerTmpKeys = NewTmpKeySet(120)
|
||||
conn.WriteCh = make(chan QueuedPacket, 300)
|
||||
conn.WriterStopCh = make(chan bool, 10)
|
||||
|
||||
var pkt Packet
|
||||
pkt.Type = GrassPacketInfo
|
||||
pkt.Content = make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(pkt.Content, client.Id)
|
||||
err := WritePacket(&pkt, &conn, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var peerId uint32
|
||||
for {
|
||||
err := ReadPacket(&pkt, &conn, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pkt.Type == GrassPacketInfo {
|
||||
if len(pkt.Content) != 4 {
|
||||
return fmt.Errorf("Unexpected length of info packet")
|
||||
}
|
||||
peerId = binary.LittleEndian.Uint32(pkt.Content)
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Printf("Connected to peer %08x\n", peerId)
|
||||
if connInfo != "" {
|
||||
client.KnownConnInfoMutex.Lock()
|
||||
client.KnownConnInfo[connInfo] = peerId
|
||||
client.KnownConnInfoMutex.Unlock()
|
||||
}
|
||||
|
||||
client.PeerMutex.Lock()
|
||||
peer, ok := client.Peers[peerId]
|
||||
if !ok {
|
||||
peer = &Peer{Id: peerId, ConnsMutex: &sync.RWMutex{}}
|
||||
peer.LastSendClients = make(OtherClientsInfo)
|
||||
if client.Peers == nil {
|
||||
client.Peers = make(map[uint32]*Peer)
|
||||
}
|
||||
client.Peers[peerId] = peer
|
||||
}
|
||||
client.PeerMutex.Unlock()
|
||||
peer.ConnsMutex.Lock()
|
||||
peer.Conns = append(peer.Conns, &conn)
|
||||
peer.ConnsMutex.Unlock()
|
||||
|
||||
working := true
|
||||
var pingKey uint32 = 0
|
||||
var pingTime int64 = 0
|
||||
|
||||
go func() {
|
||||
for working {
|
||||
changed := false
|
||||
if pingTime != 0 {
|
||||
conn.Latency = append(conn.Latency, 0)
|
||||
changed = true
|
||||
}
|
||||
if len(conn.Latency) > 10 {
|
||||
conn.Latency = conn.Latency[len(conn.Latency)-10:]
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
go client.OClients.UpdateLatencyAdd(peerId, peer.AvgLatency())
|
||||
}
|
||||
pingKey = rand.Uint32()
|
||||
pingTime = time.Now().UnixNano()
|
||||
tmp := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(tmp, pingKey)
|
||||
if err := WriteRawPacket(GrassPacketPing, tmp, &conn, client); err != nil && client.Config.Debug {
|
||||
log.Printf("Error in ping thread: %v\n", err)
|
||||
return
|
||||
}
|
||||
peer.ConnsMutex.RLock()
|
||||
wait := int(client.Config.PingInterval) * len(peer.Conns)
|
||||
peer.ConnsMutex.RUnlock()
|
||||
if len(conn.Latency) < 10 {
|
||||
wait /= 10
|
||||
}
|
||||
wait += rand.Intn(10) - 5
|
||||
if wait < 0 {
|
||||
wait = 0
|
||||
}
|
||||
time.Sleep(time.Second * time.Duration(wait))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case pkt := <-conn.WriteCh:
|
||||
WriteRawPacket(pkt.Type, pkt.Data, &conn, client)
|
||||
case <-conn.WriterStopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
closeConn := func() {
|
||||
working = false
|
||||
peer.ConnsMutex.Lock()
|
||||
i := 0
|
||||
for ; i < len(peer.Conns); i++ {
|
||||
if peer.Conns[i] == &conn {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i != len(peer.Conns) {
|
||||
peer.Conns[i] = peer.Conns[len(peer.Conns)-1]
|
||||
peer.Conns = peer.Conns[:len(peer.Conns)-1]
|
||||
}
|
||||
peer.ConnsMutex.Unlock()
|
||||
if len(peer.Conns) == 0 {
|
||||
client.PeerMutex.Lock()
|
||||
delete(client.Peers, peer.Id)
|
||||
client.PeerMutex.Unlock()
|
||||
client.OClients.ClearClients(peerId)
|
||||
}
|
||||
log.Printf("Disconnected from peer %08x\n", peerId)
|
||||
rawConn.Close()
|
||||
conn.WriterStopCh <- true
|
||||
}
|
||||
|
||||
for {
|
||||
err := ReadPacket(&pkt, &conn, client)
|
||||
if err != nil {
|
||||
closeConn()
|
||||
return err
|
||||
}
|
||||
if client.Config.Debug {
|
||||
log.Printf("Packet from %08x: %d %v\n", peerId, pkt.Type, pkt.Content)
|
||||
}
|
||||
switch pkt.Type {
|
||||
case GrassPacketIPv4:
|
||||
tmp, err := DecodePacketIPv4(pkt.Content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error in IPv4 packet: %v", err)
|
||||
}
|
||||
go client.HandlePacketIPv4(tmp)
|
||||
case GrassPacketPing:
|
||||
if len(pkt.Content) != 4 {
|
||||
return fmt.Errorf("Unexpected length of ping packet")
|
||||
}
|
||||
if err = WriteRawPacket(GrassPacketPong, pkt.Content, &conn, client); err != nil {
|
||||
return err
|
||||
}
|
||||
case GrassPacketInfo:
|
||||
continue
|
||||
case GrassPacketPong:
|
||||
if len(pkt.Content) != 4 {
|
||||
return fmt.Errorf("Unexpected length of pong packet")
|
||||
}
|
||||
tKey := binary.LittleEndian.Uint32(pkt.Content)
|
||||
if tKey == pingKey {
|
||||
latency := uint32((time.Now().UnixNano() - pingTime) / 1000)
|
||||
if latency == 0 {
|
||||
latency = 1
|
||||
}
|
||||
conn.Latency = append(conn.Latency, latency)
|
||||
pingKey = 0
|
||||
pingTime = 0
|
||||
go client.OClients.UpdateLatencyAdd(peerId, peer.AvgLatency())
|
||||
}
|
||||
case GrassPacketClients:
|
||||
tmp, err := DecodePacketClients(pkt.Content)
|
||||
if client.Config.Debug {
|
||||
log.Printf("Packet Clients from %08x: %v\n", peerId, tmp)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error in Clients packet: %v", err)
|
||||
}
|
||||
go client.OClients.UpdateClients(tmp, peerId, peer.AvgLatency())
|
||||
case GrassPacketFarPing:
|
||||
tmp, err := DecodePacketFarPing(pkt.Content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error in FarPing packet: %v", err)
|
||||
}
|
||||
go client.HandlePacketFarPing(tmp)
|
||||
case GrassPacketExpired:
|
||||
continue
|
||||
default:
|
||||
closeConn()
|
||||
return fmt.Errorf("Unknown packet type")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) CheckClients() {
|
||||
lstLatency := make(map[uint32]uint32)
|
||||
latencyIncreaseCnt := make(map[uint32]int)
|
||||
for client.running {
|
||||
time.Sleep(time.Second * time.Duration(client.Config.CheckClientsInterval))
|
||||
var peers []*Peer
|
||||
client.PeerMutex.RLock()
|
||||
for _, pr := range client.Peers {
|
||||
peers = append(peers, pr)
|
||||
}
|
||||
client.PeerMutex.RUnlock()
|
||||
|
||||
tmp := make(OtherClientsInfo)
|
||||
/*if client.Config.Debug {
|
||||
fmt.Println("peers:")
|
||||
}*/
|
||||
for i := 0; i < len(peers); i++ {
|
||||
/*if client.Config.Debug {
|
||||
fmt.Printf("id=%08x lat=%v avgLat=%d\n", peers[i].Id, peers[i].Conns[0].Latency, peers[i].AvgLatency())
|
||||
}*/
|
||||
tmp[peers[i].Id] = OtherClientInfo{ConnInfo: "", IPv4: 0, Latency: peers[i].AvgLatency()}
|
||||
}
|
||||
tmp[client.Id] = OtherClientInfo{ConnInfo: client.ConnInfo, IPv4: client.IPv4, Latency: 1}
|
||||
tmpt, useful := GenDiffPacketClients(tmp, client.LastCheckClients)
|
||||
client.OClients.UpdateClients(tmpt, client.Id, 0)
|
||||
client.LastCheckClients = tmp
|
||||
|
||||
tmp = make(OtherClientsInfo)
|
||||
var ocs []*OtherClient
|
||||
client.OClients.M.Lock()
|
||||
for _, oc := range client.OClients.C {
|
||||
ocs = append(ocs, oc)
|
||||
}
|
||||
client.OClients.M.Unlock()
|
||||
if client.Config.Debug {
|
||||
log.Println("clients:")
|
||||
}
|
||||
for i := 0; i < len(ocs); i++ {
|
||||
if client.Config.Debug {
|
||||
log.Printf("id=%08x conn=%s ipv4=%08x lat=%v latSort=%v\n", ocs[i].Id, ocs[i].ConnInfoStr, ocs[i].IPv4, ocs[i].Latency, ocs[i].LatencySorted)
|
||||
}
|
||||
lat := ocs[i].GetLatency()
|
||||
if llat, ok := lstLatency[ocs[i].Id]; ok {
|
||||
if lat > llat {
|
||||
t, ok := latencyIncreaseCnt[ocs[i].Id]
|
||||
if !ok {
|
||||
t = 0
|
||||
}
|
||||
t++
|
||||
if t < 10 {
|
||||
latencyIncreaseCnt[ocs[i].Id] = t
|
||||
} else {
|
||||
go func(id uint32) {
|
||||
var pkt PacketFarPing
|
||||
pkt.SrcNode = client.Id
|
||||
pkt.DstNode = id
|
||||
pkt.Status = FarPingRequest
|
||||
pkt.HistoryNodes = make([]uint32, 0)
|
||||
client.HandlePacketFarPing(pkt)
|
||||
}(ocs[i].Id)
|
||||
latencyIncreaseCnt[ocs[i].Id] = 0
|
||||
}
|
||||
} else if lat < llat {
|
||||
latencyIncreaseCnt[ocs[i].Id] = 0
|
||||
}
|
||||
}
|
||||
lstLatency[ocs[i].Id] = lat
|
||||
if !client.OClients.BannedClients.HasKey(ocs[i].Id) {
|
||||
tmp[ocs[i].Id] = OtherClientInfo{ConnInfo: ocs[i].ConnInfoStr, IPv4: ocs[i].IPv4, Latency: lat}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(peers); i++ {
|
||||
if rand.Intn(120) == 23 {
|
||||
peers[i].LastSendClients = make(OtherClientsInfo)
|
||||
}
|
||||
tmpt, useful = GenDiffPacketClients(tmp, peers[i].LastSendClients)
|
||||
peers[i].LastSendClients = tmp
|
||||
if useful {
|
||||
err := peers[i].SendPacket(GrassPacketClients, EncodePacketClients(tmpt), client)
|
||||
if err != nil {
|
||||
log.Printf("Error while sending Clients packet to %08x: %v\n", peers[i].Id, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if client.Config.ConnectNewPeer {
|
||||
client.KnownConnInfoMutex.Lock()
|
||||
for i := 0; i < len(ocs); i++ {
|
||||
tc := ocs[i].ConnInfo
|
||||
if ocs[i].Id == client.Id {
|
||||
continue
|
||||
}
|
||||
for j := 0; j < len(tc); j++ {
|
||||
client.KnownConnInfo[tc[j]] = ocs[i].Id
|
||||
}
|
||||
}
|
||||
client.KnownConnInfoMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) CheckConn() {
|
||||
for client.running {
|
||||
time.Sleep(time.Second * time.Duration(client.Config.CheckConnInterval))
|
||||
|
||||
tmap := make(map[uint32][]string)
|
||||
|
||||
addConn := func(id uint32, ci string) {
|
||||
_, ok := tmap[id]
|
||||
if !ok {
|
||||
tmap[id] = nil
|
||||
}
|
||||
tmap[id] = append(tmap[id], ci)
|
||||
}
|
||||
|
||||
ct := make([]struct {
|
||||
connInfo string
|
||||
cid uint32
|
||||
}, 0)
|
||||
client.KnownConnInfoMutex.Lock()
|
||||
for connInfo, cid := range client.KnownConnInfo {
|
||||
ct = append(ct, struct {
|
||||
connInfo string
|
||||
cid uint32
|
||||
}{connInfo: connInfo, cid: cid})
|
||||
}
|
||||
client.KnownConnInfoMutex.Unlock()
|
||||
for i := 0; i < len(ct); i++ {
|
||||
connInfo := ct[i].connInfo
|
||||
cid := ct[i].cid
|
||||
client.PeerMutex.RLock()
|
||||
peer, ok := client.Peers[cid]
|
||||
if !ok {
|
||||
addConn(cid, connInfo)
|
||||
client.PeerMutex.RUnlock()
|
||||
continue
|
||||
}
|
||||
client.PeerMutex.RUnlock()
|
||||
peer.ConnsMutex.RLock()
|
||||
//log.Printf("conninfo: %v %v %v\n", connInfo, cid, peer)
|
||||
if len(peer.Conns) < MAX_PEER_CONNS {
|
||||
addConn(cid, connInfo)
|
||||
}
|
||||
peer.ConnsMutex.RUnlock()
|
||||
}
|
||||
|
||||
for _, cis := range tmap {
|
||||
go func(s string) {
|
||||
conn, err := net.Dial("tcp", s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = client.HandleConn(conn, s)
|
||||
if err != nil && client.Config.Debug {
|
||||
log.Printf("Error during handling conn: %v\n", err)
|
||||
}
|
||||
}(cis[rand.Intn(len(cis))])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) Start() error {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
log.Println("Starting...")
|
||||
client.Id = client.Config.Id
|
||||
client.IPv4 = client.Config.IPv4
|
||||
client.Key = client.Config.Key
|
||||
|
||||
client.PeerMutex = &sync.RWMutex{}
|
||||
client.KnownConnInfoMutex = &sync.Mutex{}
|
||||
client.KnownConnInfo = make(map[string]uint32)
|
||||
client.LastCheckClients = make(OtherClientsInfo)
|
||||
client.OClients.Init(client.Config.CheckClientsInterval)
|
||||
client.running = true
|
||||
client.ConnInfo = ""
|
||||
|
||||
if client.Config.StartTun {
|
||||
tmpIP := Uint32ToIP(client.IPv4).String()
|
||||
tmpGW := Uint32ToIP(client.Config.IPv4Gateway).String()
|
||||
tmpMask := Uint32ToIP(client.Config.IPv4Mask).String()
|
||||
tmpDNS := make([]string, 0)
|
||||
for i := 0; i < len(client.Config.IPv4DNS); i++ {
|
||||
tmpDNS = append(tmpDNS, Uint32ToIP(client.Config.IPv4DNS[i]).String())
|
||||
}
|
||||
tun, err := tun.OpenTunDevice("tun0", tmpIP, tmpGW, tmpMask, tmpDNS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.Tun = network.New(tun, client.HandleTunIPv4)
|
||||
go client.Tun.Run()
|
||||
}
|
||||
|
||||
for i := 0; i < len(client.Config.Listen); i++ {
|
||||
go func(s ListenInfo) {
|
||||
listenAddr := s.LocalAddr + ":" + strconv.Itoa(int(s.Port))
|
||||
if client.ConnInfo != "" {
|
||||
client.ConnInfo += ","
|
||||
}
|
||||
client.ConnInfo += s.NetworkAddr + ":" + strconv.Itoa(int(s.Port))
|
||||
listener, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
log.Printf("Error during listening %s: %v\n", listenAddr, err)
|
||||
return
|
||||
}
|
||||
for client.running {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
err := client.HandleConn(conn, "")
|
||||
if err != nil && client.Config.Debug {
|
||||
log.Printf("Error during handling conn: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}(client.Config.Listen[i])
|
||||
}
|
||||
|
||||
for i := 0; i < len(client.Config.BootstrapNodes); i++ {
|
||||
go func(s string) {
|
||||
conn, err := net.Dial("tcp", s)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
err = client.HandleConn(conn, s)
|
||||
if err != nil && client.Config.Debug {
|
||||
log.Printf("Error during handling conn: %v\n", err)
|
||||
}
|
||||
}(client.Config.BootstrapNodes[i])
|
||||
}
|
||||
|
||||
go client.CheckClients()
|
||||
go client.CheckConn()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) Stop() {
|
||||
client.running = false
|
||||
log.Println("Shutting down...")
|
||||
// close connections
|
||||
}
|
||||
225
clients.go
Normal file
225
clients.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type LatencyPair struct {
|
||||
Id uint32
|
||||
Latency uint32
|
||||
}
|
||||
|
||||
type pair struct {
|
||||
a uint32
|
||||
b uint32
|
||||
}
|
||||
|
||||
type OtherClient struct {
|
||||
Id uint32
|
||||
IPv4 uint32
|
||||
Latency map[uint32]uint32 // connId->latency
|
||||
LatencyAdd map[uint32]uint32
|
||||
LatencySorted []LatencyPair
|
||||
LatencyMutex *sync.RWMutex
|
||||
ConnInfo []string
|
||||
ConnInfoStr string
|
||||
}
|
||||
|
||||
type OtherClients struct {
|
||||
C map[uint32]*OtherClient
|
||||
M *sync.RWMutex
|
||||
IPv4Map map[uint32]uint32
|
||||
BannedClients *TmpKeySet
|
||||
}
|
||||
|
||||
type pairArray []pair
|
||||
|
||||
func (s pairArray) Len() int {
|
||||
return len(s)
|
||||
}
|
||||
|
||||
func (s pairArray) Less(i, j int) bool {
|
||||
return s[i].a < s[j].a
|
||||
}
|
||||
|
||||
func (s pairArray) Swap(i, j int) {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
|
||||
func (oc *OtherClient) SortLatency() {
|
||||
oc.LatencyMutex.RLock()
|
||||
var tmp pairArray
|
||||
for vid, la := range oc.Latency {
|
||||
tmp = append(tmp, pair{a: la, b: vid})
|
||||
}
|
||||
oc.LatencyMutex.RUnlock()
|
||||
sort.Sort(tmp)
|
||||
var tmp2 []LatencyPair
|
||||
for i := 0; i < len(tmp); i++ {
|
||||
tmp2 = append(tmp2, LatencyPair{Id: tmp[i].b, Latency: tmp[i].a})
|
||||
}
|
||||
oc.LatencyMutex.Lock()
|
||||
oc.LatencySorted = tmp2
|
||||
oc.LatencyMutex.Unlock()
|
||||
}
|
||||
|
||||
func (oc *OtherClient) GetLatency() uint32 {
|
||||
oc.LatencyMutex.RLock()
|
||||
defer oc.LatencyMutex.RUnlock()
|
||||
if len(oc.Latency) == 0 {
|
||||
return 100000000
|
||||
}
|
||||
return oc.LatencySorted[0].Latency
|
||||
}
|
||||
|
||||
func (c *OtherClients) Init(checkInterval int) {
|
||||
c.C = make(map[uint32]*OtherClient)
|
||||
c.M = &sync.RWMutex{}
|
||||
c.IPv4Map = make(map[uint32]uint32)
|
||||
c.BannedClients = NewTmpKeySet(uint(checkInterval * 10))
|
||||
}
|
||||
|
||||
func (c *OtherClients) IPv4ToId(ipv4 uint32) uint32 {
|
||||
c.M.RLock()
|
||||
defer c.M.RUnlock()
|
||||
res, ok := c.IPv4Map[ipv4]
|
||||
if ok {
|
||||
return res
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *OtherClients) UpdateClient(id, ipv4, viaId, latency, latencyAdd uint32) {
|
||||
c.M.Lock()
|
||||
oc, ok := c.C[id]
|
||||
if !ok {
|
||||
oc = &OtherClient{Id: id, IPv4: ipv4, Latency: make(map[uint32]uint32), LatencyAdd: make(map[uint32]uint32), LatencyMutex: &sync.RWMutex{}}
|
||||
c.C[id] = oc
|
||||
if ipv4 != 0 {
|
||||
c.IPv4Map[ipv4] = id
|
||||
}
|
||||
}
|
||||
c.M.Unlock()
|
||||
if oc.IPv4 == 0 && ipv4 != 0 {
|
||||
oc.IPv4 = ipv4
|
||||
c.M.Lock()
|
||||
c.IPv4Map[ipv4] = id
|
||||
c.M.Unlock()
|
||||
}
|
||||
if viaId == 0 {
|
||||
return
|
||||
}
|
||||
oc.LatencyMutex.Lock()
|
||||
if latency == 0 {
|
||||
delete(oc.Latency, viaId)
|
||||
delete(oc.LatencyAdd, viaId)
|
||||
} else if !c.BannedClients.HasKey(id) {
|
||||
oc.Latency[viaId] = latency
|
||||
oc.LatencyAdd[viaId] = latencyAdd
|
||||
}
|
||||
oc.LatencyMutex.Unlock()
|
||||
oc.SortLatency()
|
||||
}
|
||||
|
||||
func (c *OtherClients) UpdateClientConnInfo(id uint32, connInfo string) {
|
||||
if c.BannedClients.HasKey(id) {
|
||||
return
|
||||
}
|
||||
c.M.Lock()
|
||||
oc, ok := c.C[id]
|
||||
if !ok {
|
||||
oc = &OtherClient{Id: id, IPv4: 0, Latency: make(map[uint32]uint32), LatencyAdd: make(map[uint32]uint32), LatencyMutex: &sync.RWMutex{}}
|
||||
c.C[id] = oc
|
||||
}
|
||||
c.M.Unlock()
|
||||
if connInfo == "" {
|
||||
oc.ConnInfo = make([]string, 0)
|
||||
} else {
|
||||
oc.ConnInfo = strings.Split(connInfo, ",")
|
||||
}
|
||||
oc.ConnInfoStr = connInfo
|
||||
}
|
||||
|
||||
func (c *OtherClients) UpdateClients(pkt PacketClients, viaId, addLatency uint32) {
|
||||
tmp := make(map[uint32]struct{})
|
||||
for i := 0; i < int(pkt.Count); i++ {
|
||||
if pkt.ConnInfo[i].Ok {
|
||||
c.UpdateClientConnInfo(pkt.Ids[i], pkt.ConnInfo[i].Val)
|
||||
}
|
||||
tid := viaId
|
||||
tlat := pkt.Latency[i].Val + addLatency
|
||||
if !pkt.Latency[i].Ok {
|
||||
tid = 0
|
||||
tlat = 0
|
||||
}
|
||||
c.UpdateClient(pkt.Ids[i], pkt.IPv4[i].Val, tid, tlat, addLatency)
|
||||
tmp[pkt.Ids[i]] = struct{}{}
|
||||
}
|
||||
c.M.Lock()
|
||||
ocs := make([]*OtherClient, 0)
|
||||
for id, oc := range c.C {
|
||||
_, ok := tmp[id]
|
||||
if !ok {
|
||||
ocs = append(ocs, oc)
|
||||
}
|
||||
}
|
||||
c.M.Unlock()
|
||||
var dels []uint32
|
||||
for i := 0; i < len(ocs); i++ {
|
||||
oc := ocs[i]
|
||||
oc.LatencyMutex.Lock()
|
||||
_, ok := oc.Latency[viaId]
|
||||
if !ok {
|
||||
oc.LatencyMutex.Unlock()
|
||||
} else {
|
||||
delete(oc.Latency, viaId)
|
||||
oc.LatencyMutex.Unlock()
|
||||
oc.SortLatency()
|
||||
if len(oc.Latency) == 0 {
|
||||
dels = append(dels, oc.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.M.Lock()
|
||||
for i := 0; i < len(dels); i++ {
|
||||
oc, ok := c.C[dels[i]]
|
||||
if ok {
|
||||
_, ok = c.IPv4Map[oc.IPv4]
|
||||
if ok {
|
||||
delete(c.IPv4Map, oc.IPv4)
|
||||
}
|
||||
delete(c.C, dels[i])
|
||||
}
|
||||
}
|
||||
c.M.Unlock()
|
||||
}
|
||||
|
||||
func (c *OtherClients) UpdateLatencyAdd(viaId, addLatency uint32) {
|
||||
c.M.Lock()
|
||||
ocs := make([]*OtherClient, 0)
|
||||
for _, oc := range c.C {
|
||||
ocs = append(ocs, oc)
|
||||
}
|
||||
c.M.Unlock()
|
||||
for i := 0; i < len(ocs); i++ {
|
||||
oc := ocs[i]
|
||||
oc.LatencyMutex.Lock()
|
||||
_, ok := oc.Latency[viaId]
|
||||
if !ok {
|
||||
oc.LatencyMutex.Unlock()
|
||||
} else {
|
||||
oc.Latency[viaId] = oc.Latency[viaId] - oc.LatencyAdd[viaId] + addLatency
|
||||
oc.LatencyAdd[viaId] = addLatency
|
||||
oc.LatencyMutex.Unlock()
|
||||
oc.SortLatency()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OtherClients) ClearClients(viaId uint32) {
|
||||
var pkt PacketClients
|
||||
pkt.Count = 0
|
||||
c.UpdateClients(pkt, viaId, 0)
|
||||
}
|
||||
252
clients_packet.go
Normal file
252
clients_packet.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type PacketClientsString struct {
|
||||
Val string
|
||||
Ok bool
|
||||
}
|
||||
|
||||
type PacketClientsStringArr []PacketClientsString
|
||||
|
||||
type PacketClientsUint32 struct {
|
||||
Val uint32
|
||||
Ok bool
|
||||
}
|
||||
|
||||
type PacketClientsUint32Arr []PacketClientsUint32
|
||||
|
||||
func (a PacketClientsStringArr) SetOk(n int) {
|
||||
a[n].Ok = true
|
||||
}
|
||||
|
||||
func (a PacketClientsStringArr) GetOk(n int) bool {
|
||||
return a[n].Ok
|
||||
}
|
||||
|
||||
func (a PacketClientsUint32Arr) SetOk(n int) {
|
||||
a[n].Ok = true
|
||||
}
|
||||
|
||||
func (a PacketClientsUint32Arr) GetOk(n int) bool {
|
||||
return a[n].Ok
|
||||
}
|
||||
|
||||
type PacketClientsOkTypeArr interface {
|
||||
SetOk(int)
|
||||
GetOk(int) bool
|
||||
}
|
||||
|
||||
type PacketClients struct {
|
||||
Count uint64
|
||||
Ids []uint32
|
||||
ConnInfo PacketClientsStringArr
|
||||
IPv4 PacketClientsUint32Arr
|
||||
Latency PacketClientsUint32Arr
|
||||
}
|
||||
|
||||
func DecodePacketClients(buf []byte) (PacketClients, error) {
|
||||
var pkt PacketClients
|
||||
var n int
|
||||
pkt.Count, n = binary.Uvarint(buf)
|
||||
if n > 0 {
|
||||
buf = buf[n:]
|
||||
} else if n == 0 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
} else {
|
||||
return pkt, fmt.Errorf("value larger than 64 bits")
|
||||
}
|
||||
n = int(pkt.Count)
|
||||
pkt.Ids = make([]uint32, n)
|
||||
pkt.ConnInfo = make([]PacketClientsString, n)
|
||||
pkt.IPv4 = make([]PacketClientsUint32, n)
|
||||
pkt.Latency = make([]PacketClientsUint32, n)
|
||||
if len(buf) < n*4 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
pkt.Ids[i] = binary.LittleEndian.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
}
|
||||
|
||||
n8 := (n + 7) / 8
|
||||
readOkType := func(s PacketClientsOkTypeArr) error {
|
||||
if len(buf) < n8 {
|
||||
return fmt.Errorf("incomplete packet")
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
if (int(buf[i>>3]) >> (i & 7) & 1) == 1 {
|
||||
s.SetOk(i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
readString := func() (string, error) {
|
||||
tlen, nt := binary.Uvarint(buf)
|
||||
if nt > 0 {
|
||||
buf = buf[nt:]
|
||||
} else if nt == 0 {
|
||||
return "", fmt.Errorf("incomplete packet")
|
||||
} else {
|
||||
return "", fmt.Errorf("value larger than 64 bits")
|
||||
}
|
||||
if len(buf) < int(tlen) {
|
||||
return "", fmt.Errorf("incomplete packet")
|
||||
}
|
||||
defer func() { buf = buf[tlen:] }()
|
||||
return string(buf[:tlen]), nil
|
||||
}
|
||||
|
||||
readOkType(pkt.ConnInfo)
|
||||
buf = buf[n8:]
|
||||
var err error
|
||||
for i := 0; i < n; i++ {
|
||||
if pkt.ConnInfo[i].Ok {
|
||||
pkt.ConnInfo[i].Val, err = readString()
|
||||
if err != nil {
|
||||
return pkt, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
readOkType(pkt.IPv4)
|
||||
buf = buf[n8:]
|
||||
for i := 0; i < n; i++ {
|
||||
if pkt.IPv4[i].Ok {
|
||||
if len(buf) < 4 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
}
|
||||
pkt.IPv4[i].Val = binary.LittleEndian.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
}
|
||||
}
|
||||
|
||||
readOkType(pkt.Latency)
|
||||
buf = buf[n8:]
|
||||
for i := 0; i < n; i++ {
|
||||
if pkt.Latency[i].Ok {
|
||||
if len(buf) < 4 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
}
|
||||
pkt.Latency[i].Val = binary.LittleEndian.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
}
|
||||
}
|
||||
if len(buf) != 0 {
|
||||
return pkt, fmt.Errorf("extra data after packet")
|
||||
}
|
||||
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func appendUvarint(s []byte, x uint64) []byte {
|
||||
n := len(s)
|
||||
s = append(s, make([]byte, binary.MaxVarintLen64)...)
|
||||
n += binary.PutUvarint(s[n:], x)
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
func EncodePacketClients(pkt PacketClients) []byte {
|
||||
var buf []byte
|
||||
buf = appendUvarint(buf, pkt.Count)
|
||||
n := int(pkt.Count)
|
||||
tmpU32 := make([]byte, 4)
|
||||
for i := 0; i < n; i++ {
|
||||
binary.LittleEndian.PutUint32(tmpU32, pkt.Ids[i])
|
||||
buf = append(buf, tmpU32...)
|
||||
}
|
||||
|
||||
n8 := (n + 7) / 8
|
||||
writeOkType := func(s PacketClientsOkTypeArr) {
|
||||
tmp := make([]byte, n8)
|
||||
for i := 0; i < n; i++ {
|
||||
if s.GetOk(i) {
|
||||
tmp[i>>3] = tmp[i>>3] | byte(1<<(i&7))
|
||||
}
|
||||
}
|
||||
buf = append(buf, tmp...)
|
||||
}
|
||||
|
||||
writeOkType(pkt.ConnInfo)
|
||||
for i := 0; i < n; i++ {
|
||||
if pkt.ConnInfo[i].Ok {
|
||||
buf = appendUvarint(buf, uint64(len(pkt.ConnInfo[i].Val)))
|
||||
buf = append(buf, pkt.ConnInfo[i].Val...)
|
||||
}
|
||||
}
|
||||
|
||||
writeOkType(pkt.IPv4)
|
||||
for i := 0; i < n; i++ {
|
||||
if pkt.IPv4[i].Ok {
|
||||
binary.LittleEndian.PutUint32(tmpU32, pkt.IPv4[i].Val)
|
||||
buf = append(buf, tmpU32...)
|
||||
}
|
||||
}
|
||||
|
||||
writeOkType(pkt.Latency)
|
||||
for i := 0; i < n; i++ {
|
||||
if pkt.Latency[i].Ok {
|
||||
binary.LittleEndian.PutUint32(tmpU32, pkt.Latency[i].Val)
|
||||
buf = append(buf, tmpU32...)
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
type OtherClientInfo struct {
|
||||
ConnInfo string
|
||||
IPv4 uint32
|
||||
Latency uint32
|
||||
}
|
||||
|
||||
type OtherClientsInfo map[uint32]OtherClientInfo
|
||||
|
||||
func GenDiffPacketClients(cur, old OtherClientsInfo) (PacketClients, bool) {
|
||||
var pkt PacketClients
|
||||
n := len(cur)
|
||||
pkt.Count = uint64(n)
|
||||
pkt.Ids = make([]uint32, n)
|
||||
pkt.ConnInfo = make([]PacketClientsString, n)
|
||||
pkt.IPv4 = make([]PacketClientsUint32, n)
|
||||
pkt.Latency = make([]PacketClientsUint32, n)
|
||||
useful := len(cur) != len(old)
|
||||
i := 0
|
||||
for id, oc := range cur {
|
||||
pkt.Ids[i] = id
|
||||
oc2, ok := old[id]
|
||||
useful = useful || !ok
|
||||
if ok {
|
||||
if oc.ConnInfo != oc2.ConnInfo && oc.ConnInfo != "" {
|
||||
pkt.ConnInfo[i].Ok = true
|
||||
pkt.ConnInfo[i].Val = oc.ConnInfo
|
||||
}
|
||||
if oc.IPv4 != oc2.IPv4 && oc.IPv4 != 0 {
|
||||
pkt.IPv4[i].Ok = true
|
||||
pkt.IPv4[i].Val = oc.IPv4
|
||||
}
|
||||
if oc.Latency != oc2.Latency {
|
||||
pkt.Latency[i].Ok = true
|
||||
pkt.Latency[i].Val = oc.Latency
|
||||
}
|
||||
} else {
|
||||
if oc.ConnInfo != "" {
|
||||
pkt.ConnInfo[i].Ok = true
|
||||
pkt.ConnInfo[i].Val = oc.ConnInfo
|
||||
}
|
||||
if oc.IPv4 != 0 {
|
||||
pkt.IPv4[i].Ok = true
|
||||
pkt.IPv4[i].Val = oc.IPv4
|
||||
}
|
||||
pkt.Latency[i].Ok = true
|
||||
pkt.Latency[i].Val = oc.Latency
|
||||
}
|
||||
useful = useful || pkt.ConnInfo[i].Ok || pkt.IPv4[i].Ok || pkt.Latency[i].Ok
|
||||
i += 1
|
||||
}
|
||||
return pkt, useful
|
||||
}
|
||||
62
config.go
Normal file
62
config.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Id uint32
|
||||
Key string
|
||||
IPv4 string
|
||||
IPv4Gateway string
|
||||
IPv4Mask string
|
||||
IPv4DNS string
|
||||
Listen []ListenInfo
|
||||
BootstrapNodes []string
|
||||
CheckClientsInterval int
|
||||
PingInterval int
|
||||
CheckConnInterval int
|
||||
Debug bool
|
||||
ConnectNewPeer bool
|
||||
StartTun bool
|
||||
}
|
||||
|
||||
func UnmarshalConfig(s []byte) (ClientConfig, error) {
|
||||
var c Config
|
||||
var r ClientConfig
|
||||
err := json.Unmarshal(s, &c)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
r.Id = c.Id
|
||||
if c.Id == 0 {
|
||||
return r, fmt.Errorf("Id cannot be zero")
|
||||
}
|
||||
thash := sha256.Sum256([]byte(c.Key))
|
||||
r.Key = thash[:16]
|
||||
|
||||
r.IPv4 = IPToUint32(net.ParseIP(c.IPv4))
|
||||
if c.IPv4 != "" {
|
||||
r.IPv4Gateway = IPToUint32(net.ParseIP(c.IPv4Gateway))
|
||||
r.IPv4Mask = IPToUint32(net.ParseIP(c.IPv4Mask))
|
||||
tdns := strings.Split(c.IPv4DNS, ",")
|
||||
r.IPv4DNS = make([]uint32, 0)
|
||||
for i := 0; i < len(tdns); i++ {
|
||||
r.IPv4DNS = append(r.IPv4DNS, IPToUint32(net.ParseIP(tdns[i])))
|
||||
}
|
||||
}
|
||||
r.Listen = c.Listen
|
||||
r.BootstrapNodes = c.BootstrapNodes
|
||||
r.CheckClientsInterval = c.CheckClientsInterval
|
||||
r.PingInterval = c.PingInterval
|
||||
r.CheckConnInterval = c.CheckConnInterval
|
||||
r.Debug = c.Debug
|
||||
r.ConnectNewPeer = c.ConnectNewPeer
|
||||
r.StartTun = c.StartTun
|
||||
return r, nil
|
||||
}
|
||||
136
far_ping.go
Normal file
136
far_ping.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
FarPingRequest uint32 = 0
|
||||
FarPingReply uint32 = 1
|
||||
FarPingCircuit uint32 = 2
|
||||
)
|
||||
|
||||
type PacketFarPing struct {
|
||||
SrcNode uint32
|
||||
DstNode uint32
|
||||
Status uint32
|
||||
HistoryNodes []uint32
|
||||
}
|
||||
|
||||
func DecodePacketFarPing(buf []byte) (PacketFarPing, error) {
|
||||
var pkt PacketFarPing
|
||||
if len(buf) < 12 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
}
|
||||
pkt.SrcNode = binary.LittleEndian.Uint32(buf[:4])
|
||||
pkt.DstNode = binary.LittleEndian.Uint32(buf[4:8])
|
||||
pkt.Status = binary.LittleEndian.Uint32(buf[8:12])
|
||||
buf = buf[12:]
|
||||
|
||||
nh, n := binary.Uvarint(buf)
|
||||
if n > 0 {
|
||||
buf = buf[n:]
|
||||
} else if n == 0 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
} else {
|
||||
return pkt, fmt.Errorf("value larger than 64 bits")
|
||||
}
|
||||
|
||||
if len(buf) < 4*int(nh) {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
}
|
||||
pkt.HistoryNodes = make([]uint32, int(nh))
|
||||
for i := 0; i < int(nh); i++ {
|
||||
pkt.HistoryNodes[i] = binary.LittleEndian.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
}
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func EncodePacketFarPing(pkt PacketFarPing) []byte {
|
||||
buf := make([]byte, 12)
|
||||
binary.LittleEndian.PutUint32(buf[:4], pkt.SrcNode)
|
||||
binary.LittleEndian.PutUint32(buf[4:8], pkt.DstNode)
|
||||
binary.LittleEndian.PutUint32(buf[8:12], pkt.Status)
|
||||
buf = appendUvarint(buf, uint64(len(pkt.HistoryNodes)))
|
||||
tmpU32 := make([]byte, 4)
|
||||
for i := 0; i < len(pkt.HistoryNodes); i++ {
|
||||
binary.LittleEndian.PutUint32(tmpU32, pkt.HistoryNodes[i])
|
||||
buf = append(buf, tmpU32...)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func (client *Client) HandlePacketFarPing(pkt PacketFarPing) error {
|
||||
if pkt.DstNode == client.Id {
|
||||
if pkt.Status == FarPingReply {
|
||||
return nil
|
||||
}
|
||||
if pkt.Status == FarPingRequest {
|
||||
var reply PacketFarPing
|
||||
reply.DstNode = pkt.SrcNode
|
||||
reply.SrcNode = client.Id
|
||||
reply.Status = FarPingReply
|
||||
reply.HistoryNodes = make([]uint32, 0)
|
||||
return client.HandlePacketFarPing(reply)
|
||||
}
|
||||
if pkt.Status == FarPingCircuit {
|
||||
client.OClients.BannedClients.AddKey(pkt.SrcNode)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
client.OClients.M.RLock()
|
||||
oc, ok := client.OClients.C[pkt.DstNode]
|
||||
if !ok {
|
||||
client.OClients.M.RUnlock()
|
||||
return fmt.Errorf("target not reachable")
|
||||
}
|
||||
client.OClients.M.RUnlock()
|
||||
oc.LatencyMutex.RLock()
|
||||
n := len(oc.LatencySorted)
|
||||
if n == 0 {
|
||||
oc.LatencyMutex.RUnlock()
|
||||
return fmt.Errorf("target not reachable")
|
||||
}
|
||||
nh := len(pkt.HistoryNodes)
|
||||
var nextId uint32
|
||||
for i := 0; i < n; i++ {
|
||||
tid := oc.LatencySorted[i].Id
|
||||
vis := false
|
||||
for j := 0; j < nh; j++ {
|
||||
if tid == pkt.HistoryNodes[j] {
|
||||
vis = true
|
||||
}
|
||||
}
|
||||
if !vis {
|
||||
nextId = tid
|
||||
break
|
||||
}
|
||||
}
|
||||
if nextId == 0 {
|
||||
var reply PacketFarPing
|
||||
reply.DstNode = pkt.SrcNode
|
||||
reply.SrcNode = pkt.DstNode
|
||||
//reply.SrcNode = client.Id
|
||||
reply.Status = FarPingCircuit
|
||||
reply.HistoryNodes = make([]uint32, 0)
|
||||
oc.LatencyMutex.RUnlock()
|
||||
return client.HandlePacketFarPing(reply)
|
||||
}
|
||||
if nextId == client.Id {
|
||||
nextId = pkt.DstNode
|
||||
}
|
||||
oc.LatencyMutex.RUnlock()
|
||||
|
||||
client.PeerMutex.RLock()
|
||||
peer, ok := client.Peers[nextId]
|
||||
client.PeerMutex.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("target conn closed unexpectedly")
|
||||
}
|
||||
pkt.HistoryNodes = append(pkt.HistoryNodes, client.Id)
|
||||
return peer.SendPacket(GrassPacketFarPing, EncodePacketFarPing(pkt), client)
|
||||
}
|
||||
179
ipv4.go
Normal file
179
ipv4.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/mcfx0/grass/network"
|
||||
"github.com/mcfx0/grass/network/packet"
|
||||
)
|
||||
|
||||
func Uint32ToIP(intIP uint32) net.IP {
|
||||
bytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(bytes, intIP)
|
||||
return net.IP(bytes)
|
||||
}
|
||||
|
||||
func IPToUint32(ip net.IP) uint32 {
|
||||
return binary.BigEndian.Uint32([]byte(ip.To4()))
|
||||
}
|
||||
|
||||
type PacketIPv4 struct {
|
||||
TargetNode uint32
|
||||
TTL uint8
|
||||
HistoryNodes []uint32
|
||||
Data []byte
|
||||
IPPacket *packet.IPv4
|
||||
}
|
||||
|
||||
func (pkt *PacketIPv4) getIPPacket() error {
|
||||
pkt.IPPacket = packet.NewIPv4()
|
||||
return packet.ParseIPv4(pkt.Data, pkt.IPPacket)
|
||||
}
|
||||
|
||||
func DecodePacketIPv4(buf []byte) (PacketIPv4, error) {
|
||||
var pkt PacketIPv4
|
||||
if len(buf) < 5 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
}
|
||||
pkt.TargetNode = binary.LittleEndian.Uint32(buf[:4])
|
||||
pkt.TTL = uint8(buf[4])
|
||||
buf = buf[5:]
|
||||
|
||||
nh, n := binary.Uvarint(buf)
|
||||
if n > 0 {
|
||||
buf = buf[n:]
|
||||
} else if n == 0 {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
} else {
|
||||
return pkt, fmt.Errorf("value larger than 64 bits")
|
||||
}
|
||||
|
||||
if len(buf) < 4*int(nh) {
|
||||
return pkt, fmt.Errorf("incomplete packet")
|
||||
}
|
||||
pkt.HistoryNodes = make([]uint32, int(nh))
|
||||
for i := 0; i < int(nh); i++ {
|
||||
pkt.HistoryNodes[i] = binary.LittleEndian.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
}
|
||||
pkt.Data = buf
|
||||
pkt.IPPacket = nil
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func EncodePacketIPv4(pkt PacketIPv4) []byte {
|
||||
buf := make([]byte, 5)
|
||||
binary.LittleEndian.PutUint32(buf, pkt.TargetNode)
|
||||
buf[4] = byte(pkt.TTL)
|
||||
buf = appendUvarint(buf, uint64(len(pkt.HistoryNodes)))
|
||||
tmpU32 := make([]byte, 4)
|
||||
for i := 0; i < len(pkt.HistoryNodes); i++ {
|
||||
binary.LittleEndian.PutUint32(tmpU32, pkt.HistoryNodes[i])
|
||||
buf = append(buf, tmpU32...)
|
||||
}
|
||||
buf = append(buf, pkt.Data...)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (client *Client) HandleFinalIPv4(pkt PacketIPv4) error {
|
||||
if err := pkt.getIPPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if pkt.IPPacket.Protocol == packet.IPProtocolICMPv4 {
|
||||
var icmp4src packet.ICMPv4
|
||||
if err := packet.ParseICMPv4(pkt.IPPacket.Payload, &icmp4src); err != nil {
|
||||
return err
|
||||
}
|
||||
if icmp4src.Type == 8 { // send icmp reply
|
||||
icmp4 := network.ResponseICMPv4Packet(Uint32ToIP(client.IPv4), pkt.IPPacket.SrcIP, 0, 0, icmp4src.Payload)
|
||||
return client.HandleTunIPv4(icmp4.Wire, icmp4.Ip)
|
||||
}
|
||||
}
|
||||
// TTL should be edited here, but I'm too lazy now
|
||||
client.Tun.WriteCh <- pkt.Data
|
||||
/*pkt.IPPacket.TTL = pkt.TTL
|
||||
packets := network.GenFragments(pkt.IPPacket, 0, pkt.IPPacket.Payload)
|
||||
for i := 0; i < len(packets); i++ {
|
||||
client.Tun.WriteCh <- packets[i]
|
||||
}*/
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) HandlePacketIPv4(pkt PacketIPv4) error {
|
||||
if pkt.TargetNode == client.Id {
|
||||
return client.HandleFinalIPv4(pkt)
|
||||
}
|
||||
if pkt.TTL == 0 {
|
||||
if err := pkt.getIPPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
payload := append([]byte{0, 0, 0, 0}, pkt.Data...)
|
||||
icmp4 := network.ResponseICMPv4Packet(Uint32ToIP(client.IPv4), pkt.IPPacket.SrcIP, 11, 0, payload)
|
||||
return client.HandleTunIPv4(icmp4.Wire, icmp4.Ip)
|
||||
}
|
||||
|
||||
client.OClients.M.RLock()
|
||||
oc, ok := client.OClients.C[pkt.TargetNode]
|
||||
if !ok {
|
||||
client.OClients.M.RUnlock()
|
||||
return fmt.Errorf("target not reachable")
|
||||
}
|
||||
client.OClients.M.RUnlock()
|
||||
oc.LatencyMutex.RLock()
|
||||
n := len(oc.LatencySorted)
|
||||
if n == 0 {
|
||||
oc.LatencyMutex.RUnlock()
|
||||
return fmt.Errorf("target not reachable")
|
||||
}
|
||||
nh := len(pkt.HistoryNodes)
|
||||
qa := make([]uint32, 0)
|
||||
qb := make([]uint32, 0)
|
||||
for i := 0; i < n; i++ {
|
||||
tid := oc.LatencySorted[i].Id
|
||||
vis := false
|
||||
for j := 0; j < nh; j++ {
|
||||
if tid == pkt.HistoryNodes[j] {
|
||||
vis = true
|
||||
}
|
||||
}
|
||||
if !vis {
|
||||
qa = append(qa, tid)
|
||||
} else {
|
||||
qb = append(qb, tid)
|
||||
}
|
||||
}
|
||||
oc.LatencyMutex.RUnlock()
|
||||
qa = append(qa, qb...)
|
||||
|
||||
pkt.TTL--
|
||||
pkt.HistoryNodes = append(pkt.HistoryNodes, client.Id)
|
||||
for mqlen := 2; mqlen <= 128; mqlen *= 2 {
|
||||
for i := 0; i < len(qa); i++ {
|
||||
client.PeerMutex.RLock()
|
||||
peer, ok := client.Peers[qa[i]]
|
||||
client.PeerMutex.RUnlock()
|
||||
//fmt.Printf("%d", peer.Id)
|
||||
if ok {
|
||||
if peer.SendPacketV2(GrassPacketIPv4, EncodePacketIPv4(pkt), client, mqlen) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to send out packet")
|
||||
}
|
||||
|
||||
func (client *Client) HandleTunIPv4(data []byte, pkt *packet.IPv4) error {
|
||||
var spkt PacketIPv4
|
||||
tmp := client.OClients.IPv4ToId(IPToUint32(pkt.DstIP))
|
||||
if tmp == 0 {
|
||||
return fmt.Errorf("target not reachable")
|
||||
}
|
||||
spkt.TargetNode = tmp
|
||||
spkt.TTL = pkt.TTL
|
||||
spkt.Data = data
|
||||
spkt.IPPacket = pkt
|
||||
return client.HandlePacketIPv4(spkt)
|
||||
}
|
||||
1
network/README.md
Normal file
1
network/README.md
Normal file
@@ -0,0 +1 @@
|
||||
Most files in this directory are copied from https://github.com/yinghuocho/gotun2socks, with slightly modifies.
|
||||
80
network/icmp.go
Normal file
80
network/icmp.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/mcfx0/grass/network/packet"
|
||||
)
|
||||
|
||||
type Icmp4Packet struct {
|
||||
Ip *packet.IPv4
|
||||
Icmp4 *packet.ICMPv4
|
||||
MtuBuf []byte
|
||||
Wire []byte
|
||||
}
|
||||
|
||||
var (
|
||||
Icmp4PacketPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &Icmp4Packet{}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func newICMPv4Packet() *Icmp4Packet {
|
||||
return Icmp4PacketPool.Get().(*Icmp4Packet)
|
||||
}
|
||||
|
||||
func releaseICMPv4Packet(pkt *Icmp4Packet) {
|
||||
packet.ReleaseIPv4(pkt.Ip)
|
||||
packet.ReleaseICMPv4(pkt.Icmp4)
|
||||
if pkt.MtuBuf != nil {
|
||||
releaseBuffer(pkt.MtuBuf)
|
||||
}
|
||||
pkt.MtuBuf = nil
|
||||
pkt.Wire = nil
|
||||
Icmp4PacketPool.Put(pkt)
|
||||
}
|
||||
|
||||
func ResponseICMPv4Packet(local net.IP, remote net.IP, Type uint8, Code uint8, respPayload []byte) *Icmp4Packet {
|
||||
ipid := packet.IPID()
|
||||
|
||||
ip := packet.NewIPv4()
|
||||
icmp4 := packet.NewICMPv4()
|
||||
|
||||
ip.Version = 4
|
||||
ip.Id = ipid
|
||||
ip.SrcIP = make(net.IP, len(local))
|
||||
copy(ip.SrcIP, local)
|
||||
ip.DstIP = make(net.IP, len(remote))
|
||||
copy(ip.DstIP, remote)
|
||||
ip.TTL = 64
|
||||
ip.Protocol = packet.IPProtocolICMPv4
|
||||
|
||||
icmp4.Type = Type
|
||||
icmp4.Code = Code
|
||||
icmp4.Payload = respPayload
|
||||
|
||||
pkt := newICMPv4Packet()
|
||||
pkt.Ip = ip
|
||||
pkt.Icmp4 = icmp4
|
||||
|
||||
pkt.MtuBuf = newBuffer()
|
||||
payloadL := len(icmp4.Payload)
|
||||
payloadStart := MTU - payloadL
|
||||
icmp4HL := 4
|
||||
icmp4Start := payloadStart - icmp4HL
|
||||
pseduoStart := icmp4Start - packet.IPv4_PSEUDO_LENGTH
|
||||
ip.PseudoHeader(pkt.MtuBuf[pseduoStart:icmp4Start], packet.IPProtocolICMPv4, icmp4HL+payloadL)
|
||||
icmp4.Serialize(pkt.MtuBuf[icmp4Start:payloadStart], pkt.MtuBuf[icmp4Start:payloadStart], icmp4.Payload)
|
||||
if payloadL != 0 {
|
||||
copy(pkt.MtuBuf[payloadStart:], icmp4.Payload)
|
||||
}
|
||||
ipHL := ip.HeaderLength()
|
||||
ipStart := icmp4Start - ipHL
|
||||
ip.Serialize(pkt.MtuBuf[ipStart:icmp4Start], icmp4HL+(MTU-payloadStart))
|
||||
pkt.Wire = pkt.MtuBuf[ipStart:]
|
||||
|
||||
return pkt
|
||||
}
|
||||
102
network/ip.go
Normal file
102
network/ip.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/mcfx0/grass/network/packet"
|
||||
)
|
||||
|
||||
type ipPacket struct {
|
||||
ip *packet.IPv4
|
||||
mtuBuf []byte
|
||||
wire []byte
|
||||
}
|
||||
|
||||
var (
|
||||
frags = make(map[uint16]*ipPacket)
|
||||
)
|
||||
|
||||
func procFragment(ip *packet.IPv4, raw []byte) (bool, *packet.IPv4, []byte) {
|
||||
exist, ok := frags[ip.Id]
|
||||
if !ok {
|
||||
if ip.Flags&0x1 == 0 {
|
||||
return false, nil, nil
|
||||
}
|
||||
// first
|
||||
//log.Printf("first fragment of IPID %d", ip.Id)
|
||||
dup := make([]byte, len(raw))
|
||||
copy(dup, raw)
|
||||
clone := &packet.IPv4{}
|
||||
packet.ParseIPv4(dup, clone)
|
||||
frags[ip.Id] = &ipPacket{
|
||||
ip: clone,
|
||||
wire: dup,
|
||||
}
|
||||
return false, clone, dup
|
||||
} else {
|
||||
exist.wire = append(exist.wire, ip.Payload...)
|
||||
packet.ParseIPv4(exist.wire, exist.ip)
|
||||
|
||||
last := false
|
||||
if ip.Flags&0x1 == 0 {
|
||||
//log.Printf("last fragment of IPID %d", ip.Id)
|
||||
last = true
|
||||
} else {
|
||||
//log.Printf("continue fragment of IPID %d", ip.Id)
|
||||
}
|
||||
|
||||
return last, exist.ip, exist.wire
|
||||
}
|
||||
}
|
||||
|
||||
func GenFragments(first *packet.IPv4, offset uint16, data []byte) []*ipPacket {
|
||||
var ret []*ipPacket
|
||||
for {
|
||||
frag := packet.NewIPv4()
|
||||
|
||||
frag.Version = 4
|
||||
frag.Id = first.Id
|
||||
frag.SrcIP = make(net.IP, len(first.SrcIP))
|
||||
copy(frag.SrcIP, first.SrcIP)
|
||||
frag.DstIP = make(net.IP, len(first.DstIP))
|
||||
copy(frag.DstIP, first.DstIP)
|
||||
frag.TTL = first.TTL
|
||||
frag.Protocol = first.Protocol
|
||||
frag.FragOffset = offset
|
||||
if len(data) <= MTU-20 {
|
||||
frag.Payload = data
|
||||
} else {
|
||||
frag.Flags = 1
|
||||
offset += (MTU - 20) / 8
|
||||
frag.Payload = data[:MTU-20]
|
||||
data = data[MTU-20:]
|
||||
}
|
||||
|
||||
pkt := &ipPacket{ip: frag}
|
||||
pkt.mtuBuf = newBuffer()
|
||||
|
||||
payloadL := len(frag.Payload)
|
||||
payloadStart := MTU - payloadL
|
||||
if payloadL != 0 {
|
||||
copy(pkt.mtuBuf[payloadStart:], frag.Payload)
|
||||
}
|
||||
ipHL := frag.HeaderLength()
|
||||
ipStart := payloadStart - ipHL
|
||||
frag.Serialize(pkt.mtuBuf[ipStart:payloadStart], payloadL)
|
||||
pkt.wire = pkt.mtuBuf[ipStart:]
|
||||
ret = append(ret, pkt)
|
||||
|
||||
if frag.Flags == 0 {
|
||||
return ret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func releaseIPPacket(pkt *ipPacket) {
|
||||
packet.ReleaseIPv4(pkt.ip)
|
||||
if pkt.mtuBuf != nil {
|
||||
releaseBuffer(pkt.mtuBuf)
|
||||
}
|
||||
pkt.mtuBuf = nil
|
||||
pkt.wire = nil
|
||||
}
|
||||
21
network/mtubuf.go
Normal file
21
network/mtubuf.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
bufPool *sync.Pool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, MTU)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func newBuffer() []byte {
|
||||
return bufPool.Get().([]byte)
|
||||
}
|
||||
|
||||
func releaseBuffer(buf []byte) {
|
||||
bufPool.Put(buf)
|
||||
}
|
||||
89
network/network.go
Normal file
89
network/network.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/mcfx0/grass/network/packet"
|
||||
)
|
||||
|
||||
const (
|
||||
MTU = 1500
|
||||
)
|
||||
|
||||
type TunHandler struct {
|
||||
dev io.ReadWriteCloser
|
||||
|
||||
writerStopCh chan bool
|
||||
WriteCh chan interface{}
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
handler func([]byte, *packet.IPv4) error
|
||||
}
|
||||
|
||||
func New(dev io.ReadWriteCloser, handler func([]byte, *packet.IPv4) error) *TunHandler {
|
||||
th := &TunHandler{
|
||||
dev: dev,
|
||||
writerStopCh: make(chan bool, 10),
|
||||
WriteCh: make(chan interface{}, 10000),
|
||||
handler: handler,
|
||||
}
|
||||
return th
|
||||
}
|
||||
|
||||
func (th *TunHandler) Run() {
|
||||
// writer
|
||||
go func() {
|
||||
th.wg.Add(1)
|
||||
defer th.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case pkt := <-th.WriteCh:
|
||||
switch pkt.(type) {
|
||||
case *ipPacket:
|
||||
ip := pkt.(*ipPacket)
|
||||
th.dev.Write(ip.wire)
|
||||
releaseIPPacket(ip)
|
||||
case []byte:
|
||||
th.dev.Write(pkt.([]byte))
|
||||
}
|
||||
case <-th.writerStopCh:
|
||||
log.Printf("quit tun2handler writer")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// reader
|
||||
var buf [MTU]byte
|
||||
var ip packet.IPv4
|
||||
|
||||
th.wg.Add(1)
|
||||
defer th.wg.Done()
|
||||
for {
|
||||
n, e := th.dev.Read(buf[:])
|
||||
if e != nil {
|
||||
// TODO: stop at critical error
|
||||
log.Printf("read packet error: %s", e)
|
||||
return
|
||||
}
|
||||
data := buf[:n]
|
||||
//log.Println(data)
|
||||
e = packet.ParseIPv4(data, &ip)
|
||||
if e != nil {
|
||||
log.Printf("error to parse IPv4: %s", e)
|
||||
continue
|
||||
}
|
||||
|
||||
//go th.handler(data, &ip)
|
||||
th.handler(data, &ip)
|
||||
/*go func() {
|
||||
err := th.handler(data, &ip)
|
||||
if err != nil {
|
||||
log.Printf("ipv4 error: %v\n", err)
|
||||
}
|
||||
}()*/
|
||||
}
|
||||
}
|
||||
39
network/packet/common.go
Normal file
39
network/packet/common.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package packet
|
||||
|
||||
var lotsOfZeros [1024]byte
|
||||
|
||||
//func Checksum(data []byte) uint16 {
|
||||
// var csum uint32
|
||||
// length := len(data) - 1
|
||||
// for i := 0; i < length; i += 2 {
|
||||
// csum += uint32(data[i]) << 8
|
||||
// csum += uint32(data[i+1])
|
||||
// }
|
||||
// if len(data)%2 == 1 {
|
||||
// csum += uint32(data[length]) << 8
|
||||
// }
|
||||
// for csum > 0xffff {
|
||||
// csum = (csum >> 16) + (csum & 0xffff)
|
||||
// }
|
||||
// return ^uint16(csum + (csum >> 16))
|
||||
//}
|
||||
|
||||
func Checksum(fields ...[]byte) uint16 {
|
||||
var csum uint32
|
||||
for _, field := range fields {
|
||||
length := len(field) - 1
|
||||
for i := 0; i < length; i += 2 {
|
||||
csum += uint32(field[i]) << 8
|
||||
csum += uint32(field[i+1])
|
||||
}
|
||||
if len(field)%2 == 1 {
|
||||
// only last field may have odd number of bytes
|
||||
csum += uint32(field[length]) << 8
|
||||
}
|
||||
}
|
||||
|
||||
for csum > 0xffff {
|
||||
csum = (csum >> 16) + (csum & 0xffff)
|
||||
}
|
||||
return ^uint16(csum + (csum >> 16))
|
||||
}
|
||||
65
network/packet/icmp.go
Normal file
65
network/packet/icmp.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ICMPv4 struct {
|
||||
Type uint8
|
||||
Code uint8
|
||||
Checksum uint16
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
var (
|
||||
icmp4Pool *sync.Pool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &ICMPv4{}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func NewICMPv4() *ICMPv4 {
|
||||
var zero ICMPv4
|
||||
icmp4 := icmp4Pool.Get().(*ICMPv4)
|
||||
*icmp4 = zero
|
||||
return icmp4
|
||||
}
|
||||
|
||||
func ReleaseICMPv4(icmp4 *ICMPv4) {
|
||||
// clear internal slice references
|
||||
icmp4.Payload = nil
|
||||
icmp4Pool.Put(icmp4)
|
||||
}
|
||||
|
||||
func ParseICMPv4(pkt []byte, icmp4 *ICMPv4) error {
|
||||
if len(pkt) < 4 {
|
||||
return fmt.Errorf("payload too small for ICMPv4: %d bytes", len(pkt))
|
||||
}
|
||||
|
||||
icmp4.Type = uint8(pkt[0])
|
||||
icmp4.Code = uint8(pkt[1])
|
||||
icmp4.Checksum = binary.BigEndian.Uint16(pkt[2:4])
|
||||
if len(pkt) > 4 {
|
||||
icmp4.Payload = pkt[4:]
|
||||
} else {
|
||||
icmp4.Payload = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (icmp4 *ICMPv4) Serialize(hdr []byte, ckFields ...[]byte) error {
|
||||
if len(hdr) != 4 {
|
||||
return fmt.Errorf("incorrect buffer size: %d buffer given, 4 needed", len(hdr))
|
||||
}
|
||||
hdr[0] = byte(icmp4.Type)
|
||||
hdr[1] = byte(icmp4.Code)
|
||||
hdr[2] = 0
|
||||
hdr[3] = 0
|
||||
icmp4.Checksum = Checksum(ckFields...)
|
||||
binary.BigEndian.PutUint16(hdr[2:], icmp4.Checksum)
|
||||
return nil
|
||||
}
|
||||
246
network/packet/ip4.go
Normal file
246
network/packet/ip4.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type IPv4Option struct {
|
||||
OptionType uint8
|
||||
OptionLength uint8
|
||||
OptionData []byte
|
||||
}
|
||||
|
||||
// IPProtocol is an enumeration of IP protocol values, and acts as a decoder
|
||||
// for any type it supports.
|
||||
type IPProtocol uint8
|
||||
|
||||
const (
|
||||
IPProtocolIPv6HopByHop IPProtocol = 0
|
||||
IPProtocolICMPv4 IPProtocol = 1
|
||||
IPProtocolIGMP IPProtocol = 2
|
||||
IPProtocolIPv4 IPProtocol = 4
|
||||
IPProtocolTCP IPProtocol = 6
|
||||
IPProtocolUDP IPProtocol = 17
|
||||
IPProtocolRUDP IPProtocol = 27
|
||||
IPProtocolIPv6 IPProtocol = 41
|
||||
IPProtocolIPv6Routing IPProtocol = 43
|
||||
IPProtocolIPv6Fragment IPProtocol = 44
|
||||
IPProtocolGRE IPProtocol = 47
|
||||
IPProtocolESP IPProtocol = 50
|
||||
IPProtocolAH IPProtocol = 51
|
||||
IPProtocolICMPv6 IPProtocol = 58
|
||||
IPProtocolNoNextHeader IPProtocol = 59
|
||||
IPProtocolIPv6Destination IPProtocol = 60
|
||||
IPProtocolIPIP IPProtocol = 94
|
||||
IPProtocolEtherIP IPProtocol = 97
|
||||
IPProtocolSCTP IPProtocol = 132
|
||||
IPProtocolUDPLite IPProtocol = 136
|
||||
IPProtocolMPLSInIP IPProtocol = 137
|
||||
|
||||
IPv4_PSEUDO_LENGTH int = 12
|
||||
)
|
||||
|
||||
type IPv4 struct {
|
||||
Version uint8
|
||||
IHL uint8
|
||||
TOS uint8
|
||||
Length uint16
|
||||
Id uint16
|
||||
Flags uint8
|
||||
FragOffset uint16
|
||||
TTL uint8
|
||||
Protocol IPProtocol
|
||||
Checksum uint16
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
Options []IPv4Option
|
||||
Padding []byte
|
||||
Payload []byte
|
||||
|
||||
headerLength int
|
||||
}
|
||||
|
||||
var (
|
||||
ipv4Pool *sync.Pool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &IPv4{}
|
||||
},
|
||||
}
|
||||
|
||||
globalIPID uint32
|
||||
)
|
||||
|
||||
func ReleaseIPv4(ip4 *IPv4) {
|
||||
// clear internal slice references
|
||||
ip4.SrcIP = nil
|
||||
ip4.DstIP = nil
|
||||
ip4.Options = nil
|
||||
ip4.Padding = nil
|
||||
ip4.Payload = nil
|
||||
|
||||
ipv4Pool.Put(ip4)
|
||||
}
|
||||
|
||||
func NewIPv4() *IPv4 {
|
||||
var zero IPv4
|
||||
ip4 := ipv4Pool.Get().(*IPv4)
|
||||
*ip4 = zero
|
||||
return ip4
|
||||
}
|
||||
|
||||
func IPID() uint16 {
|
||||
return uint16(atomic.AddUint32(&globalIPID, 1) & 0x0000ffff)
|
||||
}
|
||||
|
||||
func ParseIPv4(pkt []byte, ip4 *IPv4) error {
|
||||
flagsfrags := binary.BigEndian.Uint16(pkt[6:8])
|
||||
|
||||
ip4.Version = uint8(pkt[0]) >> 4
|
||||
ip4.IHL = uint8(pkt[0]) & 0x0F
|
||||
ip4.TOS = pkt[1]
|
||||
ip4.Length = binary.BigEndian.Uint16(pkt[2:4])
|
||||
ip4.Id = binary.BigEndian.Uint16(pkt[4:6])
|
||||
ip4.Flags = uint8(flagsfrags >> 13)
|
||||
ip4.FragOffset = flagsfrags & 0x1FFF
|
||||
ip4.TTL = pkt[8]
|
||||
ip4.Protocol = IPProtocol(pkt[9])
|
||||
ip4.Checksum = binary.BigEndian.Uint16(pkt[10:12])
|
||||
ip4.SrcIP = pkt[12:16]
|
||||
ip4.DstIP = pkt[16:20]
|
||||
|
||||
if ip4.Length < 20 {
|
||||
return fmt.Errorf("Invalid (too small) IP length (%d < 20)", ip4.Length)
|
||||
}
|
||||
if ip4.IHL < 5 {
|
||||
return fmt.Errorf("Invalid (too small) IP header length (%d < 5)", ip4.IHL)
|
||||
}
|
||||
if int(ip4.IHL*4) > int(ip4.Length) {
|
||||
return fmt.Errorf("Invalid IP header length > IP length (%d > %d)", ip4.IHL, ip4.Length)
|
||||
}
|
||||
if int(ip4.IHL)*4 > len(pkt) {
|
||||
return fmt.Errorf("Not all IP header bytes available")
|
||||
}
|
||||
ip4.Payload = pkt[ip4.IHL*4:]
|
||||
rest := pkt[20 : ip4.IHL*4]
|
||||
// Pull out IP options
|
||||
for len(rest) > 0 {
|
||||
if ip4.Options == nil {
|
||||
// Pre-allocate to avoid growing the slice too much.
|
||||
ip4.Options = make([]IPv4Option, 0, 4)
|
||||
}
|
||||
opt := IPv4Option{OptionType: rest[0]}
|
||||
switch opt.OptionType {
|
||||
case 0: // End of options
|
||||
opt.OptionLength = 1
|
||||
ip4.Options = append(ip4.Options, opt)
|
||||
ip4.Padding = rest[1:]
|
||||
break
|
||||
case 1: // 1 byte padding
|
||||
opt.OptionLength = 1
|
||||
default:
|
||||
opt.OptionLength = rest[1]
|
||||
opt.OptionData = rest[2:opt.OptionLength]
|
||||
}
|
||||
if len(rest) >= int(opt.OptionLength) {
|
||||
rest = rest[opt.OptionLength:]
|
||||
} else {
|
||||
return fmt.Errorf("IP option length exceeds remaining IP header size, option type %v length %v", opt.OptionType, opt.OptionLength)
|
||||
}
|
||||
ip4.Options = append(ip4.Options, opt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *IPv4) PseudoHeader(buf []byte, proto IPProtocol, dataLen int) error {
|
||||
if len(buf) != IPv4_PSEUDO_LENGTH {
|
||||
return fmt.Errorf("incorrect buffer size: %d buffer given, %d needed", len(buf), IPv4_PSEUDO_LENGTH)
|
||||
}
|
||||
copy(buf[0:4], ip.SrcIP)
|
||||
copy(buf[4:8], ip.DstIP)
|
||||
buf[8] = 0
|
||||
buf[9] = byte(proto)
|
||||
binary.BigEndian.PutUint16(buf[10:], uint16(dataLen))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *IPv4) HeaderLength() int {
|
||||
if ip.headerLength == 0 {
|
||||
optionLength := uint8(0)
|
||||
for _, opt := range ip.Options {
|
||||
switch opt.OptionType {
|
||||
case 0:
|
||||
// this is the end of option lists
|
||||
optionLength++
|
||||
case 1:
|
||||
// this is the padding
|
||||
optionLength++
|
||||
default:
|
||||
optionLength += opt.OptionLength
|
||||
|
||||
}
|
||||
}
|
||||
// make sure the options are aligned to 32 bit boundary
|
||||
if (optionLength % 4) != 0 {
|
||||
optionLength += 4 - (optionLength % 4)
|
||||
}
|
||||
ip.IHL = 5 + (optionLength / 4)
|
||||
ip.headerLength = int(optionLength) + 20
|
||||
}
|
||||
return ip.headerLength
|
||||
}
|
||||
|
||||
func (ip *IPv4) flagsfrags() (ff uint16) {
|
||||
ff |= uint16(ip.Flags) << 13
|
||||
ff |= ip.FragOffset
|
||||
return
|
||||
}
|
||||
|
||||
func (ip *IPv4) Serialize(hdr []byte, dataLen int) error {
|
||||
if len(hdr) != ip.HeaderLength() {
|
||||
return fmt.Errorf("incorrect buffer size: %d buffer given, %d needed", len(hdr), ip.HeaderLength())
|
||||
}
|
||||
hdr[0] = (ip.Version << 4) | ip.IHL
|
||||
hdr[1] = ip.TOS
|
||||
ip.Length = uint16(ip.headerLength + dataLen)
|
||||
binary.BigEndian.PutUint16(hdr[2:], ip.Length)
|
||||
binary.BigEndian.PutUint16(hdr[4:], ip.Id)
|
||||
binary.BigEndian.PutUint16(hdr[6:], ip.flagsfrags())
|
||||
hdr[8] = ip.TTL
|
||||
hdr[9] = byte(ip.Protocol)
|
||||
copy(hdr[12:16], ip.SrcIP)
|
||||
copy(hdr[16:20], ip.DstIP)
|
||||
|
||||
curLocation := 20
|
||||
// Now, we will encode the options
|
||||
for _, opt := range ip.Options {
|
||||
switch opt.OptionType {
|
||||
case 0:
|
||||
// this is the end of option lists
|
||||
hdr[curLocation] = 0
|
||||
curLocation++
|
||||
case 1:
|
||||
// this is the padding
|
||||
hdr[curLocation] = 1
|
||||
curLocation++
|
||||
default:
|
||||
hdr[curLocation] = opt.OptionType
|
||||
hdr[curLocation+1] = opt.OptionLength
|
||||
|
||||
// sanity checking to protect us from buffer overrun
|
||||
if len(opt.OptionData) > int(opt.OptionLength-2) {
|
||||
return fmt.Errorf("option length is smaller than length of option data")
|
||||
}
|
||||
copy(hdr[curLocation+2:curLocation+int(opt.OptionLength)], opt.OptionData)
|
||||
curLocation += int(opt.OptionLength)
|
||||
}
|
||||
}
|
||||
hdr[10] = 0
|
||||
hdr[11] = 0
|
||||
ip.Checksum = Checksum(hdr)
|
||||
binary.BigEndian.PutUint16(hdr[10:], ip.Checksum)
|
||||
return nil
|
||||
}
|
||||
34
network/tun/stop.go
Normal file
34
network/tun/stop.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
var stopMarker = []byte{2, 2, 2, 2, 2, 2, 2, 2}
|
||||
|
||||
// Close of Windows and Linux tun/tap device do not interrupt blocking Read.
|
||||
// sendStopMarker is used to issue a specific packet to notify threads blocking
|
||||
// on Read.
|
||||
func sendStopMarker(src, dst string) {
|
||||
l, _ := net.ResolveUDPAddr("udp", src+":2222")
|
||||
r, _ := net.ResolveUDPAddr("udp", dst+":2222")
|
||||
conn, err := net.DialUDP("udp", l, r)
|
||||
if err != nil {
|
||||
log.Printf("fail to send stopmarker: %s", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.Write(stopMarker)
|
||||
}
|
||||
|
||||
func isStopMarker(pkt []byte, src, dst net.IP) bool {
|
||||
n := len(pkt)
|
||||
// at least should be 20(ip) + 8(udp) + 8(stopmarker)
|
||||
if n < 20+8+8 {
|
||||
return false
|
||||
}
|
||||
return pkt[0]&0xf0 == 0x40 && pkt[9] == 0x11 && src.Equal(pkt[12:16]) &&
|
||||
dst.Equal(pkt[16:20]) && bytes.Compare(pkt[n-8:n], stopMarker) == 0
|
||||
}
|
||||
106
network/tun/tun_darwin.go
Normal file
106
network/tun/tun_darwin.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
appleUTUNCtl = "com.apple.net.utun_control"
|
||||
appleCTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
|
||||
)
|
||||
|
||||
type sockaddrCtl struct {
|
||||
scLen uint8
|
||||
scFamily uint8
|
||||
ssSysaddr uint16
|
||||
scID uint32
|
||||
scUnit uint32
|
||||
scReserved [5]uint32
|
||||
}
|
||||
|
||||
type utunDev struct {
|
||||
f *os.File
|
||||
|
||||
rBuf [2048]byte
|
||||
wBuf [2048]byte
|
||||
}
|
||||
|
||||
func (dev *utunDev) Read(data []byte) (int, error) {
|
||||
n, e := dev.f.Read(dev.rBuf[:])
|
||||
if n > 0 {
|
||||
copy(data, dev.rBuf[4:n])
|
||||
n -= 4
|
||||
}
|
||||
return n, e
|
||||
}
|
||||
|
||||
// one packet, no more than MTU
|
||||
func (dev *utunDev) Write(data []byte) (int, error) {
|
||||
n := copy(dev.wBuf[4:], data)
|
||||
return dev.f.Write(dev.wBuf[:n+4])
|
||||
}
|
||||
|
||||
func (dev *utunDev) Close() error {
|
||||
return dev.f.Close()
|
||||
}
|
||||
|
||||
var sockaddrCtlSize uintptr = 32
|
||||
|
||||
func OpenTunDevice(name, addr, gw, mask string, dns []string) (io.ReadWriteCloser, error) {
|
||||
fd, err := syscall.Socket(syscall.AF_SYSTEM, syscall.SOCK_DGRAM, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ctlInfo = &struct {
|
||||
ctlID uint32
|
||||
ctlName [96]byte
|
||||
}{}
|
||||
copy(ctlInfo.ctlName[:], []byte(appleUTUNCtl))
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(appleCTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
|
||||
if errno != 0 {
|
||||
return nil, fmt.Errorf("error in syscall.Syscall(syscall.SYS_IOTL, ...): %v", errno)
|
||||
}
|
||||
addrP := unsafe.Pointer(&sockaddrCtl{
|
||||
scLen: uint8(sockaddrCtlSize),
|
||||
scFamily: syscall.AF_SYSTEM,
|
||||
/* #define AF_SYS_CONTROL 2 */
|
||||
ssSysaddr: 2,
|
||||
scID: ctlInfo.ctlID,
|
||||
scUnit: 0,
|
||||
})
|
||||
_, _, errno = syscall.RawSyscall(syscall.SYS_CONNECT, uintptr(fd), uintptr(addrP), uintptr(sockaddrCtlSize))
|
||||
if errno != 0 {
|
||||
return nil, fmt.Errorf("error in syscall.RawSyscall(syscall.SYS_CONNECT, ...): %v", errno)
|
||||
}
|
||||
|
||||
var ifName struct {
|
||||
name [16]byte
|
||||
}
|
||||
ifNameSize := uintptr(16)
|
||||
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
|
||||
2, /* #define SYSPROTO_CONTROL 2 */
|
||||
2, /* #define UTUN_OPT_IFNAME 2 */
|
||||
uintptr(unsafe.Pointer(&ifName)),
|
||||
uintptr(unsafe.Pointer(&ifNameSize)), 0)
|
||||
if errno != 0 {
|
||||
return nil, fmt.Errorf("error in syscall.Syscall6(syscall.SYS_GETSOCKOPT, ...): %v", errno)
|
||||
}
|
||||
cmd := exec.Command("ifconfig", string(ifName.name[:ifNameSize-1]), "inet", addr, gw, "netmask", mask, "mtu", "1300", "up")
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dev := &utunDev{
|
||||
f: os.NewFile(uintptr(fd), string(ifName.name[:ifNameSize-1])),
|
||||
}
|
||||
copy(dev.wBuf[:], []byte{0, 0, 0, 2})
|
||||
return dev, nil
|
||||
}
|
||||
97
network/tun/tun_linux.go
Normal file
97
network/tun/tun_linux.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
IFF_TUN = 0x0001
|
||||
IFF_TAP = 0x0002
|
||||
IFF_NO_PI = 0x1000
|
||||
)
|
||||
|
||||
type ifReq struct {
|
||||
Name [0x10]byte
|
||||
Flags uint16
|
||||
pad [0x28 - 0x10 - 2]byte
|
||||
}
|
||||
|
||||
func OpenTunDevice(name, addr, gw, mask string, dns []string) (io.ReadWriteCloser, error) {
|
||||
file, err := os.OpenFile("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var req ifReq
|
||||
copy(req.Name[:], name)
|
||||
req.Flags = IFF_TUN | IFF_NO_PI
|
||||
log.Printf("openning tun device")
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TUNSETIFF), uintptr(unsafe.Pointer(&req)))
|
||||
if errno != 0 {
|
||||
err = errno
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// config address
|
||||
log.Printf("configuring tun device address")
|
||||
cmd := exec.Command("ifconfig", name, addr, "netmask", mask, "mtu", "1300")
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
file.Close()
|
||||
log.Printf("failed to configure tun device address")
|
||||
return nil, err
|
||||
}
|
||||
syscall.SetNonblock(int(file.Fd()), false)
|
||||
return &tunDev{
|
||||
f: file,
|
||||
addr: addr,
|
||||
addrIP: net.ParseIP(addr).To4(),
|
||||
gw: gw,
|
||||
gwIP: net.ParseIP(gw).To4(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewTunDev(fd uintptr, name string, addr string, gw string) io.ReadWriteCloser {
|
||||
syscall.SetNonblock(int(fd), false)
|
||||
return &tunDev{
|
||||
f: os.NewFile(fd, name),
|
||||
addr: addr,
|
||||
addrIP: net.ParseIP(addr).To4(),
|
||||
gw: gw,
|
||||
gwIP: net.ParseIP(gw).To4(),
|
||||
}
|
||||
}
|
||||
|
||||
type tunDev struct {
|
||||
name string
|
||||
addr string
|
||||
addrIP net.IP
|
||||
gw string
|
||||
gwIP net.IP
|
||||
marker []byte
|
||||
f *os.File
|
||||
}
|
||||
|
||||
func (dev *tunDev) Read(data []byte) (int, error) {
|
||||
n, e := dev.f.Read(data)
|
||||
if e == nil && isStopMarker(data[:n], dev.addrIP, dev.gwIP) {
|
||||
return 0, errors.New("received stop marker")
|
||||
}
|
||||
return n, e
|
||||
}
|
||||
|
||||
func (dev *tunDev) Write(data []byte) (int, error) {
|
||||
return dev.f.Write(data)
|
||||
}
|
||||
|
||||
func (dev *tunDev) Close() error {
|
||||
log.Printf("send stop marker")
|
||||
sendStopMarker(dev.addr, dev.gw)
|
||||
return dev.f.Close()
|
||||
}
|
||||
374
network/tun/tun_windows.go
Normal file
374
network/tun/tun_windows.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
// "encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
TAPWIN32_MAX_REG_SIZE = 256
|
||||
TUNTAP_COMPONENT_ID_0901 = "tap0901"
|
||||
TUNTAP_COMPONENT_ID_0801 = "tap0801"
|
||||
NETWORK_KEY = `SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}`
|
||||
ADAPTER_KEY = `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`
|
||||
)
|
||||
|
||||
func ctl_code(device_type, function, method, access uint32) uint32 {
|
||||
return (device_type << 16) | (access << 14) | (function << 2) | method
|
||||
}
|
||||
|
||||
func tap_control_code(request, method uint32) uint32 {
|
||||
return ctl_code(34, request, method, 0)
|
||||
}
|
||||
|
||||
var (
|
||||
k32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
procGetOverlappedResult = k32.NewProc("GetOverlappedResult")
|
||||
TAP_IOCTL_GET_MTU = tap_control_code(3, 0)
|
||||
TAP_IOCTL_SET_MEDIA_STATUS = tap_control_code(6, 0)
|
||||
TAP_IOCTL_CONFIG_TUN = tap_control_code(10, 0)
|
||||
TAP_WIN_IOCTL_CONFIG_DHCP_MASQ = tap_control_code(7, 0)
|
||||
TAP_WIN_IOCTL_CONFIG_DHCP_SET_OPT = tap_control_code(9, 0)
|
||||
)
|
||||
|
||||
func decodeUTF16(b []byte) string {
|
||||
if len(b)%2 != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
l := len(b) / 2
|
||||
u16 := make([]uint16, l)
|
||||
for i := 0; i < l; i += 1 {
|
||||
u16[i] = uint16(b[2*i]) + (uint16(b[2*i+1]) << 8)
|
||||
}
|
||||
return windows.UTF16ToString(u16)
|
||||
}
|
||||
|
||||
func getTuntapName(componentId string) (string, error) {
|
||||
keyName := fmt.Sprintf(NETWORK_KEY+"\\%s\\Connection", componentId)
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
|
||||
if err != nil {
|
||||
key.Close()
|
||||
return "", err
|
||||
}
|
||||
var bufLength uint32 = TAPWIN32_MAX_REG_SIZE
|
||||
buf := make([]byte, bufLength)
|
||||
name, _ := windows.UTF16FromString("Name")
|
||||
var valtype uint32
|
||||
err = windows.RegQueryValueEx(
|
||||
windows.Handle(key),
|
||||
&name[0],
|
||||
nil,
|
||||
&valtype,
|
||||
&buf[0],
|
||||
&bufLength,
|
||||
)
|
||||
if err != nil {
|
||||
key.Close()
|
||||
return "", err
|
||||
}
|
||||
s := decodeUTF16(buf)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func getTuntapComponentId() (string, error) {
|
||||
adapters, err := registry.OpenKey(registry.LOCAL_MACHINE, ADAPTER_KEY, registry.READ)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var i uint32
|
||||
for i = 0; i < 1000; i++ {
|
||||
var name_length uint32 = TAPWIN32_MAX_REG_SIZE
|
||||
buf := make([]uint16, name_length)
|
||||
if err = windows.RegEnumKeyEx(
|
||||
windows.Handle(adapters),
|
||||
i,
|
||||
&buf[0],
|
||||
&name_length,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil); err != nil {
|
||||
return "", err
|
||||
}
|
||||
key_name := windows.UTF16ToString(buf[:])
|
||||
adapter, err := registry.OpenKey(adapters, key_name, registry.READ)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
name, _ := windows.UTF16FromString("ComponentId")
|
||||
name2, _ := windows.UTF16FromString("NetCfgInstanceId")
|
||||
var valtype uint32
|
||||
var component_id = make([]byte, TAPWIN32_MAX_REG_SIZE)
|
||||
var componentLen = uint32(len(component_id))
|
||||
if err = windows.RegQueryValueEx(
|
||||
windows.Handle(adapter),
|
||||
&name[0],
|
||||
nil,
|
||||
&valtype,
|
||||
&component_id[0],
|
||||
&componentLen); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
id := decodeUTF16(component_id)
|
||||
if id == TUNTAP_COMPONENT_ID_0901 || id == TUNTAP_COMPONENT_ID_0801 {
|
||||
var valtype uint32
|
||||
var netCfgInstanceId = make([]byte, TAPWIN32_MAX_REG_SIZE)
|
||||
var netCfgInstanceIdLen = uint32(len(netCfgInstanceId))
|
||||
if err = windows.RegQueryValueEx(
|
||||
windows.Handle(adapter),
|
||||
&name2[0],
|
||||
nil,
|
||||
&valtype,
|
||||
&netCfgInstanceId[0],
|
||||
&netCfgInstanceIdLen); err != nil {
|
||||
return "", err
|
||||
}
|
||||
s := decodeUTF16(netCfgInstanceId)
|
||||
log.Printf("device component id: %s", s)
|
||||
adapter.Close()
|
||||
adapters.Close()
|
||||
return s, nil
|
||||
}
|
||||
adapter.Close()
|
||||
}
|
||||
adapters.Close()
|
||||
return "", errors.New("not found component id")
|
||||
}
|
||||
|
||||
func OpenTunDevice(name, addr, gw, mask string, dns []string) (io.ReadWriteCloser, error) {
|
||||
componentId, err := getTuntapComponentId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
devId, _ := windows.UTF16FromString(fmt.Sprintf(`\\.\Global\%s.tap`, componentId))
|
||||
devName, err := getTuntapName(componentId)
|
||||
log.Printf("device name: %s", devName)
|
||||
// set dhcp with netsh
|
||||
cmd := exec.Command("netsh", "interface", "ip", "set", "address", devName, "dhcp")
|
||||
cmd.Run()
|
||||
cmd = exec.Command("netsh", "interface", "ip", "set", "dns", devName, "dhcp")
|
||||
cmd.Run()
|
||||
// open
|
||||
fd, err := windows.CreateFile(
|
||||
&devId[0],
|
||||
windows.GENERIC_READ|windows.GENERIC_WRITE,
|
||||
windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED,
|
||||
//windows.FILE_ATTRIBUTE_SYSTEM,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// set addresses with dhcp
|
||||
var returnLen uint32
|
||||
tunAddr := net.ParseIP(addr).To4()
|
||||
tunMask := net.ParseIP(mask).To4()
|
||||
gwAddr := net.ParseIP(gw).To4()
|
||||
addrParam := append(tunAddr, tunMask...)
|
||||
addrParam = append(addrParam, gwAddr...)
|
||||
lease := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(lease[:], 86400)
|
||||
addrParam = append(addrParam, lease...)
|
||||
err = windows.DeviceIoControl(
|
||||
fd,
|
||||
TAP_WIN_IOCTL_CONFIG_DHCP_MASQ,
|
||||
&addrParam[0],
|
||||
uint32(len(addrParam)),
|
||||
&addrParam[0],
|
||||
uint32(len(addrParam)),
|
||||
&returnLen,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
windows.Close(fd)
|
||||
return nil, err
|
||||
} else {
|
||||
log.Printf("set %s with net/mask: %s/%s through DHCP", devName, addr, mask)
|
||||
}
|
||||
|
||||
// set dns with dncp
|
||||
dnsParam := []byte{6, 4}
|
||||
primaryDNS := net.ParseIP(dns[0]).To4()
|
||||
dnsParam = append(dnsParam, primaryDNS...)
|
||||
if len(dns) >= 2 {
|
||||
secondaryDNS := net.ParseIP(dns[1]).To4()
|
||||
dnsParam = append(dnsParam, secondaryDNS...)
|
||||
dnsParam[1] += 4
|
||||
}
|
||||
err = windows.DeviceIoControl(
|
||||
fd,
|
||||
TAP_WIN_IOCTL_CONFIG_DHCP_SET_OPT,
|
||||
&dnsParam[0],
|
||||
uint32(len(dnsParam)),
|
||||
&addrParam[0],
|
||||
uint32(len(dnsParam)),
|
||||
&returnLen,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
windows.Close(fd)
|
||||
return nil, err
|
||||
} else {
|
||||
log.Printf("set %s with dns: %s through DHCP", devName, strings.Join(dns, ","))
|
||||
}
|
||||
|
||||
// set connect.
|
||||
inBuffer := []byte("\x01\x00\x00\x00")
|
||||
err = windows.DeviceIoControl(
|
||||
fd,
|
||||
TAP_IOCTL_SET_MEDIA_STATUS,
|
||||
&inBuffer[0],
|
||||
uint32(len(inBuffer)),
|
||||
&inBuffer[0],
|
||||
uint32(len(inBuffer)),
|
||||
&returnLen,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
windows.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
return newWinTapDev(fd, addr, gw), nil
|
||||
}
|
||||
|
||||
type winTapDev struct {
|
||||
fd windows.Handle
|
||||
addr string
|
||||
addrIP net.IP
|
||||
gw string
|
||||
gwIP net.IP
|
||||
rBuf [2048]byte
|
||||
wBuf [2048]byte
|
||||
wInitiated bool
|
||||
rOverlapped windows.Overlapped
|
||||
wOverlapped windows.Overlapped
|
||||
}
|
||||
|
||||
func newWinTapDev(fd windows.Handle, addr string, gw string) *winTapDev {
|
||||
rOverlapped := windows.Overlapped{}
|
||||
rEvent, _ := windows.CreateEvent(nil, 0, 0, nil)
|
||||
rOverlapped.HEvent = windows.Handle(rEvent)
|
||||
|
||||
wOverlapped := windows.Overlapped{}
|
||||
wEvent, _ := windows.CreateEvent(nil, 0, 0, nil)
|
||||
wOverlapped.HEvent = windows.Handle(wEvent)
|
||||
|
||||
dev := &winTapDev{
|
||||
fd: fd,
|
||||
rOverlapped: rOverlapped,
|
||||
wOverlapped: wOverlapped,
|
||||
wInitiated: false,
|
||||
|
||||
addr: addr,
|
||||
addrIP: net.ParseIP(addr).To4(),
|
||||
gw: gw,
|
||||
gwIP: net.ParseIP(gw).To4(),
|
||||
}
|
||||
return dev
|
||||
}
|
||||
|
||||
func (dev *winTapDev) Read(data []byte) (int, error) {
|
||||
for {
|
||||
var done uint32
|
||||
var nr int
|
||||
|
||||
err := windows.ReadFile(dev.fd, dev.rBuf[:], &done, &dev.rOverlapped)
|
||||
if err != nil {
|
||||
if err != windows.ERROR_IO_PENDING {
|
||||
return 0, err
|
||||
} else {
|
||||
windows.WaitForSingleObject(dev.rOverlapped.HEvent, windows.INFINITE)
|
||||
nr, err = getOverlappedResult(dev.fd, &dev.rOverlapped)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nr = int(done)
|
||||
}
|
||||
if nr > 14 {
|
||||
if isStopMarker(dev.rBuf[14:nr], dev.addrIP, dev.gwIP) {
|
||||
return 0, errors.New("received stop marker")
|
||||
}
|
||||
|
||||
// discard IPv6 packets
|
||||
if dev.rBuf[14]&0xf0 == 0x60 {
|
||||
//log.Printf("ipv6 packet")
|
||||
continue
|
||||
} else if dev.rBuf[14]&0xf0 == 0x40 {
|
||||
if !dev.wInitiated {
|
||||
// copy ether header for writing
|
||||
copy(dev.wBuf[:], dev.rBuf[6:12])
|
||||
copy(dev.wBuf[6:], dev.rBuf[0:6])
|
||||
copy(dev.wBuf[12:], dev.rBuf[12:14])
|
||||
dev.wInitiated = true
|
||||
}
|
||||
copy(data, dev.rBuf[14:nr])
|
||||
return nr - 14, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dev *winTapDev) Write(data []byte) (int, error) {
|
||||
var done uint32
|
||||
var nw int
|
||||
|
||||
payloadL := copy(dev.wBuf[14:], data)
|
||||
packetL := payloadL + 14
|
||||
err := windows.WriteFile(dev.fd, dev.wBuf[:packetL], &done, &dev.wOverlapped)
|
||||
if err != nil {
|
||||
if err != windows.ERROR_IO_PENDING {
|
||||
return 0, err
|
||||
} else {
|
||||
windows.WaitForSingleObject(dev.wOverlapped.HEvent, windows.INFINITE)
|
||||
nw, err = getOverlappedResult(dev.fd, &dev.wOverlapped)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nw = int(done)
|
||||
}
|
||||
if nw != packetL {
|
||||
return 0, fmt.Errorf("write %d packet (%d bytes payload), return %d", packetL, payloadL, nw)
|
||||
} else {
|
||||
return payloadL, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) {
|
||||
var n int
|
||||
r, _, err := syscall.Syscall6(procGetOverlappedResult.Addr(), 4,
|
||||
uintptr(h),
|
||||
uintptr(unsafe.Pointer(overlapped)),
|
||||
uintptr(unsafe.Pointer(&n)), 1, 0, 0)
|
||||
if r == 0 {
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (dev *winTapDev) Close() error {
|
||||
log.Printf("close winTap device")
|
||||
sendStopMarker(dev.addr, dev.gw)
|
||||
return windows.Close(dev.fd)
|
||||
}
|
||||
230
packet.go
Normal file
230
packet.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
const PacketMagicNumber uint8 = 114514 % 256
|
||||
|
||||
const (
|
||||
GrassPacketIPv4 uint8 = 0
|
||||
GrassPacketPing uint8 = 1
|
||||
GrassPacketInfo uint8 = 2
|
||||
GrassPacketPong uint8 = 3
|
||||
GrassPacketClients uint8 = 4
|
||||
GrassPacketFarPing uint8 = 5
|
||||
GrassPacketExpired uint8 = 255
|
||||
)
|
||||
|
||||
type Packet struct {
|
||||
TmpKey uint32
|
||||
Timestamp uint16
|
||||
Type uint8
|
||||
CRC32 uint32
|
||||
Length uint32
|
||||
Content []byte
|
||||
}
|
||||
|
||||
type QueuedPacket struct {
|
||||
Type uint8
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(ciphertext)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padtext...)
|
||||
}
|
||||
|
||||
func pkcs5UnPadding(origData []byte, blockSize int) ([]byte, error) {
|
||||
length := len(origData)
|
||||
unpadding := int(origData[length-1])
|
||||
if unpadding == 0 || unpadding > blockSize {
|
||||
return nil, fmt.Errorf("Invalid padding")
|
||||
}
|
||||
for i := length - unpadding; i < length; i++ {
|
||||
if int(origData[i]) != unpadding {
|
||||
return nil, fmt.Errorf("Invalid padding")
|
||||
}
|
||||
}
|
||||
return origData[:length-unpadding], nil
|
||||
}
|
||||
|
||||
func decodePacket(r *Packet, s io.Reader, sharedKey []byte) error {
|
||||
buf := make([]byte, 4)
|
||||
_, err := io.ReadFull(s, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.TmpKey = binary.LittleEndian.Uint32(buf)
|
||||
tmpBlock, _ := aes.NewCipher(sharedKey)
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
key := buffer.Bytes()
|
||||
tmpBlock.Decrypt(key, key)
|
||||
|
||||
buf = make([]byte, 16)
|
||||
_, err = io.ReadFull(s, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cryptor := cipher.NewCBCDecrypter(block, key)
|
||||
cryptor.CryptBlocks(buf, buf)
|
||||
|
||||
if uint8(buf[0]) != PacketMagicNumber {
|
||||
return fmt.Errorf("Packet magic number mismatch")
|
||||
}
|
||||
r.Timestamp = binary.LittleEndian.Uint16(buf[1:3])
|
||||
r.Type = uint8(buf[3])
|
||||
r.CRC32 = binary.LittleEndian.Uint32(buf[4:8])
|
||||
r.Length = binary.LittleEndian.Uint32(buf[8:12])
|
||||
buf[4] = 0
|
||||
buf[5] = 0
|
||||
buf[6] = 0
|
||||
buf[7] = 0
|
||||
hash := crc32.NewIEEE()
|
||||
hash.Write(buf)
|
||||
|
||||
buf = make([]byte, r.Length)
|
||||
n, err := io.ReadFull(s, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nt := n - n%16
|
||||
cryptor.CryptBlocks(buf[:nt], buf[:nt])
|
||||
hash.Write(buf)
|
||||
sum := hash.Sum32()
|
||||
if sum != r.CRC32 {
|
||||
return fmt.Errorf("CRC32 mismatch")
|
||||
}
|
||||
|
||||
r.Content, err = pkcs5UnPadding(buf[:nt], 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodePacket(r *Packet, s Writer, sharedKey []byte) error {
|
||||
tmpBlock, _ := aes.NewCipher(sharedKey)
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
|
||||
key := buffer.Bytes()
|
||||
tmpBlock.Decrypt(key, key)
|
||||
|
||||
r.Timestamp = uint16(time.Now().Unix() % 65536)
|
||||
r.CRC32 = 0
|
||||
|
||||
paddedContent := pkcs5Padding(r.Content, 16)
|
||||
nadd := rand.Intn(16)
|
||||
r.Length = uint32(len(paddedContent) + nadd)
|
||||
|
||||
buffer = bytes.NewBuffer(nil)
|
||||
binary.Write(buffer, binary.LittleEndian, PacketMagicNumber)
|
||||
binary.Write(buffer, binary.LittleEndian, r.Timestamp)
|
||||
binary.Write(buffer, binary.LittleEndian, r.Type)
|
||||
binary.Write(buffer, binary.LittleEndian, r.CRC32)
|
||||
binary.Write(buffer, binary.LittleEndian, r.Length)
|
||||
buffer.Write([]byte{0, 0, 0, 0})
|
||||
buf := buffer.Bytes()
|
||||
|
||||
buf = append(buf, paddedContent...)
|
||||
|
||||
tmp := make([]byte, nadd)
|
||||
rand.Read(tmp)
|
||||
buf = append(buf, tmp...)
|
||||
|
||||
r.CRC32 = crc32.ChecksumIEEE(buf)
|
||||
binary.LittleEndian.PutUint32(buf[4:8], r.CRC32)
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cryptor := cipher.NewCBCEncrypter(block, key)
|
||||
nt := 16 + uint(r.Length-r.Length%16)
|
||||
cryptor.CryptBlocks(buf[:nt], buf[:nt])
|
||||
|
||||
buf2 := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf2, r.TmpKey)
|
||||
s.M.Lock()
|
||||
_, err = s.W.Write(buf2)
|
||||
if err != nil {
|
||||
s.M.Unlock()
|
||||
return err
|
||||
}
|
||||
_, err = s.W.Write(buf)
|
||||
if err != nil {
|
||||
s.M.Unlock()
|
||||
return err
|
||||
}
|
||||
err = s.W.Flush()
|
||||
if err != nil {
|
||||
s.M.Unlock()
|
||||
return err
|
||||
}
|
||||
s.M.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadPacket(pkt *Packet, conn *Conn, client *Client) error {
|
||||
if err := decodePacket(pkt, conn.Reader, client.Key); err != nil {
|
||||
return err
|
||||
}
|
||||
if conn.peerTmpKeys.HasKey(pkt.TmpKey) {
|
||||
pkt.Type = GrassPacketExpired
|
||||
return nil
|
||||
}
|
||||
a := time.Now().Unix() % 65536
|
||||
b := int64(pkt.Timestamp)
|
||||
if b-a < 120 && a-b < 120 {
|
||||
conn.peerTmpKeys.AddKey(pkt.TmpKey)
|
||||
return nil
|
||||
}
|
||||
if b+65536-a < 120 || a+65536-b < 120 {
|
||||
conn.peerTmpKeys.AddKey(pkt.TmpKey)
|
||||
return nil
|
||||
}
|
||||
pkt.Type = GrassPacketExpired
|
||||
return nil
|
||||
}
|
||||
|
||||
func WritePacket(pkt *Packet, conn *Conn, client *Client) error {
|
||||
for {
|
||||
pkt.TmpKey = rand.Uint32()
|
||||
if !conn.tmpKeys.HasKey(pkt.TmpKey) {
|
||||
break
|
||||
}
|
||||
}
|
||||
conn.tmpKeys.AddKey(pkt.TmpKey)
|
||||
if err := encodePacket(pkt, conn.Writer, client.Key); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WriteRawPacket(typ uint8, content []byte, conn *Conn, client *Client) error {
|
||||
var pkt Packet
|
||||
pkt.Type = typ
|
||||
pkt.Content = content
|
||||
return WritePacket(&pkt, conn, client)
|
||||
}
|
||||
46
tmpkey.go
Normal file
46
tmpkey.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package grass
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TmpKey struct {
|
||||
Key uint32
|
||||
Time uint
|
||||
}
|
||||
|
||||
type TmpKeySet struct {
|
||||
Keys []TmpKey
|
||||
Map map[uint32]struct{}
|
||||
Mutex sync.Mutex
|
||||
Expire uint
|
||||
}
|
||||
|
||||
func (s *TmpKeySet) HasKey(key uint32) bool {
|
||||
s.Mutex.Lock()
|
||||
defer s.Mutex.Unlock()
|
||||
tm := uint(time.Now().Unix())
|
||||
if len(s.Keys) > 0 {
|
||||
i := 0
|
||||
for i < len(s.Keys) && s.Keys[i].Time < tm-s.Expire {
|
||||
delete(s.Map, s.Keys[i].Key)
|
||||
i++
|
||||
}
|
||||
s.Keys = s.Keys[i:]
|
||||
}
|
||||
_, ok := s.Map[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *TmpKeySet) AddKey(key uint32) {
|
||||
s.Mutex.Lock()
|
||||
defer s.Mutex.Unlock()
|
||||
tm := uint(time.Now().Unix())
|
||||
s.Keys = append(s.Keys, TmpKey{Key: key, Time: tm})
|
||||
s.Map[key] = struct{}{}
|
||||
}
|
||||
|
||||
func NewTmpKeySet(expire uint) *TmpKeySet {
|
||||
return &TmpKeySet{Expire: expire, Map: make(map[uint32]struct{})}
|
||||
}
|
||||
Reference in New Issue
Block a user