额外参数v2插件、请求体限制插件完成

This commit is contained in:
Liujian
2023-08-14 15:47:54 +08:00
parent 5f9dfde4db
commit bc65cfa5bb
36 changed files with 2276 additions and 6 deletions

View File

@@ -18,11 +18,13 @@ import (
plugin_manager "github.com/eolinker/apinto/drivers/plugin-manager"
access_log "github.com/eolinker/apinto/drivers/plugins/access-log"
plugin_app "github.com/eolinker/apinto/drivers/plugins/app"
body_check "github.com/eolinker/apinto/drivers/plugins/body-check"
circuit_breaker "github.com/eolinker/apinto/drivers/plugins/circuit-breaker"
"github.com/eolinker/apinto/drivers/plugins/cors"
dubbo2_proxy_rewrite "github.com/eolinker/apinto/drivers/plugins/dubbo2-proxy-rewrite"
dubbo2_to_http "github.com/eolinker/apinto/drivers/plugins/dubbo2-to-http"
extra_params "github.com/eolinker/apinto/drivers/plugins/extra-params"
extra_params_v2 "github.com/eolinker/apinto/drivers/plugins/extra-params_v2"
grpc_to_http "github.com/eolinker/apinto/drivers/plugins/gRPC-to-http"
grpc_proxy_rewrite "github.com/eolinker/apinto/drivers/plugins/grpc-proxy-rewrite"
"github.com/eolinker/apinto/drivers/plugins/gzip"
@@ -69,6 +71,7 @@ func ProcessWorker() {
func registerInnerExtenders() {
extends.AddInnerExtendProject("eolinker.com", "apinto", Register)
}
func Register(extenderRegister eosc.IExtenderDriverRegister) {
// router
http_router.Register(extenderRegister)
@@ -153,4 +156,7 @@ func Register(extenderRegister eosc.IExtenderDriverRegister) {
proxy_mirror.Register(extenderRegister)
http_mocking.Register(extenderRegister)
body_check.Register(extenderRegister)
extra_params_v2.Register(extenderRegister)
}

View File

@@ -49,9 +49,9 @@ func (h *HttpOutput) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWo
if err != nil {
return err
}
if h.config != nil && !h.config.isConfUpdate(config) {
return nil
}
//if h.config != nil && !h.config.isConfUpdate(config) {
// return nil
//}
h.config = config
if h.running {

View File

@@ -0,0 +1,73 @@
package body_check
import (
"errors"
"net/http"
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
var _ http_service.HttpFilter = (*BodyCheck)(nil)
var _ eocontext.IFilter = (*BodyCheck)(nil)
type BodyCheck struct {
drivers.WorkerBase
isEmpty bool
allowedPayloadSize int
}
func (b *BodyCheck) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
return http_service.DoHttpFilter(b, ctx, next)
}
func (b *BodyCheck) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error {
if ctx.Request().Method() == http.MethodPost || ctx.Request().Method() == http.MethodPut || ctx.Request().Method() == http.MethodPatch {
body, err := ctx.Request().Body().RawBody()
if err != nil {
return err
}
bodySize := len([]rune(string(body)))
if !b.isEmpty && bodySize < 1 {
ctx.Response().SetStatus(500, "Internal Server Error")
ctx.Response().SetBody([]byte("请求体不能为空请检查body的参数情况"))
return errors.New("Body is required")
}
if b.allowedPayloadSize > 0 && bodySize > b.allowedPayloadSize {
ctx.Response().SetStatus(500, "Internal Server Error")
ctx.Response().SetBody([]byte("请求体超出长度限制"))
return errors.New("The request entity is too large")
}
}
return next.DoChain(ctx)
}
func (b *BodyCheck) Start() error {
return nil
}
func (b *BodyCheck) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
cfg, ok := conf.(*Config)
if !ok {
return errors.New("invalid config")
}
b.isEmpty = cfg.IsEmpty
b.allowedPayloadSize = cfg.AllowedPayloadSize * 1024
return nil
}
func (b *BodyCheck) Stop() error {
return nil
}
func (b *BodyCheck) Destroy() {
return
}
func (b *BodyCheck) CheckSkill(skill string) bool {
return http_service.FilterSkillName == skill
}

View File

@@ -0,0 +1,22 @@
package body_check
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
type Config struct {
IsEmpty bool `json:"is_empty" label:"是否允许为空"`
AllowedPayloadSize int `json:"allowed_payload_size" label:"允许的最大请求体大小"`
}
func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
bc := &BodyCheck{
WorkerBase: drivers.Worker(id, name),
isEmpty: conf.IsEmpty,
allowedPayloadSize: conf.AllowedPayloadSize * 1024,
}
return bc, nil
}

View File

@@ -0,0 +1,18 @@
package body_check
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
const (
Name = "body_check"
)
func Register(register eosc.IExtenderDriverRegister) {
register.RegisterExtenderDriver(Name, NewFactory())
}
func NewFactory() eosc.IExtenderDriverFactory {
return drivers.NewFactory[Config](Create)
}

View File

@@ -0,0 +1,41 @@
package counter
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/apinto/drivers/plugins/counter/separator"
"github.com/eolinker/eosc"
)
type Config struct {
Match *Match `json:"match" label:"响应匹配规则"`
Count *separator.CountRule `json:"count" label:"计数规则"`
Key string `json:"key" label:"计数字段名称"`
Cache eosc.RequireId `json:"cache" label:"缓存计数器"`
}
type Match struct {
Params []*MatchParam `json:"params" label:"匹配参数列表"`
StatusCodes []int `json:"status_codes" label:"匹配响应状态码列表"`
Type string `json:"type" label:"匹配类型" enum:"json"`
}
func (m *Match) GenerateHandler() []IMatcher {
matcher := make([]IMatcher, 0, 2)
matcher = append(matcher, newStatusCodeMatcher(m.StatusCodes))
matcher = append(matcher, newJsonMatcher(m.Params))
return matcher
}
func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
counter, err := separator.GetCounter(conf.Count)
if err != nil {
return nil, err
}
bc := &executor{
WorkerBase: drivers.Worker(id, name),
matchers: conf.Match.GenerateHandler(),
separatorCounter: counter,
}
return bc, nil
}

View File

@@ -0,0 +1,5 @@
package counter
type IClient interface {
Get(key string) (int64, error)
}

View File

@@ -0,0 +1,79 @@
package counter
import (
"math/rand"
"sync"
"testing"
"time"
"github.com/go-redis/redis/v8"
)
func TestLocalCounter(t *testing.T) {
client := NewHTTPClient("", nil)
lc := NewLocalCounter("test", client)
wg := sync.WaitGroup{}
wg.Add(100)
for i := 0; i < 100; i++ {
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
go func() {
defer wg.Done()
count := rand.Int63n(20)
err := lc.Lock(count)
if err != nil {
t.Error(err)
return
}
condition := rand.Intn(100)
switch condition % 2 {
case 0:
err = lc.Complete(count)
case 1:
err = lc.RollBack(count)
}
if err != nil {
t.Error(err)
return
}
}()
}
wg.Wait()
}
func TestRedisCounter(t *testing.T) {
client := NewHTTPClient("", nil)
key := "apinto-apiddww"
lc := NewLocalCounter(key, client)
redisConn := redis.NewClient(&redis.Options{
Addr: "172.18.65.42:6380",
Password: "password", // 如果有密码,请填写密码
DB: 9, // 选择数据库默认为0
})
rc := NewRedisCounter(key, redisConn, client, lc)
wg := sync.WaitGroup{}
wg.Add(100)
for i := 0; i < 100; i++ {
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
go func() {
defer wg.Done()
count := rand.Int63n(20)
err := rc.Lock(count)
if err != nil {
t.Error(err)
return
}
condition := rand.Intn(100)
switch condition % 2 {
case 0:
err = rc.Complete(count)
case 1:
err = rc.RollBack(count)
}
if err != nil {
t.Error(err)
return
}
}()
}
wg.Wait()
}

