mirror of
https://github.com/eolinker/apinto
synced 2025-10-23 00:39:46 +08:00
额外参数v2插件、请求体限制插件完成
This commit is contained in:
@@ -18,11 +18,13 @@ import (
|
||||
plugin_manager "github.com/eolinker/apinto/drivers/plugin-manager"
|
||||
access_log "github.com/eolinker/apinto/drivers/plugins/access-log"
|
||||
plugin_app "github.com/eolinker/apinto/drivers/plugins/app"
|
||||
body_check "github.com/eolinker/apinto/drivers/plugins/body-check"
|
||||
circuit_breaker "github.com/eolinker/apinto/drivers/plugins/circuit-breaker"
|
||||
"github.com/eolinker/apinto/drivers/plugins/cors"
|
||||
dubbo2_proxy_rewrite "github.com/eolinker/apinto/drivers/plugins/dubbo2-proxy-rewrite"
|
||||
dubbo2_to_http "github.com/eolinker/apinto/drivers/plugins/dubbo2-to-http"
|
||||
extra_params "github.com/eolinker/apinto/drivers/plugins/extra-params"
|
||||
extra_params_v2 "github.com/eolinker/apinto/drivers/plugins/extra-params_v2"
|
||||
grpc_to_http "github.com/eolinker/apinto/drivers/plugins/gRPC-to-http"
|
||||
grpc_proxy_rewrite "github.com/eolinker/apinto/drivers/plugins/grpc-proxy-rewrite"
|
||||
"github.com/eolinker/apinto/drivers/plugins/gzip"
|
||||
@@ -69,6 +71,7 @@ func ProcessWorker() {
|
||||
func registerInnerExtenders() {
|
||||
extends.AddInnerExtendProject("eolinker.com", "apinto", Register)
|
||||
}
|
||||
|
||||
func Register(extenderRegister eosc.IExtenderDriverRegister) {
|
||||
// router
|
||||
http_router.Register(extenderRegister)
|
||||
@@ -153,4 +156,7 @@ func Register(extenderRegister eosc.IExtenderDriverRegister) {
|
||||
|
||||
proxy_mirror.Register(extenderRegister)
|
||||
http_mocking.Register(extenderRegister)
|
||||
|
||||
body_check.Register(extenderRegister)
|
||||
extra_params_v2.Register(extenderRegister)
|
||||
}
|
||||
|
@@ -49,9 +49,9 @@ func (h *HttpOutput) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if h.config != nil && !h.config.isConfUpdate(config) {
|
||||
return nil
|
||||
}
|
||||
//if h.config != nil && !h.config.isConfUpdate(config) {
|
||||
// return nil
|
||||
//}
|
||||
h.config = config
|
||||
|
||||
if h.running {
|
||||
|
73
drivers/plugins/body-check/body-check.go
Normal file
73
drivers/plugins/body-check/body-check.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package body_check
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
var _ http_service.HttpFilter = (*BodyCheck)(nil)
|
||||
var _ eocontext.IFilter = (*BodyCheck)(nil)
|
||||
|
||||
type BodyCheck struct {
|
||||
drivers.WorkerBase
|
||||
isEmpty bool
|
||||
allowedPayloadSize int
|
||||
}
|
||||
|
||||
func (b *BodyCheck) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
|
||||
return http_service.DoHttpFilter(b, ctx, next)
|
||||
}
|
||||
|
||||
func (b *BodyCheck) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error {
|
||||
if ctx.Request().Method() == http.MethodPost || ctx.Request().Method() == http.MethodPut || ctx.Request().Method() == http.MethodPatch {
|
||||
body, err := ctx.Request().Body().RawBody()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bodySize := len([]rune(string(body)))
|
||||
if !b.isEmpty && bodySize < 1 {
|
||||
ctx.Response().SetStatus(500, "Internal Server Error")
|
||||
ctx.Response().SetBody([]byte("请求体不能为空,请检查body的参数情况"))
|
||||
return errors.New("Body is required")
|
||||
}
|
||||
if b.allowedPayloadSize > 0 && bodySize > b.allowedPayloadSize {
|
||||
ctx.Response().SetStatus(500, "Internal Server Error")
|
||||
ctx.Response().SetBody([]byte("请求体超出长度限制"))
|
||||
return errors.New("The request entity is too large")
|
||||
}
|
||||
}
|
||||
|
||||
return next.DoChain(ctx)
|
||||
}
|
||||
|
||||
func (b *BodyCheck) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BodyCheck) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
|
||||
cfg, ok := conf.(*Config)
|
||||
if !ok {
|
||||
return errors.New("invalid config")
|
||||
}
|
||||
b.isEmpty = cfg.IsEmpty
|
||||
b.allowedPayloadSize = cfg.AllowedPayloadSize * 1024
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BodyCheck) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BodyCheck) Destroy() {
|
||||
return
|
||||
}
|
||||
|
||||
func (b *BodyCheck) CheckSkill(skill string) bool {
|
||||
return http_service.FilterSkillName == skill
|
||||
}
|
22
drivers/plugins/body-check/config.go
Normal file
22
drivers/plugins/body-check/config.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package body_check
|
||||
|
||||
import (
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
IsEmpty bool `json:"is_empty" label:"是否允许为空"`
|
||||
AllowedPayloadSize int `json:"allowed_payload_size" label:"允许的最大请求体大小"`
|
||||
}
|
||||
|
||||
func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
|
||||
|
||||
bc := &BodyCheck{
|
||||
WorkerBase: drivers.Worker(id, name),
|
||||
isEmpty: conf.IsEmpty,
|
||||
allowedPayloadSize: conf.AllowedPayloadSize * 1024,
|
||||
}
|
||||
|
||||
return bc, nil
|
||||
}
|
18
drivers/plugins/body-check/factory.go
Normal file
18
drivers/plugins/body-check/factory.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package body_check
|
||||
|
||||
import (
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "body_check"
|
||||
)
|
||||
|
||||
func Register(register eosc.IExtenderDriverRegister) {
|
||||
register.RegisterExtenderDriver(Name, NewFactory())
|
||||
}
|
||||
|
||||
func NewFactory() eosc.IExtenderDriverFactory {
|
||||
return drivers.NewFactory[Config](Create)
|
||||
}
|
41
drivers/plugins/counter/config.go
Normal file
41
drivers/plugins/counter/config.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/apinto/drivers/plugins/counter/separator"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Match *Match `json:"match" label:"响应匹配规则"`
|
||||
Count *separator.CountRule `json:"count" label:"计数规则"`
|
||||
Key string `json:"key" label:"计数字段名称"`
|
||||
Cache eosc.RequireId `json:"cache" label:"缓存计数器"`
|
||||
}
|
||||
|
||||
type Match struct {
|
||||
Params []*MatchParam `json:"params" label:"匹配参数列表"`
|
||||
StatusCodes []int `json:"status_codes" label:"匹配响应状态码列表"`
|
||||
Type string `json:"type" label:"匹配类型" enum:"json"`
|
||||
}
|
||||
|
||||
func (m *Match) GenerateHandler() []IMatcher {
|
||||
matcher := make([]IMatcher, 0, 2)
|
||||
matcher = append(matcher, newStatusCodeMatcher(m.StatusCodes))
|
||||
matcher = append(matcher, newJsonMatcher(m.Params))
|
||||
return matcher
|
||||
}
|
||||
|
||||
func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
|
||||
counter, err := separator.GetCounter(conf.Count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bc := &executor{
|
||||
WorkerBase: drivers.Worker(id, name),
|
||||
matchers: conf.Match.GenerateHandler(),
|
||||
separatorCounter: counter,
|
||||
}
|
||||
|
||||
return bc, nil
|
||||
}
|
5
drivers/plugins/counter/counter/client.go
Normal file
5
drivers/plugins/counter/counter/client.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package counter
|
||||
|
||||
type IClient interface {
|
||||
Get(key string) (int64, error)
|
||||
}
|
79
drivers/plugins/counter/counter/client_test.go
Normal file
79
drivers/plugins/counter/counter/client_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func TestLocalCounter(t *testing.T) {
|
||||
client := NewHTTPClient("", nil)
|
||||
lc := NewLocalCounter("test", client)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(100)
|
||||
for i := 0; i < 100; i++ {
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
count := rand.Int63n(20)
|
||||
err := lc.Lock(count)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
condition := rand.Intn(100)
|
||||
switch condition % 2 {
|
||||
case 0:
|
||||
err = lc.Complete(count)
|
||||
case 1:
|
||||
err = lc.RollBack(count)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRedisCounter(t *testing.T) {
|
||||
client := NewHTTPClient("", nil)
|
||||
key := "apinto-apiddww"
|
||||
lc := NewLocalCounter(key, client)
|
||||
redisConn := redis.NewClient(&redis.Options{
|
||||
Addr: "172.18.65.42:6380",
|
||||
Password: "password", // 如果有密码,请填写密码
|
||||
DB: 9, // 选择数据库,默认为0
|
||||
})
|
||||
rc := NewRedisCounter(key, redisConn, client, lc)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(100)
|
||||
for i := 0; i < 100; i++ {
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
count := rand.Int63n(20)
|
||||
err := rc.Lock(count)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
condition := rand.Intn(100)
|
||||
switch condition % 2 {
|
||||
case 0:
|
||||
err = rc.Complete(count)
|
||||
case 1:
|
||||
err = rc.RollBack(count)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
26
drivers/plugins/counter/counter/counter.go
Normal file
26
drivers/plugins/counter/counter/counter.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package counter
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ICounter interface {
|
||||
// Lock 锁定次数
|
||||
Lock(count int64) error
|
||||
// Complete 完成扣次操作
|
||||
Complete(count int64) error
|
||||
// RollBack 回滚
|
||||
RollBack(count int64) error
|
||||
// ResetClient 重置客户端
|
||||
ResetClient(client IClient)
|
||||
}
|
||||
|
||||
func getRemainCount(client IClient, key string, count int64) (int64, error) {
|
||||
remain, err := client.Get(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
remain -= count
|
||||
if remain < 0 {
|
||||
return 0, fmt.Errorf("no enough, key:%s, remain:%d, count:%d", key, remain, count)
|
||||
}
|
||||
return remain, nil
|
||||
}
|
58
drivers/plugins/counter/counter/http-client.go
Normal file
58
drivers/plugins/counter/counter/http-client.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ohler55/ojg/oj"
|
||||
|
||||
"github.com/ohler55/ojg/jp"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var _ IClient = (*HTTPClient)(nil)
|
||||
|
||||
var httpClient = fasthttp.Client{
|
||||
Name: "apinto-counter",
|
||||
}
|
||||
|
||||
type HTTPClient struct {
|
||||
uri string
|
||||
headers map[string]string
|
||||
// jsonExpr 经过编译的JSONPath表达式
|
||||
jsonExpr jp.Expr
|
||||
}
|
||||
|
||||
func NewHTTPClient(uri string, jsonExpr jp.Expr) *HTTPClient {
|
||||
return &HTTPClient{uri: uri, jsonExpr: jsonExpr}
|
||||
}
|
||||
|
||||
func (H *HTTPClient) Get(key string) (int64, error) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
req.SetRequestURI(H.uri)
|
||||
req.Header.SetMethod("GET")
|
||||
for name, value := range H.headers {
|
||||
req.Header.Set(name, value)
|
||||
}
|
||||
err := httpClient.Do(req, resp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return 0, fmt.Errorf("error http status code: %d,key: %s,uri %s", resp.StatusCode(), key, H.uri)
|
||||
}
|
||||
result, err := oj.Parse(resp.Body())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// 解析JSON
|
||||
v := H.jsonExpr.Get(result)
|
||||
if v == nil || len(v) < 1 {
|
||||
return 0, fmt.Errorf("no found key: %s,uri: %s", key)
|
||||
}
|
||||
if len(v) != 1 {
|
||||
return 0, fmt.Errorf("invalid value: %v,key: %s,uri: %s", v, key, H.uri)
|
||||
}
|
||||
return v[0].(int64), nil
|
||||
}
|
76
drivers/plugins/counter/counter/local.go
Normal file
76
drivers/plugins/counter/counter/local.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ ICounter = (*LocalCounter)(nil)
|
||||
|
||||
func NewLocalCounter(key string, client IClient) *LocalCounter {
|
||||
return &LocalCounter{key: key, client: client}
|
||||
}
|
||||
|
||||
// LocalCounter 本地计数器
|
||||
type LocalCounter struct {
|
||||
key string
|
||||
// 剩余次数
|
||||
remain int64
|
||||
// 锁定次数
|
||||
lock int64
|
||||
|
||||
locker sync.Mutex
|
||||
|
||||
resetTime time.Time
|
||||
client IClient
|
||||
}
|
||||
|
||||
func (c *LocalCounter) Lock(count int64) error {
|
||||
c.locker.Lock()
|
||||
defer c.locker.Unlock()
|
||||
remain := c.remain - count
|
||||
if remain < 0 {
|
||||
now := time.Now()
|
||||
if now.Sub(c.resetTime) < 10*time.Second {
|
||||
return fmt.Errorf("no enough, key:%s, remain:%d, count:%d", c.key, c.remain, count)
|
||||
}
|
||||
|
||||
var err error
|
||||
c.resetTime = now
|
||||
// 获取最新的次数
|
||||
remain, err = getRemainCount(c.client, c.key, count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.remain = remain
|
||||
c.lock += count
|
||||
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "lock", "remain:", c.remain, ",lock:", c.lock, ",count:", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LocalCounter) Complete(count int64) error {
|
||||
c.locker.Lock()
|
||||
defer c.locker.Unlock()
|
||||
// 需要解除已经锁定的部分次数
|
||||
c.lock -= count
|
||||
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "complete", "remain:", c.remain, ",lock:", c.lock, ",count:", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LocalCounter) RollBack(count int64) error {
|
||||
c.locker.Lock()
|
||||
defer c.locker.Unlock()
|
||||
// 需要解除已经锁定的部分次数,并且增加剩余次数
|
||||
c.remain += c.lock
|
||||
c.lock -= count
|
||||
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "rollback", "remain:", c.remain, ",lock:", c.lock, ",count:", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LocalCounter) ResetClient(client IClient) {
|
||||
c.locker.Lock()
|
||||
defer c.locker.Unlock()
|
||||
c.client = client
|
||||
}
|
171
drivers/plugins/counter/counter/redis.go
Normal file
171
drivers/plugins/counter/counter/redis.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var _ ICounter = (*RedisCounter)(nil)
|
||||
|
||||
type RedisCounter struct {
|
||||
ctx context.Context
|
||||
key string
|
||||
redis redis.Cmdable
|
||||
|
||||
client IClient
|
||||
locker sync.Mutex
|
||||
resetTime time.Time
|
||||
|
||||
localCounter ICounter
|
||||
|
||||
lockerKey string
|
||||
lockKey string
|
||||
remainKey string
|
||||
}
|
||||
|
||||
func NewRedisCounter(key string, redis redis.Cmdable, client IClient, localCounter ICounter) *RedisCounter {
|
||||
|
||||
return &RedisCounter{
|
||||
key: key,
|
||||
redis: redis,
|
||||
client: client,
|
||||
localCounter: localCounter,
|
||||
ctx: context.Background(),
|
||||
lockerKey: fmt.Sprintf("%s:locker", key),
|
||||
lockKey: fmt.Sprintf("%s:lock", key),
|
||||
remainKey: fmt.Sprintf("%s:remain", key),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisCounter) Lock(count int64) error {
|
||||
r.locker.Lock()
|
||||
defer r.locker.Unlock()
|
||||
if r.redis == nil {
|
||||
// 如果Redis没有配置,使用本地计数器
|
||||
return r.localCounter.Lock(count)
|
||||
}
|
||||
|
||||
err := r.acquireLock()
|
||||
if err != nil {
|
||||
if err == redis.ErrClosed {
|
||||
return r.localCounter.Lock(count)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer r.releaseLock()
|
||||
|
||||
// 获取最新的次数
|
||||
remain, err := r.redis.Get(r.ctx, r.remainKey).Int64()
|
||||
if err != nil {
|
||||
if err == redis.ErrClosed {
|
||||
return r.localCounter.Lock(count)
|
||||
} else if err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
remain -= count
|
||||
if remain < 0 {
|
||||
now := time.Now()
|
||||
if now.Sub(r.resetTime) < 10*time.Second {
|
||||
return fmt.Errorf("no enough, ddd key:%s, remain:%d, count:%d", r.key, remain+count, count)
|
||||
}
|
||||
|
||||
r.resetTime = now
|
||||
lock, err := r.redis.Get(r.ctx, r.lockKey).Int64()
|
||||
if err != nil {
|
||||
if err == redis.ErrClosed {
|
||||
return r.localCounter.Lock(count)
|
||||
} else if err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
remain, err = getRemainCount(r.client, r.key, count+lock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.redis.Set(r.ctx, r.remainKey, remain, -1)
|
||||
r.redis.IncrBy(r.ctx, r.lockKey, count)
|
||||
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "lock", "remain:", remain, "count:", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisCounter) Complete(count int64) error {
|
||||
r.locker.Lock()
|
||||
defer r.locker.Unlock()
|
||||
if r.redis == nil {
|
||||
// 如果Redis没有配置,使用本地计数器
|
||||
return r.localCounter.Complete(count)
|
||||
}
|
||||
|
||||
err := r.acquireLock()
|
||||
if err != nil {
|
||||
if err == redis.ErrClosed {
|
||||
return r.localCounter.Lock(count)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer r.releaseLock()
|
||||
|
||||
r.redis.IncrBy(r.ctx, r.lockKey, -count)
|
||||
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "complete", "count:", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisCounter) RollBack(count int64) error {
|
||||
r.locker.Lock()
|
||||
defer r.locker.Unlock()
|
||||
if r.redis == nil {
|
||||
// 如果Redis没有配置,使用本地计数器
|
||||
return r.localCounter.RollBack(count)
|
||||
}
|
||||
err := r.acquireLock()
|
||||
if err != nil {
|
||||
if err == redis.ErrClosed {
|
||||
return r.localCounter.Lock(count)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer r.releaseLock()
|
||||
|
||||
r.redis.IncrBy(r.ctx, r.remainKey, count)
|
||||
r.redis.IncrBy(r.ctx, r.lockKey, -count)
|
||||
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "rollback", "count:", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisCounter) ResetClient(client IClient) {
|
||||
r.locker.Lock()
|
||||
defer r.locker.Unlock()
|
||||
r.client = client
|
||||
}
|
||||
|
||||
func (r *RedisCounter) acquireLock() error {
|
||||
for {
|
||||
// 生成唯一的锁值
|
||||
lockValue := time.Now().UnixNano()
|
||||
|
||||
// Redis连接失败,使用本地计数器
|
||||
ok, err := r.redis.SetNX(r.ctx, r.lockerKey, lockValue, 10*time.Second).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
// 设置锁成功
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisCounter) releaseLock() error {
|
||||
return r.redis.Del(r.ctx, r.lockerKey).Err()
|
||||
}
|
113
drivers/plugins/counter/executor.go
Normal file
113
drivers/plugins/counter/executor.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/eolinker/apinto/drivers/plugins/counter/counter"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/apinto/drivers/plugins/counter/separator"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
var _ http_service.HttpFilter = (*executor)(nil)
|
||||
var _ eocontext.IFilter = (*executor)(nil)
|
||||
|
||||
type executor struct {
|
||||
drivers.WorkerBase
|
||||
matchers []IMatcher
|
||||
separatorCounter separator.ICounter
|
||||
counters eosc.Untyped[string, counter.ICounter]
|
||||
client counter.IClient
|
||||
keyGenerate IKeyGenerator
|
||||
}
|
||||
|
||||
func (b *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
|
||||
return http_service.DoHttpFilter(b, ctx, next)
|
||||
}
|
||||
|
||||
func (b *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error {
|
||||
counter, has := b.counters.Get(b.keyGenerate.Key(ctx))
|
||||
if !has {
|
||||
|
||||
}
|
||||
var count int64 = 1
|
||||
var err error
|
||||
if b.separatorCounter != nil {
|
||||
|
||||
separatorCounter := b.separatorCounter
|
||||
count, err = separatorCounter.Count(ctx)
|
||||
if err != nil {
|
||||
ctx.Response().SetStatus(400, "400")
|
||||
return fmt.Errorf("%s count error", separatorCounter.Name())
|
||||
}
|
||||
if count > separatorCounter.Max() {
|
||||
ctx.Response().SetStatus(403, "not allow")
|
||||
return fmt.Errorf("%s number exceed", separatorCounter.Name())
|
||||
} else if count == 0 {
|
||||
ctx.Response().SetStatus(400, "400")
|
||||
return fmt.Errorf("%s value is missing", separatorCounter.Name())
|
||||
}
|
||||
}
|
||||
|
||||
err = counter.Lock(count)
|
||||
if err != nil {
|
||||
// 次数不足,直接返回
|
||||
//return fmt.Errorf("no enough, key:%s, remain:%d, count:%d", b.counters.Name(), b.counters.Remain(), count
|
||||
}
|
||||
if next != nil {
|
||||
err = next.DoChain(ctx)
|
||||
if err != nil {
|
||||
// 转发失败,回滚次数
|
||||
return counter.RollBack(count)
|
||||
//return err
|
||||
}
|
||||
}
|
||||
match := true
|
||||
for _, matcher := range b.matchers {
|
||||
ok := matcher.Match(ctx)
|
||||
if !ok {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
// 匹配,扣减次数
|
||||
return counter.Complete(count)
|
||||
}
|
||||
// 不匹配,回滚次数
|
||||
return counter.RollBack(count)
|
||||
}
|
||||
|
||||
func (b *executor) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
|
||||
cfg, ok := conf.(*Config)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid config, driver: %s", Name)
|
||||
}
|
||||
counter, err := separator.GetCounter(cfg.Count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.separatorCounter = counter
|
||||
b.matchers = cfg.Match.GenerateHandler()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *executor) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *executor) Destroy() {
|
||||
return
|
||||
}
|
||||
|
||||
func (b *executor) CheckSkill(skill string) bool {
|
||||
return http_service.FilterSkillName == skill
|
||||
}
|
18
drivers/plugins/counter/factory.go
Normal file
18
drivers/plugins/counter/factory.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "counter"
|
||||
)
|
||||
|
||||
func Register(register eosc.IExtenderDriverRegister) {
|
||||
register.RegisterExtenderDriver(Name, NewFactory())
|
||||
}
|
||||
|
||||
func NewFactory() eosc.IExtenderDriverFactory {
|
||||
return drivers.NewFactory[Config](Create)
|
||||
}
|
49
drivers/plugins/counter/key.go
Normal file
49
drivers/plugins/counter/key.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
var _ IKeyGenerator = (*keyGenerate)(nil)
|
||||
|
||||
type IKeyGenerator interface {
|
||||
Key(ctx http_service.IHttpContext) string
|
||||
}
|
||||
|
||||
func newKeyGenerate(key string) *keyGenerate {
|
||||
key = strings.TrimSpace(key)
|
||||
tmp := strings.Split(key, ":")
|
||||
|
||||
keys := make([]string, 0, len(tmp))
|
||||
variables := make([]string, 0, len(tmp))
|
||||
for _, t := range tmp {
|
||||
t = strings.TrimSpace(t)
|
||||
tLen := len(t)
|
||||
if tLen > 0 {
|
||||
if tLen > 1 && t[0] == '$' {
|
||||
variables = append(variables, t[1:])
|
||||
keys = append(keys, "%s")
|
||||
} else {
|
||||
keys = append(keys, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &keyGenerate{format: strings.Join(keys, ":"), variables: variables}
|
||||
}
|
||||
|
||||
type keyGenerate struct {
|
||||
format string
|
||||
// 变量列表
|
||||
variables []string
|
||||
}
|
||||
|
||||
func (k *keyGenerate) Key(ctx http_service.IHttpContext) string {
|
||||
variables := make([]interface{}, 0, len(k.variables))
|
||||
for _, v := range k.variables {
|
||||
variables = append(variables, ctx.GetLabel(v))
|
||||
}
|
||||
return fmt.Sprintf(k.format, variables...)
|
||||
}
|
8
drivers/plugins/counter/key_test.go
Normal file
8
drivers/plugins/counter/key_test.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package counter
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
key := newKeyGenerate("a:$b:$c:d")
|
||||
t.Log(key.format, key.variables)
|
||||
}
|
121
drivers/plugins/counter/match.go
Normal file
121
drivers/plugins/counter/match.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package counter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
"github.com/ohler55/ojg/jp"
|
||||
"github.com/ohler55/ojg/oj"
|
||||
)
|
||||
|
||||
type IMatcher interface {
|
||||
Match(ctx http_service.IHttpContext) bool
|
||||
}
|
||||
|
||||
type MatchParam struct {
|
||||
Key string `json:"key"`
|
||||
Kind string `json:"kind"` // int|string|bool
|
||||
Value []string `json:"value"`
|
||||
}
|
||||
|
||||
func newJsonMatcher(params []*MatchParam) *jsonMatcher {
|
||||
ps := make([]*jsonMatchParam, 0, len(params))
|
||||
for _, p := range params {
|
||||
key := p.Key
|
||||
if !strings.HasPrefix(p.Key, "$.") {
|
||||
key = "$." + p.Key
|
||||
}
|
||||
expr, err := jp.ParseString(key)
|
||||
if err != nil {
|
||||
log.Errorf("json path parse error: %v,key is %s", err, key)
|
||||
continue
|
||||
}
|
||||
ps = append(ps, &jsonMatchParam{
|
||||
MatchParam: p,
|
||||
expr: expr,
|
||||
})
|
||||
}
|
||||
return &jsonMatcher{params: ps}
|
||||
}
|
||||
|
||||
type jsonMatcher struct {
|
||||
params []*jsonMatchParam
|
||||
}
|
||||
|
||||
type jsonMatchParam struct {
|
||||
*MatchParam
|
||||
expr jp.Expr
|
||||
}
|
||||
|
||||
func (m *jsonMatcher) Match(ctx http_service.IHttpContext) bool {
|
||||
if len(m.params) < 1 {
|
||||
return true
|
||||
}
|
||||
body := ctx.Response().GetBody()
|
||||
tmp, err := oj.Parse(body)
|
||||
if err != nil {
|
||||
log.Errorf("parse body error: %v,body is %s", err, body)
|
||||
return true
|
||||
}
|
||||
for _, p := range m.params {
|
||||
results := p.expr.Get(tmp)
|
||||
for _, v := range p.Value {
|
||||
for _, r := range results {
|
||||
switch p.Kind {
|
||||
case "int":
|
||||
t, ok := r.(int64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
val, _ := strconv.ParseInt(v, 10, 64)
|
||||
if t == val {
|
||||
return true
|
||||
}
|
||||
case "bool":
|
||||
t, ok := r.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
val, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return t == val
|
||||
default:
|
||||
t, ok := r.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if t == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newStatusCodeMatcher(codes []int) *statusCodeMatcher {
|
||||
return &statusCodeMatcher{codes: codes}
|
||||
}
|
||||
|
||||
type statusCodeMatcher struct {
|
||||
codes []int
|
||||
}
|
||||
|
||||
func (m *statusCodeMatcher) Match(ctx http_service.IHttpContext) bool {
|
||||
if len(m.codes) < 1 {
|
||||
return true
|
||||
}
|
||||
code := ctx.Response().StatusCode()
|
||||
for _, c := range m.codes {
|
||||
if c == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
67
drivers/plugins/counter/separator/count.go
Normal file
67
drivers/plugins/counter/separator/count.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package separator
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
var (
|
||||
ArrayCountType = "array"
|
||||
SplitCountType = "splite"
|
||||
LengthCountType = "length"
|
||||
CountTypes = []string{
|
||||
ArrayCountType,
|
||||
SplitCountType,
|
||||
LengthCountType,
|
||||
}
|
||||
)
|
||||
|
||||
type CountRule struct {
|
||||
RequestBodyType string `json:"request_body_type" label:"请求体类型" enum:"form-data,json"`
|
||||
Key string `json:"key" label:"参数名称(支持json path)"`
|
||||
Separator string `json:"separator" label:"分隔符" switch:"separator_type===splite"`
|
||||
SeparatorType string `json:"separator_type" label:"分割类型" enum:"splite,array,length"`
|
||||
Max int64 `json:"max" label:"计数最大值"`
|
||||
}
|
||||
|
||||
type ICounter interface {
|
||||
Count(ctx http_service.IHttpContext) (int64, error)
|
||||
Max() int64
|
||||
Name() string
|
||||
}
|
||||
|
||||
func GetCounter(rule *CountRule) (ICounter, error) {
|
||||
switch strings.ToLower(rule.RequestBodyType) {
|
||||
case "form-data":
|
||||
return NewFormDataCounter(rule)
|
||||
case "multipart-formdata":
|
||||
return NewFileCounter(rule)
|
||||
case "json":
|
||||
return NewJsonCounter(rule)
|
||||
default:
|
||||
return NewFormDataCounter(rule)
|
||||
}
|
||||
}
|
||||
|
||||
func splitCount(origin string, split string) int64 {
|
||||
if len(split) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
vs := strings.Split(origin, string(split[0]))
|
||||
var count int64 = 0
|
||||
for _, v := range vs {
|
||||
if v != "" {
|
||||
childCount := splitCount(v, split[1:])
|
||||
if childCount == 0 {
|
||||
count += 1
|
||||
} else {
|
||||
count += childCount
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
74
drivers/plugins/counter/separator/file.go
Normal file
74
drivers/plugins/counter/separator/file.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package separator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
var _ ICounter = (*FileCounter)(nil)
|
||||
|
||||
const defaultMultipartMemory = 32 << 20 // 32 MB
|
||||
|
||||
type FileCounter struct {
|
||||
typ string
|
||||
split string
|
||||
name string
|
||||
max int64
|
||||
splitLen int
|
||||
}
|
||||
|
||||
func (f *FileCounter) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func NewFileCounter(rule *CountRule) (*FileCounter, error) {
|
||||
var splitLen int
|
||||
if rule.SeparatorType == LengthCountType {
|
||||
var err error
|
||||
splitLen, err = strconv.Atoi(rule.Separator)
|
||||
if err != nil {
|
||||
splitLen = 1000
|
||||
}
|
||||
}
|
||||
return &FileCounter{name: rule.Key, split: rule.Separator, typ: rule.SeparatorType, max: rule.Max, splitLen: splitLen}, nil
|
||||
}
|
||||
|
||||
func (f *FileCounter) Count(ctx http_service.IHttpContext) (int64, error) {
|
||||
raw, _ := ctx.Request().Body().RawBody()
|
||||
d, params, _ := mime.ParseMediaType(ctx.Request().ContentType())
|
||||
if !(d == "multipart/form-data") {
|
||||
return -1, fmt.Errorf("need content-type: multipart/form-data,now: %s", d)
|
||||
}
|
||||
boundary, ok := params["boundary"]
|
||||
if !ok {
|
||||
return -1, fmt.Errorf("missing boundary")
|
||||
}
|
||||
body := io.NopCloser(bytes.NewBuffer(raw))
|
||||
reader := multipart.NewReader(body, boundary)
|
||||
form, err := reader.ReadForm(defaultMultipartMemory)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parse form param err: %v", err)
|
||||
}
|
||||
|
||||
switch f.typ {
|
||||
case LengthCountType:
|
||||
value := strings.Join(form.Value[f.name], "")
|
||||
l := len([]rune(value))
|
||||
if l%f.splitLen == 0 {
|
||||
return int64(l / f.splitLen), nil
|
||||
}
|
||||
return int64(l/f.splitLen + 1), nil
|
||||
}
|
||||
return splitCount(strings.Join(form.Value[f.name], f.split), f.split), nil
|
||||
}
|
||||
|
||||
func (f *FileCounter) Max() int64 {
|
||||
return f.max
|
||||
}
|
59
drivers/plugins/counter/separator/formdata.go
Normal file
59
drivers/plugins/counter/separator/formdata.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package separator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
var _ ICounter = (*FormDataCounter)(nil)
|
||||
|
||||
type FormDataCounter struct {
|
||||
typ string
|
||||
split string
|
||||
name string
|
||||
max int64
|
||||
splitLen int
|
||||
}
|
||||
|
||||
func (f *FormDataCounter) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func NewFormDataCounter(rule *CountRule) (*FormDataCounter, error) {
|
||||
var splitLen int
|
||||
if rule.SeparatorType == LengthCountType {
|
||||
var err error
|
||||
splitLen, err = strconv.Atoi(rule.Separator)
|
||||
if err != nil {
|
||||
splitLen = 1000
|
||||
}
|
||||
}
|
||||
return &FormDataCounter{name: rule.Key, split: rule.Separator, typ: rule.SeparatorType, max: rule.Max, splitLen: splitLen}, nil
|
||||
}
|
||||
|
||||
func (f *FormDataCounter) Count(ctx http_service.IHttpContext) (int64, error) {
|
||||
body, _ := ctx.Request().Body().RawBody()
|
||||
u, err := url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parse form data error:%v", err)
|
||||
}
|
||||
switch f.typ {
|
||||
case SplitCountType:
|
||||
return splitCount(u.Get(f.name), f.split), nil
|
||||
case LengthCountType:
|
||||
value := u.Get(f.name)
|
||||
l := len([]rune(value))
|
||||
if l%f.splitLen == 0 {
|
||||
return int64(l / f.splitLen), nil
|
||||
}
|
||||
return int64(l/f.splitLen + 1), nil
|
||||
}
|
||||
return splitCount(u.Get(f.name), f.split), nil
|
||||
}
|
||||
|
||||
func (f *FormDataCounter) Max() int64 {
|
||||
return f.max
|
||||
}
|
104
drivers/plugins/counter/separator/json.go
Normal file
104
drivers/plugins/counter/separator/json.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package separator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
|
||||
"github.com/ohler55/ojg/oj"
|
||||
|
||||
"github.com/ohler55/ojg/jp"
|
||||
)
|
||||
|
||||
var _ ICounter = (*JsonCounter)(nil)
|
||||
|
||||
type JsonCounter struct {
|
||||
max int64
|
||||
split string
|
||||
expr jp.Expr
|
||||
typ string
|
||||
name string
|
||||
splitLen int
|
||||
}
|
||||
|
||||
func (j *JsonCounter) Name() string {
|
||||
return j.name
|
||||
}
|
||||
|
||||
func NewJsonCounter(rule *CountRule) (*JsonCounter, error) {
|
||||
typeValid := false
|
||||
for _, t := range CountTypes {
|
||||
if t == rule.SeparatorType {
|
||||
typeValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !typeValid {
|
||||
return nil, fmt.Errorf("json count split type config error,now type is %s, need array or split", rule.SeparatorType)
|
||||
}
|
||||
expr, err := jp.ParseString(rule.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("json path parse error:%v", err)
|
||||
}
|
||||
if rule.Max < 1 || rule.Max > 2000 {
|
||||
rule.Max = 2000
|
||||
}
|
||||
var splitLen int
|
||||
if rule.SeparatorType == LengthCountType {
|
||||
splitLen, err = strconv.Atoi(rule.Separator)
|
||||
if err != nil {
|
||||
splitLen = 1000
|
||||
}
|
||||
}
|
||||
return &JsonCounter{max: rule.Max, split: rule.Separator, expr: expr, typ: rule.SeparatorType, name: strings.TrimPrefix(rule.Key, "$."), splitLen: splitLen}, nil
|
||||
}
|
||||
|
||||
func (j *JsonCounter) Count(ctx http_service.IHttpContext) (int64, error) {
|
||||
body, _ := ctx.Request().Body().RawBody()
|
||||
obj, err := oj.Parse(body)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parse json body error:%v, body is %s", err, string(body))
|
||||
}
|
||||
results := j.expr.Get(obj)
|
||||
switch j.typ {
|
||||
case SplitCountType:
|
||||
if len(results) > 0 {
|
||||
origin, ok := results[0].(string)
|
||||
if !ok {
|
||||
return -1, fmt.Errorf("json path %s get value is not string", j.name)
|
||||
}
|
||||
return splitCount(origin, j.split), nil
|
||||
}
|
||||
case ArrayCountType:
|
||||
if len(results) > 0 {
|
||||
switch v := results[0].(type) {
|
||||
case []interface{}:
|
||||
{
|
||||
return int64(len(v)), nil
|
||||
}
|
||||
case map[string]interface{}:
|
||||
return int64(len(v)), nil
|
||||
}
|
||||
}
|
||||
case LengthCountType:
|
||||
if len(results) > 0 {
|
||||
origin, ok := results[0].(string)
|
||||
if !ok {
|
||||
return -1, fmt.Errorf("json path %s get value is not string", j.name)
|
||||
}
|
||||
l := len([]rune(origin))
|
||||
|
||||
if l%j.splitLen == 0 {
|
||||
return int64(l / j.splitLen), nil
|
||||
}
|
||||
return int64(l/j.splitLen + 1), nil
|
||||
}
|
||||
}
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
func (j *JsonCounter) Max() int64 {
|
||||
return j.max
|
||||
}
|
120
drivers/plugins/extra-params_v2/config.go
Normal file
120
drivers/plugins/extra-params_v2/config.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package extra_params_v2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Params []*ExtraParam `json:"params" label:"参数列表"`
|
||||
RequestBodyType string `json:"request_body_type" enum:"form-data,json" label:"请求体类型"`
|
||||
ErrorType string `json:"error_type" enum:"text,json" label:"报错输出格式"`
|
||||
}
|
||||
|
||||
func (c *Config) doCheck() error {
|
||||
c.ErrorType = strings.ToLower(c.ErrorType)
|
||||
if c.ErrorType != "text" && c.ErrorType != "json" {
|
||||
c.ErrorType = "text"
|
||||
}
|
||||
|
||||
for _, param := range c.Params {
|
||||
if param.Name == "" {
|
||||
return fmt.Errorf(paramNameErrInfo)
|
||||
}
|
||||
|
||||
param.Position = strings.ToLower(param.Position)
|
||||
if param.Position != "query" && param.Position != "header" && param.Position != "body" {
|
||||
return fmt.Errorf(paramPositionErrInfo, param.Position)
|
||||
}
|
||||
|
||||
param.Conflict = strings.ToLower(param.Conflict)
|
||||
if param.Conflict != paramOrigin && param.Conflict != paramConvert && param.Conflict != paramError {
|
||||
param.Conflict = paramConvert
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ExtraParam struct {
|
||||
Name string `json:"name" label:"参数名"`
|
||||
Type string `json:"type" label:"参数类型" enum:"string,int,float,bool,$datetime,$md5"`
|
||||
Position string `json:"position" enum:"header,query,body" label:"参数位置"`
|
||||
Value []string `json:"value" label:"参数值列表"`
|
||||
Conflict string `json:"conflict" label:"参数冲突时的处理方式" enum:"origin,convert,error"`
|
||||
}
|
||||
|
||||
type baseParam struct {
|
||||
header []*paramInfo
|
||||
query []*paramInfo
|
||||
body []*paramInfo
|
||||
}
|
||||
|
||||
func generateBaseParam(params []*ExtraParam) *baseParam {
|
||||
b := &baseParam{
|
||||
header: make([]*paramInfo, 0),
|
||||
query: make([]*paramInfo, 0),
|
||||
body: make([]*paramInfo, 0),
|
||||
}
|
||||
for _, param := range params {
|
||||
switch param.Position {
|
||||
case positionHeader:
|
||||
b.header = append(b.header, newParamInfo(param.Name, param.Value, param.Type, param.Conflict))
|
||||
case positionQuery:
|
||||
b.query = append(b.query, newParamInfo(param.Name, param.Value, param.Type, param.Conflict))
|
||||
case positionBody:
|
||||
b.body = append(b.body, newParamInfo(param.Name, param.Value, param.Type, param.Conflict))
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func newParamInfo(name string, value []string, typ string, conflict string) *paramInfo {
|
||||
d := ¶mInfo{name: name, value: strings.Join(value, ","), conflict: conflict}
|
||||
valueLen := len(d.value)
|
||||
if strings.HasPrefix(typ, "$") {
|
||||
factory, has := dynamic_params.Get(typ)
|
||||
if has {
|
||||
driver, err := factory.Create(name, value)
|
||||
if err == nil {
|
||||
d.driver = driver
|
||||
}
|
||||
}
|
||||
} else if valueLen > 1 && d.value[0] == '$' {
|
||||
// 系统变量
|
||||
d.systemValue = true
|
||||
d.value = d.value[1:valueLen]
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
type paramInfo struct {
|
||||
name string
|
||||
systemValue bool
|
||||
value string
|
||||
driver dynamic_params.IDynamicDriver
|
||||
conflict string
|
||||
}
|
||||
|
||||
func (b *paramInfo) Build(ctx http_service.IHttpContext, contentType string, params interface{}) (string, error) {
|
||||
if b.driver == nil {
|
||||
if b.systemValue {
|
||||
return ctx.GetLabel(b.value), nil
|
||||
}
|
||||
return b.value, nil
|
||||
}
|
||||
value, err := b.driver.Generate(ctx, contentType, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case int, int32, int64:
|
||||
return fmt.Sprintf("%d", v), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
60
drivers/plugins/extra-params_v2/driver.go
Normal file
60
drivers/plugins/extra-params_v2/driver.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package extra_params_v2
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/datetime"
|
||||
"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/md5"
|
||||
"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/timestamp"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
type Driver struct {
|
||||
profession string
|
||||
name string
|
||||
label string
|
||||
desc string
|
||||
configType reflect.Type
|
||||
}
|
||||
|
||||
func Check(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error {
|
||||
|
||||
return conf.doCheck()
|
||||
}
|
||||
|
||||
func check(v interface{}) (*Config, error) {
|
||||
conf, ok := v.(*Config)
|
||||
if !ok {
|
||||
return nil, eosc.ErrorConfigType
|
||||
}
|
||||
err := conf.doCheck()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
|
||||
once.Do(func() {
|
||||
datetime.Register()
|
||||
md5.Register()
|
||||
timestamp.Register()
|
||||
})
|
||||
ep := &executor{
|
||||
WorkerBase: drivers.Worker(id, name),
|
||||
baseParam: generateBaseParam(conf.Params),
|
||||
requestBodyType: conf.RequestBodyType,
|
||||
errorType: conf.ErrorType,
|
||||
}
|
||||
|
||||
return ep, nil
|
||||
}
|
@@ -0,0 +1,28 @@
|
||||
package datetime
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultValue = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
type Datetime struct {
|
||||
name string
|
||||
value string
|
||||
}
|
||||
|
||||
func NewDatetime(name string, value string) *Datetime {
|
||||
return &Datetime{name: name, value: value}
|
||||
}
|
||||
|
||||
func (d *Datetime) Name() string {
|
||||
return d.name
|
||||
}
|
||||
|
||||
func (d *Datetime) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) {
|
||||
return time.Now().Format(d.value), nil
|
||||
}
|
@@ -0,0 +1,24 @@
|
||||
package datetime
|
||||
|
||||
import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
|
||||
|
||||
const name = "$datetime"
|
||||
|
||||
func Register() {
|
||||
dynamic_params.Register(name, NewFactory())
|
||||
}
|
||||
|
||||
func NewFactory() *Factory {
|
||||
return &Factory{}
|
||||
}
|
||||
|
||||
type Factory struct {
|
||||
}
|
||||
|
||||
func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) {
|
||||
v := defaultValue
|
||||
if len(value) > 0 {
|
||||
v = value[0]
|
||||
}
|
||||
return NewDatetime(name, v), nil
|
||||
}
|
14
drivers/plugins/extra-params_v2/dynamic-params/dynamic.go
Normal file
14
drivers/plugins/extra-params_v2/dynamic-params/dynamic.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package dynamic_params
|
||||
|
||||
import (
|
||||
"github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
type IDynamicFactory interface {
|
||||
Create(name string, value []string) (IDynamicDriver, error)
|
||||
}
|
||||
|
||||
type IDynamicDriver interface {
|
||||
Name() string
|
||||
Generate(ctx http_context.IHttpContext, contentType string, args ...interface{}) (interface{}, error)
|
||||
}
|
95
drivers/plugins/extra-params_v2/dynamic-params/factory.go
Normal file
95
drivers/plugins/extra-params_v2/dynamic-params/factory.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package dynamic_params
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorInvalidBalance = errors.New("invalid balance")
|
||||
defaultFactoryRegister = newFactoryManager()
|
||||
)
|
||||
|
||||
// IFactoryRegister 实现了负载均衡算法工厂管理器
|
||||
type IFactoryRegister interface {
|
||||
RegisterFactoryByKey(key string, factory IDynamicFactory)
|
||||
GetFactoryByKey(key string) (IDynamicFactory, bool)
|
||||
Keys() []string
|
||||
}
|
||||
|
||||
// driverRegister 实现了IBalanceFactoryRegister接口
|
||||
type driverRegister struct {
|
||||
register eosc.IRegister[IDynamicFactory]
|
||||
keys []string
|
||||
}
|
||||
|
||||
// newFactoryManager 创建负载均衡算法工厂管理器
|
||||
func newFactoryManager() IFactoryRegister {
|
||||
return &driverRegister{
|
||||
register: eosc.NewRegister[IDynamicFactory](),
|
||||
keys: make([]string, 0, 10),
|
||||
}
|
||||
}
|
||||
|
||||
// GetFactoryByKey 获取指定balance工厂
|
||||
func (dm *driverRegister) GetFactoryByKey(key string) (IDynamicFactory, bool) {
|
||||
o, has := dm.register.Get(key)
|
||||
if has {
|
||||
log.Debug("GetFactoryByKey:", key, ":has")
|
||||
f, ok := o.(IDynamicFactory)
|
||||
return f, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RegisterFactoryByKey 注册balance工厂
|
||||
func (dm *driverRegister) RegisterFactoryByKey(key string, factory IDynamicFactory) {
|
||||
err := dm.register.Register(key, factory, true)
|
||||
if err != nil {
|
||||
log.Debug("RegisterFactoryByKey:", key, ":", err)
|
||||
return
|
||||
}
|
||||
dm.keys = append(dm.keys, key)
|
||||
}
|
||||
|
||||
// Keys 返回所有已注册的key
|
||||
func (dm *driverRegister) Keys() []string {
|
||||
return dm.keys
|
||||
}
|
||||
|
||||
// Register 注册balance工厂到默认balanceFactory注册器
|
||||
func Register(key string, factory IDynamicFactory) {
|
||||
|
||||
defaultFactoryRegister.RegisterFactoryByKey(key, factory)
|
||||
}
|
||||
|
||||
// Get 从默认balanceFactory注册器中获取balance工厂
|
||||
func Get(key string) (IDynamicFactory, bool) {
|
||||
return defaultFactoryRegister.GetFactoryByKey(key)
|
||||
}
|
||||
|
||||
// Keys 返回默认的balanceFactory注册器中所有已注册的key
|
||||
func Keys() []string {
|
||||
return defaultFactoryRegister.Keys()
|
||||
}
|
||||
|
||||
// GetFactory 获取指定负载均衡算法工厂,若指定的不存在则返回一个已注册的工厂
|
||||
func GetFactory(name string) (IDynamicFactory, error) {
|
||||
factory, ok := Get(name)
|
||||
if !ok {
|
||||
for _, key := range Keys() {
|
||||
factory, ok = Get(key)
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
if factory == nil {
|
||||
return nil, fmt.Errorf("%s:%w", name, ErrorInvalidBalance)
|
||||
}
|
||||
}
|
||||
return factory, nil
|
||||
}
|
@@ -0,0 +1,20 @@
|
||||
package md5
|
||||
|
||||
import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
|
||||
|
||||
const name = "$md5"
|
||||
|
||||
func Register() {
|
||||
dynamic_params.Register(name, NewFactory())
|
||||
}
|
||||
|
||||
func NewFactory() *Factory {
|
||||
return &Factory{}
|
||||
}
|
||||
|
||||
type Factory struct {
|
||||
}
|
||||
|
||||
func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) {
|
||||
return NewMD5(name, value), nil
|
||||
}
|
178
drivers/plugins/extra-params_v2/dynamic-params/md5/md5.go
Normal file
178
drivers/plugins/extra-params_v2/dynamic-params/md5/md5.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package md5
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
|
||||
"github.com/ohler55/ojg/oj"
|
||||
|
||||
"github.com/ohler55/ojg/jp"
|
||||
|
||||
"github.com/eolinker/apinto/utils"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
)
|
||||
|
||||
const (
|
||||
positionCurrent = iota
|
||||
positionHeader
|
||||
positionQuery
|
||||
// body
|
||||
positionBody
|
||||
positionSystem
|
||||
)
|
||||
|
||||
type MD5 struct {
|
||||
name string
|
||||
value []*Value
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
key string
|
||||
position int
|
||||
optional bool
|
||||
}
|
||||
|
||||
func NewMD5(name string, value []string) *MD5 {
|
||||
vs := make([]*Value, 0, len(value))
|
||||
for _, v := range value {
|
||||
v = strings.TrimSpace(v)
|
||||
vLen := len(v)
|
||||
if vLen > 0 {
|
||||
if v[0] == '{' && v[vLen-1] == '}' {
|
||||
vars := strings.Split(v[1:vLen-1], ".")
|
||||
position := positionBody
|
||||
variable := vars[0]
|
||||
if len(vars) > 1 {
|
||||
variable = vars[1]
|
||||
switch vars[0] {
|
||||
case "header":
|
||||
position = positionHeader
|
||||
case "query":
|
||||
position = positionQuery
|
||||
}
|
||||
}
|
||||
vs = append(vs, &Value{
|
||||
key: variable,
|
||||
position: position,
|
||||
})
|
||||
} else if v[0] == '#' {
|
||||
vars := strings.Split(v[1:], ".")
|
||||
position := positionBody
|
||||
variable := vars[0]
|
||||
if len(vars) > 1 {
|
||||
variable = vars[1]
|
||||
switch vars[0] {
|
||||
case "header":
|
||||
position = positionHeader
|
||||
case "query":
|
||||
position = positionQuery
|
||||
}
|
||||
}
|
||||
vs = append(vs, &Value{
|
||||
key: variable,
|
||||
position: position,
|
||||
optional: true,
|
||||
})
|
||||
} else if vLen > 3 && v[0] == '$' && v[1] == '{' && v[vLen-1] == '}' {
|
||||
// 使用系统变量
|
||||
vs = append(vs, &Value{
|
||||
key: v[2 : vLen-1],
|
||||
position: positionSystem,
|
||||
})
|
||||
} else {
|
||||
vs = append(vs, &Value{
|
||||
key: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return &MD5{
|
||||
name: name,
|
||||
value: vs,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MD5) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MD5) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) {
|
||||
builder := strings.Builder{}
|
||||
var params interface{}
|
||||
if contentType == "application/json" {
|
||||
if len(args) < 1 {
|
||||
return nil, errors.New("missing args")
|
||||
}
|
||||
params = args[0]
|
||||
}
|
||||
for _, v := range m.value {
|
||||
if v.key == "" {
|
||||
continue
|
||||
}
|
||||
builder.WriteString(retrieveParam(ctx, contentType, params, v))
|
||||
}
|
||||
log.Debug("md5 result: ", builder.String())
|
||||
if strings.HasPrefix(m.name, "__") {
|
||||
return utils.Md5(builder.String()), nil
|
||||
}
|
||||
return strings.ToUpper(utils.Md5(builder.String())), nil
|
||||
}
|
||||
|
||||
func retrieveParam(ctx http_service.IHttpContext, contentType string, body interface{}, value *Value) string {
|
||||
switch value.position {
|
||||
case positionCurrent:
|
||||
return value.key
|
||||
case positionHeader:
|
||||
return ctx.Proxy().Header().Headers().Get(value.key)
|
||||
case positionQuery:
|
||||
return ctx.Proxy().URI().GetQuery(value.key)
|
||||
case positionBody:
|
||||
|
||||
if contentType == "application/x-www-form-urlencoded" {
|
||||
if !value.optional {
|
||||
return ctx.Proxy().Body().GetForm(value.key)
|
||||
}
|
||||
form, _ := ctx.Proxy().Body().BodyForm()
|
||||
if _, ok := form[value.key]; ok {
|
||||
return value.key
|
||||
}
|
||||
} else if contentType == "application/json" {
|
||||
key := value.key
|
||||
if !strings.HasPrefix(key, "$.") {
|
||||
key = "$." + key
|
||||
}
|
||||
|
||||
x, err := jp.ParseString(key)
|
||||
if err != nil {
|
||||
log.Errorf("parse json path(%s) error: %v", key, err)
|
||||
return ""
|
||||
}
|
||||
result := x.Get(body)
|
||||
|
||||
if len(result) > 0 {
|
||||
if value.optional {
|
||||
return value.key
|
||||
}
|
||||
|
||||
switch r := result[0].(type) {
|
||||
case string:
|
||||
return r
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%.f", r)
|
||||
case bool:
|
||||
return strconv.FormatBool(r)
|
||||
default:
|
||||
return oj.JSON(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
case positionSystem:
|
||||
return ctx.GetLabel(value.key)
|
||||
}
|
||||
return ""
|
||||
}
|
@@ -0,0 +1,24 @@
|
||||
package timestamp
|
||||
|
||||
import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
|
||||
|
||||
const name = "$timestamp"
|
||||
|
||||
func Register() {
|
||||
dynamic_params.Register(name, NewFactory())
|
||||
}
|
||||
|
||||
func NewFactory() *Factory {
|
||||
return &Factory{}
|
||||
}
|
||||
|
||||
type Factory struct {
|
||||
}
|
||||
|
||||
func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) {
|
||||
v := defaultValue
|
||||
if len(value) > 0 {
|
||||
v = value[0]
|
||||
}
|
||||
return NewTimestamp(name, v), nil
|
||||
}
|
@@ -0,0 +1,35 @@
|
||||
package timestamp
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultValue = "string"
|
||||
)
|
||||
|
||||
type Timestamp struct {
|
||||
name string
|
||||
value string
|
||||
}
|
||||
|
||||
func NewTimestamp(name, value string) *Timestamp {
|
||||
return &Timestamp{name: name, value: value}
|
||||
}
|
||||
|
||||
func (t *Timestamp) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *Timestamp) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) {
|
||||
switch t.value {
|
||||
case "string":
|
||||
return strconv.FormatInt(time.Now().Unix(), 10), nil
|
||||
case "int":
|
||||
return time.Now().Unix(), nil
|
||||
}
|
||||
return strconv.FormatInt(time.Now().Unix(), 10), nil
|
||||
}
|
212
drivers/plugins/extra-params_v2/executor.go
Normal file
212
drivers/plugins/extra-params_v2/executor.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package extra_params_v2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ohler55/ojg/oj"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
|
||||
"github.com/ohler55/ojg/jp"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
var _ http_service.HttpFilter = (*executor)(nil)
|
||||
var _ eocontext.IFilter = (*executor)(nil)
|
||||
|
||||
var (
|
||||
errorExist = "%s: %s is already exists"
|
||||
)
|
||||
|
||||
type executor struct {
|
||||
drivers.WorkerBase
|
||||
baseParam *baseParam
|
||||
requestBodyType string
|
||||
errorType string
|
||||
}
|
||||
|
||||
func (e *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
|
||||
return http_service.DoHttpFilter(e, ctx, next)
|
||||
}
|
||||
|
||||
func (e *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error {
|
||||
statusCode, err := e.access(ctx)
|
||||
if err != nil {
|
||||
ctx.Response().SetBody([]byte(err.Error()))
|
||||
ctx.Response().SetStatus(statusCode, strconv.Itoa(statusCode))
|
||||
return err
|
||||
}
|
||||
|
||||
if next != nil {
|
||||
return next.DoChain(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addParamToBody(ctx http_service.IHttpContext, contentType string, params []*paramInfo) (interface{}, error) {
|
||||
|
||||
//var bodyParam map[string]interface{}
|
||||
if contentType == "application/json" {
|
||||
body, _ := ctx.Proxy().Body().RawBody()
|
||||
if string(body) == "" {
|
||||
body = []byte("{}")
|
||||
}
|
||||
bodyParam, err := oj.Parse(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, param := range params {
|
||||
key := param.name
|
||||
if !strings.HasPrefix(param.name, "$.") {
|
||||
key = "$." + key
|
||||
}
|
||||
|
||||
x, err := jp.ParseString(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse key error: %v", err)
|
||||
}
|
||||
result := x.Get(bodyParam)
|
||||
if len(result) > 0 {
|
||||
if param.conflict == paramError {
|
||||
return nil, fmt.Errorf(errorExist, positionBody, param.name)
|
||||
} else if param.conflict == paramOrigin {
|
||||
continue
|
||||
}
|
||||
}
|
||||
value, err := param.Build(ctx, contentType, bodyParam)
|
||||
if err != nil {
|
||||
log.Errorf("build param(s) error: %v", key, err)
|
||||
continue
|
||||
}
|
||||
err = x.Set(bodyParam, value)
|
||||
if err != nil {
|
||||
log.Errorf("set param(s) error: %v", key, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
b, _ := oj.Marshal(bodyParam)
|
||||
ctx.Proxy().Body().SetRaw(contentType, b)
|
||||
return bodyParam, nil
|
||||
} else if contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data" {
|
||||
bodyParam := make(map[string]interface{})
|
||||
bodyForm, _ := ctx.Proxy().Body().BodyForm()
|
||||
for _, param := range params {
|
||||
_, has := bodyParam[param.name]
|
||||
if has {
|
||||
if param.conflict == paramError {
|
||||
return nil, fmt.Errorf("[extra_params] body(%s) has a conflict", param.name)
|
||||
} else if param.conflict == paramOrigin {
|
||||
continue
|
||||
}
|
||||
}
|
||||
value, err := param.Build(ctx, contentType, nil)
|
||||
if err != nil {
|
||||
log.Errorf("build param(s) error: %v", param.name, err)
|
||||
continue
|
||||
}
|
||||
bodyParam[param.name] = value
|
||||
|
||||
}
|
||||
ctx.Proxy().Body().SetForm(bodyForm)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *executor) access(ctx http_service.IHttpContext) (int, error) {
|
||||
// 判断请求携带的content-type
|
||||
contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType())
|
||||
var bodyParam interface{}
|
||||
var err error
|
||||
if ctx.Proxy().Method() == http.MethodPost || ctx.Proxy().Method() == http.MethodPut || ctx.Proxy().Method() == http.MethodPatch {
|
||||
if e.requestBodyType != "" {
|
||||
if e.requestBodyType == "json" && contentType != "application/json" {
|
||||
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] request body type is not json`, clientErrStatusCode)
|
||||
} else if e.requestBodyType == "form-data" && contentType != "multipart/form-data" {
|
||||
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] request body type is not form-data`, clientErrStatusCode)
|
||||
}
|
||||
}
|
||||
bodyParam, err = addParamToBody(ctx, contentType, e.baseParam.body)
|
||||
if err != nil {
|
||||
return clientErrStatusCode, encodeErr(e.errorType, err.Error(), clientErrStatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理Query参数
|
||||
for _, param := range e.baseParam.query {
|
||||
v := ctx.Proxy().URI().GetQuery(param.name)
|
||||
if v != "" {
|
||||
if param.conflict == paramError {
|
||||
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] query("`+param.name+`") has a conflict.`, clientErrStatusCode)
|
||||
} else if param.conflict == paramOrigin {
|
||||
continue
|
||||
}
|
||||
}
|
||||
value, err := param.Build(ctx, contentType, bodyParam)
|
||||
if err != nil {
|
||||
log.Errorf("build query extra param(%s) error: %s", param.name, err.Error())
|
||||
continue
|
||||
}
|
||||
ctx.Proxy().URI().SetQuery(param.name, value)
|
||||
}
|
||||
|
||||
// 处理Header参数
|
||||
for _, param := range e.baseParam.header {
|
||||
name := textproto.CanonicalMIMEHeaderKey(param.name)
|
||||
_, has := ctx.Proxy().Header().Headers()[name]
|
||||
if has {
|
||||
if param.conflict == paramError {
|
||||
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] header("`+name+`") has a conflict.`, clientErrStatusCode)
|
||||
} else if param.conflict == paramOrigin {
|
||||
continue
|
||||
}
|
||||
}
|
||||
value, err := param.Build(ctx, contentType, bodyParam)
|
||||
if err != nil {
|
||||
log.Errorf("build header extra param(%s) error: %s", name, err.Error())
|
||||
continue
|
||||
}
|
||||
ctx.Proxy().Header().SetHeader(param.name, value)
|
||||
}
|
||||
return successStatusCode, nil
|
||||
}
|
||||
|
||||
func (e *executor) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
|
||||
cfg, err := check(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.baseParam = generateBaseParam(cfg.Params)
|
||||
e.requestBodyType = cfg.RequestBodyType
|
||||
e.errorType = cfg.ErrorType
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) Destroy() {
|
||||
e.baseParam = nil
|
||||
e.errorType = ""
|
||||
}
|
||||
|
||||
func (e *executor) CheckSkill(skill string) bool {
|
||||
return http_service.FilterSkillName == skill
|
||||
}
|
17
drivers/plugins/extra-params_v2/factory.go
Normal file
17
drivers/plugins/extra-params_v2/factory.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package extra_params_v2
|
||||
|
||||
import (
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "extra_params_v2"
|
||||
)
|
||||
|
||||
func Register(register eosc.IExtenderDriverRegister) {
|
||||
register.RegisterExtenderDriver(Name, NewFactory())
|
||||
}
|
||||
func NewFactory() eosc.IExtenderDriverFactory {
|
||||
return drivers.NewFactory[Config](Create, Check)
|
||||
}
|
210
drivers/plugins/extra-params_v2/util.go
Normal file
210
drivers/plugins/extra-params_v2/util.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package extra_params_v2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
positionQuery = "query"
|
||||
positionHeader = "header"
|
||||
positionBody = "body"
|
||||
|
||||
paramConvert string = "convert"
|
||||
paramError string = "error"
|
||||
paramOrigin string = "origin"
|
||||
|
||||
clientErrStatusCode = 400
|
||||
successStatusCode = 200
|
||||
)
|
||||
|
||||
var (
|
||||
paramPositionErrInfo = `[plugin extra-params config err] param position must be in the set ["query","header",body]. err position: %s `
|
||||
//parseBodyErrInfo = `[extra_params] Fail to parse body! [err]: %s`
|
||||
paramNameErrInfo = `[plugin extra-params config err] param name must be not null. `
|
||||
)
|
||||
|
||||
func encodeErr(ent string, origin string, statusCode int) error {
|
||||
if ent == "json" {
|
||||
tmp := map[string]interface{}{
|
||||
"message": origin,
|
||||
"status_code": statusCode,
|
||||
}
|
||||
info, _ := json.Marshal(tmp)
|
||||
return fmt.Errorf("%s", info)
|
||||
}
|
||||
return fmt.Errorf("%s statusCode: %d", origin, statusCode)
|
||||
}
|
||||
|
||||
//func parseBodyParams(ctx http_service.IHttpContext) (interface{}, url.Values, error) {
|
||||
// if ctx.Proxy().Method() != http.MethodPost && ctx.Proxy().Method() != http.MethodPut && ctx.Proxy().Method() != http.MethodPatch {
|
||||
// return nil, nil, nil
|
||||
// }
|
||||
// contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType())
|
||||
// switch contentType {
|
||||
// case http_context.FormData, http_context.MultipartForm:
|
||||
// formParams, err := ctx.Proxy().Body().BodyForm()
|
||||
// if err != nil {
|
||||
// return nil, nil, err
|
||||
// }
|
||||
// return nil, formParams, nil
|
||||
// case http_context.JSON:
|
||||
// body, err := ctx.Proxy().Body().RawBody()
|
||||
// if err != nil {
|
||||
// return nil, nil, err
|
||||
// }
|
||||
// if string(body) == "" {
|
||||
// body = []byte("{}")
|
||||
// }
|
||||
// bodyParams, err := oj.Parse(body)
|
||||
// return bodyParams, nil, err
|
||||
// }
|
||||
// return nil, nil, errors.New("unsupported content-type: " + contentType)
|
||||
//}
|
||||
|
||||
//
|
||||
//func parseBodyParams(ctx http_service.IHttpContext) (map[string]interface{}, map[string][]string, error) {
|
||||
// contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType())
|
||||
//
|
||||
// switch contentType {
|
||||
// case http_context.FormData, http_context.MultipartForm:
|
||||
// formParams, err := ctx.Proxy().Body().BodyForm()
|
||||
// if err != nil {
|
||||
// return nil, nil, err
|
||||
// }
|
||||
// return nil, formParams, nil
|
||||
// case http_context.JSON:
|
||||
// body, err := ctx.Proxy().Body().RawBody()
|
||||
// if err != nil {
|
||||
// return nil, nil, err
|
||||
// }
|
||||
// var bodyParams map[string]interface{}
|
||||
// err = json.Unmarshal(body, &bodyParams)
|
||||
// if err != nil {
|
||||
// return bodyParams, nil, err
|
||||
// }
|
||||
// }
|
||||
// return nil, nil, errors.New("[params_transformer] unsupported content-type: " + contentType)
|
||||
//}
|
||||
|
||||
//func getHeaderValue(headers map[string][]string, param *ExtraParam, value string) (string, error) {
|
||||
// paramName := ConvertHeaderKey(param.Name)
|
||||
//
|
||||
// if param.Conflict == "" {
|
||||
// param.Conflict = paramConvert
|
||||
// }
|
||||
//
|
||||
// var paramValue string
|
||||
//
|
||||
// if _, ok := headers[paramName]; !ok {
|
||||
// param.Conflict = paramConvert
|
||||
// } else {
|
||||
// paramValue = headers[paramName][0]
|
||||
// }
|
||||
//
|
||||
// if param.Conflict == paramConvert {
|
||||
// paramValue = value
|
||||
// } else if param.Conflict == paramError {
|
||||
// errInfo := `[extra_params] "` + param.Name + `" has a conflict.`
|
||||
// return "", errors.New(errInfo)
|
||||
// }
|
||||
//
|
||||
// return paramValue, nil
|
||||
//}
|
||||
|
||||
//func hasQueryValue(rawQuery string, paramName string) bool {
|
||||
// bytes := []byte(rawQuery)
|
||||
// if len(bytes) == 0 {
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// k := 0
|
||||
// for i, c := range bytes {
|
||||
// switch c {
|
||||
// case '=':
|
||||
// key := string(bytes[k:i])
|
||||
// if key == paramName {
|
||||
// return true
|
||||
// }
|
||||
// case '&':
|
||||
// k = i + 1
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return false
|
||||
//}
|
||||
|
||||
//func getQueryValue(ctx http_service.IHttpContext, param *ExtraParam, value string) (string, error) {
|
||||
// paramValue := ""
|
||||
// if param.Conflict == "" {
|
||||
// param.Conflict = paramConvert
|
||||
// }
|
||||
//
|
||||
// //判断请求中是否包含对应的query参数
|
||||
// if !hasQueryValue(ctx.Proxy().URI().RawQuery(), param.Name) {
|
||||
// param.Conflict = paramConvert
|
||||
// } else {
|
||||
// paramValue = ctx.Proxy().URI().GetQuery(param.Name)
|
||||
// }
|
||||
//
|
||||
// if param.Conflict == paramConvert {
|
||||
// paramValue = value
|
||||
// } else if param.Conflict == paramError {
|
||||
// errInfo := `[extra_params] "` + param.Name + `" has a conflict.`
|
||||
// return "", errors.New(errInfo)
|
||||
// }
|
||||
//
|
||||
// return paramValue, nil
|
||||
//}
|
||||
//
|
||||
//func getBodyValue(bodyParams map[string]interface{}, formParams map[string][]string, param *ExtraParam, contentType string, value interface{}) (interface{}, error) {
|
||||
// var paramValue interface{} = nil
|
||||
// Conflict := param.Conflict
|
||||
// if Conflict == "" {
|
||||
// Conflict = paramConvert
|
||||
// }
|
||||
// if strings.Contains(contentType, http_context.FormData) || strings.Contains(contentType, http_context.MultipartForm) {
|
||||
// if _, ok := formParams[param.Name]; !ok {
|
||||
// Conflict = paramConvert
|
||||
// } else {
|
||||
// paramValue = formParams[param.Name][0]
|
||||
// }
|
||||
// } else if strings.Contains(contentType, http_context.JSON) {
|
||||
// if _, ok := bodyParams[param.Name]; !ok {
|
||||
// param.Conflict = paramConvert
|
||||
// } else {
|
||||
// paramValue = bodyParams[param.Name]
|
||||
// }
|
||||
// }
|
||||
// if Conflict == paramConvert {
|
||||
// paramValue = value
|
||||
// } else if Conflict == paramError {
|
||||
// errInfo := `[extra_params] "` + param.Name + `" has a conflict.`
|
||||
// return "", errors.New(errInfo)
|
||||
// }
|
||||
//
|
||||
// return paramValue, nil
|
||||
//}
|
||||
|
||||
//func ConvertHeaderKey(header string) string {
|
||||
// header = strings.ToLower(header)
|
||||
// headerArray := strings.Split(header, "-")
|
||||
// h := ""
|
||||
// arrLen := len(headerArray)
|
||||
// for i, value := range headerArray {
|
||||
// vLen := len(value)
|
||||
// if vLen < 1 {
|
||||
// continue
|
||||
// } else {
|
||||
// if vLen == 1 {
|
||||
// h += strings.ToUpper(value)
|
||||
// } else {
|
||||
// h += strings.ToUpper(string(value[0])) + value[1:]
|
||||
// }
|
||||
// if i != arrLen-1 {
|
||||
// h += "-"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// return h
|
||||
//}
|
44
node/fasthttp-client/client_test.go
Normal file
44
node/fasthttp-client/client_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package fasthttp_client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestProxyTimeout(t *testing.T) {
|
||||
//addr := "https://gw.kuaidizs.cn"
|
||||
addr := fmt.Sprintf("%s://%s", "https", "gw.kuaidizs.cn")
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
req.URI().SetPath("/open/api")
|
||||
req.URI().SetHost("gw.kuaidizs.cn")
|
||||
req.Header.SetMethod("POST")
|
||||
req.Header.SetContentType("application/json")
|
||||
req.SetBody([]byte(`{"cpCode":"YTO","version":"1.0","timestamp":"2023-08-10 11:57:13","province":"广东省","city":"广州市","appKey":"DBB812347A1E44829159FE82F5C4303E","format":"json","sign_method":"md5","method":"kdzs.address.reachable","sign":"10A4B5A59340F9B98DAFA3CFCCF65449"}`))
|
||||
err := defaultClient.ProxyTimeout(addr, req, resp, 0)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Log(string(resp.Body()))
|
||||
return
|
||||
}
|
||||
|
||||
func TestMyselfProxyTimeout(t *testing.T) {
|
||||
//addr := "https://gw.kuaidizs.cn"
|
||||
addr := "http://127.0.0.1:8099"
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
req.URI().SetPath("/open/api")
|
||||
req.URI().SetHost("127.0.0.1:8099")
|
||||
req.Header.SetMethod("POST")
|
||||
req.Header.SetContentType("application/json")
|
||||
req.SetBody([]byte(`{"cpCode":"YTO","province":"广东省","city":"广州市"}`))
|
||||
err := defaultClient.ProxyTimeout(addr, req, resp, 0)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Log(string(resp.Body()))
|
||||
return
|
||||
}
|
@@ -46,14 +46,15 @@ var (
|
||||
func (r *ProxyRequest) reset(request *fasthttp.Request, remoteAddr string) {
|
||||
|
||||
r.RequestReader.reset(request, remoteAddr)
|
||||
|
||||
forwardedFor := r.req.Header.PeekBytes(xforwardedforKey)
|
||||
if len(forwardedFor) > 0 {
|
||||
r.req.Header.Set("x-forwarded-for", fmt.Sprint(string(forwardedFor), ",", remoteAddr))
|
||||
r.req.Header.Set("x-forwarded-for", fmt.Sprint(string(forwardedFor), ",", r.remoteAddr))
|
||||
} else {
|
||||
r.req.Header.Set("x-forwarded-for", remoteAddr)
|
||||
r.req.Header.Set("x-forwarded-for", r.remoteAddr)
|
||||
}
|
||||
|
||||
r.req.Header.Set("x-real-ip", r.RequestReader.realIP)
|
||||
r.req.Header.Set("x-real-ip", r.realIP)
|
||||
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user