mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-05 15:26:57 +08:00
625 lines
13 KiB
Go
625 lines
13 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/gopacket"
|
|
"github.com/google/gopacket/layers"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/wencaiwulue/kubevpn/pkg/config"
|
|
"github.com/wencaiwulue/kubevpn/pkg/util"
|
|
)
|
|
|
|
const (
|
|
MaxSize = 1000
|
|
MaxThread = 10
|
|
MaxConn = 1
|
|
)
|
|
|
|
type tunHandler struct {
|
|
chain *Chain
|
|
node *Node
|
|
routeNAT *NAT
|
|
// map[srcIP]net.Conn
|
|
routeConnNAT *sync.Map
|
|
chExit chan error
|
|
}
|
|
|
|
type NAT struct {
|
|
lock *sync.RWMutex
|
|
routes map[string][]net.Addr
|
|
}
|
|
|
|
func NewNAT() *NAT {
|
|
return &NAT{
|
|
lock: &sync.RWMutex{},
|
|
routes: map[string][]net.Addr{},
|
|
}
|
|
}
|
|
|
|
func (n *NAT) RemoveAddr(addr net.Addr) (count int) {
|
|
n.lock.Lock()
|
|
defer n.lock.Unlock()
|
|
for k, v := range n.routes {
|
|
for i := 0; i < len(v); i++ {
|
|
if v[i].String() == addr.String() {
|
|
v = append(v[:i], v[i+1:]...)
|
|
i--
|
|
count++
|
|
}
|
|
}
|
|
n.routes[k] = v
|
|
}
|
|
return
|
|
}
|
|
|
|
func (n *NAT) LoadOrStore(to net.IP, addr net.Addr) (result net.Addr, load bool) {
|
|
n.lock.RLock()
|
|
addrList := n.routes[to.String()]
|
|
n.lock.RUnlock()
|
|
for _, add := range addrList {
|
|
if add.String() == addr.String() {
|
|
load = true
|
|
result = addr
|
|
return
|
|
}
|
|
}
|
|
|
|
n.lock.Lock()
|
|
defer n.lock.Unlock()
|
|
if addrList == nil {
|
|
n.routes[to.String()] = []net.Addr{addr}
|
|
result = addr
|
|
return
|
|
} else {
|
|
n.routes[to.String()] = append(n.routes[to.String()], addr)
|
|
result = addr
|
|
return
|
|
}
|
|
}
|
|
|
|
func (n *NAT) RouteTo(ip net.IP) net.Addr {
|
|
n.lock.RLock()
|
|
defer n.lock.RUnlock()
|
|
addrList := n.routes[ip.String()]
|
|
if len(addrList) == 0 {
|
|
return nil
|
|
}
|
|
// for load balance
|
|
index := rand.Intn(len(n.routes[ip.String()]))
|
|
return addrList[index]
|
|
}
|
|
|
|
func (n *NAT) Remove(ip net.IP, addr net.Addr) {
|
|
n.lock.Lock()
|
|
defer n.lock.Unlock()
|
|
|
|
addrList, ok := n.routes[ip.String()]
|
|
if !ok {
|
|
return
|
|
}
|
|
for i := 0; i < len(addrList); i++ {
|
|
if addrList[i].String() == addr.String() {
|
|
addrList = append(addrList[:i], addrList[i+1:]...)
|
|
i--
|
|
}
|
|
}
|
|
n.routes[ip.String()] = addrList
|
|
return
|
|
}
|
|
|
|
func (n *NAT) Range(f func(key string, v []net.Addr)) {
|
|
n.lock.RLock()
|
|
defer n.lock.RUnlock()
|
|
for k, v := range n.routes {
|
|
f(k, v)
|
|
}
|
|
}
|
|
|
|
// TunHandler creates a handler for tun tunnel.
|
|
func TunHandler(chain *Chain, node *Node) Handler {
|
|
return &tunHandler{
|
|
chain: chain,
|
|
node: node,
|
|
routeNAT: RouteNAT,
|
|
routeConnNAT: RouteConnNAT,
|
|
chExit: make(chan error, 1),
|
|
}
|
|
}
|
|
|
|
func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) {
|
|
if h.node.Remote != "" {
|
|
h.HandleClient(ctx, tun)
|
|
} else {
|
|
h.HandleServer(ctx, tun)
|
|
}
|
|
}
|
|
|
|
func (h tunHandler) printRoute() {
|
|
for {
|
|
select {
|
|
case <-time.Tick(time.Second * 5):
|
|
var i int
|
|
var sb strings.Builder
|
|
h.routeNAT.Range(func(key string, value []net.Addr) {
|
|
i++
|
|
var s []string
|
|
for _, addr := range value {
|
|
if addr != nil {
|
|
s = append(s, addr.String())
|
|
}
|
|
}
|
|
if len(s) != 0 {
|
|
sb.WriteString(fmt.Sprintf("to: %s, route: %s\n", key, strings.Join(s, " ")))
|
|
}
|
|
})
|
|
log.Debug(sb.String())
|
|
log.Debug(i)
|
|
}
|
|
}
|
|
}
|
|
|
|
type Device struct {
|
|
tun net.Conn
|
|
thread int
|
|
|
|
tunInboundRaw chan *DataElem
|
|
tunInbound chan *DataElem
|
|
tunOutbound chan *DataElem
|
|
|
|
// your main logic
|
|
tunInboundHandler func(tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem)
|
|
|
|
chExit chan error
|
|
}
|
|
|
|
func (d *Device) readFromTun() {
|
|
for {
|
|
b := config.LPool.Get().([]byte)[:]
|
|
n, err := d.tun.Read(b[:])
|
|
if err != nil {
|
|
select {
|
|
case d.chExit <- err:
|
|
default:
|
|
}
|
|
return
|
|
}
|
|
d.tunInboundRaw <- &DataElem{
|
|
data: b[:],
|
|
length: n,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Device) writeToTun() {
|
|
for e := range d.tunOutbound {
|
|
_, err := d.tun.Write(e.data[:e.length])
|
|
config.LPool.Put(e.data[:])
|
|
if err != nil {
|
|
select {
|
|
case d.chExit <- err:
|
|
default:
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Device) parseIPHeader() {
|
|
for e := range d.tunInboundRaw {
|
|
if util.IsIPv4(e.data[:e.length]) {
|
|
// ipv4.ParseHeader
|
|
b := e.data[:e.length]
|
|
e.src = net.IPv4(b[12], b[13], b[14], b[15])
|
|
e.dst = net.IPv4(b[16], b[17], b[18], b[19])
|
|
} else if util.IsIPv6(e.data[:e.length]) {
|
|
// ipv6.ParseHeader
|
|
e.src = e.data[:e.length][8:24]
|
|
e.dst = e.data[:e.length][24:40]
|
|
} else {
|
|
log.Errorf("[tun-packet] unknown packet")
|
|
continue
|
|
}
|
|
|
|
log.Debugf("[tun] %s --> %s, length: %d", e.src, e.dst, e.length)
|
|
d.tunInbound <- e
|
|
}
|
|
}
|
|
|
|
func (d *Device) Close() {
|
|
d.tun.Close()
|
|
}
|
|
|
|
func heartbeats(tun net.Conn, in chan<- *DataElem) {
|
|
conn, err := util.GetTunDeviceByConn(tun)
|
|
if err != nil {
|
|
log.Errorf("get tun device error: %s", err.Error())
|
|
return
|
|
}
|
|
srcIPv4, srcIPv6, err := util.GetLocalTunIP(conn.Name)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if config.RouterIP.To4().Equal(srcIPv4) {
|
|
return
|
|
}
|
|
if config.RouterIP6.To4().Equal(srcIPv6) {
|
|
return
|
|
}
|
|
|
|
var bytes []byte
|
|
var bytes6 []byte
|
|
|
|
ticker := time.NewTicker(time.Second * 5)
|
|
defer ticker.Stop()
|
|
|
|
for ; true; <-ticker.C {
|
|
for i := 0; i < 4; i++ {
|
|
if bytes == nil {
|
|
bytes, err = genICMPPacket(srcIPv4, config.RouterIP)
|
|
if err != nil {
|
|
log.Errorf("generate ipv4 packet error: %s", err.Error())
|
|
continue
|
|
}
|
|
}
|
|
if bytes6 == nil {
|
|
bytes6, err = genICMPPacketIPv6(srcIPv6, config.RouterIP6)
|
|
if err != nil {
|
|
log.Errorf("generate ipv6 packet error: %s", err.Error())
|
|
continue
|
|
}
|
|
}
|
|
for index, i2 := range [][]byte{bytes, bytes6} {
|
|
data := config.LPool.Get().([]byte)[:]
|
|
length := copy(data, i2)
|
|
var src, dst net.IP
|
|
if index == 0 {
|
|
src, dst = srcIPv4, config.RouterIP
|
|
} else {
|
|
src, dst = srcIPv6, config.RouterIP6
|
|
}
|
|
in <- &DataElem{
|
|
data: data[:],
|
|
length: length,
|
|
src: src,
|
|
dst: dst,
|
|
}
|
|
}
|
|
time.Sleep(time.Second)
|
|
}
|
|
}
|
|
}
|
|
|
|
func genICMPPacket(src net.IP, dst net.IP) ([]byte, error) {
|
|
buf := gopacket.NewSerializeBuffer()
|
|
icmpLayer := layers.ICMPv4{
|
|
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
|
Id: 3842,
|
|
Seq: 1,
|
|
}
|
|
ipLayer := layers.IPv4{
|
|
Version: 4,
|
|
SrcIP: src,
|
|
DstIP: dst,
|
|
Protocol: layers.IPProtocolICMPv4,
|
|
Flags: layers.IPv4DontFragment,
|
|
TTL: 64,
|
|
IHL: 5,
|
|
Id: 55664,
|
|
}
|
|
opts := gopacket.SerializeOptions{
|
|
FixLengths: true,
|
|
ComputeChecksums: true,
|
|
}
|
|
err := gopacket.SerializeLayers(buf, opts, &ipLayer, &icmpLayer)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize icmp packet, err: %v", err)
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func genICMPPacketIPv6(src net.IP, dst net.IP) ([]byte, error) {
|
|
buf := gopacket.NewSerializeBuffer()
|
|
icmpLayer := layers.ICMPv6{
|
|
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
|
|
}
|
|
ipLayer := layers.IPv6{
|
|
Version: 6,
|
|
SrcIP: src,
|
|
DstIP: dst,
|
|
NextHeader: layers.IPProtocolICMPv6,
|
|
HopLimit: 255,
|
|
}
|
|
opts := gopacket.SerializeOptions{
|
|
FixLengths: true,
|
|
}
|
|
err := gopacket.SerializeLayers(buf, opts, &ipLayer, &icmpLayer)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to serialize icmp6 packet, err: %v", err)
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func (d *Device) Start(ctx context.Context) {
|
|
go d.readFromTun()
|
|
for i := 0; i < d.thread; i++ {
|
|
go d.parseIPHeader()
|
|
}
|
|
go d.tunInboundHandler(d.tunInbound, d.tunOutbound)
|
|
go d.writeToTun()
|
|
go heartbeats(d.tun, d.tunInbound)
|
|
|
|
select {
|
|
case err := <-d.chExit:
|
|
log.Errorf("device exit: %s", err.Error())
|
|
return
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
|
|
func (d *Device) SetTunInboundHandler(handler func(tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem)) {
|
|
d.tunInboundHandler = handler
|
|
}
|
|
|
|
func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) {
|
|
go h.printRoute()
|
|
|
|
device := &Device{
|
|
tun: tun,
|
|
thread: MaxThread,
|
|
tunInboundRaw: make(chan *DataElem, MaxSize),
|
|
tunInbound: make(chan *DataElem, MaxSize),
|
|
tunOutbound: make(chan *DataElem, MaxSize),
|
|
chExit: h.chExit,
|
|
}
|
|
device.SetTunInboundHandler(func(tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem) {
|
|
for {
|
|
packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", h.node.Addr)
|
|
if err != nil {
|
|
log.Debugf("[udp] can not listen %s, err: %v", h.node.Addr, err)
|
|
return
|
|
}
|
|
err = transportTun(ctx, tunInbound, tunOutbound, packetConn, h.routeNAT, h.routeConnNAT)
|
|
if err != nil {
|
|
log.Debugf("[tun] %s: %v", tun.LocalAddr(), err)
|
|
}
|
|
}
|
|
})
|
|
|
|
defer device.Close()
|
|
device.Start(ctx)
|
|
}
|
|
|
|
type DataElem struct {
|
|
data []byte
|
|
length int
|
|
src net.IP
|
|
dst net.IP
|
|
}
|
|
|
|
func NewDataElem(data []byte, length int, src net.IP, dst net.IP) *DataElem {
|
|
return &DataElem{
|
|
data: data,
|
|
length: length,
|
|
src: src,
|
|
dst: dst,
|
|
}
|
|
}
|
|
|
|
func (d *DataElem) Data() []byte {
|
|
return d.data
|
|
}
|
|
|
|
func (d *DataElem) Length() int {
|
|
return d.length
|
|
}
|
|
|
|
type udpElem struct {
|
|
from net.Addr
|
|
data []byte
|
|
length int
|
|
src net.IP
|
|
dst net.IP
|
|
}
|
|
|
|
type Peer struct {
|
|
conn net.PacketConn
|
|
thread int
|
|
|
|
connInbound chan *udpElem
|
|
parsedConnInfo chan *udpElem
|
|
|
|
tunInbound <-chan *DataElem
|
|
tunOutbound chan<- *DataElem
|
|
|
|
routeNAT *NAT
|
|
// map[srcIP]net.Conn
|
|
// routeConnNAT sync.Map
|
|
routeConnNAT *sync.Map
|
|
|
|
errChan chan error
|
|
}
|
|
|
|
func (p *Peer) sendErr(err error) {
|
|
select {
|
|
case p.errChan <- err:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (p *Peer) readFromConn() {
|
|
for {
|
|
b := config.LPool.Get().([]byte)[:]
|
|
n, srcAddr, err := p.conn.ReadFrom(b[:])
|
|
if err != nil {
|
|
p.sendErr(err)
|
|
return
|
|
}
|
|
p.connInbound <- &udpElem{
|
|
from: srcAddr,
|
|
data: b[:],
|
|
length: n,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) readFromTCPConn() {
|
|
for packet := range Chan {
|
|
u := &udpElem{
|
|
data: packet.Data[:],
|
|
length: int(packet.DataLength),
|
|
}
|
|
b := packet.Data
|
|
if util.IsIPv4(packet.Data) {
|
|
// ipv4.ParseHeader
|
|
u.src = net.IPv4(b[12], b[13], b[14], b[15])
|
|
u.dst = net.IPv4(b[16], b[17], b[18], b[19])
|
|
} else if util.IsIPv6(packet.Data) {
|
|
// ipv6.ParseHeader
|
|
u.src = b[8:24]
|
|
u.dst = b[24:40]
|
|
} else {
|
|
log.Errorf("[tun-conn] unknown packet")
|
|
continue
|
|
}
|
|
log.Debugf("[tcpserver] udp-tun %s >>> %s length: %d", u.src, u.dst, u.length)
|
|
p.parsedConnInfo <- u
|
|
}
|
|
}
|
|
|
|
func (p *Peer) parseHeader() {
|
|
var firstIPv4, firstIPv6 = true, true
|
|
for e := range p.connInbound {
|
|
b := e.data[:e.length]
|
|
if util.IsIPv4(e.data[:e.length]) {
|
|
// ipv4.ParseHeader
|
|
e.src = net.IPv4(b[12], b[13], b[14], b[15])
|
|
e.dst = net.IPv4(b[16], b[17], b[18], b[19])
|
|
} else if util.IsIPv6(e.data[:e.length]) {
|
|
// ipv6.ParseHeader
|
|
e.src = b[:e.length][8:24]
|
|
e.dst = b[:e.length][24:40]
|
|
} else {
|
|
log.Errorf("[tun] unknown packet")
|
|
continue
|
|
}
|
|
|
|
if firstIPv4 || firstIPv6 {
|
|
if util.IsIPv4(e.data[:e.length]) {
|
|
firstIPv4 = false
|
|
} else {
|
|
firstIPv6 = false
|
|
}
|
|
if _, loaded := p.routeNAT.LoadOrStore(e.src, e.from); loaded {
|
|
log.Debugf("[tun] find route: %s -> %s", e.src, e.from)
|
|
} else {
|
|
log.Debugf("[tun] new route: %s -> %s", e.src, e.from)
|
|
}
|
|
}
|
|
p.parsedConnInfo <- e
|
|
}
|
|
}
|
|
|
|
func (p *Peer) routePeer() {
|
|
for e := range p.parsedConnInfo {
|
|
if routeToAddr := p.routeNAT.RouteTo(e.dst); routeToAddr != nil {
|
|
log.Debugf("[tun] find route: %s -> %s", e.dst, routeToAddr)
|
|
_, err := p.conn.WriteTo(e.data[:e.length], routeToAddr)
|
|
config.LPool.Put(e.data[:])
|
|
if err != nil {
|
|
p.sendErr(err)
|
|
return
|
|
}
|
|
} else if conn, ok := p.routeConnNAT.Load(e.dst.String()); ok {
|
|
dgram := newDatagramPacket(e.data[:e.length])
|
|
if err := dgram.Write(conn.(net.Conn)); err != nil {
|
|
log.Debugf("[tcpserver] udp-tun %s <- %s : %s", conn.(net.Conn).RemoteAddr(), dgram.Addr(), err)
|
|
p.sendErr(err)
|
|
return
|
|
}
|
|
config.LPool.Put(e.data[:])
|
|
} else {
|
|
p.tunOutbound <- &DataElem{
|
|
data: e.data,
|
|
length: e.length,
|
|
src: e.src,
|
|
dst: e.dst,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) routeTUN() {
|
|
for e := range p.tunInbound {
|
|
if addr := p.routeNAT.RouteTo(e.dst); addr != nil {
|
|
log.Debugf("[tun] find route: %s -> %s", e.dst, addr)
|
|
_, err := p.conn.WriteTo(e.data[:e.length], addr)
|
|
config.LPool.Put(e.data[:])
|
|
if err != nil {
|
|
log.Debugf("[tun] can not route: %s -> %s", e.dst, addr)
|
|
p.sendErr(err)
|
|
return
|
|
}
|
|
} else if conn, ok := p.routeConnNAT.Load(e.dst.String()); ok {
|
|
dgram := newDatagramPacket(e.data[:e.length])
|
|
err := dgram.Write(conn.(net.Conn))
|
|
config.LPool.Put(e.data[:])
|
|
if err != nil {
|
|
log.Debugf("[tcpserver] udp-tun %s <- %s : %s", conn.(net.Conn).RemoteAddr(), dgram.Addr(), err)
|
|
p.sendErr(err)
|
|
return
|
|
}
|
|
} else {
|
|
config.LPool.Put(e.data[:])
|
|
log.Debug(fmt.Errorf("[tun] no route for %s -> %s", e.src, e.dst))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) Start() {
|
|
go p.readFromConn()
|
|
go p.readFromTCPConn()
|
|
for i := 0; i < p.thread; i++ {
|
|
go p.parseHeader()
|
|
}
|
|
go p.routePeer()
|
|
go p.routeTUN()
|
|
}
|
|
|
|
func (p *Peer) Close() {
|
|
p.conn.Close()
|
|
}
|
|
|
|
func transportTun(ctx context.Context, tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem, packetConn net.PacketConn, nat *NAT, connNAT *sync.Map) error {
|
|
p := &Peer{
|
|
conn: packetConn,
|
|
thread: MaxThread,
|
|
connInbound: make(chan *udpElem, MaxSize),
|
|
parsedConnInfo: make(chan *udpElem, MaxSize),
|
|
tunInbound: tunInbound,
|
|
tunOutbound: tunOutbound,
|
|
routeNAT: nat,
|
|
routeConnNAT: connNAT,
|
|
errChan: make(chan error, 2),
|
|
}
|
|
|
|
defer p.Close()
|
|
p.Start()
|
|
|
|
select {
|
|
case err := <-p.errChan:
|
|
log.Errorf(err.Error())
|
|
return err
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|