add files to git

This commit is contained in:
mcfx
2020-09-07 18:36:16 +08:00
parent 2f13bca07e
commit 2e621fc26f
24 changed files with 3064 additions and 0 deletions

39
README.md Normal file
View File

@@ -0,0 +1,39 @@
# Grass
![](https://github.com/k4yt3x/flowerhd/raw/master/PNG/%E8%8D%89.PNG)
Grass is a Peer-to-Peer VPN.
## Usage
Clone this repo, run `go build main.go` in the bin directory.
Sample config:
```json
{
"Id": 114514, // some random number
"Key": "some_key111", // pre-shared key
"IPv4": "10.56.0.2", // ipv4 for this client, ipv6 support will be added soon
"IPv4Gateway": "10.56.0.1", // ipv4 gateway, in fact useless
"IPv4Mask": "255.255.255.0",
"IPv4DNS": "8.8.8.8,8.8.4.4", // seems only useful on windows
"Listen": [
{
"LocalAddr": "0.0.0.0",
"NetworkAddr": "192.168.254.1", // your ip on external network
"Port": 14514
},
{
"LocalAddr": "[::]",
"NetworkAddr": "[fe80:114:514:1919:810::233]",
"Port": 14515
}
],
"BootstrapNodes": ["11.4.5.14:1919"],
"CheckClientsInterval": 1,
"PingInterval": 20,
"CheckConnInterval": 5,
"Debug": false,
"ConnectNewPeer": true, // if false, the client will not discover new clients
"StartTun": true // if false, the client will not start tunnel interface
}
```

44
bin/main.go Normal file
View File

@@ -0,0 +1,44 @@
package main
import (
"flag"
"io/ioutil"
"log"
"os"
"os/signal"
"syscall"
"github.com/mcfx0/grass"
)
func main() {
var configFile string
flag.StringVar(&configFile, "config", "config.json", "config file name")
flag.Parse()
data, err := ioutil.ReadFile(configFile)
if err != nil {
log.Fatal(err)
}
config, err := grass.UnmarshalConfig(data)
if err != nil {
log.Fatal(err)
}
client := grass.Client{Config: config}
err = client.Start()
if err != nil {
log.Fatal(err)
}
ch := make(chan os.Signal, 1)
signal.Notify(ch,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
s := <-ch
switch s {
default:
client.Stop()
}
}

577
client.go Normal file
View File

@@ -0,0 +1,577 @@
package grass
import (
"bufio"
"encoding/binary"
"fmt"
"log"
"math/rand"
"net"
"strconv"
"sync"
"time"
"github.com/mcfx0/grass/network"
"github.com/mcfx0/grass/network/tun"
)
const MAX_PEER_CONNS = 8
type Conn struct {
Conn net.Conn
Reader *bufio.Reader
Writer Writer
Latency []uint32
tmpKeys *TmpKeySet
peerTmpKeys *TmpKeySet
WriteCh chan QueuedPacket
WriterStopCh chan bool
}
type Peer struct {
Id uint32
Conns []*Conn
ConnInfo []string
ConnsMutex *sync.RWMutex
LastSendClients OtherClientsInfo
}
type ListenInfo struct {
LocalAddr string
NetworkAddr string
Port uint16
}
type ClientConfig struct {
Id uint32
Key []byte
IPv4 uint32
IPv4Gateway uint32
IPv4Mask uint32
IPv4DNS []uint32
Listen []ListenInfo
BootstrapNodes []string
CheckClientsInterval int
PingInterval int
CheckConnInterval int
Debug bool
ConnectNewPeer bool
StartTun bool
}
type Client struct {
Id uint32
IPv4 uint32
Key []byte
Peers map[uint32]*Peer
PeerMutex *sync.RWMutex
OClients OtherClients
LastCheckClients OtherClientsInfo
Tun *network.TunHandler
Config ClientConfig
ConnInfo string
KnownConnInfo map[string]uint32
KnownConnInfoMutex *sync.Mutex
running bool
}
func (peer *Peer) AvgLatency() uint32 {
peer.ConnsMutex.RLock()
tot := 0
fail := 0
sum := 0
for i := 0; i < len(peer.Conns); i++ {
tmp := peer.Conns[i].Latency
for j := 0; j < len(tmp); j++ {
if tmp[j] == 0 {
fail++
} else {
sum += int(tmp[j])
}
tot++
}
}
peer.ConnsMutex.RUnlock()
if tot == fail {
return 100000000
}
res := float64(sum) / float64(tot-fail) / float64(tot-fail) * float64(tot)
resn := int(res)
if resn < 2 {
return 2
}
if resn > 100000000 {
return 100000000
}
return uint32(resn)
}
func (peer *Peer) SendPacket(tp uint8, data []byte, client *Client) error {
// a naive implementation, need to be improved
peer.ConnsMutex.RLock()
defer peer.ConnsMutex.RUnlock()
if len(peer.Conns) == 0 {
return fmt.Errorf("no usable connections")
}
//return WriteRawPacket(tp, data, peer.Conns[rand.Intn(len(peer.Conns))], client)
peer.Conns[rand.Intn(len(peer.Conns))].WriteCh <- QueuedPacket{Type: tp, Data: data}
return nil
}
func (peer *Peer) SendPacketV2(tp uint8, data []byte, client *Client, maxQlen int) bool {
peer.ConnsMutex.RLock()
defer peer.ConnsMutex.RUnlock()
if len(peer.Conns) == 0 {
return false
}
for i := 0; i < len(peer.Conns); i++ {
//fmt.Printf("%d ", len(peer.Conns[i].WriteCh))
if len(peer.Conns[i].WriteCh) <= maxQlen {
peer.Conns[i].WriteCh <- QueuedPacket{Type: tp, Data: data}
return true
}
}
return false
}
func (client *Client) HandleConn(rawConn net.Conn, connInfo string) error {
var conn Conn
conn.Conn = rawConn
conn.Reader = bufio.NewReader(rawConn)
conn.Writer = NewWriter(rawConn)
conn.tmpKeys = NewTmpKeySet(120)
conn.peerTmpKeys = NewTmpKeySet(120)
conn.WriteCh = make(chan QueuedPacket, 300)
conn.WriterStopCh = make(chan bool, 10)
var pkt Packet
pkt.Type = GrassPacketInfo
pkt.Content = make([]byte, 4)
binary.LittleEndian.PutUint32(pkt.Content, client.Id)
err := WritePacket(&pkt, &conn, client)
if err != nil {
return err
}
var peerId uint32
for {
err := ReadPacket(&pkt, &conn, client)
if err != nil {
return err
}
if pkt.Type == GrassPacketInfo {
if len(pkt.Content) != 4 {
return fmt.Errorf("Unexpected length of info packet")
}
peerId = binary.LittleEndian.Uint32(pkt.Content)
break
}
}
log.Printf("Connected to peer %08x\n", peerId)
if connInfo != "" {
client.KnownConnInfoMutex.Lock()
client.KnownConnInfo[connInfo] = peerId
client.KnownConnInfoMutex.Unlock()
}
client.PeerMutex.Lock()
peer, ok := client.Peers[peerId]
if !ok {
peer = &Peer{Id: peerId, ConnsMutex: &sync.RWMutex{}}
peer.LastSendClients = make(OtherClientsInfo)
if client.Peers == nil {
client.Peers = make(map[uint32]*Peer)
}
client.Peers[peerId] = peer
}
client.PeerMutex.Unlock()
peer.ConnsMutex.Lock()
peer.Conns = append(peer.Conns, &conn)
peer.ConnsMutex.Unlock()
working := true
var pingKey uint32 = 0
var pingTime int64 = 0
go func() {
for working {
changed := false
if pingTime != 0 {
conn.Latency = append(conn.Latency, 0)
changed = true
}
if len(conn.Latency) > 10 {
conn.Latency = conn.Latency[len(conn.Latency)-10:]
changed = true
}
if changed {
go client.OClients.UpdateLatencyAdd(peerId, peer.AvgLatency())
}
pingKey = rand.Uint32()
pingTime = time.Now().UnixNano()
tmp := make([]byte, 4)
binary.LittleEndian.PutUint32(tmp, pingKey)
if err := WriteRawPacket(GrassPacketPing, tmp, &conn, client); err != nil && client.Config.Debug {
log.Printf("Error in ping thread: %v\n", err)
return
}
peer.ConnsMutex.RLock()
wait := int(client.Config.PingInterval) * len(peer.Conns)
peer.ConnsMutex.RUnlock()
if len(conn.Latency) < 10 {
wait /= 10
}
wait += rand.Intn(10) - 5
if wait < 0 {
wait = 0
}
time.Sleep(time.Second * time.Duration(wait))
}
}()
go func() {
for {
select {
case pkt := <-conn.WriteCh:
WriteRawPacket(pkt.Type, pkt.Data, &conn, client)
case <-conn.WriterStopCh:
return
}
}
}()
closeConn := func() {
working = false
peer.ConnsMutex.Lock()
i := 0
for ; i < len(peer.Conns); i++ {
if peer.Conns[i] == &conn {
break
}
}
if i != len(peer.Conns) {
peer.Conns[i] = peer.Conns[len(peer.Conns)-1]
peer.Conns = peer.Conns[:len(peer.Conns)-1]
}
peer.ConnsMutex.Unlock()
if len(peer.Conns) == 0 {
client.PeerMutex.Lock()
delete(client.Peers, peer.Id)
client.PeerMutex.Unlock()
client.OClients.ClearClients(peerId)
}
log.Printf("Disconnected from peer %08x\n", peerId)
rawConn.Close()
conn.WriterStopCh <- true
}
for {
err := ReadPacket(&pkt, &conn, client)
if err != nil {
closeConn()
return err
}
if client.Config.Debug {
log.Printf("Packet from %08x: %d %v\n", peerId, pkt.Type, pkt.Content)
}
switch pkt.Type {
case GrassPacketIPv4:
tmp, err := DecodePacketIPv4(pkt.Content)
if err != nil {
return fmt.Errorf("Error in IPv4 packet: %v", err)
}
go client.HandlePacketIPv4(tmp)
case GrassPacketPing:
if len(pkt.Content) != 4 {
return fmt.Errorf("Unexpected length of ping packet")
}
if err = WriteRawPacket(GrassPacketPong, pkt.Content, &conn, client); err != nil {
return err
}
case GrassPacketInfo:
continue
case GrassPacketPong:
if len(pkt.Content) != 4 {
return fmt.Errorf("Unexpected length of pong packet")
}
tKey := binary.LittleEndian.Uint32(pkt.Content)
if tKey == pingKey {
latency := uint32((time.Now().UnixNano() - pingTime) / 1000)
if latency == 0 {
latency = 1
}
conn.Latency = append(conn.Latency, latency)
pingKey = 0
pingTime = 0
go client.OClients.UpdateLatencyAdd(peerId, peer.AvgLatency())
}
case GrassPacketClients:
tmp, err := DecodePacketClients(pkt.Content)
if client.Config.Debug {
log.Printf("Packet Clients from %08x: %v\n", peerId, tmp)
}
if err != nil {
return fmt.Errorf("Error in Clients packet: %v", err)
}
go client.OClients.UpdateClients(tmp, peerId, peer.AvgLatency())
case GrassPacketFarPing:
tmp, err := DecodePacketFarPing(pkt.Content)
if err != nil {
return fmt.Errorf("Error in FarPing packet: %v", err)
}
go client.HandlePacketFarPing(tmp)
case GrassPacketExpired:
continue
default:
closeConn()
return fmt.Errorf("Unknown packet type")
}
}
return nil
}
func (client *Client) CheckClients() {
lstLatency := make(map[uint32]uint32)
latencyIncreaseCnt := make(map[uint32]int)
for client.running {
time.Sleep(time.Second * time.Duration(client.Config.CheckClientsInterval))
var peers []*Peer
client.PeerMutex.RLock()
for _, pr := range client.Peers {
peers = append(peers, pr)
}
client.PeerMutex.RUnlock()
tmp := make(OtherClientsInfo)
/*if client.Config.Debug {
fmt.Println("peers:")
}*/
for i := 0; i < len(peers); i++ {
/*if client.Config.Debug {
fmt.Printf("id=%08x lat=%v avgLat=%d\n", peers[i].Id, peers[i].Conns[0].Latency, peers[i].AvgLatency())
}*/
tmp[peers[i].Id] = OtherClientInfo{ConnInfo: "", IPv4: 0, Latency: peers[i].AvgLatency()}
}
tmp[client.Id] = OtherClientInfo{ConnInfo: client.ConnInfo, IPv4: client.IPv4, Latency: 1}
tmpt, useful := GenDiffPacketClients(tmp, client.LastCheckClients)
client.OClients.UpdateClients(tmpt, client.Id, 0)
client.LastCheckClients = tmp
tmp = make(OtherClientsInfo)
var ocs []*OtherClient
client.OClients.M.Lock()
for _, oc := range client.OClients.C {
ocs = append(ocs, oc)
}
client.OClients.M.Unlock()
if client.Config.Debug {
log.Println("clients:")
}
for i := 0; i < len(ocs); i++ {
if client.Config.Debug {
log.Printf("id=%08x conn=%s ipv4=%08x lat=%v latSort=%v\n", ocs[i].Id, ocs[i].ConnInfoStr, ocs[i].IPv4, ocs[i].Latency, ocs[i].LatencySorted)
}
lat := ocs[i].GetLatency()
if llat, ok := lstLatency[ocs[i].Id]; ok {
if lat > llat {
t, ok := latencyIncreaseCnt[ocs[i].Id]
if !ok {
t = 0
}
t++
if t < 10 {
latencyIncreaseCnt[ocs[i].Id] = t
} else {
go func(id uint32) {
var pkt PacketFarPing
pkt.SrcNode = client.Id
pkt.DstNode = id
pkt.Status = FarPingRequest
pkt.HistoryNodes = make([]uint32, 0)
client.HandlePacketFarPing(pkt)
}(ocs[i].Id)
latencyIncreaseCnt[ocs[i].Id] = 0
}
} else if lat < llat {
latencyIncreaseCnt[ocs[i].Id] = 0
}
}
lstLatency[ocs[i].Id] = lat
if !client.OClients.BannedClients.HasKey(ocs[i].Id) {
tmp[ocs[i].Id] = OtherClientInfo{ConnInfo: ocs[i].ConnInfoStr, IPv4: ocs[i].IPv4, Latency: lat}
}
}
for i := 0; i < len(peers); i++ {
if rand.Intn(120) == 23 {
peers[i].LastSendClients = make(OtherClientsInfo)
}
tmpt, useful = GenDiffPacketClients(tmp, peers[i].LastSendClients)
peers[i].LastSendClients = tmp
if useful {
err := peers[i].SendPacket(GrassPacketClients, EncodePacketClients(tmpt), client)
if err != nil {
log.Printf("Error while sending Clients packet to %08x: %v\n", peers[i].Id, err)
}
}
}
if client.Config.ConnectNewPeer {
client.KnownConnInfoMutex.Lock()
for i := 0; i < len(ocs); i++ {
tc := ocs[i].ConnInfo
if ocs[i].Id == client.Id {
continue
}
for j := 0; j < len(tc); j++ {
client.KnownConnInfo[tc[j]] = ocs[i].Id
}
}
client.KnownConnInfoMutex.Unlock()
}
}
}
func (client *Client) CheckConn() {
for client.running {
time.Sleep(time.Second * time.Duration(client.Config.CheckConnInterval))
tmap := make(map[uint32][]string)
addConn := func(id uint32, ci string) {
_, ok := tmap[id]
if !ok {
tmap[id] = nil
}
tmap[id] = append(tmap[id], ci)
}
ct := make([]struct {
connInfo string
cid uint32
}, 0)
client.KnownConnInfoMutex.Lock()
for connInfo, cid := range client.KnownConnInfo {
ct = append(ct, struct {
connInfo string
cid uint32
}{connInfo: connInfo, cid: cid})
}
client.KnownConnInfoMutex.Unlock()
for i := 0; i < len(ct); i++ {
connInfo := ct[i].connInfo
cid := ct[i].cid
client.PeerMutex.RLock()
peer, ok := client.Peers[cid]
if !ok {
addConn(cid, connInfo)
client.PeerMutex.RUnlock()
continue
}
client.PeerMutex.RUnlock()
peer.ConnsMutex.RLock()
//log.Printf("conninfo: %v %v %v\n", connInfo, cid, peer)
if len(peer.Conns) < MAX_PEER_CONNS {
addConn(cid, connInfo)
}
peer.ConnsMutex.RUnlock()
}
for _, cis := range tmap {
go func(s string) {
conn, err := net.Dial("tcp", s)
if err != nil {
return
}
err = client.HandleConn(conn, s)
if err != nil && client.Config.Debug {
log.Printf("Error during handling conn: %v\n", err)
}
}(cis[rand.Intn(len(cis))])
}
}
}
func (client *Client) Start() error {
rand.Seed(time.Now().UnixNano())
log.Println("Starting...")
client.Id = client.Config.Id
client.IPv4 = client.Config.IPv4
client.Key = client.Config.Key
client.PeerMutex = &sync.RWMutex{}
client.KnownConnInfoMutex = &sync.Mutex{}
client.KnownConnInfo = make(map[string]uint32)
client.LastCheckClients = make(OtherClientsInfo)
client.OClients.Init(client.Config.CheckClientsInterval)
client.running = true
client.ConnInfo = ""
if client.Config.StartTun {
tmpIP := Uint32ToIP(client.IPv4).String()
tmpGW := Uint32ToIP(client.Config.IPv4Gateway).String()
tmpMask := Uint32ToIP(client.Config.IPv4Mask).String()
tmpDNS := make([]string, 0)
for i := 0; i < len(client.Config.IPv4DNS); i++ {
tmpDNS = append(tmpDNS, Uint32ToIP(client.Config.IPv4DNS[i]).String())
}
tun, err := tun.OpenTunDevice("tun0", tmpIP, tmpGW, tmpMask, tmpDNS)
if err != nil {
return err
}
client.Tun = network.New(tun, client.HandleTunIPv4)
go client.Tun.Run()
}
for i := 0; i < len(client.Config.Listen); i++ {
go func(s ListenInfo) {
listenAddr := s.LocalAddr + ":" + strconv.Itoa(int(s.Port))
if client.ConnInfo != "" {
client.ConnInfo += ","
}
client.ConnInfo += s.NetworkAddr + ":" + strconv.Itoa(int(s.Port))
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
log.Printf("Error during listening %s: %v\n", listenAddr, err)
return
}
for client.running {
conn, err := listener.Accept()
if err != nil {
continue
}
go func() {
err := client.HandleConn(conn, "")
if err != nil && client.Config.Debug {
log.Printf("Error during handling conn: %v\n", err)
}
}()
}
}(client.Config.Listen[i])
}
for i := 0; i < len(client.Config.BootstrapNodes); i++ {
go func(s string) {
conn, err := net.Dial("tcp", s)
if err != nil {
log.Println(err)
return
}
err = client.HandleConn(conn, s)
if err != nil && client.Config.Debug {
log.Printf("Error during handling conn: %v\n", err)
}
}(client.Config.BootstrapNodes[i])
}
go client.CheckClients()
go client.CheckConn()
return nil
}
func (client *Client) Stop() {
client.running = false
log.Println("Shutting down...")
// close connections
}

225
clients.go Normal file
View File

@@ -0,0 +1,225 @@
package grass
import (
"sort"
"strings"
"sync"
)
type LatencyPair struct {
Id uint32
Latency uint32
}
type pair struct {
a uint32
b uint32
}
type OtherClient struct {
Id uint32
IPv4 uint32
Latency map[uint32]uint32 // connId->latency
LatencyAdd map[uint32]uint32
LatencySorted []LatencyPair
LatencyMutex *sync.RWMutex
ConnInfo []string
ConnInfoStr string
}
type OtherClients struct {
C map[uint32]*OtherClient
M *sync.RWMutex
IPv4Map map[uint32]uint32
BannedClients *TmpKeySet
}
type pairArray []pair
func (s pairArray) Len() int {
return len(s)
}
func (s pairArray) Less(i, j int) bool {
return s[i].a < s[j].a
}
func (s pairArray) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (oc *OtherClient) SortLatency() {
oc.LatencyMutex.RLock()
var tmp pairArray
for vid, la := range oc.Latency {
tmp = append(tmp, pair{a: la, b: vid})
}
oc.LatencyMutex.RUnlock()
sort.Sort(tmp)
var tmp2 []LatencyPair
for i := 0; i < len(tmp); i++ {
tmp2 = append(tmp2, LatencyPair{Id: tmp[i].b, Latency: tmp[i].a})
}
oc.LatencyMutex.Lock()
oc.LatencySorted = tmp2
oc.LatencyMutex.Unlock()
}
func (oc *OtherClient) GetLatency() uint32 {
oc.LatencyMutex.RLock()
defer oc.LatencyMutex.RUnlock()
if len(oc.Latency) == 0 {
return 100000000
}
return oc.LatencySorted[0].Latency
}
func (c *OtherClients) Init(checkInterval int) {
c.C = make(map[uint32]*OtherClient)
c.M = &sync.RWMutex{}
c.IPv4Map = make(map[uint32]uint32)
c.BannedClients = NewTmpKeySet(uint(checkInterval * 10))
}
func (c *OtherClients) IPv4ToId(ipv4 uint32) uint32 {
c.M.RLock()
defer c.M.RUnlock()
res, ok := c.IPv4Map[ipv4]
if ok {
return res
}
return 0
}
func (c *OtherClients) UpdateClient(id, ipv4, viaId, latency, latencyAdd uint32) {
c.M.Lock()
oc, ok := c.C[id]
if !ok {
oc = &OtherClient{Id: id, IPv4: ipv4, Latency: make(map[uint32]uint32), LatencyAdd: make(map[uint32]uint32), LatencyMutex: &sync.RWMutex{}}
c.C[id] = oc
if ipv4 != 0 {
c.IPv4Map[ipv4] = id
}
}
c.M.Unlock()
if oc.IPv4 == 0 && ipv4 != 0 {
oc.IPv4 = ipv4
c.M.Lock()
c.IPv4Map[ipv4] = id
c.M.Unlock()
}
if viaId == 0 {
return
}
oc.LatencyMutex.Lock()
if latency == 0 {
delete(oc.Latency, viaId)
delete(oc.LatencyAdd, viaId)
} else if !c.BannedClients.HasKey(id) {
oc.Latency[viaId] = latency
oc.LatencyAdd[viaId] = latencyAdd
}
oc.LatencyMutex.Unlock()
oc.SortLatency()
}
func (c *OtherClients) UpdateClientConnInfo(id uint32, connInfo string) {
if c.BannedClients.HasKey(id) {
return
}
c.M.Lock()
oc, ok := c.C[id]
if !ok {
oc = &OtherClient{Id: id, IPv4: 0, Latency: make(map[uint32]uint32), LatencyAdd: make(map[uint32]uint32), LatencyMutex: &sync.RWMutex{}}
c.C[id] = oc
}
c.M.Unlock()
if connInfo == "" {
oc.ConnInfo = make([]string, 0)
} else {
oc.ConnInfo = strings.Split(connInfo, ",")
}
oc.ConnInfoStr = connInfo
}
func (c *OtherClients) UpdateClients(pkt PacketClients, viaId, addLatency uint32) {
tmp := make(map[uint32]struct{})
for i := 0; i < int(pkt.Count); i++ {
if pkt.ConnInfo[i].Ok {
c.UpdateClientConnInfo(pkt.Ids[i], pkt.ConnInfo[i].Val)
}
tid := viaId
tlat := pkt.Latency[i].Val + addLatency
if !pkt.Latency[i].Ok {
tid = 0
tlat = 0
}
c.UpdateClient(pkt.Ids[i], pkt.IPv4[i].Val, tid, tlat, addLatency)
tmp[pkt.Ids[i]] = struct{}{}
}
c.M.Lock()
ocs := make([]*OtherClient, 0)
for id, oc := range c.C {
_, ok := tmp[id]
if !ok {
ocs = append(ocs, oc)
}
}
c.M.Unlock()
var dels []uint32
for i := 0; i < len(ocs); i++ {
oc := ocs[i]
oc.LatencyMutex.Lock()
_, ok := oc.Latency[viaId]
if !ok {
oc.LatencyMutex.Unlock()
} else {
delete(oc.Latency, viaId)
oc.LatencyMutex.Unlock()
oc.SortLatency()
if len(oc.Latency) == 0 {
dels = append(dels, oc.Id)
}
}
}
c.M.Lock()
for i := 0; i < len(dels); i++ {
oc, ok := c.C[dels[i]]
if ok {
_, ok = c.IPv4Map[oc.IPv4]
if ok {
delete(c.IPv4Map, oc.IPv4)
}
delete(c.C, dels[i])
}
}
c.M.Unlock()
}
func (c *OtherClients) UpdateLatencyAdd(viaId, addLatency uint32) {
c.M.Lock()
ocs := make([]*OtherClient, 0)
for _, oc := range c.C {
ocs = append(ocs, oc)
}
c.M.Unlock()
for i := 0; i < len(ocs); i++ {
oc := ocs[i]
oc.LatencyMutex.Lock()
_, ok := oc.Latency[viaId]
if !ok {
oc.LatencyMutex.Unlock()
} else {
oc.Latency[viaId] = oc.Latency[viaId] - oc.LatencyAdd[viaId] + addLatency
oc.LatencyAdd[viaId] = addLatency
oc.LatencyMutex.Unlock()
oc.SortLatency()
}
}
}
func (c *OtherClients) ClearClients(viaId uint32) {
var pkt PacketClients
pkt.Count = 0
c.UpdateClients(pkt, viaId, 0)
}

252
clients_packet.go Normal file
View File

@@ -0,0 +1,252 @@
package grass
import (
"encoding/binary"
"fmt"
)
type PacketClientsString struct {
Val string
Ok bool
}
type PacketClientsStringArr []PacketClientsString
type PacketClientsUint32 struct {
Val uint32
Ok bool
}
type PacketClientsUint32Arr []PacketClientsUint32
func (a PacketClientsStringArr) SetOk(n int) {
a[n].Ok = true
}
func (a PacketClientsStringArr) GetOk(n int) bool {
return a[n].Ok
}
func (a PacketClientsUint32Arr) SetOk(n int) {
a[n].Ok = true
}
func (a PacketClientsUint32Arr) GetOk(n int) bool {
return a[n].Ok
}
type PacketClientsOkTypeArr interface {
SetOk(int)
GetOk(int) bool
}
type PacketClients struct {
Count uint64
Ids []uint32
ConnInfo PacketClientsStringArr
IPv4 PacketClientsUint32Arr
Latency PacketClientsUint32Arr
}
func DecodePacketClients(buf []byte) (PacketClients, error) {
var pkt PacketClients
var n int
pkt.Count, n = binary.Uvarint(buf)
if n > 0 {
buf = buf[n:]
} else if n == 0 {
return pkt, fmt.Errorf("incomplete packet")
} else {
return pkt, fmt.Errorf("value larger than 64 bits")
}
n = int(pkt.Count)
pkt.Ids = make([]uint32, n)
pkt.ConnInfo = make([]PacketClientsString, n)
pkt.IPv4 = make([]PacketClientsUint32, n)
pkt.Latency = make([]PacketClientsUint32, n)
if len(buf) < n*4 {
return pkt, fmt.Errorf("incomplete packet")
}
for i := 0; i < n; i++ {
pkt.Ids[i] = binary.LittleEndian.Uint32(buf[:4])
buf = buf[4:]
}
n8 := (n + 7) / 8
readOkType := func(s PacketClientsOkTypeArr) error {
if len(buf) < n8 {
return fmt.Errorf("incomplete packet")
}
for i := 0; i < n; i++ {
if (int(buf[i>>3]) >> (i & 7) & 1) == 1 {
s.SetOk(i)
}
}
return nil
}
readString := func() (string, error) {
tlen, nt := binary.Uvarint(buf)
if nt > 0 {
buf = buf[nt:]
} else if nt == 0 {
return "", fmt.Errorf("incomplete packet")
} else {
return "", fmt.Errorf("value larger than 64 bits")
}
if len(buf) < int(tlen) {
return "", fmt.Errorf("incomplete packet")
}
defer func() { buf = buf[tlen:] }()
return string(buf[:tlen]), nil
}
readOkType(pkt.ConnInfo)
buf = buf[n8:]
var err error
for i := 0; i < n; i++ {
if pkt.ConnInfo[i].Ok {
pkt.ConnInfo[i].Val, err = readString()
if err != nil {
return pkt, err
}
}
}
readOkType(pkt.IPv4)
buf = buf[n8:]
for i := 0; i < n; i++ {
if pkt.IPv4[i].Ok {
if len(buf) < 4 {
return pkt, fmt.Errorf("incomplete packet")
}
pkt.IPv4[i].Val = binary.LittleEndian.Uint32(buf[:4])
buf = buf[4:]
}
}
readOkType(pkt.Latency)
buf = buf[n8:]
for i := 0; i < n; i++ {
if pkt.Latency[i].Ok {
if len(buf) < 4 {
return pkt, fmt.Errorf("incomplete packet")
}
pkt.Latency[i].Val = binary.LittleEndian.Uint32(buf[:4])
buf = buf[4:]
}
}
if len(buf) != 0 {
return pkt, fmt.Errorf("extra data after packet")
}
return pkt, nil
}
func appendUvarint(s []byte, x uint64) []byte {
n := len(s)
s = append(s, make([]byte, binary.MaxVarintLen64)...)
n += binary.PutUvarint(s[n:], x)
return s[:n]
}
func EncodePacketClients(pkt PacketClients) []byte {
var buf []byte
buf = appendUvarint(buf, pkt.Count)
n := int(pkt.Count)
tmpU32 := make([]byte, 4)
for i := 0; i < n; i++ {
binary.LittleEndian.PutUint32(tmpU32, pkt.Ids[i])
buf = append(buf, tmpU32...)
}
n8 := (n + 7) / 8
writeOkType := func(s PacketClientsOkTypeArr) {
tmp := make([]byte, n8)
for i := 0; i < n; i++ {
if s.GetOk(i) {
tmp[i>>3] = tmp[i>>3] | byte(1<<(i&7))
}
}
buf = append(buf, tmp...)
}
writeOkType(pkt.ConnInfo)
for i := 0; i < n; i++ {
if pkt.ConnInfo[i].Ok {
buf = appendUvarint(buf, uint64(len(pkt.ConnInfo[i].Val)))
buf = append(buf, pkt.ConnInfo[i].Val...)
}
}
writeOkType(pkt.IPv4)
for i := 0; i < n; i++ {
if pkt.IPv4[i].Ok {
binary.LittleEndian.PutUint32(tmpU32, pkt.IPv4[i].Val)
buf = append(buf, tmpU32...)
}
}
writeOkType(pkt.Latency)
for i := 0; i < n; i++ {
if pkt.Latency[i].Ok {
binary.LittleEndian.PutUint32(tmpU32, pkt.Latency[i].Val)
buf = append(buf, tmpU32...)
}
}
return buf
}
type OtherClientInfo struct {
ConnInfo string
IPv4 uint32
Latency uint32
}
type OtherClientsInfo map[uint32]OtherClientInfo
func GenDiffPacketClients(cur, old OtherClientsInfo) (PacketClients, bool) {
var pkt PacketClients
n := len(cur)
pkt.Count = uint64(n)
pkt.Ids = make([]uint32, n)
pkt.ConnInfo = make([]PacketClientsString, n)
pkt.IPv4 = make([]PacketClientsUint32, n)
pkt.Latency = make([]PacketClientsUint32, n)
useful := len(cur) != len(old)
i := 0
for id, oc := range cur {
pkt.Ids[i] = id
oc2, ok := old[id]
useful = useful || !ok
if ok {
if oc.ConnInfo != oc2.ConnInfo && oc.ConnInfo != "" {
pkt.ConnInfo[i].Ok = true
pkt.ConnInfo[i].Val = oc.ConnInfo
}
if oc.IPv4 != oc2.IPv4 && oc.IPv4 != 0 {
pkt.IPv4[i].Ok = true
pkt.IPv4[i].Val = oc.IPv4
}
if oc.Latency != oc2.Latency {
pkt.Latency[i].Ok = true
pkt.Latency[i].Val = oc.Latency
}
} else {
if oc.ConnInfo != "" {
pkt.ConnInfo[i].Ok = true
pkt.ConnInfo[i].Val = oc.ConnInfo
}
if oc.IPv4 != 0 {
pkt.IPv4[i].Ok = true
pkt.IPv4[i].Val = oc.IPv4
}
pkt.Latency[i].Ok = true
pkt.Latency[i].Val = oc.Latency
}
useful = useful || pkt.ConnInfo[i].Ok || pkt.IPv4[i].Ok || pkt.Latency[i].Ok
i += 1
}
return pkt, useful
}

62
config.go Normal file
View File

@@ -0,0 +1,62 @@
package grass
import (
"crypto/sha256"
"encoding/json"
"fmt"
"net"
"strings"
)
type Config struct {
Id uint32
Key string
IPv4 string
IPv4Gateway string
IPv4Mask string
IPv4DNS string
Listen []ListenInfo
BootstrapNodes []string
CheckClientsInterval int
PingInterval int
CheckConnInterval int
Debug bool
ConnectNewPeer bool
StartTun bool
}
func UnmarshalConfig(s []byte) (ClientConfig, error) {
var c Config
var r ClientConfig
err := json.Unmarshal(s, &c)
if err != nil {
return r, err
}
r.Id = c.Id
if c.Id == 0 {
return r, fmt.Errorf("Id cannot be zero")
}
thash := sha256.Sum256([]byte(c.Key))
r.Key = thash[:16]
r.IPv4 = IPToUint32(net.ParseIP(c.IPv4))
if c.IPv4 != "" {
r.IPv4Gateway = IPToUint32(net.ParseIP(c.IPv4Gateway))
r.IPv4Mask = IPToUint32(net.ParseIP(c.IPv4Mask))
tdns := strings.Split(c.IPv4DNS, ",")
r.IPv4DNS = make([]uint32, 0)
for i := 0; i < len(tdns); i++ {
r.IPv4DNS = append(r.IPv4DNS, IPToUint32(net.ParseIP(tdns[i])))
}
}
r.Listen = c.Listen
r.BootstrapNodes = c.BootstrapNodes
r.CheckClientsInterval = c.CheckClientsInterval
r.PingInterval = c.PingInterval
r.CheckConnInterval = c.CheckConnInterval
r.Debug = c.Debug
r.ConnectNewPeer = c.ConnectNewPeer
r.StartTun = c.StartTun
return r, nil
}

136
far_ping.go Normal file
View File

@@ -0,0 +1,136 @@
package grass
import (
"encoding/binary"
"fmt"
)
const (
FarPingRequest uint32 = 0
FarPingReply uint32 = 1
FarPingCircuit uint32 = 2
)
type PacketFarPing struct {
SrcNode uint32
DstNode uint32
Status uint32
HistoryNodes []uint32
}
func DecodePacketFarPing(buf []byte) (PacketFarPing, error) {
var pkt PacketFarPing
if len(buf) < 12 {
return pkt, fmt.Errorf("incomplete packet")
}
pkt.SrcNode = binary.LittleEndian.Uint32(buf[:4])
pkt.DstNode = binary.LittleEndian.Uint32(buf[4:8])
pkt.Status = binary.LittleEndian.Uint32(buf[8:12])
buf = buf[12:]
nh, n := binary.Uvarint(buf)
if n > 0 {
buf = buf[n:]
} else if n == 0 {
return pkt, fmt.Errorf("incomplete packet")
} else {
return pkt, fmt.Errorf("value larger than 64 bits")
}
if len(buf) < 4*int(nh) {
return pkt, fmt.Errorf("incomplete packet")
}
pkt.HistoryNodes = make([]uint32, int(nh))
for i := 0; i < int(nh); i++ {
pkt.HistoryNodes[i] = binary.LittleEndian.Uint32(buf[:4])
buf = buf[4:]
}
return pkt, nil
}
func EncodePacketFarPing(pkt PacketFarPing) []byte {
buf := make([]byte, 12)
binary.LittleEndian.PutUint32(buf[:4], pkt.SrcNode)
binary.LittleEndian.PutUint32(buf[4:8], pkt.DstNode)
binary.LittleEndian.PutUint32(buf[8:12], pkt.Status)
buf = appendUvarint(buf, uint64(len(pkt.HistoryNodes)))
tmpU32 := make([]byte, 4)
for i := 0; i < len(pkt.HistoryNodes); i++ {
binary.LittleEndian.PutUint32(tmpU32, pkt.HistoryNodes[i])
buf = append(buf, tmpU32...)
}
return buf
}
func (client *Client) HandlePacketFarPing(pkt PacketFarPing) error {
if pkt.DstNode == client.Id {
if pkt.Status == FarPingReply {
return nil
}
if pkt.Status == FarPingRequest {
var reply PacketFarPing
reply.DstNode = pkt.SrcNode
reply.SrcNode = client.Id
reply.Status = FarPingReply
reply.HistoryNodes = make([]uint32, 0)
return client.HandlePacketFarPing(reply)
}
if pkt.Status == FarPingCircuit {
client.OClients.BannedClients.AddKey(pkt.SrcNode)
return nil
}
return nil
}
client.OClients.M.RLock()
oc, ok := client.OClients.C[pkt.DstNode]
if !ok {
client.OClients.M.RUnlock()
return fmt.Errorf("target not reachable")
}
client.OClients.M.RUnlock()
oc.LatencyMutex.RLock()
n := len(oc.LatencySorted)
if n == 0 {
oc.LatencyMutex.RUnlock()
return fmt.Errorf("target not reachable")
}
nh := len(pkt.HistoryNodes)
var nextId uint32
for i := 0; i < n; i++ {
tid := oc.LatencySorted[i].Id
vis := false
for j := 0; j < nh; j++ {
if tid == pkt.HistoryNodes[j] {
vis = true
}
}
if !vis {
nextId = tid
break
}
}
if nextId == 0 {
var reply PacketFarPing
reply.DstNode = pkt.SrcNode
reply.SrcNode = pkt.DstNode
//reply.SrcNode = client.Id
reply.Status = FarPingCircuit
reply.HistoryNodes = make([]uint32, 0)
oc.LatencyMutex.RUnlock()
return client.HandlePacketFarPing(reply)
}
if nextId == client.Id {
nextId = pkt.DstNode
}
oc.LatencyMutex.RUnlock()
client.PeerMutex.RLock()
peer, ok := client.Peers[nextId]
client.PeerMutex.RUnlock()
if !ok {
return fmt.Errorf("target conn closed unexpectedly")
}
pkt.HistoryNodes = append(pkt.HistoryNodes, client.Id)
return peer.SendPacket(GrassPacketFarPing, EncodePacketFarPing(pkt), client)
}

1
grass.go Normal file
View File

@@ -0,0 +1 @@
package grass

179
ipv4.go Normal file
View File

@@ -0,0 +1,179 @@
package grass
import (
"encoding/binary"
"fmt"
"net"
"github.com/mcfx0/grass/network"
"github.com/mcfx0/grass/network/packet"
)
func Uint32ToIP(intIP uint32) net.IP {
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, intIP)
return net.IP(bytes)
}
func IPToUint32(ip net.IP) uint32 {
return binary.BigEndian.Uint32([]byte(ip.To4()))
}
type PacketIPv4 struct {
TargetNode uint32
TTL uint8
HistoryNodes []uint32
Data []byte
IPPacket *packet.IPv4
}
func (pkt *PacketIPv4) getIPPacket() error {
pkt.IPPacket = packet.NewIPv4()
return packet.ParseIPv4(pkt.Data, pkt.IPPacket)
}
func DecodePacketIPv4(buf []byte) (PacketIPv4, error) {
var pkt PacketIPv4
if len(buf) < 5 {
return pkt, fmt.Errorf("incomplete packet")
}
pkt.TargetNode = binary.LittleEndian.Uint32(buf[:4])
pkt.TTL = uint8(buf[4])
buf = buf[5:]
nh, n := binary.Uvarint(buf)
if n > 0 {
buf = buf[n:]
} else if n == 0 {
return pkt, fmt.Errorf("incomplete packet")
} else {
return pkt, fmt.Errorf("value larger than 64 bits")
}
if len(buf) < 4*int(nh) {
return pkt, fmt.Errorf("incomplete packet")
}
pkt.HistoryNodes = make([]uint32, int(nh))
for i := 0; i < int(nh); i++ {
pkt.HistoryNodes[i] = binary.LittleEndian.Uint32(buf[:4])
buf = buf[4:]
}
pkt.Data = buf
pkt.IPPacket = nil
return pkt, nil
}
func EncodePacketIPv4(pkt PacketIPv4) []byte {
buf := make([]byte, 5)
binary.LittleEndian.PutUint32(buf, pkt.TargetNode)
buf[4] = byte(pkt.TTL)
buf = appendUvarint(buf, uint64(len(pkt.HistoryNodes)))
tmpU32 := make([]byte, 4)
for i := 0; i < len(pkt.HistoryNodes); i++ {
binary.LittleEndian.PutUint32(tmpU32, pkt.HistoryNodes[i])
buf = append(buf, tmpU32...)
}
buf = append(buf, pkt.Data...)
return buf
}
func (client *Client) HandleFinalIPv4(pkt PacketIPv4) error {
if err := pkt.getIPPacket(); err != nil {
return err
}
if pkt.IPPacket.Protocol == packet.IPProtocolICMPv4 {
var icmp4src packet.ICMPv4
if err := packet.ParseICMPv4(pkt.IPPacket.Payload, &icmp4src); err != nil {
return err
}
if icmp4src.Type == 8 { // send icmp reply
icmp4 := network.ResponseICMPv4Packet(Uint32ToIP(client.IPv4), pkt.IPPacket.SrcIP, 0, 0, icmp4src.Payload)
return client.HandleTunIPv4(icmp4.Wire, icmp4.Ip)
}
}
// TTL should be edited here, but I'm too lazy now
client.Tun.WriteCh <- pkt.Data
/*pkt.IPPacket.TTL = pkt.TTL
packets := network.GenFragments(pkt.IPPacket, 0, pkt.IPPacket.Payload)
for i := 0; i < len(packets); i++ {
client.Tun.WriteCh <- packets[i]
}*/
return nil
}
func (client *Client) HandlePacketIPv4(pkt PacketIPv4) error {
if pkt.TargetNode == client.Id {
return client.HandleFinalIPv4(pkt)
}
if pkt.TTL == 0 {
if err := pkt.getIPPacket(); err != nil {
return err
}
payload := append([]byte{0, 0, 0, 0}, pkt.Data...)
icmp4 := network.ResponseICMPv4Packet(Uint32ToIP(client.IPv4), pkt.IPPacket.SrcIP, 11, 0, payload)
return client.HandleTunIPv4(icmp4.Wire, icmp4.Ip)
}
client.OClients.M.RLock()
oc, ok := client.OClients.C[pkt.TargetNode]
if !ok {
client.OClients.M.RUnlock()
return fmt.Errorf("target not reachable")
}
client.OClients.M.RUnlock()
oc.LatencyMutex.RLock()
n := len(oc.LatencySorted)
if n == 0 {
oc.LatencyMutex.RUnlock()
return fmt.Errorf("target not reachable")
}
nh := len(pkt.HistoryNodes)
qa := make([]uint32, 0)
qb := make([]uint32, 0)
for i := 0; i < n; i++ {
tid := oc.LatencySorted[i].Id
vis := false
for j := 0; j < nh; j++ {
if tid == pkt.HistoryNodes[j] {
vis = true
}
}
if !vis {
qa = append(qa, tid)
} else {
qb = append(qb, tid)
}
}
oc.LatencyMutex.RUnlock()
qa = append(qa, qb...)
pkt.TTL--
pkt.HistoryNodes = append(pkt.HistoryNodes, client.Id)
for mqlen := 2; mqlen <= 128; mqlen *= 2 {
for i := 0; i < len(qa); i++ {
client.PeerMutex.RLock()
peer, ok := client.Peers[qa[i]]
client.PeerMutex.RUnlock()
//fmt.Printf("%d", peer.Id)
if ok {
if peer.SendPacketV2(GrassPacketIPv4, EncodePacketIPv4(pkt), client, mqlen) {
return nil
}
}
}
}
return fmt.Errorf("failed to send out packet")
}
func (client *Client) HandleTunIPv4(data []byte, pkt *packet.IPv4) error {
var spkt PacketIPv4
tmp := client.OClients.IPv4ToId(IPToUint32(pkt.DstIP))
if tmp == 0 {
return fmt.Errorf("target not reachable")
}
spkt.TargetNode = tmp
spkt.TTL = pkt.TTL
spkt.Data = data
spkt.IPPacket = pkt
return client.HandlePacketIPv4(spkt)
}

