mirror of
				https://github.com/eolinker/apinto
				synced 2025-10-22 00:09:31 +08:00 
			
		
		
		
	额外参数v2插件、请求体限制插件完成
This commit is contained in:
		| @@ -18,11 +18,13 @@ import ( | ||||
| 	plugin_manager "github.com/eolinker/apinto/drivers/plugin-manager" | ||||
| 	access_log "github.com/eolinker/apinto/drivers/plugins/access-log" | ||||
| 	plugin_app "github.com/eolinker/apinto/drivers/plugins/app" | ||||
| 	body_check "github.com/eolinker/apinto/drivers/plugins/body-check" | ||||
| 	circuit_breaker "github.com/eolinker/apinto/drivers/plugins/circuit-breaker" | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/cors" | ||||
| 	dubbo2_proxy_rewrite "github.com/eolinker/apinto/drivers/plugins/dubbo2-proxy-rewrite" | ||||
| 	dubbo2_to_http "github.com/eolinker/apinto/drivers/plugins/dubbo2-to-http" | ||||
| 	extra_params "github.com/eolinker/apinto/drivers/plugins/extra-params" | ||||
| 	extra_params_v2 "github.com/eolinker/apinto/drivers/plugins/extra-params_v2" | ||||
| 	grpc_to_http "github.com/eolinker/apinto/drivers/plugins/gRPC-to-http" | ||||
| 	grpc_proxy_rewrite "github.com/eolinker/apinto/drivers/plugins/grpc-proxy-rewrite" | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/gzip" | ||||
| @@ -69,6 +71,7 @@ func ProcessWorker() { | ||||
| func registerInnerExtenders() { | ||||
| 	extends.AddInnerExtendProject("eolinker.com", "apinto", Register) | ||||
| } | ||||
|  | ||||
| func Register(extenderRegister eosc.IExtenderDriverRegister) { | ||||
| 	// router | ||||
| 	http_router.Register(extenderRegister) | ||||
| @@ -153,4 +156,7 @@ func Register(extenderRegister eosc.IExtenderDriverRegister) { | ||||
|  | ||||
| 	proxy_mirror.Register(extenderRegister) | ||||
| 	http_mocking.Register(extenderRegister) | ||||
|  | ||||
| 	body_check.Register(extenderRegister) | ||||
| 	extra_params_v2.Register(extenderRegister) | ||||
| } | ||||
|   | ||||
| @@ -49,9 +49,9 @@ func (h *HttpOutput) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWo | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if h.config != nil && !h.config.isConfUpdate(config) { | ||||
| 		return nil | ||||
| 	} | ||||
| 	//if h.config != nil && !h.config.isConfUpdate(config) { | ||||
| 	//	return nil | ||||
| 	//} | ||||
| 	h.config = config | ||||
|  | ||||
| 	if h.running { | ||||
|   | ||||
							
								
								
									
										73
									
								
								drivers/plugins/body-check/body-check.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								drivers/plugins/body-check/body-check.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | ||||
| package body_check | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
|  | ||||
| 	"github.com/eolinker/eosc" | ||||
| 	"github.com/eolinker/eosc/eocontext" | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| var _ http_service.HttpFilter = (*BodyCheck)(nil) | ||||
| var _ eocontext.IFilter = (*BodyCheck)(nil) | ||||
|  | ||||
| type BodyCheck struct { | ||||
| 	drivers.WorkerBase | ||||
| 	isEmpty            bool | ||||
| 	allowedPayloadSize int | ||||
| } | ||||
|  | ||||
| func (b *BodyCheck) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) { | ||||
| 	return http_service.DoHttpFilter(b, ctx, next) | ||||
| } | ||||
|  | ||||
| func (b *BodyCheck) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error { | ||||
| 	if ctx.Request().Method() == http.MethodPost || ctx.Request().Method() == http.MethodPut || ctx.Request().Method() == http.MethodPatch { | ||||
| 		body, err := ctx.Request().Body().RawBody() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		bodySize := len([]rune(string(body))) | ||||
| 		if !b.isEmpty && bodySize < 1 { | ||||
| 			ctx.Response().SetStatus(500, "Internal Server Error") | ||||
| 			ctx.Response().SetBody([]byte("请求体不能为空,请检查body的参数情况")) | ||||
| 			return errors.New("Body is required") | ||||
| 		} | ||||
| 		if b.allowedPayloadSize > 0 && bodySize > b.allowedPayloadSize { | ||||
| 			ctx.Response().SetStatus(500, "Internal Server Error") | ||||
| 			ctx.Response().SetBody([]byte("请求体超出长度限制")) | ||||
| 			return errors.New("The request entity is too large") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return next.DoChain(ctx) | ||||
| } | ||||
|  | ||||
| func (b *BodyCheck) Start() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (b *BodyCheck) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { | ||||
| 	cfg, ok := conf.(*Config) | ||||
| 	if !ok { | ||||
| 		return errors.New("invalid config") | ||||
| 	} | ||||
| 	b.isEmpty = cfg.IsEmpty | ||||
| 	b.allowedPayloadSize = cfg.AllowedPayloadSize * 1024 | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (b *BodyCheck) Stop() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (b *BodyCheck) Destroy() { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (b *BodyCheck) CheckSkill(skill string) bool { | ||||
| 	return http_service.FilterSkillName == skill | ||||
| } | ||||
							
								
								
									
										22
									
								
								drivers/plugins/body-check/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								drivers/plugins/body-check/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package body_check | ||||
|  | ||||
| import ( | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
| 	"github.com/eolinker/eosc" | ||||
| ) | ||||
|  | ||||
| type Config struct { | ||||
| 	IsEmpty            bool `json:"is_empty" label:"是否允许为空"` | ||||
| 	AllowedPayloadSize int  `json:"allowed_payload_size" label:"允许的最大请求体大小"` | ||||
| } | ||||
|  | ||||
| func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { | ||||
|  | ||||
| 	bc := &BodyCheck{ | ||||
| 		WorkerBase:         drivers.Worker(id, name), | ||||
| 		isEmpty:            conf.IsEmpty, | ||||
| 		allowedPayloadSize: conf.AllowedPayloadSize * 1024, | ||||
| 	} | ||||
|  | ||||
| 	return bc, nil | ||||
| } | ||||
							
								
								
									
										18
									
								
								drivers/plugins/body-check/factory.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								drivers/plugins/body-check/factory.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
| package body_check | ||||
|  | ||||
| import ( | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
| 	"github.com/eolinker/eosc" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	Name = "body_check" | ||||
| ) | ||||
|  | ||||
| func Register(register eosc.IExtenderDriverRegister) { | ||||
| 	register.RegisterExtenderDriver(Name, NewFactory()) | ||||
| } | ||||
|  | ||||
| func NewFactory() eosc.IExtenderDriverFactory { | ||||
| 	return drivers.NewFactory[Config](Create) | ||||
| } | ||||
							
								
								
									
										41
									
								
								drivers/plugins/counter/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								drivers/plugins/counter/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/counter/separator" | ||||
| 	"github.com/eolinker/eosc" | ||||
| ) | ||||
|  | ||||
| type Config struct { | ||||
| 	Match *Match               `json:"match" label:"响应匹配规则"` | ||||
| 	Count *separator.CountRule `json:"count" label:"计数规则"` | ||||
| 	Key   string               `json:"key" label:"计数字段名称"` | ||||
| 	Cache eosc.RequireId       `json:"cache" label:"缓存计数器"` | ||||
| } | ||||
|  | ||||
| type Match struct { | ||||
| 	Params      []*MatchParam `json:"params" label:"匹配参数列表"` | ||||
| 	StatusCodes []int         `json:"status_codes" label:"匹配响应状态码列表"` | ||||
| 	Type        string        `json:"type" label:"匹配类型" enum:"json"` | ||||
| } | ||||
|  | ||||
| func (m *Match) GenerateHandler() []IMatcher { | ||||
| 	matcher := make([]IMatcher, 0, 2) | ||||
| 	matcher = append(matcher, newStatusCodeMatcher(m.StatusCodes)) | ||||
| 	matcher = append(matcher, newJsonMatcher(m.Params)) | ||||
| 	return matcher | ||||
| } | ||||
|  | ||||
| func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { | ||||
| 	counter, err := separator.GetCounter(conf.Count) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	bc := &executor{ | ||||
| 		WorkerBase:       drivers.Worker(id, name), | ||||
| 		matchers:         conf.Match.GenerateHandler(), | ||||
| 		separatorCounter: counter, | ||||
| 	} | ||||
|  | ||||
| 	return bc, nil | ||||
| } | ||||
							
								
								
									
										5
									
								
								drivers/plugins/counter/counter/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								drivers/plugins/counter/counter/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package counter | ||||
|  | ||||
| type IClient interface { | ||||
| 	Get(key string) (int64, error) | ||||
| } | ||||
							
								
								
									
										79
									
								
								drivers/plugins/counter/counter/client_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								drivers/plugins/counter/counter/client_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"math/rand" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| ) | ||||
|  | ||||
| func TestLocalCounter(t *testing.T) { | ||||
| 	client := NewHTTPClient("", nil) | ||||
| 	lc := NewLocalCounter("test", client) | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	wg.Add(100) | ||||
| 	for i := 0; i < 100; i++ { | ||||
| 		time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			count := rand.Int63n(20) | ||||
| 			err := lc.Lock(count) | ||||
| 			if err != nil { | ||||
| 				t.Error(err) | ||||
| 				return | ||||
| 			} | ||||
| 			condition := rand.Intn(100) | ||||
| 			switch condition % 2 { | ||||
| 			case 0: | ||||
| 				err = lc.Complete(count) | ||||
| 			case 1: | ||||
| 				err = lc.RollBack(count) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				t.Error(err) | ||||
| 				return | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| func TestRedisCounter(t *testing.T) { | ||||
| 	client := NewHTTPClient("", nil) | ||||
| 	key := "apinto-apiddww" | ||||
| 	lc := NewLocalCounter(key, client) | ||||
| 	redisConn := redis.NewClient(&redis.Options{ | ||||
| 		Addr:     "172.18.65.42:6380", | ||||
| 		Password: "password", // 如果有密码,请填写密码 | ||||
| 		DB:       9,          // 选择数据库,默认为0 | ||||
| 	}) | ||||
| 	rc := NewRedisCounter(key, redisConn, client, lc) | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	wg.Add(100) | ||||
| 	for i := 0; i < 100; i++ { | ||||
| 		time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			count := rand.Int63n(20) | ||||
| 			err := rc.Lock(count) | ||||
| 			if err != nil { | ||||
| 				t.Error(err) | ||||
| 				return | ||||
| 			} | ||||
| 			condition := rand.Intn(100) | ||||
| 			switch condition % 2 { | ||||
| 			case 0: | ||||
| 				err = rc.Complete(count) | ||||
| 			case 1: | ||||
| 				err = rc.RollBack(count) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				t.Error(err) | ||||
| 				return | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 	wg.Wait() | ||||
| } | ||||
							
								
								
									
										26
									
								
								drivers/plugins/counter/counter/counter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								drivers/plugins/counter/counter/counter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| package counter | ||||
|  | ||||
| import "fmt" | ||||
|  | ||||
| type ICounter interface { | ||||
| 	// Lock 锁定次数 | ||||
| 	Lock(count int64) error | ||||
| 	// Complete 完成扣次操作 | ||||
| 	Complete(count int64) error | ||||
| 	// RollBack 回滚 | ||||
| 	RollBack(count int64) error | ||||
| 	// ResetClient 重置客户端 | ||||
| 	ResetClient(client IClient) | ||||
| } | ||||
|  | ||||
| func getRemainCount(client IClient, key string, count int64) (int64, error) { | ||||
| 	remain, err := client.Get(key) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	remain -= count | ||||
| 	if remain < 0 { | ||||
| 		return 0, fmt.Errorf("no enough, key:%s, remain:%d, count:%d", key, remain, count) | ||||
| 	} | ||||
| 	return remain, nil | ||||
| } | ||||
							
								
								
									
										58
									
								
								drivers/plugins/counter/counter/http-client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								drivers/plugins/counter/counter/http-client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/oj" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/jp" | ||||
| 	"github.com/valyala/fasthttp" | ||||
| ) | ||||
|  | ||||
| var _ IClient = (*HTTPClient)(nil) | ||||
|  | ||||
| var httpClient = fasthttp.Client{ | ||||
| 	Name: "apinto-counter", | ||||
| } | ||||
|  | ||||
| type HTTPClient struct { | ||||
| 	uri     string | ||||
| 	headers map[string]string | ||||
| 	// jsonExpr 经过编译的JSONPath表达式 | ||||
| 	jsonExpr jp.Expr | ||||
| } | ||||
|  | ||||
| func NewHTTPClient(uri string, jsonExpr jp.Expr) *HTTPClient { | ||||
| 	return &HTTPClient{uri: uri, jsonExpr: jsonExpr} | ||||
| } | ||||
|  | ||||
| func (H *HTTPClient) Get(key string) (int64, error) { | ||||
| 	req := fasthttp.AcquireRequest() | ||||
| 	resp := fasthttp.AcquireResponse() | ||||
| 	defer fasthttp.ReleaseRequest(req) | ||||
| 	req.SetRequestURI(H.uri) | ||||
| 	req.Header.SetMethod("GET") | ||||
| 	for name, value := range H.headers { | ||||
| 		req.Header.Set(name, value) | ||||
| 	} | ||||
| 	err := httpClient.Do(req, resp) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if resp.StatusCode() != fasthttp.StatusOK { | ||||
| 		return 0, fmt.Errorf("error http status code: %d,key: %s,uri %s", resp.StatusCode(), key, H.uri) | ||||
| 	} | ||||
| 	result, err := oj.Parse(resp.Body()) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	// 解析JSON | ||||
| 	v := H.jsonExpr.Get(result) | ||||
| 	if v == nil || len(v) < 1 { | ||||
| 		return 0, fmt.Errorf("no found key: %s,uri: %s", key) | ||||
| 	} | ||||
| 	if len(v) != 1 { | ||||
| 		return 0, fmt.Errorf("invalid value: %v,key: %s,uri: %s", v, key, H.uri) | ||||
| 	} | ||||
| 	return v[0].(int64), nil | ||||
| } | ||||
							
								
								
									
										76
									
								
								drivers/plugins/counter/counter/local.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								drivers/plugins/counter/counter/local.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var _ ICounter = (*LocalCounter)(nil) | ||||
|  | ||||
| func NewLocalCounter(key string, client IClient) *LocalCounter { | ||||
| 	return &LocalCounter{key: key, client: client} | ||||
| } | ||||
|  | ||||
| // LocalCounter 本地计数器 | ||||
| type LocalCounter struct { | ||||
| 	key string | ||||
| 	// 剩余次数 | ||||
| 	remain int64 | ||||
| 	// 锁定次数 | ||||
| 	lock int64 | ||||
|  | ||||
| 	locker sync.Mutex | ||||
|  | ||||
| 	resetTime time.Time | ||||
| 	client    IClient | ||||
| } | ||||
|  | ||||
| func (c *LocalCounter) Lock(count int64) error { | ||||
| 	c.locker.Lock() | ||||
| 	defer c.locker.Unlock() | ||||
| 	remain := c.remain - count | ||||
| 	if remain < 0 { | ||||
| 		now := time.Now() | ||||
| 		if now.Sub(c.resetTime) < 10*time.Second { | ||||
| 			return fmt.Errorf("no enough, key:%s, remain:%d, count:%d", c.key, c.remain, count) | ||||
| 		} | ||||
|  | ||||
| 		var err error | ||||
| 		c.resetTime = now | ||||
| 		// 获取最新的次数 | ||||
| 		remain, err = getRemainCount(c.client, c.key, count) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	c.remain = remain | ||||
| 	c.lock += count | ||||
| 	fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "lock", "remain:", c.remain, ",lock:", c.lock, ",count:", count) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *LocalCounter) Complete(count int64) error { | ||||
| 	c.locker.Lock() | ||||
| 	defer c.locker.Unlock() | ||||
| 	// 需要解除已经锁定的部分次数 | ||||
| 	c.lock -= count | ||||
| 	fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "complete", "remain:", c.remain, ",lock:", c.lock, ",count:", count) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *LocalCounter) RollBack(count int64) error { | ||||
| 	c.locker.Lock() | ||||
| 	defer c.locker.Unlock() | ||||
| 	// 需要解除已经锁定的部分次数,并且增加剩余次数 | ||||
| 	c.remain += c.lock | ||||
| 	c.lock -= count | ||||
| 	fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "rollback", "remain:", c.remain, ",lock:", c.lock, ",count:", count) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *LocalCounter) ResetClient(client IClient) { | ||||
| 	c.locker.Lock() | ||||
| 	defer c.locker.Unlock() | ||||
| 	c.client = client | ||||
| } | ||||
							
								
								
									
										171
									
								
								drivers/plugins/counter/counter/redis.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								drivers/plugins/counter/counter/redis.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,171 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| ) | ||||
|  | ||||
| var _ ICounter = (*RedisCounter)(nil) | ||||
|  | ||||
| type RedisCounter struct { | ||||
| 	ctx   context.Context | ||||
| 	key   string | ||||
| 	redis redis.Cmdable | ||||
|  | ||||
| 	client    IClient | ||||
| 	locker    sync.Mutex | ||||
| 	resetTime time.Time | ||||
|  | ||||
| 	localCounter ICounter | ||||
|  | ||||
| 	lockerKey string | ||||
| 	lockKey   string | ||||
| 	remainKey string | ||||
| } | ||||
|  | ||||
| func NewRedisCounter(key string, redis redis.Cmdable, client IClient, localCounter ICounter) *RedisCounter { | ||||
|  | ||||
| 	return &RedisCounter{ | ||||
| 		key:          key, | ||||
| 		redis:        redis, | ||||
| 		client:       client, | ||||
| 		localCounter: localCounter, | ||||
| 		ctx:          context.Background(), | ||||
| 		lockerKey:    fmt.Sprintf("%s:locker", key), | ||||
| 		lockKey:      fmt.Sprintf("%s:lock", key), | ||||
| 		remainKey:    fmt.Sprintf("%s:remain", key), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *RedisCounter) Lock(count int64) error { | ||||
| 	r.locker.Lock() | ||||
| 	defer r.locker.Unlock() | ||||
| 	if r.redis == nil { | ||||
| 		// 如果Redis没有配置,使用本地计数器 | ||||
| 		return r.localCounter.Lock(count) | ||||
| 	} | ||||
|  | ||||
| 	err := r.acquireLock() | ||||
| 	if err != nil { | ||||
| 		if err == redis.ErrClosed { | ||||
| 			return r.localCounter.Lock(count) | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	defer r.releaseLock() | ||||
|  | ||||
| 	// 获取最新的次数 | ||||
| 	remain, err := r.redis.Get(r.ctx, r.remainKey).Int64() | ||||
| 	if err != nil { | ||||
| 		if err == redis.ErrClosed { | ||||
| 			return r.localCounter.Lock(count) | ||||
| 		} else if err != redis.Nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	remain -= count | ||||
| 	if remain < 0 { | ||||
| 		now := time.Now() | ||||
| 		if now.Sub(r.resetTime) < 10*time.Second { | ||||
| 			return fmt.Errorf("no enough, ddd key:%s, remain:%d, count:%d", r.key, remain+count, count) | ||||
| 		} | ||||
|  | ||||
| 		r.resetTime = now | ||||
| 		lock, err := r.redis.Get(r.ctx, r.lockKey).Int64() | ||||
| 		if err != nil { | ||||
| 			if err == redis.ErrClosed { | ||||
| 				return r.localCounter.Lock(count) | ||||
| 			} else if err != redis.Nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		remain, err = getRemainCount(r.client, r.key, count+lock) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	r.redis.Set(r.ctx, r.remainKey, remain, -1) | ||||
| 	r.redis.IncrBy(r.ctx, r.lockKey, count) | ||||
| 	fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "lock", "remain:", remain, "count:", count) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *RedisCounter) Complete(count int64) error { | ||||
| 	r.locker.Lock() | ||||
| 	defer r.locker.Unlock() | ||||
| 	if r.redis == nil { | ||||
| 		// 如果Redis没有配置,使用本地计数器 | ||||
| 		return r.localCounter.Complete(count) | ||||
| 	} | ||||
|  | ||||
| 	err := r.acquireLock() | ||||
| 	if err != nil { | ||||
| 		if err == redis.ErrClosed { | ||||
| 			return r.localCounter.Lock(count) | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	defer r.releaseLock() | ||||
|  | ||||
| 	r.redis.IncrBy(r.ctx, r.lockKey, -count) | ||||
| 	fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "complete", "count:", count) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *RedisCounter) RollBack(count int64) error { | ||||
| 	r.locker.Lock() | ||||
| 	defer r.locker.Unlock() | ||||
| 	if r.redis == nil { | ||||
| 		// 如果Redis没有配置,使用本地计数器 | ||||
| 		return r.localCounter.RollBack(count) | ||||
| 	} | ||||
| 	err := r.acquireLock() | ||||
| 	if err != nil { | ||||
| 		if err == redis.ErrClosed { | ||||
| 			return r.localCounter.Lock(count) | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	defer r.releaseLock() | ||||
|  | ||||
| 	r.redis.IncrBy(r.ctx, r.remainKey, count) | ||||
| 	r.redis.IncrBy(r.ctx, r.lockKey, -count) | ||||
| 	fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "rollback", "count:", count) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *RedisCounter) ResetClient(client IClient) { | ||||
| 	r.locker.Lock() | ||||
| 	defer r.locker.Unlock() | ||||
| 	r.client = client | ||||
| } | ||||
|  | ||||
| func (r *RedisCounter) acquireLock() error { | ||||
| 	for { | ||||
| 		// 生成唯一的锁值 | ||||
| 		lockValue := time.Now().UnixNano() | ||||
|  | ||||
| 		// Redis连接失败,使用本地计数器 | ||||
| 		ok, err := r.redis.SetNX(r.ctx, r.lockerKey, lockValue, 10*time.Second).Result() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if ok { | ||||
| 			// 设置锁成功 | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *RedisCounter) releaseLock() error { | ||||
| 	return r.redis.Del(r.ctx, r.lockerKey).Err() | ||||
| } | ||||
							
								
								
									
										113
									
								
								drivers/plugins/counter/executor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								drivers/plugins/counter/executor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/counter/counter" | ||||
|  | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/counter/separator" | ||||
|  | ||||
| 	"github.com/eolinker/eosc" | ||||
| 	"github.com/eolinker/eosc/eocontext" | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| var _ http_service.HttpFilter = (*executor)(nil) | ||||
| var _ eocontext.IFilter = (*executor)(nil) | ||||
|  | ||||
| type executor struct { | ||||
| 	drivers.WorkerBase | ||||
| 	matchers         []IMatcher | ||||
| 	separatorCounter separator.ICounter | ||||
| 	counters         eosc.Untyped[string, counter.ICounter] | ||||
| 	client           counter.IClient | ||||
| 	keyGenerate      IKeyGenerator | ||||
| } | ||||
|  | ||||
| func (b *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) { | ||||
| 	return http_service.DoHttpFilter(b, ctx, next) | ||||
| } | ||||
|  | ||||
| func (b *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error { | ||||
| 	counter, has := b.counters.Get(b.keyGenerate.Key(ctx)) | ||||
| 	if !has { | ||||
|  | ||||
| 	} | ||||
| 	var count int64 = 1 | ||||
| 	var err error | ||||
| 	if b.separatorCounter != nil { | ||||
|  | ||||
| 		separatorCounter := b.separatorCounter | ||||
| 		count, err = separatorCounter.Count(ctx) | ||||
| 		if err != nil { | ||||
| 			ctx.Response().SetStatus(400, "400") | ||||
| 			return fmt.Errorf("%s count error", separatorCounter.Name()) | ||||
| 		} | ||||
| 		if count > separatorCounter.Max() { | ||||
| 			ctx.Response().SetStatus(403, "not allow") | ||||
| 			return fmt.Errorf("%s number exceed", separatorCounter.Name()) | ||||
| 		} else if count == 0 { | ||||
| 			ctx.Response().SetStatus(400, "400") | ||||
| 			return fmt.Errorf("%s value is missing", separatorCounter.Name()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	err = counter.Lock(count) | ||||
| 	if err != nil { | ||||
| 		// 次数不足,直接返回 | ||||
| 		//return fmt.Errorf("no enough, key:%s, remain:%d, count:%d", b.counters.Name(), b.counters.Remain(), count | ||||
| 	} | ||||
| 	if next != nil { | ||||
| 		err = next.DoChain(ctx) | ||||
| 		if err != nil { | ||||
| 			// 转发失败,回滚次数 | ||||
| 			return counter.RollBack(count) | ||||
| 			//return err | ||||
| 		} | ||||
| 	} | ||||
| 	match := true | ||||
| 	for _, matcher := range b.matchers { | ||||
| 		ok := matcher.Match(ctx) | ||||
| 		if !ok { | ||||
| 			match = false | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if match { | ||||
| 		// 匹配,扣减次数 | ||||
| 		return counter.Complete(count) | ||||
| 	} | ||||
| 	// 不匹配,回滚次数 | ||||
| 	return counter.RollBack(count) | ||||
| } | ||||
|  | ||||
| func (b *executor) Start() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (b *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { | ||||
| 	cfg, ok := conf.(*Config) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("invalid config, driver: %s", Name) | ||||
| 	} | ||||
| 	counter, err := separator.GetCounter(cfg.Count) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	b.separatorCounter = counter | ||||
| 	b.matchers = cfg.Match.GenerateHandler() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (b *executor) Stop() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (b *executor) Destroy() { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (b *executor) CheckSkill(skill string) bool { | ||||
| 	return http_service.FilterSkillName == skill | ||||
| } | ||||
							
								
								
									
										18
									
								
								drivers/plugins/counter/factory.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								drivers/plugins/counter/factory.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
| 	"github.com/eolinker/eosc" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	Name = "counter" | ||||
| ) | ||||
|  | ||||
| func Register(register eosc.IExtenderDriverRegister) { | ||||
| 	register.RegisterExtenderDriver(Name, NewFactory()) | ||||
| } | ||||
|  | ||||
| func NewFactory() eosc.IExtenderDriverFactory { | ||||
| 	return drivers.NewFactory[Config](Create) | ||||
| } | ||||
							
								
								
									
										49
									
								
								drivers/plugins/counter/key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								drivers/plugins/counter/key.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| var _ IKeyGenerator = (*keyGenerate)(nil) | ||||
|  | ||||
| type IKeyGenerator interface { | ||||
| 	Key(ctx http_service.IHttpContext) string | ||||
| } | ||||
|  | ||||
| func newKeyGenerate(key string) *keyGenerate { | ||||
| 	key = strings.TrimSpace(key) | ||||
| 	tmp := strings.Split(key, ":") | ||||
|  | ||||
| 	keys := make([]string, 0, len(tmp)) | ||||
| 	variables := make([]string, 0, len(tmp)) | ||||
| 	for _, t := range tmp { | ||||
| 		t = strings.TrimSpace(t) | ||||
| 		tLen := len(t) | ||||
| 		if tLen > 0 { | ||||
| 			if tLen > 1 && t[0] == '$' { | ||||
| 				variables = append(variables, t[1:]) | ||||
| 				keys = append(keys, "%s") | ||||
| 			} else { | ||||
| 				keys = append(keys, t) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return &keyGenerate{format: strings.Join(keys, ":"), variables: variables} | ||||
| } | ||||
|  | ||||
| type keyGenerate struct { | ||||
| 	format string | ||||
| 	// 变量列表 | ||||
| 	variables []string | ||||
| } | ||||
|  | ||||
| func (k *keyGenerate) Key(ctx http_service.IHttpContext) string { | ||||
| 	variables := make([]interface{}, 0, len(k.variables)) | ||||
| 	for _, v := range k.variables { | ||||
| 		variables = append(variables, ctx.GetLabel(v)) | ||||
| 	} | ||||
| 	return fmt.Sprintf(k.format, variables...) | ||||
| } | ||||
							
								
								
									
										8
									
								
								drivers/plugins/counter/key_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								drivers/plugins/counter/key_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| package counter | ||||
|  | ||||
| import "testing" | ||||
|  | ||||
| func TestGenerateKey(t *testing.T) { | ||||
| 	key := newKeyGenerate("a:$b:$c:d") | ||||
| 	t.Log(key.format, key.variables) | ||||
| } | ||||
							
								
								
									
										121
									
								
								drivers/plugins/counter/match.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								drivers/plugins/counter/match.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | ||||
| package counter | ||||
|  | ||||
| import ( | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
|  | ||||
| 	"github.com/eolinker/eosc/log" | ||||
| 	"github.com/ohler55/ojg/jp" | ||||
| 	"github.com/ohler55/ojg/oj" | ||||
| ) | ||||
|  | ||||
| type IMatcher interface { | ||||
| 	Match(ctx http_service.IHttpContext) bool | ||||
| } | ||||
|  | ||||
| type MatchParam struct { | ||||
| 	Key   string   `json:"key"` | ||||
| 	Kind  string   `json:"kind"` // int|string|bool | ||||
| 	Value []string `json:"value"` | ||||
| } | ||||
|  | ||||
| func newJsonMatcher(params []*MatchParam) *jsonMatcher { | ||||
| 	ps := make([]*jsonMatchParam, 0, len(params)) | ||||
| 	for _, p := range params { | ||||
| 		key := p.Key | ||||
| 		if !strings.HasPrefix(p.Key, "$.") { | ||||
| 			key = "$." + p.Key | ||||
| 		} | ||||
| 		expr, err := jp.ParseString(key) | ||||
| 		if err != nil { | ||||
| 			log.Errorf("json path parse error: %v,key is %s", err, key) | ||||
| 			continue | ||||
| 		} | ||||
| 		ps = append(ps, &jsonMatchParam{ | ||||
| 			MatchParam: p, | ||||
| 			expr:       expr, | ||||
| 		}) | ||||
| 	} | ||||
| 	return &jsonMatcher{params: ps} | ||||
| } | ||||
|  | ||||
| type jsonMatcher struct { | ||||
| 	params []*jsonMatchParam | ||||
| } | ||||
|  | ||||
| type jsonMatchParam struct { | ||||
| 	*MatchParam | ||||
| 	expr jp.Expr | ||||
| } | ||||
|  | ||||
| func (m *jsonMatcher) Match(ctx http_service.IHttpContext) bool { | ||||
| 	if len(m.params) < 1 { | ||||
| 		return true | ||||
| 	} | ||||
| 	body := ctx.Response().GetBody() | ||||
| 	tmp, err := oj.Parse(body) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("parse body error: %v,body is %s", err, body) | ||||
| 		return true | ||||
| 	} | ||||
| 	for _, p := range m.params { | ||||
| 		results := p.expr.Get(tmp) | ||||
| 		for _, v := range p.Value { | ||||
| 			for _, r := range results { | ||||
| 				switch p.Kind { | ||||
| 				case "int": | ||||
| 					t, ok := r.(int64) | ||||
| 					if !ok { | ||||
| 						return false | ||||
| 					} | ||||
| 					val, _ := strconv.ParseInt(v, 10, 64) | ||||
| 					if t == val { | ||||
| 						return true | ||||
| 					} | ||||
| 				case "bool": | ||||
| 					t, ok := r.(bool) | ||||
| 					if !ok { | ||||
| 						return false | ||||
| 					} | ||||
| 					val, err := strconv.ParseBool(v) | ||||
| 					if err != nil { | ||||
| 						return false | ||||
| 					} | ||||
| 					return t == val | ||||
| 				default: | ||||
| 					t, ok := r.(string) | ||||
| 					if !ok { | ||||
| 						return false | ||||
| 					} | ||||
| 					if t == v { | ||||
| 						return true | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func newStatusCodeMatcher(codes []int) *statusCodeMatcher { | ||||
| 	return &statusCodeMatcher{codes: codes} | ||||
| } | ||||
|  | ||||
| type statusCodeMatcher struct { | ||||
| 	codes []int | ||||
| } | ||||
|  | ||||
| func (m *statusCodeMatcher) Match(ctx http_service.IHttpContext) bool { | ||||
| 	if len(m.codes) < 1 { | ||||
| 		return true | ||||
| 	} | ||||
| 	code := ctx.Response().StatusCode() | ||||
| 	for _, c := range m.codes { | ||||
| 		if c == code { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										67
									
								
								drivers/plugins/counter/separator/count.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								drivers/plugins/counter/separator/count.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | ||||
| package separator | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	ArrayCountType  = "array" | ||||
| 	SplitCountType  = "splite" | ||||
| 	LengthCountType = "length" | ||||
| 	CountTypes      = []string{ | ||||
| 		ArrayCountType, | ||||
| 		SplitCountType, | ||||
| 		LengthCountType, | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| type CountRule struct { | ||||
| 	RequestBodyType string `json:"request_body_type" label:"请求体类型" enum:"form-data,json"` | ||||
| 	Key             string `json:"key" label:"参数名称(支持json path)"` | ||||
| 	Separator       string `json:"separator" label:"分隔符" switch:"separator_type===splite"` | ||||
| 	SeparatorType   string `json:"separator_type" label:"分割类型" enum:"splite,array,length"` | ||||
| 	Max             int64  `json:"max" label:"计数最大值"` | ||||
| } | ||||
|  | ||||
| type ICounter interface { | ||||
| 	Count(ctx http_service.IHttpContext) (int64, error) | ||||
| 	Max() int64 | ||||
| 	Name() string | ||||
| } | ||||
|  | ||||
| func GetCounter(rule *CountRule) (ICounter, error) { | ||||
| 	switch strings.ToLower(rule.RequestBodyType) { | ||||
| 	case "form-data": | ||||
| 		return NewFormDataCounter(rule) | ||||
| 	case "multipart-formdata": | ||||
| 		return NewFileCounter(rule) | ||||
| 	case "json": | ||||
| 		return NewJsonCounter(rule) | ||||
| 	default: | ||||
| 		return NewFormDataCounter(rule) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func splitCount(origin string, split string) int64 { | ||||
| 	if len(split) == 0 { | ||||
| 		return 0 | ||||
| 	} | ||||
|  | ||||
| 	vs := strings.Split(origin, string(split[0])) | ||||
| 	var count int64 = 0 | ||||
| 	for _, v := range vs { | ||||
| 		if v != "" { | ||||
| 			childCount := splitCount(v, split[1:]) | ||||
| 			if childCount == 0 { | ||||
| 				count += 1 | ||||
| 			} else { | ||||
| 				count += childCount | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	return count | ||||
| } | ||||
							
								
								
									
										74
									
								
								drivers/plugins/counter/separator/file.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								drivers/plugins/counter/separator/file.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,74 @@ | ||||
| package separator | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"mime" | ||||
| 	"mime/multipart" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| var _ ICounter = (*FileCounter)(nil) | ||||
|  | ||||
| const defaultMultipartMemory = 32 << 20 // 32 MB | ||||
|  | ||||
| type FileCounter struct { | ||||
| 	typ      string | ||||
| 	split    string | ||||
| 	name     string | ||||
| 	max      int64 | ||||
| 	splitLen int | ||||
| } | ||||
|  | ||||
| func (f *FileCounter) Name() string { | ||||
| 	return f.name | ||||
| } | ||||
|  | ||||
| func NewFileCounter(rule *CountRule) (*FileCounter, error) { | ||||
| 	var splitLen int | ||||
| 	if rule.SeparatorType == LengthCountType { | ||||
| 		var err error | ||||
| 		splitLen, err = strconv.Atoi(rule.Separator) | ||||
| 		if err != nil { | ||||
| 			splitLen = 1000 | ||||
| 		} | ||||
| 	} | ||||
| 	return &FileCounter{name: rule.Key, split: rule.Separator, typ: rule.SeparatorType, max: rule.Max, splitLen: splitLen}, nil | ||||
| } | ||||
|  | ||||
| func (f *FileCounter) Count(ctx http_service.IHttpContext) (int64, error) { | ||||
| 	raw, _ := ctx.Request().Body().RawBody() | ||||
| 	d, params, _ := mime.ParseMediaType(ctx.Request().ContentType()) | ||||
| 	if !(d == "multipart/form-data") { | ||||
| 		return -1, fmt.Errorf("need content-type: multipart/form-data,now: %s", d) | ||||
| 	} | ||||
| 	boundary, ok := params["boundary"] | ||||
| 	if !ok { | ||||
| 		return -1, fmt.Errorf("missing boundary") | ||||
| 	} | ||||
| 	body := io.NopCloser(bytes.NewBuffer(raw)) | ||||
| 	reader := multipart.NewReader(body, boundary) | ||||
| 	form, err := reader.ReadForm(defaultMultipartMemory) | ||||
| 	if err != nil { | ||||
| 		return -1, fmt.Errorf("parse form param err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	switch f.typ { | ||||
| 	case LengthCountType: | ||||
| 		value := strings.Join(form.Value[f.name], "") | ||||
| 		l := len([]rune(value)) | ||||
| 		if l%f.splitLen == 0 { | ||||
| 			return int64(l / f.splitLen), nil | ||||
| 		} | ||||
| 		return int64(l/f.splitLen + 1), nil | ||||
| 	} | ||||
| 	return splitCount(strings.Join(form.Value[f.name], f.split), f.split), nil | ||||
| } | ||||
|  | ||||
| func (f *FileCounter) Max() int64 { | ||||
| 	return f.max | ||||
| } | ||||
							
								
								
									
										59
									
								
								drivers/plugins/counter/separator/formdata.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								drivers/plugins/counter/separator/formdata.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| package separator | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| var _ ICounter = (*FormDataCounter)(nil) | ||||
|  | ||||
| type FormDataCounter struct { | ||||
| 	typ      string | ||||
| 	split    string | ||||
| 	name     string | ||||
| 	max      int64 | ||||
| 	splitLen int | ||||
| } | ||||
|  | ||||
| func (f *FormDataCounter) Name() string { | ||||
| 	return f.name | ||||
| } | ||||
|  | ||||
| func NewFormDataCounter(rule *CountRule) (*FormDataCounter, error) { | ||||
| 	var splitLen int | ||||
| 	if rule.SeparatorType == LengthCountType { | ||||
| 		var err error | ||||
| 		splitLen, err = strconv.Atoi(rule.Separator) | ||||
| 		if err != nil { | ||||
| 			splitLen = 1000 | ||||
| 		} | ||||
| 	} | ||||
| 	return &FormDataCounter{name: rule.Key, split: rule.Separator, typ: rule.SeparatorType, max: rule.Max, splitLen: splitLen}, nil | ||||
| } | ||||
|  | ||||
| func (f *FormDataCounter) Count(ctx http_service.IHttpContext) (int64, error) { | ||||
| 	body, _ := ctx.Request().Body().RawBody() | ||||
| 	u, err := url.ParseQuery(string(body)) | ||||
| 	if err != nil { | ||||
| 		return -1, fmt.Errorf("parse form data error:%v", err) | ||||
| 	} | ||||
| 	switch f.typ { | ||||
| 	case SplitCountType: | ||||
| 		return splitCount(u.Get(f.name), f.split), nil | ||||
| 	case LengthCountType: | ||||
| 		value := u.Get(f.name) | ||||
| 		l := len([]rune(value)) | ||||
| 		if l%f.splitLen == 0 { | ||||
| 			return int64(l / f.splitLen), nil | ||||
| 		} | ||||
| 		return int64(l/f.splitLen + 1), nil | ||||
| 	} | ||||
| 	return splitCount(u.Get(f.name), f.split), nil | ||||
| } | ||||
|  | ||||
| func (f *FormDataCounter) Max() int64 { | ||||
| 	return f.max | ||||
| } | ||||
							
								
								
									
										104
									
								
								drivers/plugins/counter/separator/json.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								drivers/plugins/counter/separator/json.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,104 @@ | ||||
| package separator | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/oj" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/jp" | ||||
| ) | ||||
|  | ||||
| var _ ICounter = (*JsonCounter)(nil) | ||||
|  | ||||
| type JsonCounter struct { | ||||
| 	max      int64 | ||||
| 	split    string | ||||
| 	expr     jp.Expr | ||||
| 	typ      string | ||||
| 	name     string | ||||
| 	splitLen int | ||||
| } | ||||
|  | ||||
| func (j *JsonCounter) Name() string { | ||||
| 	return j.name | ||||
| } | ||||
|  | ||||
| func NewJsonCounter(rule *CountRule) (*JsonCounter, error) { | ||||
| 	typeValid := false | ||||
| 	for _, t := range CountTypes { | ||||
| 		if t == rule.SeparatorType { | ||||
| 			typeValid = true | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if !typeValid { | ||||
| 		return nil, fmt.Errorf("json count split type config error,now type is %s, need array or split", rule.SeparatorType) | ||||
| 	} | ||||
| 	expr, err := jp.ParseString(rule.Key) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("json path parse error:%v", err) | ||||
| 	} | ||||
| 	if rule.Max < 1 || rule.Max > 2000 { | ||||
| 		rule.Max = 2000 | ||||
| 	} | ||||
| 	var splitLen int | ||||
| 	if rule.SeparatorType == LengthCountType { | ||||
| 		splitLen, err = strconv.Atoi(rule.Separator) | ||||
| 		if err != nil { | ||||
| 			splitLen = 1000 | ||||
| 		} | ||||
| 	} | ||||
| 	return &JsonCounter{max: rule.Max, split: rule.Separator, expr: expr, typ: rule.SeparatorType, name: strings.TrimPrefix(rule.Key, "$."), splitLen: splitLen}, nil | ||||
| } | ||||
|  | ||||
| func (j *JsonCounter) Count(ctx http_service.IHttpContext) (int64, error) { | ||||
| 	body, _ := ctx.Request().Body().RawBody() | ||||
| 	obj, err := oj.Parse(body) | ||||
| 	if err != nil { | ||||
| 		return -1, fmt.Errorf("parse json body error:%v, body is %s", err, string(body)) | ||||
| 	} | ||||
| 	results := j.expr.Get(obj) | ||||
| 	switch j.typ { | ||||
| 	case SplitCountType: | ||||
| 		if len(results) > 0 { | ||||
| 			origin, ok := results[0].(string) | ||||
| 			if !ok { | ||||
| 				return -1, fmt.Errorf("json path %s get value is not string", j.name) | ||||
| 			} | ||||
| 			return splitCount(origin, j.split), nil | ||||
| 		} | ||||
| 	case ArrayCountType: | ||||
| 		if len(results) > 0 { | ||||
| 			switch v := results[0].(type) { | ||||
| 			case []interface{}: | ||||
| 				{ | ||||
| 					return int64(len(v)), nil | ||||
| 				} | ||||
| 			case map[string]interface{}: | ||||
| 				return int64(len(v)), nil | ||||
| 			} | ||||
| 		} | ||||
| 	case LengthCountType: | ||||
| 		if len(results) > 0 { | ||||
| 			origin, ok := results[0].(string) | ||||
| 			if !ok { | ||||
| 				return -1, fmt.Errorf("json path %s get value is not string", j.name) | ||||
| 			} | ||||
| 			l := len([]rune(origin)) | ||||
|  | ||||
| 			if l%j.splitLen == 0 { | ||||
| 				return int64(l / j.splitLen), nil | ||||
| 			} | ||||
| 			return int64(l/j.splitLen + 1), nil | ||||
| 		} | ||||
| 	} | ||||
| 	return -1, nil | ||||
| } | ||||
|  | ||||
| func (j *JsonCounter) Max() int64 { | ||||
| 	return j.max | ||||
| } | ||||
							
								
								
									
										120
									
								
								drivers/plugins/extra-params_v2/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								drivers/plugins/extra-params_v2/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | ||||
| package extra_params_v2 | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
|  | ||||
| 	dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params" | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| type Config struct { | ||||
| 	Params          []*ExtraParam `json:"params" label:"参数列表"` | ||||
| 	RequestBodyType string        `json:"request_body_type" enum:"form-data,json" label:"请求体类型"` | ||||
| 	ErrorType       string        `json:"error_type" enum:"text,json" label:"报错输出格式"` | ||||
| } | ||||
|  | ||||
| func (c *Config) doCheck() error { | ||||
| 	c.ErrorType = strings.ToLower(c.ErrorType) | ||||
| 	if c.ErrorType != "text" && c.ErrorType != "json" { | ||||
| 		c.ErrorType = "text" | ||||
| 	} | ||||
|  | ||||
| 	for _, param := range c.Params { | ||||
| 		if param.Name == "" { | ||||
| 			return fmt.Errorf(paramNameErrInfo) | ||||
| 		} | ||||
|  | ||||
| 		param.Position = strings.ToLower(param.Position) | ||||
| 		if param.Position != "query" && param.Position != "header" && param.Position != "body" { | ||||
| 			return fmt.Errorf(paramPositionErrInfo, param.Position) | ||||
| 		} | ||||
|  | ||||
| 		param.Conflict = strings.ToLower(param.Conflict) | ||||
| 		if param.Conflict != paramOrigin && param.Conflict != paramConvert && param.Conflict != paramError { | ||||
| 			param.Conflict = paramConvert | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type ExtraParam struct { | ||||
| 	Name     string   `json:"name" label:"参数名"` | ||||
| 	Type     string   `json:"type" label:"参数类型" enum:"string,int,float,bool,$datetime,$md5"` | ||||
| 	Position string   `json:"position" enum:"header,query,body" label:"参数位置"` | ||||
| 	Value    []string `json:"value" label:"参数值列表"` | ||||
| 	Conflict string   `json:"conflict" label:"参数冲突时的处理方式" enum:"origin,convert,error"` | ||||
| } | ||||
|  | ||||
| type baseParam struct { | ||||
| 	header []*paramInfo | ||||
| 	query  []*paramInfo | ||||
| 	body   []*paramInfo | ||||
| } | ||||
|  | ||||
| func generateBaseParam(params []*ExtraParam) *baseParam { | ||||
| 	b := &baseParam{ | ||||
| 		header: make([]*paramInfo, 0), | ||||
| 		query:  make([]*paramInfo, 0), | ||||
| 		body:   make([]*paramInfo, 0), | ||||
| 	} | ||||
| 	for _, param := range params { | ||||
| 		switch param.Position { | ||||
| 		case positionHeader: | ||||
| 			b.header = append(b.header, newParamInfo(param.Name, param.Value, param.Type, param.Conflict)) | ||||
| 		case positionQuery: | ||||
| 			b.query = append(b.query, newParamInfo(param.Name, param.Value, param.Type, param.Conflict)) | ||||
| 		case positionBody: | ||||
| 			b.body = append(b.body, newParamInfo(param.Name, param.Value, param.Type, param.Conflict)) | ||||
| 		} | ||||
| 	} | ||||
| 	return b | ||||
| } | ||||
|  | ||||
| func newParamInfo(name string, value []string, typ string, conflict string) *paramInfo { | ||||
| 	d := ¶mInfo{name: name, value: strings.Join(value, ","), conflict: conflict} | ||||
| 	valueLen := len(d.value) | ||||
| 	if strings.HasPrefix(typ, "$") { | ||||
| 		factory, has := dynamic_params.Get(typ) | ||||
| 		if has { | ||||
| 			driver, err := factory.Create(name, value) | ||||
| 			if err == nil { | ||||
| 				d.driver = driver | ||||
| 			} | ||||
| 		} | ||||
| 	} else if valueLen > 1 && d.value[0] == '$' { | ||||
| 		// 系统变量 | ||||
| 		d.systemValue = true | ||||
| 		d.value = d.value[1:valueLen] | ||||
| 	} | ||||
| 	return d | ||||
| } | ||||
|  | ||||
| type paramInfo struct { | ||||
| 	name        string | ||||
| 	systemValue bool | ||||
| 	value       string | ||||
| 	driver      dynamic_params.IDynamicDriver | ||||
| 	conflict    string | ||||
| } | ||||
|  | ||||
| func (b *paramInfo) Build(ctx http_service.IHttpContext, contentType string, params interface{}) (string, error) { | ||||
| 	if b.driver == nil { | ||||
| 		if b.systemValue { | ||||
| 			return ctx.GetLabel(b.value), nil | ||||
| 		} | ||||
| 		return b.value, nil | ||||
| 	} | ||||
| 	value, err := b.driver.Generate(ctx, contentType, params) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	switch v := value.(type) { | ||||
| 	case string: | ||||
| 		return v, nil | ||||
| 	case int, int32, int64: | ||||
| 		return fmt.Sprintf("%d", v), nil | ||||
| 	} | ||||
| 	return "", nil | ||||
| } | ||||
							
								
								
									
										60
									
								
								drivers/plugins/extra-params_v2/driver.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								drivers/plugins/extra-params_v2/driver.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| package extra_params_v2 | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
|  | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/datetime" | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/md5" | ||||
| 	"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/timestamp" | ||||
|  | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
|  | ||||
| 	"github.com/eolinker/eosc" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	once sync.Once | ||||
| ) | ||||
|  | ||||
| type Driver struct { | ||||
| 	profession string | ||||
| 	name       string | ||||
| 	label      string | ||||
| 	desc       string | ||||
| 	configType reflect.Type | ||||
| } | ||||
|  | ||||
| func Check(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error { | ||||
|  | ||||
| 	return conf.doCheck() | ||||
| } | ||||
|  | ||||
| func check(v interface{}) (*Config, error) { | ||||
| 	conf, ok := v.(*Config) | ||||
| 	if !ok { | ||||
| 		return nil, eosc.ErrorConfigType | ||||
| 	} | ||||
| 	err := conf.doCheck() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return conf, nil | ||||
| } | ||||
|  | ||||
| func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { | ||||
| 	once.Do(func() { | ||||
| 		datetime.Register() | ||||
| 		md5.Register() | ||||
| 		timestamp.Register() | ||||
| 	}) | ||||
| 	ep := &executor{ | ||||
| 		WorkerBase:      drivers.Worker(id, name), | ||||
| 		baseParam:       generateBaseParam(conf.Params), | ||||
| 		requestBodyType: conf.RequestBodyType, | ||||
| 		errorType:       conf.ErrorType, | ||||
| 	} | ||||
|  | ||||
| 	return ep, nil | ||||
| } | ||||
| @@ -0,0 +1,28 @@ | ||||
| package datetime | ||||
|  | ||||
| import ( | ||||
| 	"time" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	defaultValue = "2006-01-02 15:04:05" | ||||
| ) | ||||
|  | ||||
| type Datetime struct { | ||||
| 	name  string | ||||
| 	value string | ||||
| } | ||||
|  | ||||
| func NewDatetime(name string, value string) *Datetime { | ||||
| 	return &Datetime{name: name, value: value} | ||||
| } | ||||
|  | ||||
| func (d *Datetime) Name() string { | ||||
| 	return d.name | ||||
| } | ||||
|  | ||||
| func (d *Datetime) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) { | ||||
| 	return time.Now().Format(d.value), nil | ||||
| } | ||||
| @@ -0,0 +1,24 @@ | ||||
| package datetime | ||||
|  | ||||
| import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params" | ||||
|  | ||||
| const name = "$datetime" | ||||
|  | ||||
| func Register() { | ||||
| 	dynamic_params.Register(name, NewFactory()) | ||||
| } | ||||
|  | ||||
| func NewFactory() *Factory { | ||||
| 	return &Factory{} | ||||
| } | ||||
|  | ||||
| type Factory struct { | ||||
| } | ||||
|  | ||||
| func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) { | ||||
| 	v := defaultValue | ||||
| 	if len(value) > 0 { | ||||
| 		v = value[0] | ||||
| 	} | ||||
| 	return NewDatetime(name, v), nil | ||||
| } | ||||
							
								
								
									
										14
									
								
								drivers/plugins/extra-params_v2/dynamic-params/dynamic.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								drivers/plugins/extra-params_v2/dynamic-params/dynamic.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| package dynamic_params | ||||
|  | ||||
| import ( | ||||
| 	"github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| type IDynamicFactory interface { | ||||
| 	Create(name string, value []string) (IDynamicDriver, error) | ||||
| } | ||||
|  | ||||
| type IDynamicDriver interface { | ||||
| 	Name() string | ||||
| 	Generate(ctx http_context.IHttpContext, contentType string, args ...interface{}) (interface{}, error) | ||||
| } | ||||
							
								
								
									
										95
									
								
								drivers/plugins/extra-params_v2/dynamic-params/factory.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								drivers/plugins/extra-params_v2/dynamic-params/factory.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,95 @@ | ||||
| package dynamic_params | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/eolinker/eosc/log" | ||||
|  | ||||
| 	"github.com/eolinker/eosc" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	ErrorInvalidBalance    = errors.New("invalid balance") | ||||
| 	defaultFactoryRegister = newFactoryManager() | ||||
| ) | ||||
|  | ||||
| // IFactoryRegister 实现了负载均衡算法工厂管理器 | ||||
| type IFactoryRegister interface { | ||||
| 	RegisterFactoryByKey(key string, factory IDynamicFactory) | ||||
| 	GetFactoryByKey(key string) (IDynamicFactory, bool) | ||||
| 	Keys() []string | ||||
| } | ||||
|  | ||||
| // driverRegister 实现了IBalanceFactoryRegister接口 | ||||
| type driverRegister struct { | ||||
| 	register eosc.IRegister[IDynamicFactory] | ||||
| 	keys     []string | ||||
| } | ||||
|  | ||||
| // newFactoryManager 创建负载均衡算法工厂管理器 | ||||
| func newFactoryManager() IFactoryRegister { | ||||
| 	return &driverRegister{ | ||||
| 		register: eosc.NewRegister[IDynamicFactory](), | ||||
| 		keys:     make([]string, 0, 10), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // GetFactoryByKey 获取指定balance工厂 | ||||
| func (dm *driverRegister) GetFactoryByKey(key string) (IDynamicFactory, bool) { | ||||
| 	o, has := dm.register.Get(key) | ||||
| 	if has { | ||||
| 		log.Debug("GetFactoryByKey:", key, ":has") | ||||
| 		f, ok := o.(IDynamicFactory) | ||||
| 		return f, ok | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
|  | ||||
| // RegisterFactoryByKey 注册balance工厂 | ||||
| func (dm *driverRegister) RegisterFactoryByKey(key string, factory IDynamicFactory) { | ||||
| 	err := dm.register.Register(key, factory, true) | ||||
| 	if err != nil { | ||||
| 		log.Debug("RegisterFactoryByKey:", key, ":", err) | ||||
| 		return | ||||
| 	} | ||||
| 	dm.keys = append(dm.keys, key) | ||||
| } | ||||
|  | ||||
| // Keys 返回所有已注册的key | ||||
| func (dm *driverRegister) Keys() []string { | ||||
| 	return dm.keys | ||||
| } | ||||
|  | ||||
| // Register 注册balance工厂到默认balanceFactory注册器 | ||||
| func Register(key string, factory IDynamicFactory) { | ||||
|  | ||||
| 	defaultFactoryRegister.RegisterFactoryByKey(key, factory) | ||||
| } | ||||
|  | ||||
| // Get 从默认balanceFactory注册器中获取balance工厂 | ||||
| func Get(key string) (IDynamicFactory, bool) { | ||||
| 	return defaultFactoryRegister.GetFactoryByKey(key) | ||||
| } | ||||
|  | ||||
| // Keys 返回默认的balanceFactory注册器中所有已注册的key | ||||
| func Keys() []string { | ||||
| 	return defaultFactoryRegister.Keys() | ||||
| } | ||||
|  | ||||
| // GetFactory 获取指定负载均衡算法工厂,若指定的不存在则返回一个已注册的工厂 | ||||
| func GetFactory(name string) (IDynamicFactory, error) { | ||||
| 	factory, ok := Get(name) | ||||
| 	if !ok { | ||||
| 		for _, key := range Keys() { | ||||
| 			factory, ok = Get(key) | ||||
| 			if ok { | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		if factory == nil { | ||||
| 			return nil, fmt.Errorf("%s:%w", name, ErrorInvalidBalance) | ||||
| 		} | ||||
| 	} | ||||
| 	return factory, nil | ||||
| } | ||||
| @@ -0,0 +1,20 @@ | ||||
| package md5 | ||||
|  | ||||
| import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params" | ||||
|  | ||||
| const name = "$md5" | ||||
|  | ||||
| func Register() { | ||||
| 	dynamic_params.Register(name, NewFactory()) | ||||
| } | ||||
|  | ||||
| func NewFactory() *Factory { | ||||
| 	return &Factory{} | ||||
| } | ||||
|  | ||||
| type Factory struct { | ||||
| } | ||||
|  | ||||
| func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) { | ||||
| 	return NewMD5(name, value), nil | ||||
| } | ||||
							
								
								
									
										178
									
								
								drivers/plugins/extra-params_v2/dynamic-params/md5/md5.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								drivers/plugins/extra-params_v2/dynamic-params/md5/md5.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| package md5 | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/oj" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/jp" | ||||
|  | ||||
| 	"github.com/eolinker/apinto/utils" | ||||
|  | ||||
| 	"github.com/eolinker/eosc/log" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	positionCurrent = iota | ||||
| 	positionHeader | ||||
| 	positionQuery | ||||
| 	// body | ||||
| 	positionBody | ||||
| 	positionSystem | ||||
| ) | ||||
|  | ||||
| type MD5 struct { | ||||
| 	name  string | ||||
| 	value []*Value | ||||
| } | ||||
|  | ||||
| type Value struct { | ||||
| 	key      string | ||||
| 	position int | ||||
| 	optional bool | ||||
| } | ||||
|  | ||||
| func NewMD5(name string, value []string) *MD5 { | ||||
| 	vs := make([]*Value, 0, len(value)) | ||||
| 	for _, v := range value { | ||||
| 		v = strings.TrimSpace(v) | ||||
| 		vLen := len(v) | ||||
| 		if vLen > 0 { | ||||
| 			if v[0] == '{' && v[vLen-1] == '}' { | ||||
| 				vars := strings.Split(v[1:vLen-1], ".") | ||||
| 				position := positionBody | ||||
| 				variable := vars[0] | ||||
| 				if len(vars) > 1 { | ||||
| 					variable = vars[1] | ||||
| 					switch vars[0] { | ||||
| 					case "header": | ||||
| 						position = positionHeader | ||||
| 					case "query": | ||||
| 						position = positionQuery | ||||
| 					} | ||||
| 				} | ||||
| 				vs = append(vs, &Value{ | ||||
| 					key:      variable, | ||||
| 					position: position, | ||||
| 				}) | ||||
| 			} else if v[0] == '#' { | ||||
| 				vars := strings.Split(v[1:], ".") | ||||
| 				position := positionBody | ||||
| 				variable := vars[0] | ||||
| 				if len(vars) > 1 { | ||||
| 					variable = vars[1] | ||||
| 					switch vars[0] { | ||||
| 					case "header": | ||||
| 						position = positionHeader | ||||
| 					case "query": | ||||
| 						position = positionQuery | ||||
| 					} | ||||
| 				} | ||||
| 				vs = append(vs, &Value{ | ||||
| 					key:      variable, | ||||
| 					position: position, | ||||
| 					optional: true, | ||||
| 				}) | ||||
| 			} else if vLen > 3 && v[0] == '$' && v[1] == '{' && v[vLen-1] == '}' { | ||||
| 				// 使用系统变量 | ||||
| 				vs = append(vs, &Value{ | ||||
| 					key:      v[2 : vLen-1], | ||||
| 					position: positionSystem, | ||||
| 				}) | ||||
| 			} else { | ||||
| 				vs = append(vs, &Value{ | ||||
| 					key: v, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return &MD5{ | ||||
| 		name:  name, | ||||
| 		value: vs, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (m *MD5) Name() string { | ||||
| 	return m.name | ||||
| } | ||||
|  | ||||
| func (m *MD5) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) { | ||||
| 	builder := strings.Builder{} | ||||
| 	var params interface{} | ||||
| 	if contentType == "application/json" { | ||||
| 		if len(args) < 1 { | ||||
| 			return nil, errors.New("missing args") | ||||
| 		} | ||||
| 		params = args[0] | ||||
| 	} | ||||
| 	for _, v := range m.value { | ||||
| 		if v.key == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		builder.WriteString(retrieveParam(ctx, contentType, params, v)) | ||||
| 	} | ||||
| 	log.Debug("md5 result: ", builder.String()) | ||||
| 	if strings.HasPrefix(m.name, "__") { | ||||
| 		return utils.Md5(builder.String()), nil | ||||
| 	} | ||||
| 	return strings.ToUpper(utils.Md5(builder.String())), nil | ||||
| } | ||||
|  | ||||
| func retrieveParam(ctx http_service.IHttpContext, contentType string, body interface{}, value *Value) string { | ||||
| 	switch value.position { | ||||
| 	case positionCurrent: | ||||
| 		return value.key | ||||
| 	case positionHeader: | ||||
| 		return ctx.Proxy().Header().Headers().Get(value.key) | ||||
| 	case positionQuery: | ||||
| 		return ctx.Proxy().URI().GetQuery(value.key) | ||||
| 	case positionBody: | ||||
|  | ||||
| 		if contentType == "application/x-www-form-urlencoded" { | ||||
| 			if !value.optional { | ||||
| 				return ctx.Proxy().Body().GetForm(value.key) | ||||
| 			} | ||||
| 			form, _ := ctx.Proxy().Body().BodyForm() | ||||
| 			if _, ok := form[value.key]; ok { | ||||
| 				return value.key | ||||
| 			} | ||||
| 		} else if contentType == "application/json" { | ||||
| 			key := value.key | ||||
| 			if !strings.HasPrefix(key, "$.") { | ||||
| 				key = "$." + key | ||||
| 			} | ||||
|  | ||||
| 			x, err := jp.ParseString(key) | ||||
| 			if err != nil { | ||||
| 				log.Errorf("parse json path(%s) error: %v", key, err) | ||||
| 				return "" | ||||
| 			} | ||||
| 			result := x.Get(body) | ||||
|  | ||||
| 			if len(result) > 0 { | ||||
| 				if value.optional { | ||||
| 					return value.key | ||||
| 				} | ||||
|  | ||||
| 				switch r := result[0].(type) { | ||||
| 				case string: | ||||
| 					return r | ||||
| 				case float32, float64: | ||||
| 					return fmt.Sprintf("%.f", r) | ||||
| 				case bool: | ||||
| 					return strconv.FormatBool(r) | ||||
| 				default: | ||||
| 					return oj.JSON(r) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	case positionSystem: | ||||
| 		return ctx.GetLabel(value.key) | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| @@ -0,0 +1,24 @@ | ||||
| package timestamp | ||||
|  | ||||
| import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params" | ||||
|  | ||||
| const name = "$timestamp" | ||||
|  | ||||
| func Register() { | ||||
| 	dynamic_params.Register(name, NewFactory()) | ||||
| } | ||||
|  | ||||
| func NewFactory() *Factory { | ||||
| 	return &Factory{} | ||||
| } | ||||
|  | ||||
| type Factory struct { | ||||
| } | ||||
|  | ||||
| func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) { | ||||
| 	v := defaultValue | ||||
| 	if len(value) > 0 { | ||||
| 		v = value[0] | ||||
| 	} | ||||
| 	return NewTimestamp(name, v), nil | ||||
| } | ||||
| @@ -0,0 +1,35 @@ | ||||
| package timestamp | ||||
|  | ||||
| import ( | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	defaultValue = "string" | ||||
| ) | ||||
|  | ||||
| type Timestamp struct { | ||||
| 	name  string | ||||
| 	value string | ||||
| } | ||||
|  | ||||
| func NewTimestamp(name, value string) *Timestamp { | ||||
| 	return &Timestamp{name: name, value: value} | ||||
| } | ||||
|  | ||||
| func (t *Timestamp) Name() string { | ||||
| 	return t.name | ||||
| } | ||||
|  | ||||
| func (t *Timestamp) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) { | ||||
| 	switch t.value { | ||||
| 	case "string": | ||||
| 		return strconv.FormatInt(time.Now().Unix(), 10), nil | ||||
| 	case "int": | ||||
| 		return time.Now().Unix(), nil | ||||
| 	} | ||||
| 	return strconv.FormatInt(time.Now().Unix(), 10), nil | ||||
| } | ||||
							
								
								
									
										212
									
								
								drivers/plugins/extra-params_v2/executor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								drivers/plugins/extra-params_v2/executor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,212 @@ | ||||
| package extra_params_v2 | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"mime" | ||||
| 	"net/http" | ||||
| 	"net/textproto" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/oj" | ||||
|  | ||||
| 	"github.com/eolinker/eosc/log" | ||||
|  | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
|  | ||||
| 	"github.com/ohler55/ojg/jp" | ||||
|  | ||||
| 	"github.com/eolinker/eosc" | ||||
| 	"github.com/eolinker/eosc/eocontext" | ||||
| 	http_service "github.com/eolinker/eosc/eocontext/http-context" | ||||
| ) | ||||
|  | ||||
| var _ http_service.HttpFilter = (*executor)(nil) | ||||
| var _ eocontext.IFilter = (*executor)(nil) | ||||
|  | ||||
| var ( | ||||
| 	errorExist = "%s: %s is already exists" | ||||
| ) | ||||
|  | ||||
| type executor struct { | ||||
| 	drivers.WorkerBase | ||||
| 	baseParam       *baseParam | ||||
| 	requestBodyType string | ||||
| 	errorType       string | ||||
| } | ||||
|  | ||||
| func (e *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) { | ||||
| 	return http_service.DoHttpFilter(e, ctx, next) | ||||
| } | ||||
|  | ||||
| func (e *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error { | ||||
| 	statusCode, err := e.access(ctx) | ||||
| 	if err != nil { | ||||
| 		ctx.Response().SetBody([]byte(err.Error())) | ||||
| 		ctx.Response().SetStatus(statusCode, strconv.Itoa(statusCode)) | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if next != nil { | ||||
| 		return next.DoChain(ctx) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func addParamToBody(ctx http_service.IHttpContext, contentType string, params []*paramInfo) (interface{}, error) { | ||||
|  | ||||
| 	//var bodyParam map[string]interface{} | ||||
| 	if contentType == "application/json" { | ||||
| 		body, _ := ctx.Proxy().Body().RawBody() | ||||
| 		if string(body) == "" { | ||||
| 			body = []byte("{}") | ||||
| 		} | ||||
| 		bodyParam, err := oj.Parse(body) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		for _, param := range params { | ||||
| 			key := param.name | ||||
| 			if !strings.HasPrefix(param.name, "$.") { | ||||
| 				key = "$." + key | ||||
| 			} | ||||
|  | ||||
| 			x, err := jp.ParseString(key) | ||||
| 			if err != nil { | ||||
| 				return nil, fmt.Errorf("parse key error: %v", err) | ||||
| 			} | ||||
| 			result := x.Get(bodyParam) | ||||
| 			if len(result) > 0 { | ||||
| 				if param.conflict == paramError { | ||||
| 					return nil, fmt.Errorf(errorExist, positionBody, param.name) | ||||
| 				} else if param.conflict == paramOrigin { | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
| 			value, err := param.Build(ctx, contentType, bodyParam) | ||||
| 			if err != nil { | ||||
| 				log.Errorf("build param(s) error: %v", key, err) | ||||
| 				continue | ||||
| 			} | ||||
| 			err = x.Set(bodyParam, value) | ||||
| 			if err != nil { | ||||
| 				log.Errorf("set param(s) error: %v", key, err) | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		b, _ := oj.Marshal(bodyParam) | ||||
| 		ctx.Proxy().Body().SetRaw(contentType, b) | ||||
| 		return bodyParam, nil | ||||
| 	} else if contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data" { | ||||
| 		bodyParam := make(map[string]interface{}) | ||||
| 		bodyForm, _ := ctx.Proxy().Body().BodyForm() | ||||
| 		for _, param := range params { | ||||
| 			_, has := bodyParam[param.name] | ||||
| 			if has { | ||||
| 				if param.conflict == paramError { | ||||
| 					return nil, fmt.Errorf("[extra_params] body(%s) has a conflict", param.name) | ||||
| 				} else if param.conflict == paramOrigin { | ||||
| 					continue | ||||
| 				} | ||||
| 			} | ||||
| 			value, err := param.Build(ctx, contentType, nil) | ||||
| 			if err != nil { | ||||
| 				log.Errorf("build param(s) error: %v", param.name, err) | ||||
| 				continue | ||||
| 			} | ||||
| 			bodyParam[param.name] = value | ||||
|  | ||||
| 		} | ||||
| 		ctx.Proxy().Body().SetForm(bodyForm) | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| func (e *executor) access(ctx http_service.IHttpContext) (int, error) { | ||||
| 	// 判断请求携带的content-type | ||||
| 	contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType()) | ||||
| 	var bodyParam interface{} | ||||
| 	var err error | ||||
| 	if ctx.Proxy().Method() == http.MethodPost || ctx.Proxy().Method() == http.MethodPut || ctx.Proxy().Method() == http.MethodPatch { | ||||
| 		if e.requestBodyType != "" { | ||||
| 			if e.requestBodyType == "json" && contentType != "application/json" { | ||||
| 				return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] request body type is not json`, clientErrStatusCode) | ||||
| 			} else if e.requestBodyType == "form-data" && contentType != "multipart/form-data" { | ||||
| 				return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] request body type is not form-data`, clientErrStatusCode) | ||||
| 			} | ||||
| 		} | ||||
| 		bodyParam, err = addParamToBody(ctx, contentType, e.baseParam.body) | ||||
| 		if err != nil { | ||||
| 			return clientErrStatusCode, encodeErr(e.errorType, err.Error(), clientErrStatusCode) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 处理Query参数 | ||||
| 	for _, param := range e.baseParam.query { | ||||
| 		v := ctx.Proxy().URI().GetQuery(param.name) | ||||
| 		if v != "" { | ||||
| 			if param.conflict == paramError { | ||||
| 				return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] query("`+param.name+`") has a conflict.`, clientErrStatusCode) | ||||
| 			} else if param.conflict == paramOrigin { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 		value, err := param.Build(ctx, contentType, bodyParam) | ||||
| 		if err != nil { | ||||
| 			log.Errorf("build query extra param(%s) error: %s", param.name, err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 		ctx.Proxy().URI().SetQuery(param.name, value) | ||||
| 	} | ||||
|  | ||||
| 	// 处理Header参数 | ||||
| 	for _, param := range e.baseParam.header { | ||||
| 		name := textproto.CanonicalMIMEHeaderKey(param.name) | ||||
| 		_, has := ctx.Proxy().Header().Headers()[name] | ||||
| 		if has { | ||||
| 			if param.conflict == paramError { | ||||
| 				return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] header("`+name+`") has a conflict.`, clientErrStatusCode) | ||||
| 			} else if param.conflict == paramOrigin { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 		value, err := param.Build(ctx, contentType, bodyParam) | ||||
| 		if err != nil { | ||||
| 			log.Errorf("build header extra param(%s) error: %s", name, err.Error()) | ||||
| 			continue | ||||
| 		} | ||||
| 		ctx.Proxy().Header().SetHeader(param.name, value) | ||||
| 	} | ||||
| 	return successStatusCode, nil | ||||
| } | ||||
|  | ||||
| func (e *executor) Start() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { | ||||
| 	cfg, err := check(conf) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	e.baseParam = generateBaseParam(cfg.Params) | ||||
| 	e.requestBodyType = cfg.RequestBodyType | ||||
| 	e.errorType = cfg.ErrorType | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (e *executor) Stop() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (e *executor) Destroy() { | ||||
| 	e.baseParam = nil | ||||
| 	e.errorType = "" | ||||
| } | ||||
|  | ||||
| func (e *executor) CheckSkill(skill string) bool { | ||||
| 	return http_service.FilterSkillName == skill | ||||
| } | ||||
							
								
								
									
										17
									
								
								drivers/plugins/extra-params_v2/factory.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								drivers/plugins/extra-params_v2/factory.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| package extra_params_v2 | ||||
|  | ||||
| import ( | ||||
| 	"github.com/eolinker/apinto/drivers" | ||||
| 	"github.com/eolinker/eosc" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	Name = "extra_params_v2" | ||||
| ) | ||||
|  | ||||
| func Register(register eosc.IExtenderDriverRegister) { | ||||
| 	register.RegisterExtenderDriver(Name, NewFactory()) | ||||
| } | ||||
| func NewFactory() eosc.IExtenderDriverFactory { | ||||
| 	return drivers.NewFactory[Config](Create, Check) | ||||
| } | ||||
							
								
								
									
										210
									
								
								drivers/plugins/extra-params_v2/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										210
									
								
								drivers/plugins/extra-params_v2/util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,210 @@ | ||||
| package extra_params_v2 | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	positionQuery  = "query" | ||||
| 	positionHeader = "header" | ||||
| 	positionBody   = "body" | ||||
|  | ||||
| 	paramConvert string = "convert" | ||||
| 	paramError   string = "error" | ||||
| 	paramOrigin  string = "origin" | ||||
|  | ||||
| 	clientErrStatusCode = 400 | ||||
| 	successStatusCode   = 200 | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	paramPositionErrInfo = `[plugin extra-params config err] param position must be in the set ["query","header",body]. err position: %s ` | ||||
| 	//parseBodyErrInfo     = `[extra_params] Fail to parse body! [err]: %s` | ||||
| 	paramNameErrInfo = `[plugin extra-params config err] param name must be not null. ` | ||||
| ) | ||||
|  | ||||
| func encodeErr(ent string, origin string, statusCode int) error { | ||||
| 	if ent == "json" { | ||||
| 		tmp := map[string]interface{}{ | ||||
| 			"message":     origin, | ||||
| 			"status_code": statusCode, | ||||
| 		} | ||||
| 		info, _ := json.Marshal(tmp) | ||||
| 		return fmt.Errorf("%s", info) | ||||
| 	} | ||||
| 	return fmt.Errorf("%s statusCode: %d", origin, statusCode) | ||||
| } | ||||
|  | ||||
| //func parseBodyParams(ctx http_service.IHttpContext) (interface{}, url.Values, error) { | ||||
| //	if ctx.Proxy().Method() != http.MethodPost && ctx.Proxy().Method() != http.MethodPut && ctx.Proxy().Method() != http.MethodPatch { | ||||
| //		return nil, nil, nil | ||||
| //	} | ||||
| //	contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType()) | ||||
| //	switch contentType { | ||||
| //	case http_context.FormData, http_context.MultipartForm: | ||||
| //		formParams, err := ctx.Proxy().Body().BodyForm() | ||||
| //		if err != nil { | ||||
| //			return nil, nil, err | ||||
| //		} | ||||
| //		return nil, formParams, nil | ||||
| //	case http_context.JSON: | ||||
| //		body, err := ctx.Proxy().Body().RawBody() | ||||
| //		if err != nil { | ||||
| //			return nil, nil, err | ||||
| //		} | ||||
| //		if string(body) == "" { | ||||
| //			body = []byte("{}") | ||||
| //		} | ||||
| //		bodyParams, err := oj.Parse(body) | ||||
| //		return bodyParams, nil, err | ||||
| //	} | ||||
| //	return nil, nil, errors.New("unsupported content-type: " + contentType) | ||||
| //} | ||||
|  | ||||
| // | ||||
| //func parseBodyParams(ctx http_service.IHttpContext) (map[string]interface{}, map[string][]string, error) { | ||||
| //	contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType()) | ||||
| // | ||||
| //	switch contentType { | ||||
| //	case http_context.FormData, http_context.MultipartForm: | ||||
| //		formParams, err := ctx.Proxy().Body().BodyForm() | ||||
| //		if err != nil { | ||||
| //			return nil, nil, err | ||||
| //		} | ||||
| //		return nil, formParams, nil | ||||
| //	case http_context.JSON: | ||||
| //		body, err := ctx.Proxy().Body().RawBody() | ||||
| //		if err != nil { | ||||
| //			return nil, nil, err | ||||
| //		} | ||||
| //		var bodyParams map[string]interface{} | ||||
| //		err = json.Unmarshal(body, &bodyParams) | ||||
| //		if err != nil { | ||||
| //			return bodyParams, nil, err | ||||
| //		} | ||||
| //	} | ||||
| //	return nil, nil, errors.New("[params_transformer] unsupported content-type: " + contentType) | ||||
| //} | ||||
|  | ||||
| //func getHeaderValue(headers map[string][]string, param *ExtraParam, value string) (string, error) { | ||||
| //	paramName := ConvertHeaderKey(param.Name) | ||||
| // | ||||
| //	if param.Conflict == "" { | ||||
| //		param.Conflict = paramConvert | ||||
| //	} | ||||
| // | ||||
| //	var paramValue string | ||||
| // | ||||
| //	if _, ok := headers[paramName]; !ok { | ||||
| //		param.Conflict = paramConvert | ||||
| //	} else { | ||||
| //		paramValue = headers[paramName][0] | ||||
| //	} | ||||
| // | ||||
| //	if param.Conflict == paramConvert { | ||||
| //		paramValue = value | ||||
| //	} else if param.Conflict == paramError { | ||||
| //		errInfo := `[extra_params] "` + param.Name + `" has a conflict.` | ||||
| //		return "", errors.New(errInfo) | ||||
| //	} | ||||
| // | ||||
| //	return paramValue, nil | ||||
| //} | ||||
|  | ||||
| //func hasQueryValue(rawQuery string, paramName string) bool { | ||||
| //	bytes := []byte(rawQuery) | ||||
| //	if len(bytes) == 0 { | ||||
| //		return false | ||||
| //	} | ||||
| // | ||||
| //	k := 0 | ||||
| //	for i, c := range bytes { | ||||
| //		switch c { | ||||
| //		case '=': | ||||
| //			key := string(bytes[k:i]) | ||||
| //			if key == paramName { | ||||
| //				return true | ||||
| //			} | ||||
| //		case '&': | ||||
| //			k = i + 1 | ||||
| //		} | ||||
| //	} | ||||
| // | ||||
| //	return false | ||||
| //} | ||||
|  | ||||
| //func getQueryValue(ctx http_service.IHttpContext, param *ExtraParam, value string) (string, error) { | ||||
| //	paramValue := "" | ||||
| //	if param.Conflict == "" { | ||||
| //		param.Conflict = paramConvert | ||||
| //	} | ||||
| // | ||||
| //	//判断请求中是否包含对应的query参数 | ||||
| //	if !hasQueryValue(ctx.Proxy().URI().RawQuery(), param.Name) { | ||||
| //		param.Conflict = paramConvert | ||||
| //	} else { | ||||
| //		paramValue = ctx.Proxy().URI().GetQuery(param.Name) | ||||
| //	} | ||||
| // | ||||
| //	if param.Conflict == paramConvert { | ||||
| //		paramValue = value | ||||
| //	} else if param.Conflict == paramError { | ||||
| //		errInfo := `[extra_params] "` + param.Name + `" has a conflict.` | ||||
| //		return "", errors.New(errInfo) | ||||
| //	} | ||||
| // | ||||
| //	return paramValue, nil | ||||
| //} | ||||
| // | ||||
| //func getBodyValue(bodyParams map[string]interface{}, formParams map[string][]string, param *ExtraParam, contentType string, value interface{}) (interface{}, error) { | ||||
| //	var paramValue interface{} = nil | ||||
| //	Conflict := param.Conflict | ||||
| //	if Conflict == "" { | ||||
| //		Conflict = paramConvert | ||||
| //	} | ||||
| //	if strings.Contains(contentType, http_context.FormData) || strings.Contains(contentType, http_context.MultipartForm) { | ||||
| //		if _, ok := formParams[param.Name]; !ok { | ||||
| //			Conflict = paramConvert | ||||
| //		} else { | ||||
| //			paramValue = formParams[param.Name][0] | ||||
| //		} | ||||
| //	} else if strings.Contains(contentType, http_context.JSON) { | ||||
| //		if _, ok := bodyParams[param.Name]; !ok { | ||||
| //			param.Conflict = paramConvert | ||||
| //		} else { | ||||
| //			paramValue = bodyParams[param.Name] | ||||
| //		} | ||||
| //	} | ||||
| //	if Conflict == paramConvert { | ||||
| //		paramValue = value | ||||
| //	} else if Conflict == paramError { | ||||
| //		errInfo := `[extra_params] "` + param.Name + `" has a conflict.` | ||||
| //		return "", errors.New(errInfo) | ||||
| //	} | ||||
| // | ||||
| //	return paramValue, nil | ||||
| //} | ||||
|  | ||||
| //func ConvertHeaderKey(header string) string { | ||||
| //	header = strings.ToLower(header) | ||||
| //	headerArray := strings.Split(header, "-") | ||||
| //	h := "" | ||||
| //	arrLen := len(headerArray) | ||||
| //	for i, value := range headerArray { | ||||
| //		vLen := len(value) | ||||
| //		if vLen < 1 { | ||||
| //			continue | ||||
| //		} else { | ||||
| //			if vLen == 1 { | ||||
| //				h += strings.ToUpper(value) | ||||
| //			} else { | ||||
| //				h += strings.ToUpper(string(value[0])) + value[1:] | ||||
| //			} | ||||
| //			if i != arrLen-1 { | ||||
| //				h += "-" | ||||
| //			} | ||||
| //		} | ||||
| //	} | ||||
| //	return h | ||||
| //} | ||||
							
								
								
									
										44
									
								
								node/fasthttp-client/client_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								node/fasthttp-client/client_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| package fasthttp_client | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/valyala/fasthttp" | ||||
| ) | ||||
|  | ||||
| func TestProxyTimeout(t *testing.T) { | ||||
| 	//addr := "https://gw.kuaidizs.cn" | ||||
| 	addr := fmt.Sprintf("%s://%s", "https", "gw.kuaidizs.cn") | ||||
| 	req := fasthttp.AcquireRequest() | ||||
| 	resp := fasthttp.AcquireResponse() | ||||
| 	req.URI().SetPath("/open/api") | ||||
| 	req.URI().SetHost("gw.kuaidizs.cn") | ||||
| 	req.Header.SetMethod("POST") | ||||
| 	req.Header.SetContentType("application/json") | ||||
| 	req.SetBody([]byte(`{"cpCode":"YTO","version":"1.0","timestamp":"2023-08-10 11:57:13","province":"广东省","city":"广州市","appKey":"DBB812347A1E44829159FE82F5C4303E","format":"json","sign_method":"md5","method":"kdzs.address.reachable","sign":"10A4B5A59340F9B98DAFA3CFCCF65449"}`)) | ||||
| 	err := defaultClient.ProxyTimeout(addr, req, resp, 0) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	t.Log(string(resp.Body())) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func TestMyselfProxyTimeout(t *testing.T) { | ||||
| 	//addr := "https://gw.kuaidizs.cn" | ||||
| 	addr := "http://127.0.0.1:8099" | ||||
| 	req := fasthttp.AcquireRequest() | ||||
| 	resp := fasthttp.AcquireResponse() | ||||
| 	req.URI().SetPath("/open/api") | ||||
| 	req.URI().SetHost("127.0.0.1:8099") | ||||
| 	req.Header.SetMethod("POST") | ||||
| 	req.Header.SetContentType("application/json") | ||||
| 	req.SetBody([]byte(`{"cpCode":"YTO","province":"广东省","city":"广州市"}`)) | ||||
| 	err := defaultClient.ProxyTimeout(addr, req, resp, 0) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	t.Log(string(resp.Body())) | ||||
| 	return | ||||
| } | ||||
| @@ -46,14 +46,15 @@ var ( | ||||
| func (r *ProxyRequest) reset(request *fasthttp.Request, remoteAddr string) { | ||||
|  | ||||
| 	r.RequestReader.reset(request, remoteAddr) | ||||
|  | ||||
| 	forwardedFor := r.req.Header.PeekBytes(xforwardedforKey) | ||||
| 	if len(forwardedFor) > 0 { | ||||
| 		r.req.Header.Set("x-forwarded-for", fmt.Sprint(string(forwardedFor), ",", remoteAddr)) | ||||
| 		r.req.Header.Set("x-forwarded-for", fmt.Sprint(string(forwardedFor), ",", r.remoteAddr)) | ||||
| 	} else { | ||||
| 		r.req.Header.Set("x-forwarded-for", remoteAddr) | ||||
| 		r.req.Header.Set("x-forwarded-for", r.remoteAddr) | ||||
| 	} | ||||
|  | ||||
| 	r.req.Header.Set("x-real-ip", r.RequestReader.realIP) | ||||
| 	r.req.Header.Set("x-real-ip", r.realIP) | ||||
|  | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Liujian
					Liujian