mirror of
https://github.com/eolinker/apinto
synced 2025-10-23 16:53:18 +08:00
114 lines
2.7 KiB
Go
114 lines
2.7 KiB
Go
package counter
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/eolinker/apinto/drivers/plugins/counter/counter"
|
|
|
|
"github.com/eolinker/apinto/drivers"
|
|
"github.com/eolinker/apinto/drivers/plugins/counter/separator"
|
|
|
|
"github.com/eolinker/eosc"
|
|
"github.com/eolinker/eosc/eocontext"
|
|
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
|
)
|
|
|
|
var _ http_service.HttpFilter = (*executor)(nil)
|
|
var _ eocontext.IFilter = (*executor)(nil)
|
|
|
|
type executor struct {
|
|
drivers.WorkerBase
|
|
matchers []IMatcher
|
|
separatorCounter separator.ICounter
|
|
counters eosc.Untyped[string, counter.ICounter]
|
|
client counter.IClient
|
|
keyGenerate IKeyGenerator
|
|
}
|
|
|
|
func (b *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
|
|
return http_service.DoHttpFilter(b, ctx, next)
|
|
}
|
|
|
|
func (b *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) error {
|
|
counter, has := b.counters.Get(b.keyGenerate.Key(ctx))
|
|
if !has {
|
|
|
|
}
|
|
var count int64 = 1
|
|
var err error
|
|
if b.separatorCounter != nil {
|
|
|
|
separatorCounter := b.separatorCounter
|
|
count, err = separatorCounter.Count(ctx)
|
|
if err != nil {
|
|
ctx.Response().SetStatus(400, "400")
|
|
return fmt.Errorf("%s count error", separatorCounter.Name())
|
|
}
|
|
if count > separatorCounter.Max() {
|
|
ctx.Response().SetStatus(403, "not allow")
|
|
return fmt.Errorf("%s number exceed", separatorCounter.Name())
|
|
} else if count == 0 {
|
|
ctx.Response().SetStatus(400, "400")
|
|
return fmt.Errorf("%s value is missing", separatorCounter.Name())
|
|
}
|
|
}
|
|
|
|
err = counter.Lock(count)
|
|
if err != nil {
|
|
// 次数不足,直接返回
|
|
//return fmt.Errorf("no enough, key:%s, remain:%d, count:%d", b.counters.Name(), b.counters.Remain(), count
|
|
}
|
|
if next != nil {
|
|
err = next.DoChain(ctx)
|
|
if err != nil {
|
|
// 转发失败,回滚次数
|
|
return counter.RollBack(count)
|
|
//return err
|
|
}
|
|
}
|
|
match := true
|
|
for _, matcher := range b.matchers {
|
|
ok := matcher.Match(ctx)
|
|
if !ok {
|
|
match = false
|
|
break
|
|
}
|
|
}
|
|
if match {
|
|
// 匹配,扣减次数
|
|
return counter.Complete(count)
|
|
}
|
|
// 不匹配,回滚次数
|
|
return counter.RollBack(count)
|
|
}
|
|
|
|
func (b *executor) Start() error {
|
|
return nil
|
|
}
|
|
|
|
func (b *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
|
|
cfg, ok := conf.(*Config)
|
|
if !ok {
|
|
return fmt.Errorf("invalid config, driver: %s", Name)
|
|
}
|
|
counter, err := separator.GetCounter(cfg.Count)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.separatorCounter = counter
|
|
b.matchers = cfg.Match.GenerateHandler()
|
|
return nil
|
|
}
|
|
|
|
func (b *executor) Stop() error {
|
|
return nil
|
|
}
|
|
|
|
func (b *executor) Destroy() {
|
|
return
|
|
}
|
|
|
|
func (b *executor) CheckSkill(skill string) bool {
|
|
return http_service.FilterSkillName == skill
|
|
}
|