1
network/README.md Normal file
View File

@@ -0,0 +1 @@
Most files in this directory are copied from https://github.com/yinghuocho/gotun2socks, with slightly modifies.

80
network/icmp.go Normal file
View File

@@ -0,0 +1,80 @@
package network
import (
"net"
"sync"
"github.com/mcfx0/grass/network/packet"
)
type Icmp4Packet struct {
Ip *packet.IPv4
Icmp4 *packet.ICMPv4
MtuBuf []byte
Wire []byte
}
var (
Icmp4PacketPool = &sync.Pool{
New: func() interface{} {
return &Icmp4Packet{}
},
}
)
func newICMPv4Packet() *Icmp4Packet {
return Icmp4PacketPool.Get().(*Icmp4Packet)
}
func releaseICMPv4Packet(pkt *Icmp4Packet) {
packet.ReleaseIPv4(pkt.Ip)
packet.ReleaseICMPv4(pkt.Icmp4)
if pkt.MtuBuf != nil {
releaseBuffer(pkt.MtuBuf)
}
pkt.MtuBuf = nil
pkt.Wire = nil
Icmp4PacketPool.Put(pkt)
}
func ResponseICMPv4Packet(local net.IP, remote net.IP, Type uint8, Code uint8, respPayload []byte) *Icmp4Packet {
ipid := packet.IPID()
ip := packet.NewIPv4()
icmp4 := packet.NewICMPv4()
ip.Version = 4
ip.Id = ipid
ip.SrcIP = make(net.IP, len(local))
copy(ip.SrcIP, local)
ip.DstIP = make(net.IP, len(remote))
copy(ip.DstIP, remote)
ip.TTL = 64
ip.Protocol = packet.IPProtocolICMPv4
icmp4.Type = Type
icmp4.Code = Code
icmp4.Payload = respPayload
pkt := newICMPv4Packet()
pkt.Ip = ip
pkt.Icmp4 = icmp4
pkt.MtuBuf = newBuffer()
payloadL := len(icmp4.Payload)
payloadStart := MTU - payloadL
icmp4HL := 4
icmp4Start := payloadStart - icmp4HL
pseduoStart := icmp4Start - packet.IPv4_PSEUDO_LENGTH
ip.PseudoHeader(pkt.MtuBuf[pseduoStart:icmp4Start], packet.IPProtocolICMPv4, icmp4HL+payloadL)
icmp4.Serialize(pkt.MtuBuf[icmp4Start:payloadStart], pkt.MtuBuf[icmp4Start:payloadStart], icmp4.Payload)
if payloadL != 0 {
copy(pkt.MtuBuf[payloadStart:], icmp4.Payload)
}
ipHL := ip.HeaderLength()
ipStart := icmp4Start - ipHL
ip.Serialize(pkt.MtuBuf[ipStart:icmp4Start], icmp4HL+(MTU-payloadStart))
pkt.Wire = pkt.MtuBuf[ipStart:]
return pkt
}

