From a6bfacb0c2a108991b7f2c4aeccc241ddef9e719 Mon Sep 17 00:00:00 2001 From: lucheng Date: Fri, 9 Aug 2024 12:30:26 +0800 Subject: [PATCH] Fix client ip by get client ip base on username hash Count username hash, get client ip base on it, it will make ip fixed base on client login username. --- pkg/server/conn_linux.go | 8 ++-- pkg/server/ipmgr_linux.go | 38 ++++++++++++++- pkg/server/ipmgr_linux_test.go | 84 ++++++++++++++++++++++++++++++++++ pkg/utils/hash.go | 12 +++++ pkg/utils/hash_test.go | 55 ++++++++++++++++++++++ 5 files changed, 193 insertions(+), 4 deletions(-) create mode 100644 pkg/server/ipmgr_linux_test.go create mode 100644 pkg/utils/hash.go create mode 100644 pkg/utils/hash_test.go diff --git a/pkg/server/conn_linux.go b/pkg/server/conn_linux.go index 20507d4..5996887 100644 --- a/pkg/server/conn_linux.go +++ b/pkg/server/conn_linux.go @@ -175,14 +175,14 @@ func (svc *Server) OfferIPToClient(conn *net.UDPConn, ip string, raddr *net.UDPA return nil } -func (svc *Server) CreateClientForAddr(addr *net.UDPAddr, conn *net.UDPConn) (*UClient, error) { +func (svc *Server) CreateClientForAddr(addr *net.UDPAddr, conn *net.UDPConn, username string) (*UClient, error) { iface, err := utils.NewTap(svc.Bridge) if err != nil { return nil, err } // Pop a ip for client - ip, err := svc.PopIPFromPool() + ip, err := svc.IPForUser(username) if err != nil { return nil, err } @@ -199,6 +199,8 @@ func (svc *Server) CreateClientForAddr(addr *net.UDPAddr, conn *net.UDPConn) (*U client.Svc = svc client.IP = ip + log.Infof("new client remote addr %s ip %s login at %s\n", client.RAddr.String(), client.IP.String(), client.Login) + UPool[addr.String()] = client // Monitor client heartbeat @@ -276,7 +278,7 @@ func (svc *Server) ListenAndServe() error { log.Infof("client %s login to %s succeed\n", addr.String(), u) // Create client for authed addr - client, err := svc.CreateClientForAddr(addr, ln) + client, err := svc.CreateClientForAddr(addr, ln, u) if err != nil { log.Errorf("create authed client %s\n", err.Error()) svc.SendResponse(ln, packet.RSP_INTERNAL_ERR, addr) diff --git a/pkg/server/ipmgr_linux.go b/pkg/server/ipmgr_linux.go index 9feb677..5bab298 100644 --- a/pkg/server/ipmgr_linux.go +++ b/pkg/server/ipmgr_linux.go @@ -57,7 +57,43 @@ func (svc *Server) IdxFromIP(ip net.IP) int { return int(ipInt - ipStartInt) } -// TODO(shawnlu): Implement it with hash map, username as input to count hash +func (svc *Server) IPForUser(username string) (net.IP, error) { + idx := utils.IdxFromString(svc.IPCount, username) + idxEnd := idx - 1 + + for idx < svc.IPCount { + if idx == idxEnd { + // Checked the last idx + if svc.IPIdxInPool(idx) { + break + } + + svc.MLock.Lock() + svc.UsedIP = append(svc.UsedIP, idx) + svc.MLock.Unlock() + return svc.IPFromIdx(idx), nil + } + + if svc.IPIdxInPool(idx) { + idx += 1 + + if idx == svc.IPCount { + // Check from zero + idx = 0 + } + + continue + } + + svc.MLock.Lock() + svc.UsedIP = append(svc.UsedIP, idx) + svc.MLock.Unlock() + return svc.IPFromIdx(idx), nil + } + + return nil, errors.New("run out of ip") +} + func (svc *Server) PopIPFromPool() (net.IP, error) { for idx := 0; idx < svc.IPCount; idx++ { if svc.IPIdxInPool(idx) { diff --git a/pkg/server/ipmgr_linux_test.go b/pkg/server/ipmgr_linux_test.go new file mode 100644 index 0000000..8cf5df6 --- /dev/null +++ b/pkg/server/ipmgr_linux_test.go @@ -0,0 +1,84 @@ +package server + +import ( + "net" + "reflect" + "sync" + "testing" + + "bou.ke/monkey" + "github.com/lucheng0127/virtuallan/pkg/utils" +) + +func TestServer_IPForUser(t *testing.T) { + type args struct { + username string + } + tests := []struct { + name string + args args + want net.IP + wantErr bool + patchFunc interface{} + targetFunc interface{} + }{ + { + name: "idx 1-1", + args: args{username: "whocares"}, + want: net.ParseIP("192.168.123.101").To4(), + wantErr: false, + patchFunc: utils.IdxFromString, + targetFunc: func(int, string) int { return 1 }, + }, + { + name: "idx 1-2", + args: args{username: "whocares"}, + want: net.ParseIP("192.168.123.102").To4(), + wantErr: false, + patchFunc: utils.IdxFromString, + targetFunc: func(int, string) int { return 1 }, + }, + { + name: "idx 1-3", + args: args{username: "whocares"}, + want: net.ParseIP("192.168.123.100").To4(), + wantErr: false, + patchFunc: utils.IdxFromString, + targetFunc: func(int, string) int { return 1 }, + }, + { + name: "idx 1-4", + args: args{username: "whocares"}, + want: nil, + wantErr: true, + patchFunc: utils.IdxFromString, + targetFunc: func(int, string) int { return 1 }, + }, + } + + ipStart := net.ParseIP("192.168.123.100").To4() + svc := &Server{ + UsedIP: make([]int, 0), + IPStart: ipStart, + IPCount: 3, + MLock: sync.Mutex{}, + Routes: make(map[string]string), + } + + for _, tt := range tests { + if tt.targetFunc != nil { + monkey.Patch(tt.patchFunc, tt.targetFunc) + } + + t.Run(tt.name, func(t *testing.T) { + got, err := svc.IPForUser(tt.args.username) + if (err != nil) != tt.wantErr { + t.Errorf("Server.IPForUser() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Server.IPForUser() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go new file mode 100644 index 0000000..0967bc8 --- /dev/null +++ b/pkg/utils/hash.go @@ -0,0 +1,12 @@ +package utils + +import ( + "hash/fnv" +) + +func IdxFromString(step int, str string) int { + h := fnv.New32a() + h.Write([]byte(str)) + + return int(h.Sum32() % uint32(step)) +} diff --git a/pkg/utils/hash_test.go b/pkg/utils/hash_test.go new file mode 100644 index 0000000..a6a0494 --- /dev/null +++ b/pkg/utils/hash_test.go @@ -0,0 +1,55 @@ +package utils + +import "testing" + +func TestIdxFromString(t *testing.T) { + type args struct { + step int + str string + } + tests := []struct { + name string + args args + want int + }{ + { + name: "step 100 user1", + want: 32, + args: args{ + step: 100, + str: "shawn", + }, + }, + { + name: "step 100 user2", + want: 59, + args: args{ + step: 100, + str: "guest", + }, + }, + { + name: "step 30 user1", + want: 2, + args: args{ + step: 30, + str: "shawn", + }, + }, + { + name: "step 30 user2", + want: 9, + args: args{ + step: 30, + str: "guest", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IdxFromString(tt.args.step, tt.args.str); got != tt.want { + t.Errorf("IdxFromString() = %v, want %v", got, tt.want) + } + }) + } +}