View File

@@ -0,0 +1,26 @@
package counter
import "fmt"
type ICounter interface {
// Lock 锁定次数
Lock(count int64) error
// Complete 完成扣次操作
Complete(count int64) error
// RollBack 回滚
RollBack(count int64) error
// ResetClient 重置客户端
ResetClient(client IClient)
}
func getRemainCount(client IClient, key string, count int64) (int64, error) {
remain, err := client.Get(key)
if err != nil {
return 0, err
}
remain -= count
if remain < 0 {
return 0, fmt.Errorf("no enough, key:%s, remain:%d, count:%d", key, remain, count)
}
return remain, nil
}

View File

@@ -0,0 +1,58 @@
package counter
import (
"fmt"
"github.com/ohler55/ojg/oj"
"github.com/ohler55/ojg/jp"
"github.com/valyala/fasthttp"
)
var _ IClient = (*HTTPClient)(nil)
var httpClient = fasthttp.Client{
Name: "apinto-counter",
}
type HTTPClient struct {
uri string
headers map[string]string
// jsonExpr 经过编译的JSONPath表达式
jsonExpr jp.Expr
}
func NewHTTPClient(uri string, jsonExpr jp.Expr) *HTTPClient {
return &HTTPClient{uri: uri, jsonExpr: jsonExpr}
}
func (H *HTTPClient) Get(key string) (int64, error) {
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
req.SetRequestURI(H.uri)
req.Header.SetMethod("GET")
for name, value := range H.headers {
req.Header.Set(name, value)
}
err := httpClient.Do(req, resp)
if err != nil {
return 0, err
}
if resp.StatusCode() != fasthttp.StatusOK {
return 0, fmt.Errorf("error http status code: %d,key: %s,uri %s", resp.StatusCode(), key, H.uri)
}
result, err := oj.Parse(resp.Body())
if err != nil {
return 0, err
}
// 解析JSON
v := H.jsonExpr.Get(result)
if v == nil || len(v) < 1 {
return 0, fmt.Errorf("no found key: %s,uri: %s", key)
}
if len(v) != 1 {
return 0, fmt.Errorf("invalid value: %v,key: %s,uri: %s", v, key, H.uri)
}
return v[0].(int64), nil
}

View File

@@ -0,0 +1,76 @@
package counter
import (
"fmt"
"sync"
"time"
)
var _ ICounter = (*LocalCounter)(nil)
func NewLocalCounter(key string, client IClient) *LocalCounter {
return &LocalCounter{key: key, client: client}
}
// LocalCounter 本地计数器
type LocalCounter struct {
key string
// 剩余次数
remain int64
// 锁定次数
lock int64
locker sync.Mutex
resetTime time.Time
client IClient
}
func (c *LocalCounter) Lock(count int64) error {
c.locker.Lock()
defer c.locker.Unlock()
remain := c.remain - count
if remain < 0 {
now := time.Now()
if now.Sub(c.resetTime) < 10*time.Second {
return fmt.Errorf("no enough, key:%s, remain:%d, count:%d", c.key, c.remain, count)
}
var err error
c.resetTime = now
// 获取最新的次数
remain, err = getRemainCount(c.client, c.key, count)
if err != nil {
return err
}
}
c.remain = remain
c.lock += count
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "lock", "remain:", c.remain, ",lock:", c.lock, ",count:", count)
return nil
}
func (c *LocalCounter) Complete(count int64) error {
c.locker.Lock()
defer c.locker.Unlock()
// 需要解除已经锁定的部分次数
c.lock -= count
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "complete", "remain:", c.remain, ",lock:", c.lock, ",count:", count)
return nil
}
func (c *LocalCounter) RollBack(count int64) error {
c.locker.Lock()
defer c.locker.Unlock()
// 需要解除已经锁定的部分次数,并且增加剩余次数
c.remain += c.lock
c.lock -= count
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "rollback", "remain:", c.remain, ",lock:", c.lock, ",count:", count)
return nil
}
func (c *LocalCounter) ResetClient(client IClient) {
c.locker.Lock()
defer c.locker.Unlock()
c.client = client
}

View File

@@ -0,0 +1,171 @@
package counter
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-redis/redis/v8"
)
var _ ICounter = (*RedisCounter)(nil)
type RedisCounter struct {
ctx context.Context
key string
redis redis.Cmdable
client IClient
locker sync.Mutex
resetTime time.Time
localCounter ICounter
lockerKey string
lockKey string
remainKey string
}
func NewRedisCounter(key string, redis redis.Cmdable, client IClient, localCounter ICounter) *RedisCounter {
return &RedisCounter{
key: key,
redis: redis,
client: client,
localCounter: localCounter,
ctx: context.Background(),
lockerKey: fmt.Sprintf("%s:locker", key),
lockKey: fmt.Sprintf("%s:lock", key),
remainKey: fmt.Sprintf("%s:remain", key),
}
}
func (r *RedisCounter) Lock(count int64) error {
r.locker.Lock()
defer r.locker.Unlock()
if r.redis == nil {
// 如果Redis没有配置使用本地计数器
return r.localCounter.Lock(count)
}
err := r.acquireLock()
if err != nil {
if err == redis.ErrClosed {
return r.localCounter.Lock(count)
}
return err
}
defer r.releaseLock()
// 获取最新的次数
remain, err := r.redis.Get(r.ctx, r.remainKey).Int64()
if err != nil {
if err == redis.ErrClosed {
return r.localCounter.Lock(count)
} else if err != redis.Nil {
return err
}
}
remain -= count
if remain < 0 {
now := time.Now()
if now.Sub(r.resetTime) < 10*time.Second {
return fmt.Errorf("no enough, ddd key:%s, remain:%d, count:%d", r.key, remain+count, count)
}
r.resetTime = now
lock, err := r.redis.Get(r.ctx, r.lockKey).Int64()
if err != nil {
if err == redis.ErrClosed {
return r.localCounter.Lock(count)
} else if err != redis.Nil {
return err
}
}
remain, err = getRemainCount(r.client, r.key, count+lock)
if err != nil {
return err
}
}
r.redis.Set(r.ctx, r.remainKey, remain, -1)
r.redis.IncrBy(r.ctx, r.lockKey, count)
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "lock", "remain:", remain, "count:", count)
return nil
}
func (r *RedisCounter) Complete(count int64) error {
r.locker.Lock()
defer r.locker.Unlock()
if r.redis == nil {
// 如果Redis没有配置使用本地计数器
return r.localCounter.Complete(count)
}
err := r.acquireLock()
if err != nil {
if err == redis.ErrClosed {
return r.localCounter.Lock(count)
}
return err
}
defer r.releaseLock()
r.redis.IncrBy(r.ctx, r.lockKey, -count)
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "complete", "count:", count)
return nil
}
func (r *RedisCounter) RollBack(count int64) error {
r.locker.Lock()
defer r.locker.Unlock()
if r.redis == nil {
// 如果Redis没有配置使用本地计数器
return r.localCounter.RollBack(count)
}
err := r.acquireLock()
if err != nil {
if err == redis.ErrClosed {
return r.localCounter.Lock(count)
}
return err
}
defer r.releaseLock()
r.redis.IncrBy(r.ctx, r.remainKey, count)
r.redis.IncrBy(r.ctx, r.lockKey, -count)
fmt.Println(time.Now().Format("2006-01-02 15:04:05"), "rollback", "count:", count)
return nil
}
func (r *RedisCounter) ResetClient(client IClient) {
r.locker.Lock()
defer r.locker.Unlock()
r.client = client
}
func (r *RedisCounter) acquireLock() error {
for {
// 生成唯一的锁值
lockValue := time.Now().UnixNano()
// Redis连接失败使用本地计数器
ok, err := r.redis.SetNX(r.ctx, r.lockerKey, lockValue, 10*time.Second).Result()
if err != nil {
return err
}
if ok {
// 设置锁成功
break
}
}
return nil
}
func (r *RedisCounter) releaseLock() error {
return r.redis.Del(r.ctx, r.lockerKey).Err()
}