102
network/ip.go Normal file
View File

@@ -0,0 +1,102 @@
package network
import (
"net"
"github.com/mcfx0/grass/network/packet"
)
type ipPacket struct {
ip *packet.IPv4
mtuBuf []byte
wire []byte
}
var (
frags = make(map[uint16]*ipPacket)
)
func procFragment(ip *packet.IPv4, raw []byte) (bool, *packet.IPv4, []byte) {
exist, ok := frags[ip.Id]
if !ok {
if ip.Flags&0x1 == 0 {
return false, nil, nil
}
// first
//log.Printf("first fragment of IPID %d", ip.Id)
dup := make([]byte, len(raw))
copy(dup, raw)
clone := &packet.IPv4{}
packet.ParseIPv4(dup, clone)
frags[ip.Id] = &ipPacket{
ip: clone,
wire: dup,
}
return false, clone, dup
} else {
exist.wire = append(exist.wire, ip.Payload...)
packet.ParseIPv4(exist.wire, exist.ip)
last := false
if ip.Flags&0x1 == 0 {
//log.Printf("last fragment of IPID %d", ip.Id)
last = true
} else {
//log.Printf("continue fragment of IPID %d", ip.Id)
}
return last, exist.ip, exist.wire
}
}
func GenFragments(first *packet.IPv4, offset uint16, data []byte) []*ipPacket {
var ret []*ipPacket
for {
frag := packet.NewIPv4()
frag.Version = 4
frag.Id = first.Id
frag.SrcIP = make(net.IP, len(first.SrcIP))
copy(frag.SrcIP, first.SrcIP)
frag.DstIP = make(net.IP, len(first.DstIP))
copy(frag.DstIP, first.DstIP)
frag.TTL = first.TTL
frag.Protocol = first.Protocol
frag.FragOffset = offset
if len(data) <= MTU-20 {
frag.Payload = data
} else {
frag.Flags = 1
offset += (MTU - 20) / 8
frag.Payload = data[:MTU-20]
data = data[MTU-20:]
}
pkt := &ipPacket{ip: frag}
pkt.mtuBuf = newBuffer()
payloadL := len(frag.Payload)
payloadStart := MTU - payloadL
if payloadL != 0 {
copy(pkt.mtuBuf[payloadStart:], frag.Payload)
}
ipHL := frag.HeaderLength()
ipStart := payloadStart - ipHL
frag.Serialize(pkt.mtuBuf[ipStart:payloadStart], payloadL)
pkt.wire = pkt.mtuBuf[ipStart:]
ret = append(ret, pkt)
if frag.Flags == 0 {
return ret
}
}
}
func releaseIPPacket(pkt *ipPacket) {
packet.ReleaseIPv4(pkt.ip)
if pkt.mtuBuf != nil {
releaseBuffer(pkt.mtuBuf)
}
pkt.mtuBuf = nil
pkt.wire = nil
}

