mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-09-27 04:16:06 +08:00
522 lines
11 KiB
Go
522 lines
11 KiB
Go
// Copyright 2024 Kelvin Clement Mwinuka
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package internal
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"cmp"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/big"
|
|
"net"
|
|
"reflect"
|
|
"runtime"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/echovault/sugardb/internal/constants"
|
|
"github.com/sethvargo/go-retry"
|
|
"github.com/tidwall/resp"
|
|
)
|
|
|
|
func AdaptType(s string) interface{} {
|
|
// Adapt the type of the parameter to string, float64 or int
|
|
n, _, err := big.ParseFloat(s, 10, 256, big.RoundingMode(big.Exact))
|
|
|
|
if err != nil {
|
|
return s
|
|
}
|
|
|
|
if n.IsInt() {
|
|
i, _ := n.Int64()
|
|
return int(i)
|
|
}
|
|
|
|
f, _ := n.Float64()
|
|
|
|
return f
|
|
}
|
|
|
|
func Decode(raw []byte) ([]string, error) {
|
|
reader := resp.NewReader(bytes.NewReader(raw))
|
|
|
|
value, _, err := reader.ReadValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var res []string
|
|
for i := 0; i < len(value.Array()); i++ {
|
|
res = append(res, value.Array()[i].String())
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func ReadMessage(r io.Reader) ([]byte, error) {
|
|
reader := bufio.NewReader(r)
|
|
|
|
var res []byte
|
|
|
|
chunk := make([]byte, 8192)
|
|
|
|
for {
|
|
n, err := reader.Read(chunk)
|
|
if err != nil && errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res = append(res, chunk...)
|
|
if n < len(chunk) {
|
|
break
|
|
}
|
|
clear(chunk)
|
|
}
|
|
|
|
return bytes.Trim(res, "\x00"), nil
|
|
}
|
|
|
|
func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff {
|
|
backoff := b
|
|
if maxRetries > 0 {
|
|
backoff = retry.WithMaxRetries(maxRetries, backoff)
|
|
}
|
|
if jitter > 0 {
|
|
backoff = retry.WithJitter(jitter, backoff)
|
|
}
|
|
if cappedDuration > 0 {
|
|
backoff = retry.WithCappedDuration(cappedDuration, backoff)
|
|
}
|
|
if maxDuration > 0 {
|
|
backoff = retry.WithMaxDuration(maxDuration, backoff)
|
|
}
|
|
return backoff
|
|
}
|
|
|
|
func GetIPAddress() (string, error) {
|
|
conn, err := net.Dial("udp", "8.8.8.8:80")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer func() {
|
|
if err = conn.Close(); err != nil {
|
|
log.Println(err)
|
|
}
|
|
}()
|
|
|
|
localAddr := strings.Split(conn.LocalAddr().String(), ":")[0]
|
|
|
|
return localAddr, nil
|
|
}
|
|
|
|
func GetSubCommand(command Command, cmd []string) (interface{}, error) {
|
|
if command.SubCommands == nil || len(command.SubCommands) == 0 {
|
|
// If the command has no sub-commands, return nil
|
|
return nil, nil
|
|
}
|
|
if len(cmd) < 2 {
|
|
// If the cmd provided by the user has less than 2 tokens, there's no need to search for a subcommand
|
|
return nil, nil
|
|
}
|
|
for _, subCommand := range command.SubCommands {
|
|
if strings.EqualFold(subCommand.Command, cmd[1]) {
|
|
return subCommand, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("command %s %s not supported", cmd[0], cmd[1])
|
|
}
|
|
|
|
func IsWriteCommand(command Command, subCommand SubCommand) bool {
|
|
return slices.Contains(append(command.Categories, subCommand.Categories...), constants.WriteCategory)
|
|
}
|
|
|
|
func AbsInt(n int) int {
|
|
if n < 0 {
|
|
return -n
|
|
}
|
|
return n
|
|
}
|
|
|
|
// ParseMemory returns an integer representing the bytes in the memory string
|
|
func ParseMemory(memory string) (uint64, error) {
|
|
// Parse memory strings such as "100mb", "16gb"
|
|
memString := memory[0 : len(memory)-2]
|
|
bytesInt, err := strconv.ParseInt(memString, 10, 64)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
memUnit := strings.ToLower(memory[len(memory)-2:])
|
|
switch memUnit {
|
|
case "kb":
|
|
bytesInt *= 1024
|
|
case "mb":
|
|
bytesInt *= 1024 * 1024
|
|
case "gb":
|
|
bytesInt *= 1024 * 1024 * 1024
|
|
case "tb":
|
|
bytesInt *= 1024 * 1024 * 1024 * 1024
|
|
case "pb":
|
|
bytesInt *= 1024 * 1024 * 1024 * 1024 * 1024
|
|
default:
|
|
return 0, fmt.Errorf("memory unit %s not supported, use (kb, mb, gb, tb, pb) ", memUnit)
|
|
}
|
|
|
|
return uint64(bytesInt), nil
|
|
}
|
|
|
|
// IsMaxMemoryExceeded checks whether we have exceeded the current maximum memory limit.
|
|
func IsMaxMemoryExceeded(memUsed int64, maxMemory uint64) bool {
|
|
if maxMemory == 0 {
|
|
return false
|
|
}
|
|
|
|
// If we're currently using less than the configured max memory, return false.
|
|
if uint64(memUsed) < maxMemory {
|
|
return false
|
|
}
|
|
|
|
// If we're currently using more than max memory, force a garbage collection before we start deleting keys.
|
|
// This measure is to prevent deleting keys that may be important when some memory can be reclaimed
|
|
// by just collecting garbage.
|
|
runtime.GC()
|
|
|
|
// Return true when whe are above or equal to max memory.
|
|
return uint64(memUsed) >= maxMemory
|
|
}
|
|
|
|
// FilterExpiredKeys filters out keys that are already expired, so they are not persisted.
|
|
func FilterExpiredKeys(now time.Time, state map[int]map[string]KeyData) map[int]map[string]KeyData {
|
|
for database, data := range state {
|
|
var keysToDelete []string
|
|
for k, v := range data {
|
|
// Skip keys with no expiry time.
|
|
if v.ExpireAt == (time.Time{}) {
|
|
continue
|
|
}
|
|
// If the key is already expired, mark it for deletion.
|
|
if v.ExpireAt.Before(now) {
|
|
keysToDelete = append(keysToDelete, k)
|
|
}
|
|
}
|
|
for _, key := range keysToDelete {
|
|
delete(state[database], key)
|
|
}
|
|
}
|
|
return state
|
|
}
|
|
|
|
// CompareLex returns -1 when s2 is lexicographically greater than s1,
|
|
// 0 if they're equal and 1 if s2 is lexicographically less than s1.
|
|
func CompareLex(s1 string, s2 string) int {
|
|
if s1 == s2 {
|
|
return 0
|
|
}
|
|
if strings.Contains(s1, s2) {
|
|
return 1
|
|
}
|
|
if strings.Contains(s2, s1) {
|
|
return -1
|
|
}
|
|
|
|
limit := len(s1)
|
|
if len(s2) < limit {
|
|
limit = len(s2)
|
|
}
|
|
|
|
var c int
|
|
for i := 0; i < limit; i++ {
|
|
c = cmp.Compare(s1[i], s2[i])
|
|
if c != 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func EncodeCommand(cmd []string) []byte {
|
|
res := fmt.Sprintf("*%d\r\n", len(cmd))
|
|
for _, token := range cmd {
|
|
res += fmt.Sprintf("$%d\r\n%s\r\n", len(token), token)
|
|
}
|
|
return []byte(res)
|
|
}
|
|
|
|
func ParseNilResponse(b []byte) (bool, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return v.IsNull(), nil
|
|
}
|
|
|
|
func ParseStringResponse(b []byte) (string, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return v.String(), nil
|
|
}
|
|
|
|
func ParseIntegerResponse(b []byte) (int, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return v.Integer(), nil
|
|
}
|
|
|
|
func ParseFloatResponse(b []byte) (float64, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return v.Float(), nil
|
|
}
|
|
|
|
func ParseBooleanResponse(b []byte) (bool, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return v.Bool(), nil
|
|
}
|
|
|
|
func ParseStringArrayResponse(b []byte) ([]string, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if v.IsNull() {
|
|
return []string{}, nil
|
|
}
|
|
|
|
if v.Type().String() == "BulkString" {
|
|
return []string{v.String()}, nil
|
|
}
|
|
|
|
arr := make([]string, len(v.Array()))
|
|
for i, e := range v.Array() {
|
|
if e.IsNull() {
|
|
arr[i] = ""
|
|
continue
|
|
}
|
|
arr[i] = e.String()
|
|
}
|
|
return arr, nil
|
|
}
|
|
|
|
func ParseNestedStringArrayResponse(b []byte) ([][]string, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if v.IsNull() {
|
|
return [][]string{}, nil
|
|
}
|
|
arr := make([][]string, len(v.Array()))
|
|
for i, e1 := range v.Array() {
|
|
if e1.IsNull() {
|
|
arr[i] = []string{}
|
|
continue
|
|
}
|
|
entry := make([]string, len(e1.Array()))
|
|
for j, e2 := range e1.Array() {
|
|
entry[j] = e2.String()
|
|
}
|
|
arr[i] = entry
|
|
}
|
|
return arr, nil
|
|
}
|
|
|
|
func ParseIntegerArrayResponse(b []byte) ([]int, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if v.IsNull() {
|
|
return []int{}, nil
|
|
}
|
|
arr := make([]int, len(v.Array()))
|
|
for i, e := range v.Array() {
|
|
if e.IsNull() {
|
|
arr[i] = 0
|
|
continue
|
|
}
|
|
arr[i] = e.Integer()
|
|
}
|
|
return arr, nil
|
|
}
|
|
|
|
func ParseBooleanArrayResponse(b []byte) ([]bool, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if v.IsNull() {
|
|
return []bool{}, nil
|
|
}
|
|
arr := make([]bool, len(v.Array()))
|
|
for i, e := range v.Array() {
|
|
if e.IsNull() {
|
|
arr[i] = false
|
|
continue
|
|
}
|
|
arr[i] = e.Bool()
|
|
}
|
|
return arr, nil
|
|
}
|
|
|
|
func CompareNestedStringArrays(got [][]string, want [][]string) bool {
|
|
for _, wantItem := range want {
|
|
if !slices.ContainsFunc(got, func(gotItem []string) bool {
|
|
return reflect.DeepEqual(wantItem, gotItem)
|
|
}) {
|
|
return false
|
|
}
|
|
}
|
|
for _, gotItem := range got {
|
|
if !slices.ContainsFunc(want, func(wantItem []string) bool {
|
|
return reflect.DeepEqual(wantItem, gotItem)
|
|
}) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func GetFreePort() (int, error) {
|
|
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
l, err := net.ListenTCP("tcp", addr)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer func() {
|
|
_ = l.Close()
|
|
}()
|
|
|
|
return l.Addr().(*net.TCPAddr).Port, nil
|
|
}
|
|
|
|
func GetConnection(addr string, port int) (net.Conn, error) {
|
|
var conn net.Conn
|
|
var err error
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
for {
|
|
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port))
|
|
if err != nil && errors.Is(err.(*net.OpError), syscall.ECONNREFUSED) {
|
|
// If we get a "connection refused error, try again."
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
done <- struct{}{}
|
|
}()
|
|
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
defer func() {
|
|
ticker.Stop()
|
|
}()
|
|
|
|
select {
|
|
case <-ticker.C:
|
|
return nil, errors.New("connection timeout")
|
|
case <-done:
|
|
return conn, err
|
|
}
|
|
}
|
|
|
|
func GetTLSConnection(addr string, port int, config *tls.Config) (net.Conn, error) {
|
|
var conn net.Conn
|
|
var err error
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
for {
|
|
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", addr, port), config)
|
|
if err != nil && errors.Is(err.(*net.OpError), syscall.ECONNREFUSED) {
|
|
// If we get a "connection refused error, try again."
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
done <- struct{}{}
|
|
}()
|
|
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
defer func() {
|
|
ticker.Stop()
|
|
}()
|
|
|
|
select {
|
|
case <-ticker.C:
|
|
return nil, errors.New("connection timeout")
|
|
case <-done:
|
|
return conn, err
|
|
}
|
|
}
|
|
|
|
func ParseInteger64ArrayResponse(b []byte) ([]int64, error) {
|
|
r := resp.NewReader(bytes.NewReader(b))
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if v.IsNull() {
|
|
return []int64{}, nil
|
|
}
|
|
arr := make([]int64, len(v.Array()))
|
|
for i, e := range v.Array() {
|
|
if e.IsNull() {
|
|
arr[i] = 0
|
|
continue
|
|
}
|
|
|
|
val, err := strconv.ParseInt(e.String(), 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
arr[i] = val
|
|
}
|
|
return arr, nil
|
|
}
|