Remove endpoint

This commit is contained in:
shawnlu
2024-06-08 17:13:09 +08:00
parent b0d37148bf
commit 83fd5f45f3
5 changed files with 165 additions and 179 deletions

View File

@@ -2,81 +2,33 @@ package server
import (
"fmt"
"time"
"net"
"github.com/lucheng0127/virtuallan/pkg/users"
log "github.com/sirupsen/logrus"
)
type Endpoints struct {
IP string
Addr string
Beat chan string
}
var EPMap map[string]*Endpoints
func init() {
EPMap = make(map[string]*Endpoints, 1024)
}
func (ep *Endpoints) Close() {
client, ok := UPool[ep.Addr]
if ok {
delete(users.UserEPMap, client.User)
client.Close()
}
delete(EPMap, ep.IP)
}
func (ep *Endpoints) Countdown() {
for {
select {
case <-ep.Beat:
continue
case <-time.After(50 * time.Second):
log.Infof("endpoint %s with raddr %s don't get keepalive pkt for long time, close it\n", ep.IP, ep.Addr)
ep.Close()
return
}
}
}
func GetOrCreateEp(ip, raddr string) (*Endpoints, error) {
ep, ok := EPMap[ip]
if ok {
if ep.Addr != raddr {
return nil, fmt.Errorf("ip %s used by other endpoint", ip)
}
return ep, nil
}
ep = new(Endpoints)
ep.IP = ip
ep.Addr = raddr
ep.Beat = make(chan string)
EPMap[ip] = ep
go ep.Countdown()
return ep, nil
}
func HandleKeepalive(ip, raddr string) error {
func HandleKeepalive(ipAddr, raddr string, svc *Server) error {
c, ok := UPool[raddr]
if !ok {
return fmt.Errorf("unauthed client")
}
log.Debugf("handle keepalive pkt for %s with ip %s", raddr, ip)
log.Debugf("handle keepalive pkt for %s with ip %s", raddr, ipAddr)
ep, err := GetOrCreateEp(ip, raddr)
if err != nil {
return err
// TODO: IP maybe conflict, use dhcp
ip := net.ParseIP(ipAddr).To4()
if !svc.IPInPool(ip) {
c.IP = ip
ipIdx := svc.IdxFromIP(ip)
svc.MLock.Lock()
svc.UsedIP = append(svc.UsedIP, ipIdx)
svc.MLock.Unlock()
go c.Countdown()
}
c.IP = ip
ep.Beat <- "ok"
c.Beat <- "ok"
return nil
}

View File

@@ -1,12 +1,14 @@
package server
import (
"fmt"
"net"
"os"
"sync"
"time"
"github.com/lucheng0127/virtuallan/pkg/packet"
"github.com/lucheng0127/virtuallan/pkg/users"
"github.com/lucheng0127/virtuallan/pkg/utils"
log "github.com/sirupsen/logrus"
"github.com/songgao/water"
@@ -19,8 +21,10 @@ type UClient struct {
NetToIface chan *packet.VLPkt
Once sync.Once
User string
IP string
IP net.IP
Login string
Beat chan string
Svc *Server
}
var UPool map[string]*UClient
@@ -40,9 +44,24 @@ func (client *UClient) Close() {
log.Error(err)
}
client.Svc.ReleaseIP(client.IP)
delete(users.UserEPMap, client.User)
delete(UPool, client.RAddr.String())
}
func (client *UClient) Countdown() {
for {
select {
case <-client.Beat:
continue
case <-time.After(50 * time.Second):
log.Infof("endpoint %s with raddr %s don't get keepalive pkt for long time, close it\n", client.IP, client.RAddr.String())
client.Close()
return
}
}
}
func (client *UClient) Handle() {
go func() {
for {
@@ -123,6 +142,116 @@ func (svc *Server) CreateClientForAddr(addr *net.UDPAddr, conn *net.UDPConn) (*U
client.NetToIface = make(chan *packet.VLPkt, 1024)
client.Login = time.Now().Format("2006-01-02 15:04:05")
client.Once = sync.Once{}
client.Beat = make(chan string)
client.Svc = svc
UPool[addr.String()] = client
return client, nil
}
func (svc *Server) ListenAndServe() error {
if !utils.ValidatePort(svc.Port) {
return fmt.Errorf("invalidate port %d", svc.Port)
}
addr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("0.0.0.0:%d", svc.Port))
if err != nil {
return err
}
ln, err := net.ListenUDP("udp4", addr)
if err != nil {
return err
}
defer ln.Close()
for {
// Max vlpkt len 1502 = 1500(max ethernet pkt) + 2(vlheader)
// for encrypted data len should be n*16(aes block size) + 16(key len)
// so buf len should be 94 * 16 + 16 = 1520
var buf [65535]byte
n, addr, err := ln.ReadFromUDP(buf[:])
if err != nil {
return err
}
if n < 2 {
continue
}
// For wrong AES key, will return pkt to nill or unsupported pkt error, just skip
pkt, err := packet.Decode(buf[:n])
if pkt == nil {
continue
}
if err != nil {
if utils.IsUnsupportedPkt(err) {
log.Warn(err)
continue
}
log.Error("parse packet ", err)
}
// TODO(shawnlu): Add close conn
switch pkt.Type {
case packet.P_AUTH:
u, p := pkt.VLBody.(*packet.AuthBody).Parse()
// Check user logged
if _, ok := users.UserEPMap[u]; ok {
svc.SendResponse(ln, packet.RSP_USER_LOGGED, addr)
continue
}
// Auth user
err = users.ValidateUser(svc.userDb, u, p)
if err != nil {
log.Warn(err)
svc.SendResponse(ln, packet.RSP_AUTH_REQUIRED, addr)
continue
}
users.UserEPMap[u] = addr.String()
log.Infof("client %s login to %s succeed\n", addr.String(), u)
// Create client for authed addr
client, err := svc.CreateClientForAddr(addr, ln)
if err != nil {
log.Errorf("create authed client %s\n", err.Error())
}
client.User = u
log.Infof("client %s auth succeed", addr.String())
case packet.P_KEEPALIVE:
// Handle keepalive
err = HandleKeepalive(pkt.VLBody.(*packet.KeepaliveBody).Parse(), addr.String(), svc)
if err != nil {
if utils.IsUnauthedErr(err) {
continue
}
svc.SendResponse(ln, packet.RSP_IP_CONFLICET, addr)
log.Warnf("heartbeat from %s %s, send ip conflicet response", addr.String(), err.Error())
}
case packet.P_RAW:
// Get authed client from UPool
client, ok := UPool[addr.String()]
if !ok {
svc.SendResponse(ln, packet.RSP_AUTH_REQUIRED, addr)
continue
}
go client.HandleOnce()
client.NetToIface <- pkt
default:
log.Debug("unknow stream, do nothing")
continue
}
}
}

View File

@@ -10,8 +10,6 @@ import (
"sync"
"github.com/lucheng0127/virtuallan/pkg/config"
"github.com/lucheng0127/virtuallan/pkg/packet"
"github.com/lucheng0127/virtuallan/pkg/users"
"github.com/lucheng0127/virtuallan/pkg/utils"
log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
@@ -28,6 +26,15 @@ type Server struct {
MLock sync.Mutex
}
func NewServer() *Server {
svc := new(Server)
svc.UsedIP = make([]int, 0)
svc.MLock = sync.Mutex{}
return svc
}
func (svc *Server) SetupLan() error {
// Create bridge
la := netlink.NewLinkAttrs()
@@ -48,112 +55,6 @@ func (svc *Server) SetupLan() error {
return nil
}
func (svc *Server) ListenAndServe() error {
if !utils.ValidatePort(svc.Port) {
return fmt.Errorf("invalidate port %d", svc.Port)
}
addr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("0.0.0.0:%d", svc.Port))
if err != nil {
return err
}
ln, err := net.ListenUDP("udp4", addr)
if err != nil {
return err
}
defer ln.Close()
for {
// Max vlpkt len 1502 = 1500(max ethernet pkt) + 2(vlheader)
// for encrypted data len should be n*16(aes block size) + 16(key len)
// so buf len should be 94 * 16 + 16 = 1520
var buf [65535]byte
n, addr, err := ln.ReadFromUDP(buf[:])
if err != nil {
return err
}
if n < 2 {
continue
}
// For wrong AES key, will return pkt to nill or unsupported pkt error, just skip
pkt, err := packet.Decode(buf[:n])
if pkt == nil {
continue
}
if err != nil {
if utils.IsUnsupportedPkt(err) {
log.Warn(err)
continue
}
log.Error("parse packet ", err)
}
// TODO(shawnlu): Add close conn
switch pkt.Type {
case packet.P_AUTH:
u, p := pkt.VLBody.(*packet.AuthBody).Parse()
// Check user logged
if _, ok := users.UserEPMap[u]; ok {
svc.SendResponse(ln, packet.RSP_USER_LOGGED, addr)
continue
}
// Auth user
err = users.ValidateUser(svc.userDb, u, p)
if err != nil {
log.Warn(err)
svc.SendResponse(ln, packet.RSP_AUTH_REQUIRED, addr)
continue
}
users.UserEPMap[u] = addr.String()
log.Infof("client %s login to %s succeed\n", addr.String(), u)
// Create client for authed addr
client, err := svc.CreateClientForAddr(addr, ln)
if err != nil {
log.Errorf("create authed client %s\n", err.Error())
}
client.User = u
log.Infof("client %s auth succeed", addr.String())
case packet.P_KEEPALIVE:
// Handle keepalive
err = HandleKeepalive(pkt.VLBody.(*packet.KeepaliveBody).Parse(), addr.String())
if err != nil {
if utils.IsUnauthedErr(err) {
continue
}
svc.SendResponse(ln, packet.RSP_IP_CONFLICET, addr)
log.Warnf("heartbeat from %s %s, send ip conflicet response", addr.String(), err.Error())
}
case packet.P_RAW:
// Get authed client from UPool
client, ok := UPool[addr.String()]
if !ok {
svc.SendResponse(ln, packet.RSP_AUTH_REQUIRED, addr)
continue
}
go client.HandleOnce()
client.NetToIface <- pkt
default:
log.Debug("unknow stream, do nothing")
continue
}
}
}
func (svc *Server) Teardown() {
err := utils.DelLinkByName(svc.Bridge)
if err != nil {
@@ -169,12 +70,9 @@ func (svc *Server) HandleSignal(sigChan chan os.Signal) {
svc.Teardown()
}
// TODO(shawnlu): Add dhcp
func Run(cCtx *cli.Context) error {
// New server and do cfg parse
svc := new(Server)
svc.UsedIP = make([]int, 0)
svc.MLock = sync.Mutex{}
svc := NewServer()
cfgDir := cCtx.String("config-dir")
cfg, err := config.LoadConfigFile(config.GetCfgPath(cfgDir))
@@ -206,7 +104,10 @@ func Run(cCtx *cli.Context) error {
// Run web server
if svc.ServerConfig.WebConfig.Enable {
webSvc := &webServe{port: svc.ServerConfig.WebConfig.Port}
webSvc := &webServe{
port: svc.ServerConfig.WebConfig.Port,
svc: svc,
}
go webSvc.Serve()
log.Info("run web server on port ", webSvc.port)
}

View File

@@ -11,6 +11,7 @@ import (
type webServe struct {
port int
svc *Server
}
type EpEntry struct {
@@ -35,7 +36,7 @@ func listEpEntries(c *gin.Context) {
User: user,
Addr: addr,
Iface: c.Iface.Name(),
IP: c.IP,
IP: c.IP.String(),
Login: c.Login,
})
}

View File

@@ -8,6 +8,9 @@ import (
"os"
)
// Logged users
// Key: username
// Value: remote address string
var UserEPMap map[string]string
func init() {