21
network/mtubuf.go Normal file
View File

@@ -0,0 +1,21 @@
package network
import (
"sync"
)
var (
bufPool *sync.Pool = &sync.Pool{
New: func() interface{} {
return make([]byte, MTU)
},
}
)
func newBuffer() []byte {
return bufPool.Get().([]byte)
}
func releaseBuffer(buf []byte) {
bufPool.Put(buf)
}

89
network/network.go Normal file
View File

@@ -0,0 +1,89 @@
package network
import (
"io"
"log"
"sync"
"github.com/mcfx0/grass/network/packet"
)
const (
MTU = 1500
)
type TunHandler struct {
dev io.ReadWriteCloser
writerStopCh chan bool
WriteCh chan interface{}
wg sync.WaitGroup
handler func([]byte, *packet.IPv4) error
}
func New(dev io.ReadWriteCloser, handler func([]byte, *packet.IPv4) error) *TunHandler {
th := &TunHandler{
dev: dev,
writerStopCh: make(chan bool, 10),
WriteCh: make(chan interface{}, 10000),
handler: handler,
}
return th
}
func (th *TunHandler) Run() {
// writer
go func() {
th.wg.Add(1)
defer th.wg.Done()
for {
select {
case pkt := <-th.WriteCh:
switch pkt.(type) {
case *ipPacket:
ip := pkt.(*ipPacket)
th.dev.Write(ip.wire)
releaseIPPacket(ip)
case []byte:
th.dev.Write(pkt.([]byte))
}
case <-th.writerStopCh:
log.Printf("quit tun2handler writer")
return
}
}
}()
// reader
var buf [MTU]byte
var ip packet.IPv4
th.wg.Add(1)
defer th.wg.Done()
for {
n, e := th.dev.Read(buf[:])
if e != nil {
// TODO: stop at critical error
log.Printf("read packet error: %s", e)
return
}
data := buf[:n]
//log.Println(data)
e = packet.ParseIPv4(data, &ip)
if e != nil {
log.Printf("error to parse IPv4: %s", e)
continue
}
//go th.handler(data, &ip)
th.handler(data, &ip)
/*go func() {
err := th.handler(data, &ip)
if err != nil {
log.Printf("ipv4 error: %v\n", err)
}
}()*/
}
}

