Files
SugarDB/server/utils.go
Kelvin Clement Mwinuka d82a6a98d7 Use UDP dial to find default address for outbound traffic.
Set default bindAddr if it has not been explicitly provided be the user.
2023-07-29 01:33:33 +08:00

206 lines
3.6 KiB
Go

package main
import (
"bufio"
"bytes"
"encoding/csv"
"errors"
"fmt"
"math"
"math/big"
"net"
"reflect"
"strings"
"time"
"github.com/sethvargo/go-retry"
"github.com/tidwall/resp"
)
const (
OK = "+OK\r\n\n"
)
type Command interface {
Name() string
Commands() []string
Description() string
HandleCommand(cmd []string, server *Server, conn *bufio.Writer)
}
func Contains[T comparable](arr []T, elem T) bool {
for _, v := range arr {
if v == elem {
return true
}
}
return false
}
func ContainsMutual[T comparable](arr1 []T, arr2 []T) (bool, T) {
for _, a := range arr1 {
for _, b := range arr2 {
if a == b {
return true, a
}
}
}
return false, arr1[0]
}
func IsInteger(n float64) bool {
return math.Mod(n, 1.0) == 0
}
func AdaptType(s string) interface{} {
// Adapt the type of the parameter to string, float64 or int
n, _, err := big.ParseFloat(s, 10, 256, big.RoundingMode(big.Exact))
if err != nil {
return s
}
if n.IsInt() {
i, _ := n.Int64()
return i
}
return n
}
func IncrBy(num interface{}, by interface{}) (interface{}, error) {
if !Contains[string]([]string{"int", "float64"}, reflect.TypeOf(num).String()) {
return nil, errors.New("can only increment number")
}
if !Contains[string]([]string{"int", "float64"}, reflect.TypeOf(by).String()) {
return nil, errors.New("can only increment by number")
}
n, _ := num.(float64)
b, _ := by.(float64)
res := n + b
if IsInteger(res) {
return int(res), nil
}
return res, nil
}
func Filter[T comparable](arr []T, test func(elem T) bool) (res []T) {
for _, e := range arr {
if test(e) {
res = append(res, e)
}
}
return
}
func tokenize(comm string) ([]string, error) {
r := csv.NewReader(strings.NewReader(comm))
r.Comma = ' '
return r.Read()
}
func Encode(comm string) (string, error) {
tokens, err := tokenize(comm)
if err != nil {
return "", errors.New("could not parse command")
}
str := fmt.Sprintf("*%d\r\n", len(tokens))
for i, token := range tokens {
if i == 0 {
str += fmt.Sprintf("$%d\r\n%s\r\n", len(token), strings.ToUpper(token))
} else {
str += fmt.Sprintf("$%d\r\n%s\r\n", len(token), token)
}
}
str += "\n"
return str, nil
}
func Decode(raw string) ([]string, error) {
rd := resp.NewReader(bytes.NewBufferString(raw))
res := []string{}
v, _, err := rd.ReadValue()
if err != nil {
return nil, err
}
if Contains[string]([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) {
return []string{v.String()}, nil
}
if v.Type().String() == "Array" {
for _, elem := range v.Array() {
res = append(res, elem.String())
}
}
return res, nil
}
func ReadMessage(r *bufio.ReadWriter) (message string, err error) {
var line [][]byte
for {
b, _, err := r.ReadLine()
if err != nil {
return "", err
}
if bytes.Equal(b, []byte("")) {
// End of message
break
}
line = append(line, b)
}
return fmt.Sprintf("%s\r\n", string(bytes.Join(line, []byte("\r\n")))), nil
}
func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff {
backoff := b
if maxRetries > 0 {
backoff = retry.WithMaxRetries(maxRetries, backoff)
}
if jitter > 0 {
backoff = retry.WithJitter(jitter, backoff)
}
if cappedDuration > 0 {
backoff = retry.WithCappedDuration(cappedDuration, backoff)
}
if maxDuration > 0 {
backoff = retry.WithMaxDuration(maxDuration, backoff)
}
return backoff
}
func GetIPAddress() (string, error) {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
return "", err
}
defer conn.Close()
localAddr := strings.Split(conn.LocalAddr().String(), ":")[0]
return localAddr, nil
}