View File

@@ -0,0 +1,113 @@
package counter
import (
"fmt"
"github.com/eolinker/apinto/drivers/plugins/counter/counter"
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/apinto/drivers/plugins/counter/separator"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
var _ http_service.HttpFilter = (*executor)(nil)
var _ eocontext.IFilter = (*executor)(nil)
type executor struct {
drivers.WorkerBase
matchers []IMatcher
separatorCounter separator.ICounter
counters eosc.Untyped[string, counter.ICounter]
client counter.IClient
keyGenerate IKeyGenerator
}
func (b *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
return http_service.DoHttpFilter(b, ctx, next)
}
func (b *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error {
counter, has := b.counters.Get(b.keyGenerate.Key(ctx))
if !has {
}
var count int64 = 1
var err error
if b.separatorCounter != nil {
separatorCounter := b.separatorCounter
count, err = separatorCounter.Count(ctx)
if err != nil {
ctx.Response().SetStatus(400, "400")
return fmt.Errorf("%s count error", separatorCounter.Name())
}
if count > separatorCounter.Max() {
ctx.Response().SetStatus(403, "not allow")
return fmt.Errorf("%s number exceed", separatorCounter.Name())
} else if count == 0 {
ctx.Response().SetStatus(400, "400")
return fmt.Errorf("%s value is missing", separatorCounter.Name())
}
}
err = counter.Lock(count)
if err != nil {
// 次数不足,直接返回
//return fmt.Errorf("no enough, key:%s, remain:%d, count:%d", b.counters.Name(), b.counters.Remain(), count
}
if next != nil {
err = next.DoChain(ctx)
if err != nil {
// 转发失败,回滚次数
return counter.RollBack(count)
//return err
}
}
match := true
for _, matcher := range b.matchers {
ok := matcher.Match(ctx)
if !ok {
match = false
break
}
}
if match {
// 匹配,扣减次数
return counter.Complete(count)
}
// 不匹配,回滚次数
return counter.RollBack(count)
}
func (b *executor) Start() error {
return nil
}
func (b *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
cfg, ok := conf.(*Config)
if !ok {
return fmt.Errorf("invalid config, driver: %s", Name)
}
counter, err := separator.GetCounter(cfg.Count)
if err != nil {
return err
}
b.separatorCounter = counter
b.matchers = cfg.Match.GenerateHandler()
return nil
}
func (b *executor) Stop() error {
return nil
}
func (b *executor) Destroy() {
return
}
func (b *executor) CheckSkill(skill string) bool {
return http_service.FilterSkillName == skill
}

View File

@@ -0,0 +1,18 @@
package counter
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
const (
Name = "counter"
)
func Register(register eosc.IExtenderDriverRegister) {
register.RegisterExtenderDriver(Name, NewFactory())
}
func NewFactory() eosc.IExtenderDriverFactory {
return drivers.NewFactory[Config](Create)
}

View File

@@ -0,0 +1,49 @@
package counter
import (
"fmt"
"strings"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
var _ IKeyGenerator = (*keyGenerate)(nil)
type IKeyGenerator interface {
Key(ctx http_service.IHttpContext) string
}
func newKeyGenerate(key string) *keyGenerate {
key = strings.TrimSpace(key)
tmp := strings.Split(key, ":")
keys := make([]string, 0, len(tmp))
variables := make([]string, 0, len(tmp))
for _, t := range tmp {
t = strings.TrimSpace(t)
tLen := len(t)
if tLen > 0 {
if tLen > 1 && t[0] == '$' {
variables = append(variables, t[1:])
keys = append(keys, "%s")
} else {
keys = append(keys, t)
}
}
}
return &keyGenerate{format: strings.Join(keys, ":"), variables: variables}
}
type keyGenerate struct {
format string
// 变量列表
variables []string
}
func (k *keyGenerate) Key(ctx http_service.IHttpContext) string {
variables := make([]interface{}, 0, len(k.variables))
for _, v := range k.variables {
variables = append(variables, ctx.GetLabel(v))
}
return fmt.Sprintf(k.format, variables...)
}

View File

@@ -0,0 +1,8 @@
package counter
import "testing"
func TestGenerateKey(t *testing.T) {
key := newKeyGenerate("a:$b:$c:d")
t.Log(key.format, key.variables)
}

View File

@@ -0,0 +1,121 @@
package counter
import (
"strconv"
"strings"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"github.com/eolinker/eosc/log"
"github.com/ohler55/ojg/jp"
"github.com/ohler55/ojg/oj"
)
type IMatcher interface {
Match(ctx http_service.IHttpContext) bool
}
type MatchParam struct {
Key string `json:"key"`
Kind string `json:"kind"` // int|string|bool
Value []string `json:"value"`
}
func newJsonMatcher(params []*MatchParam) *jsonMatcher {
ps := make([]*jsonMatchParam, 0, len(params))
for _, p := range params {
key := p.Key
if !strings.HasPrefix(p.Key, "$.") {
key = "$." + p.Key
}
expr, err := jp.ParseString(key)
if err != nil {
log.Errorf("json path parse error: %v,key is %s", err, key)
continue
}
ps = append(ps, &jsonMatchParam{
MatchParam: p,
expr: expr,
})
}
return &jsonMatcher{params: ps}
}
type jsonMatcher struct {
params []*jsonMatchParam
}
type jsonMatchParam struct {
*MatchParam
expr jp.Expr
}
func (m *jsonMatcher) Match(ctx http_service.IHttpContext) bool {
if len(m.params) < 1 {
return true
}
body := ctx.Response().GetBody()
tmp, err := oj.Parse(body)
if err != nil {
log.Errorf("parse body error: %v,body is %s", err, body)
return true
}
for _, p := range m.params {
results := p.expr.Get(tmp)
for _, v := range p.Value {
for _, r := range results {
switch p.Kind {
case "int":
t, ok := r.(int64)
if !ok {
return false
}
val, _ := strconv.ParseInt(v, 10, 64)
if t == val {
return true
}
case "bool":
t, ok := r.(bool)
if !ok {
return false
}
val, err := strconv.ParseBool(v)
if err != nil {
return false
}
return t == val
default:
t, ok := r.(string)
if !ok {
return false
}
if t == v {
return true
}
}
}
}
}
return false
}
func newStatusCodeMatcher(codes []int) *statusCodeMatcher {
return &statusCodeMatcher{codes: codes}
}
type statusCodeMatcher struct {
codes []int
}
func (m *statusCodeMatcher) Match(ctx http_service.IHttpContext) bool {
if len(m.codes) < 1 {
return true
}
code := ctx.Response().StatusCode()
for _, c := range m.codes {
if c == code {
return true
}
}
return false
}

View File

@@ -0,0 +1,67 @@
package separator
import (
"strings"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
var (
ArrayCountType = "array"
SplitCountType = "splite"
LengthCountType = "length"
CountTypes = []string{
ArrayCountType,
SplitCountType,
LengthCountType,
}
)
type CountRule struct {
RequestBodyType string `json:"request_body_type" label:"请求体类型" enum:"form-data,json"`
Key string `json:"key" label:"参数名称支持json path"`
Separator string `json:"separator" label:"分隔符" switch:"separator_type===splite"`
SeparatorType string `json:"separator_type" label:"分割类型" enum:"splite,array,length"`
Max int64 `json:"max" label:"计数最大值"`
}
type ICounter interface {
Count(ctx http_service.IHttpContext) (int64, error)
Max() int64
Name() string
}
func GetCounter(rule *CountRule) (ICounter, error) {
switch strings.ToLower(rule.RequestBodyType) {
case "form-data":
return NewFormDataCounter(rule)
case "multipart-formdata":
return NewFileCounter(rule)
case "json":
return NewJsonCounter(rule)
default:
return NewFormDataCounter(rule)
}
}
func splitCount(origin string, split string) int64 {
if len(split) == 0 {
return 0
}
vs := strings.Split(origin, string(split[0]))
var count int64 = 0
for _, v := range vs {
if v != "" {
childCount := splitCount(v, split[1:])
if childCount == 0 {
count += 1
} else {
count += childCount
}
}
}
return count
}

View File

@@ -0,0 +1,74 @@
package separator
import (
"bytes"
"fmt"
"io"
"mime"
"mime/multipart"
"strconv"
"strings"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
var _ ICounter = (*FileCounter)(nil)
const defaultMultipartMemory = 32 << 20 // 32 MB
type FileCounter struct {
typ string
split string
name string
max int64
splitLen int
}
func (f *FileCounter) Name() string {
return f.name
}
func NewFileCounter(rule *CountRule) (*FileCounter, error) {
var splitLen int
if rule.SeparatorType == LengthCountType {
var err error
splitLen, err = strconv.Atoi(rule.Separator)
if err != nil {
splitLen = 1000
}
}
return &FileCounter{name: rule.Key, split: rule.Separator, typ: rule.SeparatorType, max: rule.Max, splitLen: splitLen}, nil
}
func (f *FileCounter) Count(ctx http_service.IHttpContext) (int64, error) {
raw, _ := ctx.Request().Body().RawBody()
d, params, _ := mime.ParseMediaType(ctx.Request().ContentType())
if !(d == "multipart/form-data") {
return -1, fmt.Errorf("need content-type: multipart/form-data,now: %s", d)
}
boundary, ok := params["boundary"]
if !ok {
return -1, fmt.Errorf("missing boundary")
}
body := io.NopCloser(bytes.NewBuffer(raw))
reader := multipart.NewReader(body, boundary)
form, err := reader.ReadForm(defaultMultipartMemory)
if err != nil {
return -1, fmt.Errorf("parse form param err: %v", err)
}
switch f.typ {
case LengthCountType:
value := strings.Join(form.Value[f.name], "")
l := len([]rune(value))
if l%f.splitLen == 0 {
return int64(l / f.splitLen), nil
}
return int64(l/f.splitLen + 1), nil
}
return splitCount(strings.Join(form.Value[f.name], f.split), f.split), nil
}
func (f *FileCounter) Max() int64 {
return f.max
}

View File

@@ -0,0 +1,59 @@
package separator
import (
"fmt"
"net/url"
"strconv"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
var _ ICounter = (*FormDataCounter)(nil)
type FormDataCounter struct {
typ string
split string
name string
max int64
splitLen int
}
func (f *FormDataCounter) Name() string {
return f.name
}
func NewFormDataCounter(rule *CountRule) (*FormDataCounter, error) {
var splitLen int
if rule.SeparatorType == LengthCountType {
var err error
splitLen, err = strconv.Atoi(rule.Separator)
if err != nil {
splitLen = 1000
}
}
return &FormDataCounter{name: rule.Key, split: rule.Separator, typ: rule.SeparatorType, max: rule.Max, splitLen: splitLen}, nil
}
func (f *FormDataCounter) Count(ctx http_service.IHttpContext) (int64, error) {
body, _ := ctx.Request().Body().RawBody()
u, err := url.ParseQuery(string(body))
if err != nil {
return -1, fmt.Errorf("parse form data error:%v", err)
}
switch f.typ {
case SplitCountType:
return splitCount(u.Get(f.name), f.split), nil
case LengthCountType:
value := u.Get(f.name)
l := len([]rune(value))
if l%f.splitLen == 0 {
return int64(l / f.splitLen), nil
}
return int64(l/f.splitLen + 1), nil
}
return splitCount(u.Get(f.name), f.split), nil
}
func (f *FormDataCounter) Max() int64 {
return f.max
}

View File

@@ -0,0 +1,104 @@
package separator
import (
"fmt"
"strconv"
"strings"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"github.com/ohler55/ojg/oj"
"github.com/ohler55/ojg/jp"
)
var _ ICounter = (*JsonCounter)(nil)
type JsonCounter struct {
max int64
split string
expr jp.Expr
typ string
name string
splitLen int
}
func (j *JsonCounter) Name() string {
return j.name
}
func NewJsonCounter(rule *CountRule) (*JsonCounter, error) {
typeValid := false
for _, t := range CountTypes {
if t == rule.SeparatorType {
typeValid = true
break
}
}
if !typeValid {
return nil, fmt.Errorf("json count split type config error,now type is %s, need array or split", rule.SeparatorType)
}
expr, err := jp.ParseString(rule.Key)
if err != nil {
return nil, fmt.Errorf("json path parse error:%v", err)
}
if rule.Max < 1 || rule.Max > 2000 {
rule.Max = 2000
}
var splitLen int
if rule.SeparatorType == LengthCountType {
splitLen, err = strconv.Atoi(rule.Separator)
if err != nil {
splitLen = 1000
}
}
return &JsonCounter{max: rule.Max, split: rule.Separator, expr: expr, typ: rule.SeparatorType, name: strings.TrimPrefix(rule.Key, "$."), splitLen: splitLen}, nil
}
func (j *JsonCounter) Count(ctx http_service.IHttpContext) (int64, error) {
body, _ := ctx.Request().Body().RawBody()
obj, err := oj.Parse(body)
if err != nil {
return -1, fmt.Errorf("parse json body error:%v, body is %s", err, string(body))
}
results := j.expr.Get(obj)
switch j.typ {
case SplitCountType:
if len(results) > 0 {
origin, ok := results[0].(string)
if !ok {
return -1, fmt.Errorf("json path %s get value is not string", j.name)
}
return splitCount(origin, j.split), nil
}
case ArrayCountType:
if len(results) > 0 {
switch v := results[0].(type) {
case []interface{}:
{
return int64(len(v)), nil
}
case map[string]interface{}:
return int64(len(v)), nil
}
}
case LengthCountType:
if len(results) > 0 {
origin, ok := results[0].(string)
if !ok {
return -1, fmt.Errorf("json path %s get value is not string", j.name)
}
l := len([]rune(origin))
if l%j.splitLen == 0 {
return int64(l / j.splitLen), nil
}
return int64(l/j.splitLen + 1), nil
}
}
return -1, nil
}
func (j *JsonCounter) Max() int64 {
return j.max
}

View File

@@ -0,0 +1,120 @@
package extra_params_v2
import (
"fmt"
"strings"
dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
type Config struct {
Params []*ExtraParam `json:"params" label:"参数列表"`
RequestBodyType string `json:"request_body_type" enum:"form-data,json" label:"请求体类型"`
ErrorType string `json:"error_type" enum:"text,json" label:"报错输出格式"`
}
func (c *Config) doCheck() error {
c.ErrorType = strings.ToLower(c.ErrorType)
if c.ErrorType != "text" && c.ErrorType != "json" {
c.ErrorType = "text"
}
for _, param := range c.Params {
if param.Name == "" {
return fmt.Errorf(paramNameErrInfo)
}
param.Position = strings.ToLower(param.Position)
if param.Position != "query" && param.Position != "header" && param.Position != "body" {
return fmt.Errorf(paramPositionErrInfo, param.Position)
}
param.Conflict = strings.ToLower(param.Conflict)
if param.Conflict != paramOrigin && param.Conflict != paramConvert && param.Conflict != paramError {
param.Conflict = paramConvert
}
}
return nil
}
type ExtraParam struct {
Name string `json:"name" label:"参数名"`
Type string `json:"type" label:"参数类型" enum:"string,int,float,bool,$datetime,$md5"`
Position string `json:"position" enum:"header,query,body" label:"参数位置"`
Value []string `json:"value" label:"参数值列表"`
Conflict string `json:"conflict" label:"参数冲突时的处理方式" enum:"origin,convert,error"`
}
type baseParam struct {
header []*paramInfo
query []*paramInfo
body []*paramInfo
}
func generateBaseParam(params []*ExtraParam) *baseParam {
b := &baseParam{
header: make([]*paramInfo, 0),
query: make([]*paramInfo, 0),
body: make([]*paramInfo, 0),
}
for _, param := range params {
switch param.Position {
case positionHeader:
b.header = append(b.header, newParamInfo(param.Name, param.Value, param.Type, param.Conflict))
case positionQuery:
b.query = append(b.query, newParamInfo(param.Name, param.Value, param.Type, param.Conflict))
case positionBody:
b.body = append(b.body, newParamInfo(param.Name, param.Value, param.Type, param.Conflict))
}
}
return b
}
func newParamInfo(name string, value []string, typ string, conflict string) *paramInfo {
d := &paramInfo{name: name, value: strings.Join(value, ","), conflict: conflict}
valueLen := len(d.value)
if strings.HasPrefix(typ, "$") {
factory, has := dynamic_params.Get(typ)
if has {
driver, err := factory.Create(name, value)
if err == nil {
d.driver = driver
}
}
} else if valueLen > 1 && d.value[0] == '$' {
// 系统变量
d.systemValue = true
d.value = d.value[1:valueLen]
}
return d
}
type paramInfo struct {
name string
systemValue bool
value string
driver dynamic_params.IDynamicDriver
conflict string
}
func (b *paramInfo) Build(ctx http_service.IHttpContext, contentType string, params interface{}) (string, error) {
if b.driver == nil {
if b.systemValue {
return ctx.GetLabel(b.value), nil
}
return b.value, nil
}
value, err := b.driver.Generate(ctx, contentType, params)
if err != nil {
return "", err
}
switch v := value.(type) {
case string:
return v, nil
case int, int32, int64:
return fmt.Sprintf("%d", v), nil
}
return "", nil
}

View File

@@ -0,0 +1,60 @@
package extra_params_v2
import (
"reflect"
"sync"
"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/datetime"
"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/md5"
"github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params/timestamp"
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
var (
once sync.Once
)
type Driver struct {
profession string
name string
label string
desc string
configType reflect.Type
}
func Check(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error {
return conf.doCheck()
}
func check(v interface{}) (*Config, error) {
conf, ok := v.(*Config)
if !ok {
return nil, eosc.ErrorConfigType
}
err := conf.doCheck()
if err != nil {
return nil, err
}
return conf, nil
}
func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
once.Do(func() {
datetime.Register()
md5.Register()
timestamp.Register()
})
ep := &executor{
WorkerBase: drivers.Worker(id, name),
baseParam: generateBaseParam(conf.Params),
requestBodyType: conf.RequestBodyType,
errorType: conf.ErrorType,
}
return ep, nil
}

View File

@@ -0,0 +1,28 @@
package datetime
import (
"time"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
const (
defaultValue = "2006-01-02 15:04:05"
)
type Datetime struct {
name string
value string
}
func NewDatetime(name string, value string) *Datetime {
return &Datetime{name: name, value: value}
}
func (d *Datetime) Name() string {
return d.name
}
func (d *Datetime) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) {
return time.Now().Format(d.value), nil
}

View File

@@ -0,0 +1,24 @@
package datetime
import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
const name = "$datetime"
func Register() {
dynamic_params.Register(name, NewFactory())
}
func NewFactory() *Factory {
return &Factory{}
}
type Factory struct {
}
func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) {
v := defaultValue
if len(value) > 0 {
v = value[0]
}
return NewDatetime(name, v), nil
}

View File

@@ -0,0 +1,14 @@
package dynamic_params
import (
"github.com/eolinker/eosc/eocontext/http-context"
)
type IDynamicFactory interface {
Create(name string, value []string) (IDynamicDriver, error)
}
type IDynamicDriver interface {
Name() string
Generate(ctx http_context.IHttpContext, contentType string, args ...interface{}) (interface{}, error)
}

View File

@@ -0,0 +1,95 @@
package dynamic_params
import (
"errors"
"fmt"
"github.com/eolinker/eosc/log"
"github.com/eolinker/eosc"
)
var (
ErrorInvalidBalance = errors.New("invalid balance")
defaultFactoryRegister = newFactoryManager()
)
// IFactoryRegister 实现了负载均衡算法工厂管理器
type IFactoryRegister interface {
RegisterFactoryByKey(key string, factory IDynamicFactory)
GetFactoryByKey(key string) (IDynamicFactory, bool)
Keys() []string
}
// driverRegister 实现了IBalanceFactoryRegister接口
type driverRegister struct {
register eosc.IRegister[IDynamicFactory]
keys []string
}
// newFactoryManager 创建负载均衡算法工厂管理器
func newFactoryManager() IFactoryRegister {
return &driverRegister{
register: eosc.NewRegister[IDynamicFactory](),
keys: make([]string, 0, 10),
}
}
// GetFactoryByKey 获取指定balance工厂
func (dm *driverRegister) GetFactoryByKey(key string) (IDynamicFactory, bool) {
o, has := dm.register.Get(key)
if has {
log.Debug("GetFactoryByKey:", key, ":has")
f, ok := o.(IDynamicFactory)
return f, ok
}
return nil, false
}
// RegisterFactoryByKey 注册balance工厂
func (dm *driverRegister) RegisterFactoryByKey(key string, factory IDynamicFactory) {
err := dm.register.Register(key, factory, true)
if err != nil {
log.Debug("RegisterFactoryByKey:", key, ":", err)
return
}
dm.keys = append(dm.keys, key)
}
// Keys 返回所有已注册的key
func (dm *driverRegister) Keys() []string {
return dm.keys
}
// Register 注册balance工厂到默认balanceFactory注册器
func Register(key string, factory IDynamicFactory) {
defaultFactoryRegister.RegisterFactoryByKey(key, factory)
}
// Get 从默认balanceFactory注册器中获取balance工厂
func Get(key string) (IDynamicFactory, bool) {
return defaultFactoryRegister.GetFactoryByKey(key)
}
// Keys 返回默认的balanceFactory注册器中所有已注册的key
func Keys() []string {
return defaultFactoryRegister.Keys()
}
// GetFactory 获取指定负载均衡算法工厂,若指定的不存在则返回一个已注册的工厂
func GetFactory(name string) (IDynamicFactory, error) {
factory, ok := Get(name)
if !ok {
for _, key := range Keys() {
factory, ok = Get(key)
if ok {
break
}
}
if factory == nil {
return nil, fmt.Errorf("%s:%w", name, ErrorInvalidBalance)
}
}
return factory, nil
}

View File

@@ -0,0 +1,20 @@
package md5
import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
const name = "$md5"
func Register() {
dynamic_params.Register(name, NewFactory())
}
func NewFactory() *Factory {
return &Factory{}
}
type Factory struct {
}
func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) {
return NewMD5(name, value), nil
}

View File

@@ -0,0 +1,178 @@
package md5
import (
"errors"
"fmt"
"strconv"
"strings"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"github.com/ohler55/ojg/oj"
"github.com/ohler55/ojg/jp"
"github.com/eolinker/apinto/utils"
"github.com/eolinker/eosc/log"
)
const (
positionCurrent = iota
positionHeader
positionQuery
// body
positionBody
positionSystem
)
type MD5 struct {
name string
value []*Value
}
type Value struct {
key string
position int
optional bool
}
func NewMD5(name string, value []string) *MD5 {
vs := make([]*Value, 0, len(value))
for _, v := range value {
v = strings.TrimSpace(v)
vLen := len(v)
if vLen > 0 {
if v[0] == '{' && v[vLen-1] == '}' {
vars := strings.Split(v[1:vLen-1], ".")
position := positionBody
variable := vars[0]
if len(vars) > 1 {
variable = vars[1]
switch vars[0] {
case "header":
position = positionHeader
case "query":
position = positionQuery
}
}
vs = append(vs, &Value{
key: variable,
position: position,
})
} else if v[0] == '#' {
vars := strings.Split(v[1:], ".")
position := positionBody
variable := vars[0]
if len(vars) > 1 {
variable = vars[1]
switch vars[0] {
case "header":
position = positionHeader
case "query":
position = positionQuery
}
}
vs = append(vs, &Value{
key: variable,
position: position,
optional: true,
})
} else if vLen > 3 && v[0] == '$' && v[1] == '{' && v[vLen-1] == '}' {
// 使用系统变量
vs = append(vs, &Value{
key: v[2 : vLen-1],
position: positionSystem,
})
} else {
vs = append(vs, &Value{
key: v,
})
}
}
}
return &MD5{
name: name,
value: vs,
}
}
func (m *MD5) Name() string {
return m.name
}
func (m *MD5) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) {
builder := strings.Builder{}
var params interface{}
if contentType == "application/json" {
if len(args) < 1 {
return nil, errors.New("missing args")
}
params = args[0]
}
for _, v := range m.value {
if v.key == "" {
continue
}
builder.WriteString(retrieveParam(ctx, contentType, params, v))
}
log.Debug("md5 result: ", builder.String())
if strings.HasPrefix(m.name, "__") {
return utils.Md5(builder.String()), nil
}
return strings.ToUpper(utils.Md5(builder.String())), nil
}
func retrieveParam(ctx http_service.IHttpContext, contentType string, body interface{}, value *Value) string {
switch value.position {
case positionCurrent:
return value.key
case positionHeader:
return ctx.Proxy().Header().Headers().Get(value.key)
case positionQuery:
return ctx.Proxy().URI().GetQuery(value.key)
case positionBody:
if contentType == "application/x-www-form-urlencoded" {
if !value.optional {
return ctx.Proxy().Body().GetForm(value.key)
}
form, _ := ctx.Proxy().Body().BodyForm()
if _, ok := form[value.key]; ok {
return value.key
}
} else if contentType == "application/json" {
key := value.key
if !strings.HasPrefix(key, "$.") {
key = "$." + key
}
x, err := jp.ParseString(key)
if err != nil {
log.Errorf("parse json path(%s) error: %v", key, err)
return ""
}
result := x.Get(body)
if len(result) > 0 {
if value.optional {
return value.key
}
switch r := result[0].(type) {
case string:
return r
case float32, float64:
return fmt.Sprintf("%.f", r)
case bool:
return strconv.FormatBool(r)
default:
return oj.JSON(r)
}
}
}
case positionSystem:
return ctx.GetLabel(value.key)
}
return ""
}

View File

@@ -0,0 +1,24 @@
package timestamp
import dynamic_params "github.com/eolinker/apinto/drivers/plugins/extra-params_v2/dynamic-params"
const name = "$timestamp"
func Register() {
dynamic_params.Register(name, NewFactory())
}
func NewFactory() *Factory {
return &Factory{}
}
type Factory struct {
}
func (f *Factory) Create(name string, value []string) (dynamic_params.IDynamicDriver, error) {
v := defaultValue
if len(value) > 0 {
v = value[0]
}
return NewTimestamp(name, v), nil
}

View File

@@ -0,0 +1,35 @@
package timestamp
import (
"strconv"
"time"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
const (
defaultValue = "string"
)
type Timestamp struct {
name string
value string
}
func NewTimestamp(name, value string) *Timestamp {
return &Timestamp{name: name, value: value}
}
func (t *Timestamp) Name() string {
return t.name
}
func (t *Timestamp) Generate(ctx http_service.IHttpContext, contentType string, args ...interface{}) (interface{}, error) {
switch t.value {
case "string":
return strconv.FormatInt(time.Now().Unix(), 10), nil
case "int":
return time.Now().Unix(), nil
}
return strconv.FormatInt(time.Now().Unix(), 10), nil
}

View File

@@ -0,0 +1,212 @@
package extra_params_v2
import (
"fmt"
"mime"
"net/http"
"net/textproto"
"strconv"
"strings"
"github.com/ohler55/ojg/oj"
"github.com/eolinker/eosc/log"
"github.com/eolinker/apinto/drivers"
"github.com/ohler55/ojg/jp"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
http_service "github.com/eolinker/eosc/eocontext/http-context"
)
var _ http_service.HttpFilter = (*executor)(nil)
var _ eocontext.IFilter = (*executor)(nil)
var (
errorExist = "%s: %s is already exists"
)
type executor struct {
drivers.WorkerBase
baseParam *baseParam
requestBodyType string
errorType string
}
func (e *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
return http_service.DoHttpFilter(e, ctx, next)
}
func (e *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error {
statusCode, err := e.access(ctx)
if err != nil {
ctx.Response().SetBody([]byte(err.Error()))
ctx.Response().SetStatus(statusCode, strconv.Itoa(statusCode))
return err
}
if next != nil {
return next.DoChain(ctx)
}
return nil
}
func addParamToBody(ctx http_service.IHttpContext, contentType string, params []*paramInfo) (interface{}, error) {
//var bodyParam map[string]interface{}
if contentType == "application/json" {
body, _ := ctx.Proxy().Body().RawBody()
if string(body) == "" {
body = []byte("{}")
}
bodyParam, err := oj.Parse(body)
if err != nil {
return nil, err
}
for _, param := range params {
key := param.name
if !strings.HasPrefix(param.name, "$.") {
key = "$." + key
}
x, err := jp.ParseString(key)
if err != nil {
return nil, fmt.Errorf("parse key error: %v", err)
}
result := x.Get(bodyParam)
if len(result) > 0 {
if param.conflict == paramError {
return nil, fmt.Errorf(errorExist, positionBody, param.name)
} else if param.conflict == paramOrigin {
continue
}
}
value, err := param.Build(ctx, contentType, bodyParam)
if err != nil {
log.Errorf("build param(s) error: %v", key, err)
continue
}
err = x.Set(bodyParam, value)
if err != nil {
log.Errorf("set param(s) error: %v", key, err)
continue
}
}
b, _ := oj.Marshal(bodyParam)
ctx.Proxy().Body().SetRaw(contentType, b)
return bodyParam, nil
} else if contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data" {
bodyParam := make(map[string]interface{})
bodyForm, _ := ctx.Proxy().Body().BodyForm()
for _, param := range params {
_, has := bodyParam[param.name]
if has {
if param.conflict == paramError {
return nil, fmt.Errorf("[extra_params] body(%s) has a conflict", param.name)
} else if param.conflict == paramOrigin {
continue
}
}
value, err := param.Build(ctx, contentType, nil)
if err != nil {
log.Errorf("build param(s) error: %v", param.name, err)
continue
}
bodyParam[param.name] = value
}
ctx.Proxy().Body().SetForm(bodyForm)
}
return nil, nil
}
func (e *executor) access(ctx http_service.IHttpContext) (int, error) {
// 判断请求携带的content-type
contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType())
var bodyParam interface{}
var err error
if ctx.Proxy().Method() == http.MethodPost || ctx.Proxy().Method() == http.MethodPut || ctx.Proxy().Method() == http.MethodPatch {
if e.requestBodyType != "" {
if e.requestBodyType == "json" && contentType != "application/json" {
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] request body type is not json`, clientErrStatusCode)
} else if e.requestBodyType == "form-data" && contentType != "multipart/form-data" {
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] request body type is not form-data`, clientErrStatusCode)
}
}
bodyParam, err = addParamToBody(ctx, contentType, e.baseParam.body)
if err != nil {
return clientErrStatusCode, encodeErr(e.errorType, err.Error(), clientErrStatusCode)
}
}
// 处理Query参数
for _, param := range e.baseParam.query {
v := ctx.Proxy().URI().GetQuery(param.name)
if v != "" {
if param.conflict == paramError {
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] query("`+param.name+`") has a conflict.`, clientErrStatusCode)
} else if param.conflict == paramOrigin {
continue
}
}
value, err := param.Build(ctx, contentType, bodyParam)
if err != nil {
log.Errorf("build query extra param(%s) error: %s", param.name, err.Error())
continue
}
ctx.Proxy().URI().SetQuery(param.name, value)
}
// 处理Header参数
for _, param := range e.baseParam.header {
name := textproto.CanonicalMIMEHeaderKey(param.name)
_, has := ctx.Proxy().Header().Headers()[name]
if has {
if param.conflict == paramError {
return clientErrStatusCode, encodeErr(e.errorType, `[extra_params] header("`+name+`") has a conflict.`, clientErrStatusCode)
} else if param.conflict == paramOrigin {
continue
}
}
value, err := param.Build(ctx, contentType, bodyParam)
if err != nil {
log.Errorf("build header extra param(%s) error: %s", name, err.Error())
continue
}
ctx.Proxy().Header().SetHeader(param.name, value)
}
return successStatusCode, nil
}
func (e *executor) Start() error {
return nil
}
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
cfg, err := check(conf)
if err != nil {
return err
}
e.baseParam = generateBaseParam(cfg.Params)
e.requestBodyType = cfg.RequestBodyType
e.errorType = cfg.ErrorType
return nil
}
func (e *executor) Stop() error {
return nil
}
func (e *executor) Destroy() {
e.baseParam = nil
e.errorType = ""
}
func (e *executor) CheckSkill(skill string) bool {
return http_service.FilterSkillName == skill
}

View File

@@ -0,0 +1,17 @@
package extra_params_v2
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
const (
Name = "extra_params_v2"
)
func Register(register eosc.IExtenderDriverRegister) {
register.RegisterExtenderDriver(Name, NewFactory())
}
func NewFactory() eosc.IExtenderDriverFactory {
return drivers.NewFactory[Config](Create, Check)
}

View File

@@ -0,0 +1,210 @@
package extra_params_v2
import (
"encoding/json"
"fmt"
)
const (
positionQuery = "query"
positionHeader = "header"
positionBody = "body"
paramConvert string = "convert"
paramError string = "error"
paramOrigin string = "origin"
clientErrStatusCode = 400
successStatusCode = 200
)
var (
paramPositionErrInfo = `[plugin extra-params config err] param position must be in the set ["query","header",body]. err position: %s `
//parseBodyErrInfo = `[extra_params] Fail to parse body! [err]: %s`
paramNameErrInfo = `[plugin extra-params config err] param name must be not null. `
)
func encodeErr(ent string, origin string, statusCode int) error {
if ent == "json" {
tmp := map[string]interface{}{
"message": origin,
"status_code": statusCode,
}
info, _ := json.Marshal(tmp)
return fmt.Errorf("%s", info)
}
return fmt.Errorf("%s statusCode: %d", origin, statusCode)
}
//func parseBodyParams(ctx http_service.IHttpContext) (interface{}, url.Values, error) {
// if ctx.Proxy().Method() != http.MethodPost && ctx.Proxy().Method() != http.MethodPut && ctx.Proxy().Method() != http.MethodPatch {
// return nil, nil, nil
// }
// contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType())
// switch contentType {
// case http_context.FormData, http_context.MultipartForm:
// formParams, err := ctx.Proxy().Body().BodyForm()
// if err != nil {
// return nil, nil, err
// }
// return nil, formParams, nil
// case http_context.JSON:
// body, err := ctx.Proxy().Body().RawBody()
// if err != nil {
// return nil, nil, err
// }
// if string(body) == "" {
// body = []byte("{}")
// }
// bodyParams, err := oj.Parse(body)
// return bodyParams, nil, err
// }
// return nil, nil, errors.New("unsupported content-type: " + contentType)
//}
//
//func parseBodyParams(ctx http_service.IHttpContext) (map[string]interface{}, map[string][]string, error) {
// contentType, _, _ := mime.ParseMediaType(ctx.Proxy().Body().ContentType())
//
// switch contentType {
// case http_context.FormData, http_context.MultipartForm:
// formParams, err := ctx.Proxy().Body().BodyForm()
// if err != nil {
// return nil, nil, err
// }
// return nil, formParams, nil
// case http_context.JSON:
// body, err := ctx.Proxy().Body().RawBody()
// if err != nil {
// return nil, nil, err
// }
// var bodyParams map[string]interface{}
// err = json.Unmarshal(body, &bodyParams)
// if err != nil {
// return bodyParams, nil, err
// }
// }
// return nil, nil, errors.New("[params_transformer] unsupported content-type: " + contentType)
//}
//func getHeaderValue(headers map[string][]string, param *ExtraParam, value string) (string, error) {
// paramName := ConvertHeaderKey(param.Name)
//
// if param.Conflict == "" {
// param.Conflict = paramConvert
// }
//
// var paramValue string
//
// if _, ok := headers[paramName]; !ok {
// param.Conflict = paramConvert
// } else {
// paramValue = headers[paramName][0]
// }
//
// if param.Conflict == paramConvert {
// paramValue = value
// } else if param.Conflict == paramError {
// errInfo := `[extra_params] "` + param.Name + `" has a conflict.`
// return "", errors.New(errInfo)
// }
//
// return paramValue, nil
//}
//func hasQueryValue(rawQuery string, paramName string) bool {
// bytes := []byte(rawQuery)
// if len(bytes) == 0 {
// return false
// }
//
// k := 0
// for i, c := range bytes {
// switch c {
// case '=':
// key := string(bytes[k:i])
// if key == paramName {
// return true
// }
// case '&':
// k = i + 1
// }
// }
//
// return false
//}
//func getQueryValue(ctx http_service.IHttpContext, param *ExtraParam, value string) (string, error) {
// paramValue := ""
// if param.Conflict == "" {
// param.Conflict = paramConvert
// }
//
// //判断请求中是否包含对应的query参数
// if !hasQueryValue(ctx.Proxy().URI().RawQuery(), param.Name) {
// param.Conflict = paramConvert
// } else {
// paramValue = ctx.Proxy().URI().GetQuery(param.Name)
// }
//
// if param.Conflict == paramConvert {
// paramValue = value
// } else if param.Conflict == paramError {
// errInfo := `[extra_params] "` + param.Name + `" has a conflict.`
// return "", errors.New(errInfo)
// }
//
// return paramValue, nil
//}
//
//func getBodyValue(bodyParams map[string]interface{}, formParams map[string][]string, param *ExtraParam, contentType string, value interface{}) (interface{}, error) {
// var paramValue interface{} = nil
// Conflict := param.Conflict
// if Conflict == "" {
// Conflict = paramConvert
// }
// if strings.Contains(contentType, http_context.FormData) || strings.Contains(contentType, http_context.MultipartForm) {
// if _, ok := formParams[param.Name]; !ok {
// Conflict = paramConvert
// } else {
// paramValue = formParams[param.Name][0]
// }
// } else if strings.Contains(contentType, http_context.JSON) {
// if _, ok := bodyParams[param.Name]; !ok {
// param.Conflict = paramConvert
// } else {
// paramValue = bodyParams[param.Name]
// }
// }
// if Conflict == paramConvert {
// paramValue = value
// } else if Conflict == paramError {
// errInfo := `[extra_params] "` + param.Name + `" has a conflict.`
// return "", errors.New(errInfo)
// }
//
// return paramValue, nil
//}
//func ConvertHeaderKey(header string) string {
// header = strings.ToLower(header)
// headerArray := strings.Split(header, "-")
// h := ""
// arrLen := len(headerArray)
// for i, value := range headerArray {
// vLen := len(value)
// if vLen < 1 {
// continue
// } else {
// if vLen == 1 {
// h += strings.ToUpper(value)
// } else {
// h += strings.ToUpper(string(value[0])) + value[1:]
// }
// if i != arrLen-1 {
// h += "-"
// }
// }
// }
// return h
//}

View File

@@ -0,0 +1,44 @@
package fasthttp_client
import (
"fmt"
"testing"
"github.com/valyala/fasthttp"
)
func TestProxyTimeout(t *testing.T) {
//addr := "https://gw.kuaidizs.cn"
addr := fmt.Sprintf("%s://%s", "https", "gw.kuaidizs.cn")
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.URI().SetPath("/open/api")
req.URI().SetHost("gw.kuaidizs.cn")
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.SetBody([]byte(`{"cpCode":"YTO","version":"1.0","timestamp":"2023-08-10 11:57:13","province":"广东省","city":"广州市","appKey":"DBB812347A1E44829159FE82F5C4303E","format":"json","sign_method":"md5","method":"kdzs.address.reachable","sign":"10A4B5A59340F9B98DAFA3CFCCF65449"}`))
err := defaultClient.ProxyTimeout(addr, req, resp, 0)
if err != nil {
t.Error(err)
}
t.Log(string(resp.Body()))
return
}
func TestMyselfProxyTimeout(t *testing.T) {
//addr := "https://gw.kuaidizs.cn"
addr := "http://127.0.0.1:8099"
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.URI().SetPath("/open/api")
req.URI().SetHost("127.0.0.1:8099")
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.SetBody([]byte(`{"cpCode":"YTO","province":"广东省","city":"广州市"}`))
err := defaultClient.ProxyTimeout(addr, req, resp, 0)
if err != nil {
t.Error(err)
}
t.Log(string(resp.Body()))
return
}

View File

@@ -46,14 +46,15 @@ var (
func (r *ProxyRequest) reset(request *fasthttp.Request, remoteAddr string) {
r.RequestReader.reset(request, remoteAddr)
forwardedFor := r.req.Header.PeekBytes(xforwardedforKey)
if len(forwardedFor) > 0 {
r.req.Header.Set("x-forwarded-for", fmt.Sprint(string(forwardedFor), ",", remoteAddr))
r.req.Header.Set("x-forwarded-for", fmt.Sprint(string(forwardedFor), ",", r.remoteAddr))
} else {
r.req.Header.Set("x-forwarded-for", remoteAddr)
r.req.Header.Set("x-forwarded-for", r.remoteAddr)
}
r.req.Header.Set("x-real-ip", r.RequestReader.realIP)
r.req.Header.Set("x-real-ip", r.realIP)
}