39
network/packet/common.go Normal file
View File

@@ -0,0 +1,39 @@
package packet
var lotsOfZeros [1024]byte
//func Checksum(data []byte) uint16 {
// var csum uint32
// length := len(data) - 1
// for i := 0; i < length; i += 2 {
// csum += uint32(data[i]) << 8
// csum += uint32(data[i+1])
// }
// if len(data)%2 == 1 {
// csum += uint32(data[length]) << 8
// }
// for csum > 0xffff {
// csum = (csum >> 16) + (csum & 0xffff)
// }
// return ^uint16(csum + (csum >> 16))
//}
func Checksum(fields ...[]byte) uint16 {
var csum uint32
for _, field := range fields {
length := len(field) - 1
for i := 0; i < length; i += 2 {
csum += uint32(field[i]) << 8
csum += uint32(field[i+1])
}
if len(field)%2 == 1 {
// only last field may have odd number of bytes
csum += uint32(field[length]) << 8
}
}
for csum > 0xffff {
csum = (csum >> 16) + (csum & 0xffff)
}
return ^uint16(csum + (csum >> 16))
}

65
network/packet/icmp.go Normal file
View File

@@ -0,0 +1,65 @@
package packet
import (
"encoding/binary"
"fmt"
"sync"
)
type ICMPv4 struct {
Type uint8
Code uint8
Checksum uint16
Payload []byte
}
var (
icmp4Pool *sync.Pool = &sync.Pool{
New: func() interface{} {
return &ICMPv4{}
},
}
)
func NewICMPv4() *ICMPv4 {
var zero ICMPv4
icmp4 := icmp4Pool.Get().(*ICMPv4)
*icmp4 = zero
return icmp4
}
func ReleaseICMPv4(icmp4 *ICMPv4) {
// clear internal slice references
icmp4.Payload = nil
icmp4Pool.Put(icmp4)
}
func ParseICMPv4(pkt []byte, icmp4 *ICMPv4) error {
if len(pkt) < 4 {
return fmt.Errorf("payload too small for ICMPv4: %d bytes", len(pkt))
}
icmp4.Type = uint8(pkt[0])
icmp4.Code = uint8(pkt[1])
icmp4.Checksum = binary.BigEndian.Uint16(pkt[2:4])
if len(pkt) > 4 {
icmp4.Payload = pkt[4:]
} else {
icmp4.Payload = nil
}
return nil
}
func (icmp4 *ICMPv4) Serialize(hdr []byte, ckFields ...[]byte) error {
if len(hdr) != 4 {
return fmt.Errorf("incorrect buffer size: %d buffer given, 4 needed", len(hdr))
}
hdr[0] = byte(icmp4.Type)
hdr[1] = byte(icmp4.Code)
hdr[2] = 0
hdr[3] = 0
icmp4.Checksum = Checksum(ckFields...)
binary.BigEndian.PutUint16(hdr[2:], icmp4.Checksum)
return nil
}

