diff --git a/common/statistic.go b/common/statistic.go new file mode 100644 index 0000000..125b4f8 --- /dev/null +++ b/common/statistic.go @@ -0,0 +1,208 @@ +package common + +import ( + "sync" + "time" +) + +type Statistic struct { + rx, tx, count int + timestamp time.Time +} + +type aggregatedStatistic struct { + Statistic + duration time.Duration +} + +func (s *Statistic) add(new Statistic) { + s.count += new.count + s.rx += new.rx + s.tx += new.tx + if new.timestamp.After(s.timestamp) { + s.timestamp = new.timestamp + } +} + +type statisticWindow struct { + aggregatedStatistic + startTime int64 + endTime int64 + mux *sync.Mutex +} + +func getStatisticWindow(startTime time.Time, duration time.Duration) statisticWindow { + win := statisticWindow{ + startTime: startTime.Unix(), + endTime: startTime.Add(duration).Unix(), + mux: &sync.Mutex{}, + } + win.rx = 0 + win.tx = 0 + win.count = 0 + win.duration = duration + win.timestamp = startTime + return win +} + +func (sw *statisticWindow) aggregateStatistic(statistic Statistic) { + if sw.startTime <= statistic.timestamp.Unix() && statistic.timestamp.Unix() < sw.endTime { + sw.mux.Lock() + sw.add(statistic) + sw.mux.Unlock() + } +} + +func (sw statisticWindow) getFinalResult() aggregatedStatistic { + statistic := aggregatedStatistic{ + duration: sw.duration, + } + statistic.rx = sw.rx + statistic.tx = sw.tx + statistic.timestamp = sw.timestamp + statistic.count = sw.count + return statistic +} + +func (sw statisticWindow) getTotal() (rx, tx int) { + var crx, ctx int + sw.mux.Lock() + crx = sw.rx + ctx = sw.tx + sw.mux.Unlock() + return crx, ctx +} + +func (sw statisticWindow) getAverageByTime() (rx, tx int) { + var arx, atx int + sw.mux.Lock() + arx = sw.rx / int(sw.startTime-sw.timestamp.Unix()) + atx = sw.tx / int(sw.startTime-sw.timestamp.Unix()) + sw.mux.Unlock() + return arx, atx +} + +func (sw statisticWindow) getAverageByCount() (rx, tx int) { + var arx, atx int + sw.mux.Lock() + arx = sw.rx / sw.count + atx = sw.tx / sw.count + sw.mux.Unlock() + return arx, atx +} + +type TunnelStatistic struct { + history map[int64]aggregatedStatistic + name string + windowDuration time.Duration + currentWin statisticWindow + mux *sync.Mutex +} + +func GetTunnelStatistic(name string, duration time.Duration, historyDuration time.Duration) *TunnelStatistic { + ts := TunnelStatistic{ + history: make(map[int64]aggregatedStatistic), + name: name, + windowDuration: duration, + mux: &sync.Mutex{}, + } + historySize := int(historyDuration.Seconds() / ts.windowDuration.Seconds()) + ts.currentWin = getStatisticWindow(time.Now(), ts.windowDuration) + go func() { + for range time.Tick(ts.windowDuration) { + ts.mux.Lock() + c := ts.currentWin + if len(ts.history) >= historySize { + var t int64 + for t, _ = range ts.history { + break + } + delete(ts.history, t) + } + ts.history[c.startTime] = c.getFinalResult() + ts.currentWin = getStatisticWindow(time.Now(), ts.windowDuration) + ts.mux.Unlock() + } + }() + return &ts +} + +func (ts *TunnelStatistic) AddStatistic(statistic Statistic) { + ts.mux.Lock() + ts.currentWin.add(statistic) + ts.mux.Unlock() +} + +func (ts TunnelStatistic) GetCurrentTotal() (rx, tx int) { + var crx, ctx int + ts.mux.Lock() + crx, ctx = ts.currentWin.getTotal() + ts.mux.Unlock() + return crx, ctx +} + +func (ts TunnelStatistic) GetCurrentAverageByTime() (rx, tx int) { + var arx, atx int + ts.mux.Lock() + arx, atx = ts.currentWin.getAverageByTime() + ts.mux.Unlock() + return arx, atx +} + +func (ts TunnelStatistic) GetCurrentAverageByCount() (rx, tx int) { + var arx, atx int + ts.mux.Lock() + arx, atx = ts.currentWin.getAverageByCount() + ts.mux.Unlock() + return arx, atx +} + +func (ts TunnelStatistic) GetTotal(start, end time.Time) (arx, atx, count int) { + arx = 0 + atx = 0 + count = 0 + startUnix := start.Unix() + endUnix := end.Unix() + for t, s := range ts.history { + if startUnix <= t && t < endUnix { + arx += s.rx + atx += s.tx + count += s.count + } + } + return arx, atx, count +} + +func (ts TunnelStatistic) GetAverageByTime(start, end time.Time) (arx, atx int) { + arx = 0 + atx = 0 + startUnix := start.Unix() + endUnix := end.Unix() + for t, s := range ts.history { + if startUnix <= t && t < endUnix { + arx += s.rx + atx += s.tx + } + } + arx = arx / int((startUnix-endUnix)/int64(ts.windowDuration.Seconds())) + atx = atx / int((startUnix-endUnix)/int64(ts.windowDuration.Seconds())) + return arx, atx +} + +func (ts TunnelStatistic) GetAverageByCount(start, end time.Time) (arx, atx int) { + arx = 0 + atx = 0 + count := 0 + startUnix := start.Unix() + endUnix := end.Unix() + for t, s := range ts.history { + if startUnix <= t && t < endUnix { + arx += s.rx + atx += s.tx + count += s.count + } + } + arx = arx / count + atx = atx / count + return arx, atx +} diff --git a/stunning.go b/stunning.go index 22c23b5..4971abe 100644 --- a/stunning.go +++ b/stunning.go @@ -1,7 +1,7 @@ package stunning import ( - "github.com/songgao/water" + "encoding/json" "github.com/hbahadorzadeh/stunning/common" icommon "github.com/hbahadorzadeh/stunning/interface/common" socksiface "github.com/hbahadorzadeh/stunning/interface/socks" @@ -13,11 +13,28 @@ import ( tcptun "github.com/hbahadorzadeh/stunning/tunnel/tcp" tlstun "github.com/hbahadorzadeh/stunning/tunnel/tls" udptun "github.com/hbahadorzadeh/stunning/tunnel/udp" + "github.com/songgao/water" + "io/ioutil" "log" + "os" + "time" ) +type TunnelConfig struct { + Cert string + Connect string + DeviceName string + InterfaceType string + Key string + Listen string + Mtu string + ServerType string + ServiceMode string +} + type Tunnel interface { ListenAndServer() + IsAlive() bool } type TunnelCommon struct { Tunnel @@ -62,6 +79,14 @@ func (t TunnelServer) ListenAndServer() { } } +func (t TunnelServer) IsAlive() bool { + if &t.tunnelServer != nil { + return !t.tunnelServer.Closed() + } else { + return false + } +} + func (t TunnelClient) ListenAndServer() { if &t.interfaceClient != nil { defer t.interfaceClient.Close() @@ -71,12 +96,22 @@ func (t TunnelClient) ListenAndServer() { } } -func TunnelFactory(conf map[string]string) TunnelCommon { +func readConfig(confFile string) map[string]TunnelConfig { + confStruct := make(map[string]TunnelConfig) + data, err := ioutil.ReadFile(confFile) + if err != nil { + panic(err) + } + json.Unmarshal(data, confStruct) + return confStruct +} + +func TunnelFactory(name string, conf TunnelConfig) TunnelCommon { var tun TunnelCommon - if sorc, exist := conf["service_mode"]; exist { + if sorc := conf.ServiceMode; sorc != "" { if common.TunnelMode(sorc) == common.CLIENT { ttun := TunnelClient{} - if stype, exist := conf["server_type"]; exist { + if stype := conf.ServerType; stype != "" { switch common.TunnelType(stype) { case common.HTTP_TUN: ttun.tunnelClient = httptun.GetHttpDialer() @@ -94,18 +129,18 @@ func TunnelFactory(conf map[string]string) TunnelCommon { ttun.tunnelClient = tlstun.GetTlsDialer() break default: - log.Panicf("Invalid server type(%s).", stype) + log.Panicf("Conf `%s`: Invalid server type(%s).", name, stype) } - saddr, sexist := conf["connect"] - if !sexist { - log.Panicf("Service connect address not specified.") + saddr := conf.Connect + if saddr == "" { + log.Panicf("Conf `%s`: Service connect address not specified.", name) } - caddr, cexist := conf["listen"] - if !cexist { - log.Panicf("Service listen address not specified.") + caddr := conf.Listen + if caddr == "" { + log.Panicf("Conf `%s`: Service listen address not specified.", name) } - if itype, exist := conf["interface_type"]; exist { + if itype := conf.InterfaceType; itype != "" { switch common.InterfaceType(itype) { case common.SOCKS_IFACE: ttun.interfaceClient = socksiface.GetSocksClient(caddr, saddr, ttun.tunnelClient) @@ -114,136 +149,197 @@ func TunnelFactory(conf map[string]string) TunnelCommon { ttun.interfaceClient = tcpiface.GetTcpClient(caddr, saddr, ttun.tunnelClient) break case common.TUN_IFACE: - imtu, exist := conf["mtu"] - if !exist { + imtu := conf.Mtu + if imtu == "" { imtu = "1500" } - iname, exist := conf["devname"] - if !exist { + iname := conf.DeviceName + if iname == "" { iname = "tun" } - conf := tuniface.TunConfig{ + tconf := tuniface.TunConfig{ DevType: water.TUN, Address: caddr, Name: iname, MTU: imtu, } - tuniface.GetTunIfaceClient(conf, saddr, ttun.tunnelClient) + tuniface.GetTunIfaceClient(tconf, saddr, ttun.tunnelClient) break case common.UDP_IFACE: case common.SERIAL_IFACE: default: - log.Panicf("Invalid interface type (%s)", itype) + log.Panicf("Conf `%s`: Invalid interface type (%s)", name, itype) } } } } else if common.TunnelMode(sorc) == common.SERVER { ttun := TunnelServer{} - if stype, exist := conf["server_type"]; exist { - if saddr, exist := conf["listen"]; exist { + if stype := conf.ServerType; stype != "" { + if saddr := conf.Listen; saddr != "" { switch common.TunnelType(stype) { case common.HTTP_TUN: tServer, err := httptun.StartHttpServer(saddr) if err != nil { - log.Panicf("Failed to create tunnel server.\n%v", err) + log.Panicf("Conf `%s`: Failed to create tunnel server.\n%v", name, err) } ttun.tunnelServer = tServer break case common.HTTPS_TUN: - scert, cexist := conf["cert"] - skey, kexist := conf["key"] - if cexist && kexist { - tServer, err := httpstun.StartHttpsServer(scert, skey, saddr) - if err != nil { - log.Panicf("Failed to create tunnel server.\n%v", err) - } - ttun.tunnelServer = tServer - } else { - log.Panicf("Key or Cert not defiend") + scert := conf.Cert + skey := conf.Key + if scert == "" { + log.Panicf("Conf `%s`: Cert not defiend", name) + } else if _, err := os.Stat(scert); os.IsNotExist(err) { + log.Panicf("Conf `%s`: Cert file not exist", name) } + if skey == "" { + log.Panicf("Conf `%s`: Key not defiend", name) + } else if _, err := os.Stat(skey); os.IsNotExist(err) { + log.Panicf("Conf `%s`: Key file not exist", name) + } + + tServer, err := httpstun.StartHttpsServer(scert, skey, saddr) + if err != nil { + log.Panicf("Conf `%s`: Failed to create tunnel server.\n%v", name, err) + } + ttun.tunnelServer = tServer break case common.TCP_TUN: tServer, err := tcptun.StartTcpServer(saddr) if err != nil { - log.Panicf("Failed to create tunnel server.\n%v", err) + log.Panicf("Conf `%s`: Failed to create tunnel server.\n%v", name, err) } ttun.tunnelServer = tServer break case common.UDP_TUN: tServer, err := udptun.StartUdpServer(saddr) if err != nil { - log.Panicf("Failed to create tunnel server.\n%v", err) + log.Panicf("Conf `%s`: Failed to create tunnel server.\n%v", name, err) } ttun.tunnelServer = tServer break case common.TLS_TUN: - scert, cexist := conf["cert"] - skey, kexist := conf["key"] - if cexist && kexist { - tServer, err := tlstun.StartTlsServer(scert, skey, saddr) - if err != nil { - log.Panicf("Failed to create tunnel server.\n%v", err) - } - ttun.tunnelServer = tServer - } else { - log.Panicf("Key or Cert not defiend") + scert := conf.Cert + skey := conf.Key + if scert == "" { + log.Panicf("Conf `%s`: Cert not defiend", name) + } else if _, err := os.Stat(scert); os.IsNotExist(err) { + log.Panicf("Conf `%s`: Cert file not exist", name) } + if skey == "" { + log.Panicf("Conf `%s`: Key not defiend", name) + } else if _, err := os.Stat(skey); os.IsNotExist(err) { + log.Panicf("Conf `%s`: Key file not exist", name) + } + + tServer, err := tlstun.StartTlsServer(scert, skey, saddr) + if err != nil { + log.Panicf("Conf `%s`: Failed to create tunnel server.\n%v", name, err) + } + ttun.tunnelServer = tServer break default: - log.Panicf("Invalid server type(%s).", stype) + log.Panicf("Conf `%s`: Invalid server type(%s).", name, stype) } } else { - log.Panicf("Service listen address not specified.") + log.Panicf("Conf `%s`: Service listen address not specified.", name) } - if itype, exist := conf["interface_type"]; exist { + if itype := conf.InterfaceType; itype != "" { switch common.InterfaceType(itype) { case common.SOCKS_IFACE: ttun.interfaceServer = socksiface.GetSocksServer() break case common.TCP_IFACE: - if iaddr, exist := conf["connect"]; exist { + if iaddr := conf.Connect; iaddr != "" { ttun.interfaceServer = tcpiface.GetTcpServer(iaddr) } else { - log.Panicf("Service connect address not specified.") + log.Panicf("Conf `%s`: Service connect address not specified.", name) } break case common.TUN_IFACE: - iaddr, exist := conf["connect"] - if !exist { - log.Panicf("Service connect address not specified.") + iaddr := conf.Connect + if iaddr == "" { + log.Panicf("Conf `%s`: Service connect address not specified.", name) } - imtu, exist := conf["mtu"] - if !exist { + imtu := conf.Mtu + if imtu == "" { imtu = "1500" } - iname, exist := conf["devname"] - if !exist { + iname := conf.DeviceName + if iname == "" { iname = "tun" } - conf := tuniface.TunConfig{ + tconf := tuniface.TunConfig{ DevType: water.TUN, Address: iaddr, Name: iname, MTU: imtu, } - ttun.interfaceServer = tuniface.GetTunIface(conf) + ttun.interfaceServer = tuniface.GetTunIface(tconf) break case common.UDP_IFACE: case common.SERIAL_IFACE: default: - log.Panicf("Invalid interface type (%s)", itype) + log.Panicf("Conf `%s`: Invalid interface type (%s)", name, itype) } } } else { - log.Panicf("Server type not defined.") + log.Panicf("Conf `%s`: Server type not defined.", name) } } else { - log.Panicf("Invalid service mode(%s).", sorc) + log.Panicf("Conf `%s`: Invalid service mode(%s).", name, sorc) } } else { - log.Panicf("Service mode not specified.") + log.Panicf("Conf `%s`: Service mode not specified.", name) } tun = TunnelCommon{} return tun } + +func main() { + argsWithoutProg := os.Args[1:] + var confFile string + for i := 0; i < len(argsWithoutProg); i++ { + arg := argsWithoutProg[i] + if arg[:8] == "--config=" { + confFile = arg[8:] + } else if arg == "-c" && i+1 <= len(argsWithoutProg) { + i++ + arg := argsWithoutProg[i] + confFile = arg + } else if arg[:2] == "-c" && len(arg) > 2 { + confFile = arg[2:] + } + } + if confFile != "" { + confsMap := readConfig(confFile) + tunsMap := make(map[string]TunnelCommon) + for name, conf := range confsMap { + tun := TunnelFactory(name, conf) + tunsMap[name] = tun + tun.ListenAndServer() + } + + for { + for name, tun := range tunsMap { + if !tun.IsAlive() { + log.Printf("Tunnel `%s` is down!", name) + go func() { + conf, exist := confsMap[name] + if exist { + tun := TunnelFactory(name, conf) + tunsMap[name] = tun + tun.ListenAndServer() + } else { + log.Printf("config not found for tunnel `%s` for recreation!", name) + } + }() + } + } + time.Sleep(time.Second) + } + } else { + log.Panicf("Config file not defiend") + } +}