246
network/packet/ip4.go Normal file
View File

@@ -0,0 +1,246 @@
package packet
import (
"encoding/binary"
"fmt"
"net"
"sync"
"sync/atomic"
)
type IPv4Option struct {
OptionType uint8
OptionLength uint8
OptionData []byte
}
// IPProtocol is an enumeration of IP protocol values, and acts as a decoder
// for any type it supports.
type IPProtocol uint8
const (
IPProtocolIPv6HopByHop IPProtocol = 0
IPProtocolICMPv4 IPProtocol = 1
IPProtocolIGMP IPProtocol = 2
IPProtocolIPv4 IPProtocol = 4
IPProtocolTCP IPProtocol = 6
IPProtocolUDP IPProtocol = 17
IPProtocolRUDP IPProtocol = 27
IPProtocolIPv6 IPProtocol = 41
IPProtocolIPv6Routing IPProtocol = 43
IPProtocolIPv6Fragment IPProtocol = 44
IPProtocolGRE IPProtocol = 47
IPProtocolESP IPProtocol = 50
IPProtocolAH IPProtocol = 51
IPProtocolICMPv6 IPProtocol = 58
IPProtocolNoNextHeader IPProtocol = 59
IPProtocolIPv6Destination IPProtocol = 60
IPProtocolIPIP IPProtocol = 94
IPProtocolEtherIP IPProtocol = 97
IPProtocolSCTP IPProtocol = 132
IPProtocolUDPLite IPProtocol = 136
IPProtocolMPLSInIP IPProtocol = 137
IPv4_PSEUDO_LENGTH int = 12
)
type IPv4 struct {
Version uint8
IHL uint8
TOS uint8
Length uint16
Id uint16
Flags uint8
FragOffset uint16
TTL uint8
Protocol IPProtocol
Checksum uint16
SrcIP net.IP
DstIP net.IP
Options []IPv4Option
Padding []byte
Payload []byte
headerLength int
}
var (
ipv4Pool *sync.Pool = &sync.Pool{
New: func() interface{} {
return &IPv4{}
},
}
globalIPID uint32
)
func ReleaseIPv4(ip4 *IPv4) {
// clear internal slice references
ip4.SrcIP = nil
ip4.DstIP = nil
ip4.Options = nil
ip4.Padding = nil
ip4.Payload = nil
ipv4Pool.Put(ip4)
}
func NewIPv4() *IPv4 {
var zero IPv4
ip4 := ipv4Pool.Get().(*IPv4)
*ip4 = zero
return ip4
}
func IPID() uint16 {
return uint16(atomic.AddUint32(&globalIPID, 1) & 0x0000ffff)
}
func ParseIPv4(pkt []byte, ip4 *IPv4) error {
flagsfrags := binary.BigEndian.Uint16(pkt[6:8])
ip4.Version = uint8(pkt[0]) >> 4
ip4.IHL = uint8(pkt[0]) & 0x0F
ip4.TOS = pkt[1]
ip4.Length = binary.BigEndian.Uint16(pkt[2:4])
ip4.Id = binary.BigEndian.Uint16(pkt[4:6])
ip4.Flags = uint8(flagsfrags >> 13)
ip4.FragOffset = flagsfrags & 0x1FFF
ip4.TTL = pkt[8]
ip4.Protocol = IPProtocol(pkt[9])
ip4.Checksum = binary.BigEndian.Uint16(pkt[10:12])
ip4.SrcIP = pkt[12:16]
ip4.DstIP = pkt[16:20]
if ip4.Length < 20 {
return fmt.Errorf("Invalid (too small) IP length (%d < 20)", ip4.Length)
}
if ip4.IHL < 5 {
return fmt.Errorf("Invalid (too small) IP header length (%d < 5)", ip4.IHL)
}
if int(ip4.IHL*4) > int(ip4.Length) {
return fmt.Errorf("Invalid IP header length > IP length (%d > %d)", ip4.IHL, ip4.Length)
}
if int(ip4.IHL)*4 > len(pkt) {
return fmt.Errorf("Not all IP header bytes available")
}
ip4.Payload = pkt[ip4.IHL*4:]
rest := pkt[20 : ip4.IHL*4]
// Pull out IP options
for len(rest) > 0 {
if ip4.Options == nil {
// Pre-allocate to avoid growing the slice too much.
ip4.Options = make([]IPv4Option, 0, 4)
}
opt := IPv4Option{OptionType: rest[0]}
switch opt.OptionType {
case 0: // End of options
opt.OptionLength = 1
ip4.Options = append(ip4.Options, opt)
ip4.Padding = rest[1:]
break
case 1: // 1 byte padding
opt.OptionLength = 1
default:
opt.OptionLength = rest[1]
opt.OptionData = rest[2:opt.OptionLength]
}
if len(rest) >= int(opt.OptionLength) {
rest = rest[opt.OptionLength:]
} else {
return fmt.Errorf("IP option length exceeds remaining IP header size, option type %v length %v", opt.OptionType, opt.OptionLength)
}
ip4.Options = append(ip4.Options, opt)
}
return nil
}
func (ip *IPv4) PseudoHeader(buf []byte, proto IPProtocol, dataLen int) error {
if len(buf) != IPv4_PSEUDO_LENGTH {
return fmt.Errorf("incorrect buffer size: %d buffer given, %d needed", len(buf), IPv4_PSEUDO_LENGTH)
}
copy(buf[0:4], ip.SrcIP)
copy(buf[4:8], ip.DstIP)
buf[8] = 0
buf[9] = byte(proto)
binary.BigEndian.PutUint16(buf[10:], uint16(dataLen))
return nil
}
func (ip *IPv4) HeaderLength() int {
if ip.headerLength == 0 {
optionLength := uint8(0)
for _, opt := range ip.Options {
switch opt.OptionType {
case 0:
// this is the end of option lists
optionLength++
case 1:
// this is the padding
optionLength++
default:
optionLength += opt.OptionLength
}
}
// make sure the options are aligned to 32 bit boundary
if (optionLength % 4) != 0 {
optionLength += 4 - (optionLength % 4)
}
ip.IHL = 5 + (optionLength / 4)
ip.headerLength = int(optionLength) + 20
}
return ip.headerLength
}
func (ip *IPv4) flagsfrags() (ff uint16) {
ff |= uint16(ip.Flags) << 13
ff |= ip.FragOffset
return
}
func (ip *IPv4) Serialize(hdr []byte, dataLen int) error {
if len(hdr) != ip.HeaderLength() {
return fmt.Errorf("incorrect buffer size: %d buffer given, %d needed", len(hdr), ip.HeaderLength())
}
hdr[0] = (ip.Version << 4) | ip.IHL
hdr[1] = ip.TOS
ip.Length = uint16(ip.headerLength + dataLen)
binary.BigEndian.PutUint16(hdr[2:], ip.Length)
binary.BigEndian.PutUint16(hdr[4:], ip.Id)
binary.BigEndian.PutUint16(hdr[6:], ip.flagsfrags())
hdr[8] = ip.TTL
hdr[9] = byte(ip.Protocol)
copy(hdr[12:16], ip.SrcIP)
copy(hdr[16:20], ip.DstIP)
curLocation := 20
// Now, we will encode the options
for _, opt := range ip.Options {
switch opt.OptionType {
case 0:
// this is the end of option lists
hdr[curLocation] = 0
curLocation++
case 1:
// this is the padding
hdr[curLocation] = 1
curLocation++
default:
hdr[curLocation] = opt.OptionType
hdr[curLocation+1] = opt.OptionLength
// sanity checking to protect us from buffer overrun
if len(opt.OptionData) > int(opt.OptionLength-2) {
return fmt.Errorf("option length is smaller than length of option data")
}
copy(hdr[curLocation+2:curLocation+int(opt.OptionLength)], opt.OptionData)
curLocation += int(opt.OptionLength)
}
}
hdr[10] = 0
hdr[11] = 0
ip.Checksum = Checksum(hdr)
binary.BigEndian.PutUint16(hdr[10:], ip.Checksum)
return nil
}

34
network/tun/stop.go Normal file
View File

@@ -0,0 +1,34 @@
package tun
import (
"bytes"
"log"
"net"
)
var stopMarker = []byte{2, 2, 2, 2, 2, 2, 2, 2}
// Close of Windows and Linux tun/tap device do not interrupt blocking Read.
// sendStopMarker is used to issue a specific packet to notify threads blocking
// on Read.
func sendStopMarker(src, dst string) {
l, _ := net.ResolveUDPAddr("udp", src+":2222")
r, _ := net.ResolveUDPAddr("udp", dst+":2222")
conn, err := net.DialUDP("udp", l, r)
if err != nil {
log.Printf("fail to send stopmarker: %s", err)
return
}
defer conn.Close()
conn.Write(stopMarker)
}
func isStopMarker(pkt []byte, src, dst net.IP) bool {
n := len(pkt)
// at least should be 20(ip) + 8(udp) + 8(stopmarker)
if n < 20+8+8 {
return false
}
return pkt[0]&0xf0 == 0x40 && pkt[9] == 0x11 && src.Equal(pkt[12:16]) &&
dst.Equal(pkt[16:20]) && bytes.Compare(pkt[n-8:n], stopMarker) == 0
}

106
network/tun/tun_darwin.go Normal file
View File

@@ -0,0 +1,106 @@
package tun
import (
"fmt"
"io"
"os"
"os/exec"
"syscall"
"unsafe"
)
const (
appleUTUNCtl = "com.apple.net.utun_control"
appleCTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
)
type sockaddrCtl struct {
scLen uint8
scFamily uint8
ssSysaddr uint16
scID uint32
scUnit uint32
scReserved [5]uint32
}
type utunDev struct {
f *os.File
rBuf [2048]byte
wBuf [2048]byte
}
func (dev *utunDev) Read(data []byte) (int, error) {
n, e := dev.f.Read(dev.rBuf[:])
if n > 0 {
copy(data, dev.rBuf[4:n])
n -= 4
}
return n, e
}
// one packet, no more than MTU
func (dev *utunDev) Write(data []byte) (int, error) {
n := copy(dev.wBuf[4:], data)
return dev.f.Write(dev.wBuf[:n+4])
}
func (dev *utunDev) Close() error {
return dev.f.Close()
}
var sockaddrCtlSize uintptr = 32
func OpenTunDevice(name, addr, gw, mask string, dns []string) (io.ReadWriteCloser, error) {
fd, err := syscall.Socket(syscall.AF_SYSTEM, syscall.SOCK_DGRAM, 2)
if err != nil {
return nil, err
}
var ctlInfo = &struct {
ctlID uint32
ctlName [96]byte
}{}
copy(ctlInfo.ctlName[:], []byte(appleUTUNCtl))
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(appleCTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
if errno != 0 {
return nil, fmt.Errorf("error in syscall.Syscall(syscall.SYS_IOTL, ...): %v", errno)
}
addrP := unsafe.Pointer(&sockaddrCtl{
scLen: uint8(sockaddrCtlSize),
scFamily: syscall.AF_SYSTEM,
/* #define AF_SYS_CONTROL 2 */
ssSysaddr: 2,
scID: ctlInfo.ctlID,
scUnit: 0,
})
_, _, errno = syscall.RawSyscall(syscall.SYS_CONNECT, uintptr(fd), uintptr(addrP), uintptr(sockaddrCtlSize))
if errno != 0 {
return nil, fmt.Errorf("error in syscall.RawSyscall(syscall.SYS_CONNECT, ...): %v", errno)
}
var ifName struct {
name [16]byte
}
ifNameSize := uintptr(16)
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
2, /* #define SYSPROTO_CONTROL 2 */
2, /* #define UTUN_OPT_IFNAME 2 */
uintptr(unsafe.Pointer(&ifName)),
uintptr(unsafe.Pointer(&ifNameSize)), 0)
if errno != 0 {
return nil, fmt.Errorf("error in syscall.Syscall6(syscall.SYS_GETSOCKOPT, ...): %v", errno)
}
cmd := exec.Command("ifconfig", string(ifName.name[:ifNameSize-1]), "inet", addr, gw, "netmask", mask, "mtu", "1300", "up")
err = cmd.Run()
if err != nil {
syscall.Close(fd)
return nil, err
}
dev := &utunDev{
f: os.NewFile(uintptr(fd), string(ifName.name[:ifNameSize-1])),
}
copy(dev.wBuf[:], []byte{0, 0, 0, 2})
return dev, nil
}

97
network/tun/tun_linux.go Normal file
View File

@@ -0,0 +1,97 @@
package tun
import (
"errors"
"io"
"log"
"net"
"os"
"os/exec"
"syscall"
"unsafe"
)
const (
IFF_TUN = 0x0001
IFF_TAP = 0x0002
IFF_NO_PI = 0x1000
)
type ifReq struct {
Name [0x10]byte
Flags uint16
pad [0x28 - 0x10 - 2]byte
}
func OpenTunDevice(name, addr, gw, mask string, dns []string) (io.ReadWriteCloser, error) {
file, err := os.OpenFile("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
copy(req.Name[:], name)
req.Flags = IFF_TUN | IFF_NO_PI
log.Printf("openning tun device")
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TUNSETIFF), uintptr(unsafe.Pointer(&req)))
if errno != 0 {
err = errno
return nil, err
}
// config address
log.Printf("configuring tun device address")
cmd := exec.Command("ifconfig", name, addr, "netmask", mask, "mtu", "1300")
err = cmd.Run()
if err != nil {
file.Close()
log.Printf("failed to configure tun device address")
return nil, err
}
syscall.SetNonblock(int(file.Fd()), false)
return &tunDev{
f: file,
addr: addr,
addrIP: net.ParseIP(addr).To4(),
gw: gw,
gwIP: net.ParseIP(gw).To4(),
}, nil
}
func NewTunDev(fd uintptr, name string, addr string, gw string) io.ReadWriteCloser {
syscall.SetNonblock(int(fd), false)
return &tunDev{
f: os.NewFile(fd, name),
addr: addr,
addrIP: net.ParseIP(addr).To4(),
gw: gw,
gwIP: net.ParseIP(gw).To4(),
}
}
type tunDev struct {
name string
addr string
addrIP net.IP
gw string
gwIP net.IP
marker []byte
f *os.File
}
func (dev *tunDev) Read(data []byte) (int, error) {
n, e := dev.f.Read(data)
if e == nil && isStopMarker(data[:n], dev.addrIP, dev.gwIP) {
return 0, errors.New("received stop marker")
}
return n, e
}
func (dev *tunDev) Write(data []byte) (int, error) {
return dev.f.Write(data)
}
func (dev *tunDev) Close() error {
log.Printf("send stop marker")
sendStopMarker(dev.addr, dev.gw)
return dev.f.Close()
}

374
network/tun/tun_windows.go Normal file
View File

@@ -0,0 +1,374 @@
package tun
import (
"encoding/binary"
// "encoding/hex"
"errors"
"fmt"
"io"
"log"
"net"
"os/exec"
"strings"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
const (
TAPWIN32_MAX_REG_SIZE = 256
TUNTAP_COMPONENT_ID_0901 = "tap0901"
TUNTAP_COMPONENT_ID_0801 = "tap0801"
NETWORK_KEY = `SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}`
ADAPTER_KEY = `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`
)
func ctl_code(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func tap_control_code(request, method uint32) uint32 {
return ctl_code(34, request, method, 0)
}
var (
k32 = windows.NewLazySystemDLL("kernel32.dll")
procGetOverlappedResult = k32.NewProc("GetOverlappedResult")
TAP_IOCTL_GET_MTU = tap_control_code(3, 0)
TAP_IOCTL_SET_MEDIA_STATUS = tap_control_code(6, 0)
TAP_IOCTL_CONFIG_TUN = tap_control_code(10, 0)
TAP_WIN_IOCTL_CONFIG_DHCP_MASQ = tap_control_code(7, 0)
TAP_WIN_IOCTL_CONFIG_DHCP_SET_OPT = tap_control_code(9, 0)
)
func decodeUTF16(b []byte) string {
if len(b)%2 != 0 {
return ""
}
l := len(b) / 2
u16 := make([]uint16, l)
for i := 0; i < l; i += 1 {
u16[i] = uint16(b[2*i]) + (uint16(b[2*i+1]) << 8)
}
return windows.UTF16ToString(u16)
}
func getTuntapName(componentId string) (string, error) {
keyName := fmt.Sprintf(NETWORK_KEY+"\\%s\\Connection", componentId)
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
if err != nil {
key.Close()
return "", err
}
var bufLength uint32 = TAPWIN32_MAX_REG_SIZE
buf := make([]byte, bufLength)
name, _ := windows.UTF16FromString("Name")
var valtype uint32
err = windows.RegQueryValueEx(
windows.Handle(key),
&name[0],
nil,
&valtype,
&buf[0],
&bufLength,
)
if err != nil {
key.Close()
return "", err
}
s := decodeUTF16(buf)
return s, nil
}
func getTuntapComponentId() (string, error) {
adapters, err := registry.OpenKey(registry.LOCAL_MACHINE, ADAPTER_KEY, registry.READ)
if err != nil {
return "", err
}
var i uint32
for i = 0; i < 1000; i++ {
var name_length uint32 = TAPWIN32_MAX_REG_SIZE
buf := make([]uint16, name_length)
if err = windows.RegEnumKeyEx(
windows.Handle(adapters),
i,
&buf[0],
&name_length,
nil,
nil,
nil,
nil); err != nil {
return "", err
}
key_name := windows.UTF16ToString(buf[:])
adapter, err := registry.OpenKey(adapters, key_name, registry.READ)
if err != nil {
continue
}
name, _ := windows.UTF16FromString("ComponentId")
name2, _ := windows.UTF16FromString("NetCfgInstanceId")
var valtype uint32
var component_id = make([]byte, TAPWIN32_MAX_REG_SIZE)
var componentLen = uint32(len(component_id))
if err = windows.RegQueryValueEx(
windows.Handle(adapter),
&name[0],
nil,
&valtype,
&component_id[0],
&componentLen); err != nil {
continue
}
id := decodeUTF16(component_id)
if id == TUNTAP_COMPONENT_ID_0901 || id == TUNTAP_COMPONENT_ID_0801 {
var valtype uint32
var netCfgInstanceId = make([]byte, TAPWIN32_MAX_REG_SIZE)
var netCfgInstanceIdLen = uint32(len(netCfgInstanceId))
if err = windows.RegQueryValueEx(
windows.Handle(adapter),
&name2[0],
nil,
&valtype,
&netCfgInstanceId[0],
&netCfgInstanceIdLen); err != nil {
return "", err
}
s := decodeUTF16(netCfgInstanceId)
log.Printf("device component id: %s", s)
adapter.Close()
adapters.Close()
return s, nil
}
adapter.Close()
}
adapters.Close()
return "", errors.New("not found component id")
}
func OpenTunDevice(name, addr, gw, mask string, dns []string) (io.ReadWriteCloser, error) {
componentId, err := getTuntapComponentId()
if err != nil {
return nil, err
}
devId, _ := windows.UTF16FromString(fmt.Sprintf(`\\.\Global\%s.tap`, componentId))
devName, err := getTuntapName(componentId)
log.Printf("device name: %s", devName)
// set dhcp with netsh
cmd := exec.Command("netsh", "interface", "ip", "set", "address", devName, "dhcp")
cmd.Run()
cmd = exec.Command("netsh", "interface", "ip", "set", "dns", devName, "dhcp")
cmd.Run()
// open
fd, err := windows.CreateFile(
&devId[0],
windows.GENERIC_READ|windows.GENERIC_WRITE,
windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE,
nil,
windows.OPEN_EXISTING,
windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED,
//windows.FILE_ATTRIBUTE_SYSTEM,
0,
)
if err != nil {
return nil, err
}
// set addresses with dhcp
var returnLen uint32
tunAddr := net.ParseIP(addr).To4()
tunMask := net.ParseIP(mask).To4()
gwAddr := net.ParseIP(gw).To4()
addrParam := append(tunAddr, tunMask...)
addrParam = append(addrParam, gwAddr...)
lease := make([]byte, 4)
binary.BigEndian.PutUint32(lease[:], 86400)
addrParam = append(addrParam, lease...)
err = windows.DeviceIoControl(
fd,
TAP_WIN_IOCTL_CONFIG_DHCP_MASQ,
&addrParam[0],
uint32(len(addrParam)),
&addrParam[0],
uint32(len(addrParam)),
&returnLen,
nil,
)
if err != nil {
windows.Close(fd)
return nil, err
} else {
log.Printf("set %s with net/mask: %s/%s through DHCP", devName, addr, mask)
}
// set dns with dncp
dnsParam := []byte{6, 4}
primaryDNS := net.ParseIP(dns[0]).To4()
dnsParam = append(dnsParam, primaryDNS...)
if len(dns) >= 2 {
secondaryDNS := net.ParseIP(dns[1]).To4()
dnsParam = append(dnsParam, secondaryDNS...)
dnsParam[1] += 4
}
err = windows.DeviceIoControl(
fd,
TAP_WIN_IOCTL_CONFIG_DHCP_SET_OPT,
&dnsParam[0],
uint32(len(dnsParam)),
&addrParam[0],
uint32(len(dnsParam)),
&returnLen,
nil,
)
if err != nil {
windows.Close(fd)
return nil, err
} else {
log.Printf("set %s with dns: %s through DHCP", devName, strings.Join(dns, ","))
}
// set connect.
inBuffer := []byte("\x01\x00\x00\x00")
err = windows.DeviceIoControl(
fd,
TAP_IOCTL_SET_MEDIA_STATUS,
&inBuffer[0],
uint32(len(inBuffer)),
&inBuffer[0],
uint32(len(inBuffer)),
&returnLen,
nil,
)
if err != nil {
windows.Close(fd)
return nil, err
}
return newWinTapDev(fd, addr, gw), nil
}
type winTapDev struct {
fd windows.Handle
addr string
addrIP net.IP
gw string
gwIP net.IP
rBuf [2048]byte
wBuf [2048]byte
wInitiated bool
rOverlapped windows.Overlapped
wOverlapped windows.Overlapped
}
func newWinTapDev(fd windows.Handle, addr string, gw string) *winTapDev {
rOverlapped := windows.Overlapped{}
rEvent, _ := windows.CreateEvent(nil, 0, 0, nil)
rOverlapped.HEvent = windows.Handle(rEvent)
wOverlapped := windows.Overlapped{}
wEvent, _ := windows.CreateEvent(nil, 0, 0, nil)
wOverlapped.HEvent = windows.Handle(wEvent)
dev := &winTapDev{
fd: fd,
rOverlapped: rOverlapped,
wOverlapped: wOverlapped,
wInitiated: false,
addr: addr,
addrIP: net.ParseIP(addr).To4(),
gw: gw,
gwIP: net.ParseIP(gw).To4(),
}
return dev
}
func (dev *winTapDev) Read(data []byte) (int, error) {
for {
var done uint32
var nr int
err := windows.ReadFile(dev.fd, dev.rBuf[:], &done, &dev.rOverlapped)
if err != nil {
if err != windows.ERROR_IO_PENDING {
return 0, err
} else {
windows.WaitForSingleObject(dev.rOverlapped.HEvent, windows.INFINITE)
nr, err = getOverlappedResult(dev.fd, &dev.rOverlapped)
if err != nil {
return 0, err
}
}
} else {
nr = int(done)
}
if nr > 14 {
if isStopMarker(dev.rBuf[14:nr], dev.addrIP, dev.gwIP) {
return 0, errors.New("received stop marker")
}
// discard IPv6 packets
if dev.rBuf[14]&0xf0 == 0x60 {
//log.Printf("ipv6 packet")
continue
} else if dev.rBuf[14]&0xf0 == 0x40 {
if !dev.wInitiated {
// copy ether header for writing
copy(dev.wBuf[:], dev.rBuf[6:12])
copy(dev.wBuf[6:], dev.rBuf[0:6])
copy(dev.wBuf[12:], dev.rBuf[12:14])
dev.wInitiated = true
}
copy(data, dev.rBuf[14:nr])
return nr - 14, nil
}
}
}
}
func (dev *winTapDev) Write(data []byte) (int, error) {
var done uint32
var nw int
payloadL := copy(dev.wBuf[14:], data)
packetL := payloadL + 14
err := windows.WriteFile(dev.fd, dev.wBuf[:packetL], &done, &dev.wOverlapped)
if err != nil {
if err != windows.ERROR_IO_PENDING {
return 0, err
} else {
windows.WaitForSingleObject(dev.wOverlapped.HEvent, windows.INFINITE)
nw, err = getOverlappedResult(dev.fd, &dev.wOverlapped)
if err != nil {
return 0, err
}
}
} else {
nw = int(done)
}
if nw != packetL {
return 0, fmt.Errorf("write %d packet (%d bytes payload), return %d", packetL, payloadL, nw)
} else {
return payloadL, nil
}
}
func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) {
var n int
r, _, err := syscall.Syscall6(procGetOverlappedResult.Addr(), 4,
uintptr(h),
uintptr(unsafe.Pointer(overlapped)),
uintptr(unsafe.Pointer(&n)), 1, 0, 0)
if r == 0 {
return n, err
}
return n, nil
}
func (dev *winTapDev) Close() error {
log.Printf("close winTap device")
sendStopMarker(dev.addr, dev.gw)
return windows.Close(dev.fd)
}

230
packet.go Normal file
View File

@@ -0,0 +1,230 @@
package grass
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"hash/crc32"
"io"
"math/rand"
"time"
)
const PacketMagicNumber uint8 = 114514 % 256
const (
GrassPacketIPv4 uint8 = 0
GrassPacketPing uint8 = 1
GrassPacketInfo uint8 = 2
GrassPacketPong uint8 = 3
GrassPacketClients uint8 = 4
GrassPacketFarPing uint8 = 5
GrassPacketExpired uint8 = 255
)
type Packet struct {
TmpKey uint32
Timestamp uint16
Type uint8
CRC32 uint32
Length uint32
Content []byte
}
type QueuedPacket struct {
Type uint8
Data []byte
}
func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
func pkcs5UnPadding(origData []byte, blockSize int) ([]byte, error) {
length := len(origData)
unpadding := int(origData[length-1])
if unpadding == 0 || unpadding > blockSize {
return nil, fmt.Errorf("Invalid padding")
}
for i := length - unpadding; i < length; i++ {
if int(origData[i]) != unpadding {
return nil, fmt.Errorf("Invalid padding")
}
}
return origData[:length-unpadding], nil
}
func decodePacket(r *Packet, s io.Reader, sharedKey []byte) error {
buf := make([]byte, 4)
_, err := io.ReadFull(s, buf)
if err != nil {
return err
}
r.TmpKey = binary.LittleEndian.Uint32(buf)
tmpBlock, _ := aes.NewCipher(sharedKey)
buffer := bytes.NewBuffer(nil)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
key := buffer.Bytes()
tmpBlock.Decrypt(key, key)
buf = make([]byte, 16)
_, err = io.ReadFull(s, buf)
if err != nil {
return err
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
cryptor := cipher.NewCBCDecrypter(block, key)
cryptor.CryptBlocks(buf, buf)
if uint8(buf[0]) != PacketMagicNumber {
return fmt.Errorf("Packet magic number mismatch")
}
r.Timestamp = binary.LittleEndian.Uint16(buf[1:3])
r.Type = uint8(buf[3])
r.CRC32 = binary.LittleEndian.Uint32(buf[4:8])
r.Length = binary.LittleEndian.Uint32(buf[8:12])
buf[4] = 0
buf[5] = 0
buf[6] = 0
buf[7] = 0
hash := crc32.NewIEEE()
hash.Write(buf)
buf = make([]byte, r.Length)
n, err := io.ReadFull(s, buf)
if err != nil {
return err
}
nt := n - n%16
cryptor.CryptBlocks(buf[:nt], buf[:nt])
hash.Write(buf)
sum := hash.Sum32()
if sum != r.CRC32 {
return fmt.Errorf("CRC32 mismatch")
}
r.Content, err = pkcs5UnPadding(buf[:nt], 16)
if err != nil {
return err
}
return nil
}
func encodePacket(r *Packet, s Writer, sharedKey []byte) error {
tmpBlock, _ := aes.NewCipher(sharedKey)
buffer := bytes.NewBuffer(nil)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
binary.Write(buffer, binary.LittleEndian, r.TmpKey)
key := buffer.Bytes()
tmpBlock.Decrypt(key, key)
r.Timestamp = uint16(time.Now().Unix() % 65536)
r.CRC32 = 0
paddedContent := pkcs5Padding(r.Content, 16)
nadd := rand.Intn(16)
r.Length = uint32(len(paddedContent) + nadd)
buffer = bytes.NewBuffer(nil)
binary.Write(buffer, binary.LittleEndian, PacketMagicNumber)
binary.Write(buffer, binary.LittleEndian, r.Timestamp)
binary.Write(buffer, binary.LittleEndian, r.Type)
binary.Write(buffer, binary.LittleEndian, r.CRC32)
binary.Write(buffer, binary.LittleEndian, r.Length)
buffer.Write([]byte{0, 0, 0, 0})
buf := buffer.Bytes()
buf = append(buf, paddedContent...)
tmp := make([]byte, nadd)
rand.Read(tmp)
buf = append(buf, tmp...)
r.CRC32 = crc32.ChecksumIEEE(buf)
binary.LittleEndian.PutUint32(buf[4:8], r.CRC32)
block, err := aes.NewCipher(key)
if err != nil {
return err
}
cryptor := cipher.NewCBCEncrypter(block, key)
nt := 16 + uint(r.Length-r.Length%16)
cryptor.CryptBlocks(buf[:nt], buf[:nt])
buf2 := make([]byte, 4)
binary.LittleEndian.PutUint32(buf2, r.TmpKey)
s.M.Lock()
_, err = s.W.Write(buf2)
if err != nil {
s.M.Unlock()
return err
}
_, err = s.W.Write(buf)
if err != nil {
s.M.Unlock()
return err
}
err = s.W.Flush()
if err != nil {
s.M.Unlock()
return err
}
s.M.Unlock()
return nil
}
func ReadPacket(pkt *Packet, conn *Conn, client *Client) error {
if err := decodePacket(pkt, conn.Reader, client.Key); err != nil {
return err
}
if conn.peerTmpKeys.HasKey(pkt.TmpKey) {
pkt.Type = GrassPacketExpired
return nil
}
a := time.Now().Unix() % 65536
b := int64(pkt.Timestamp)
if b-a < 120 && a-b < 120 {
conn.peerTmpKeys.AddKey(pkt.TmpKey)
return nil
}
if b+65536-a < 120 || a+65536-b < 120 {
conn.peerTmpKeys.AddKey(pkt.TmpKey)
return nil
}
pkt.Type = GrassPacketExpired
return nil
}
func WritePacket(pkt *Packet, conn *Conn, client *Client) error {
for {
pkt.TmpKey = rand.Uint32()
if !conn.tmpKeys.HasKey(pkt.TmpKey) {
break
}
}
conn.tmpKeys.AddKey(pkt.TmpKey)
if err := encodePacket(pkt, conn.Writer, client.Key); err != nil {
return err
}
return nil
}
func WriteRawPacket(typ uint8, content []byte, conn *Conn, client *Client) error {
var pkt Packet
pkt.Type = typ
pkt.Content = content
return WritePacket(&pkt, conn, client)
}

46
tmpkey.go Normal file
View File

@@ -0,0 +1,46 @@
package grass
import (
"sync"
"time"
)
type TmpKey struct {
Key uint32
Time uint
}
type TmpKeySet struct {
Keys []TmpKey
Map map[uint32]struct{}
Mutex sync.Mutex
Expire uint
}
func (s *TmpKeySet) HasKey(key uint32) bool {
s.Mutex.Lock()
defer s.Mutex.Unlock()
tm := uint(time.Now().Unix())
if len(s.Keys) > 0 {
i := 0
for i < len(s.Keys) && s.Keys[i].Time < tm-s.Expire {
delete(s.Map, s.Keys[i].Key)
i++
}
s.Keys = s.Keys[i:]
}
_, ok := s.Map[key]
return ok
}
func (s *TmpKeySet) AddKey(key uint32) {
s.Mutex.Lock()
defer s.Mutex.Unlock()
tm := uint(time.Now().Unix())
s.Keys = append(s.Keys, TmpKey{Key: key, Time: tm})
s.Map[key] = struct{}{}
}
func NewTmpKeySet(expire uint) *TmpKeySet {
return &TmpKeySet{Expire: expire, Map: make(map[uint32]struct{})}
}

19
writer.go Normal file
View File

@@ -0,0 +1,19 @@
package grass
import (
"bufio"
"io"
"sync"
)
type Writer struct {
W *bufio.Writer
M *sync.Mutex
}
func NewWriter(w io.Writer) Writer {
var b Writer
b.W = bufio.NewWriter(w)
b.M = &sync.Mutex{}
return b
}