diff --git a/examples/plugin/docs/autoreflect.md b/examples/plugin/docs/autoreflect.md new file mode 100644 index 0000000..70d87ac --- /dev/null +++ b/examples/plugin/docs/autoreflect.md @@ -0,0 +1,125 @@ +# 插件自动方法注册与调用机制 + +## 简介 + +插件自动方法注册与调用机制是一个基于 Go 反射的工具,它能够自动发现插件中的方法,并使这些方法可以通过统一的 `Execute` 接口调用,而无需手动为每个方法编写对应的处理代码。 + +这个机制大大简化了插件开发,提高了代码的可维护性,并降低了出错的可能性。 + +## 工作原理 + +1. 在插件初始化时,通过反射获取插件实例的所有可用方法 +2. 将这些方法注册为可通过 `Execute` 调用的动作,动作名称为方法名的小写形式 +3. 当调用 `Execute` 方法时,自动查找对应的方法并调用,处理参数转换和错误处理 + +## 使用方法 + +### 1. 创建插件 + +创建插件时,只需继承 `BasePluginImpl` 结构体,然后实现你的方法: + +```go +type MyPlugin struct { + *plugin.BasePluginImpl + // 插件的属性 +} + +// 创建导出的插件变量 +var Plugin = &MyPlugin{ + BasePluginImpl: plugin.NewPlugin( + "MyPlugin", + "1.0.0", + "我的插件", + "开发者", + plugin.PluginTypeUtils, + ), + // 初始化属性 +} + +// 这个方法会被自动注册为 "doSomething" 操作 +func (p *MyPlugin) DoSomething(name string, value int) (string, error) { + // 方法实现 + return fmt.Sprintf("%s: %d", name, value), nil +} +``` + +### 2. 方法参数约定 + +方法的参数和返回值需要遵循一定的约定,以便自动注册系统能够正确处理: + +1. **Context 参数**:如果方法的第一个参数是 `context.Context`,调用时会自动传入上下文参数 +2. **普通参数**:可以使用基本类型参数,如 `string`、`int`、`bool` 等 +3. **结构体参数**:可以定义结构体参数,用于接收多个参数,结构体字段名会被用作参数名 +4. **返回值**: + - 单个返回值:直接作为操作结果返回 + - 两个返回值,第二个为 `error`:遵循 Go 惯例,返回结果和错误 + - 多个返回值:打包成 map 返回 + +### 3. 结构体参数示例 + +使用结构体参数可以使方法定义更加清晰,并支持更复杂的参数: + +```go +// 参数结构体 +type RequestParams struct { + Name string `json:"name"` + Count int `json:"count"` + IsEnabled bool `json:"isEnabled"` + Metadata map[string]string `json:"metadata"` +} + +// 使用结构体参数的方法 +func (p *MyPlugin) ProcessRequest(ctx context.Context, params RequestParams) (interface{}, error) { + // 处理请求 + return result, nil +} +``` + +调用示例: + +```go +// 调用带结构体参数的方法 +result, err := pm.ExecutePlugin(ctx, "MyPlugin", "processRequest", map[string]interface{}{ + "name": "测试请求", + "count": 5, + "isEnabled": true, + "metadata": map[string]string{ + "key1": "value1", + "key2": "value2", + }, +}) +``` + +### 4. 获取可用操作列表 + +插件可以实现 `GetAvailableOperations` 方法,返回所有可用的操作列表: + +```go +func (p *MyPlugin) GetAvailableOperations() []string { + return p.GetAvailableActions() +} +``` + +调用示例: + +```go +// 获取可用操作 +operations, err := pm.ExecutePlugin(ctx, "MyPlugin", "getAvailableOperations", nil) +fmt.Println("可用操作列表:", operations) +``` + +## 注意事项 + +1. **方法命名**:方法名会被转换为小写形式作为操作名,调用时不区分大小写 +2. **跳过接口方法**:接口方法如 `Name()`、`Version()` 等不会被注册为可调用操作 +3. **参数转换**:系统会尝试进行参数类型转换,但复杂的自定义类型可能需要手动处理 +4. **错误处理**:如果找不到对应的方法,会回退到基础的 `Execute` 实现 +5. **性能考虑**:反射调用的性能略低于直接调用,但在大多数场景下差异可忽略不计 + +## 最佳实践 + +1. **使用结构体参数**:对于多参数方法,使用结构体参数会使代码更清晰 +2. **提供帮助方法**:实现 `GetAvailableOperations` 方法,便于用户了解插件支持的操作 +3. **遵循命名约定**:方法名使用驼峰命名法,参数名使用小写 +4. **添加适当的文档**:为方法添加注释,说明参数和返回值的含义 +5. **考虑向后兼容性**:新增方法不会影响现有功能,是扩展插件功能的理想方式 \ No newline at end of file diff --git a/examples/plugin/example/dynamic_params.go b/examples/plugin/example/dynamic_params.go index 71f22fb..fb0a3d1 100644 --- a/examples/plugin/example/dynamic_params.go +++ b/examples/plugin/example/dynamic_params.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "path/filepath" - "time" "github.com/darkit/goproxy/examples/plugin" ) @@ -51,8 +50,10 @@ func DynamicParamsExample() { } fmt.Println() - // 示例1: 调用日志插件 - fmt.Println("=== 示例1: 调用日志插件 ===") + // ======= 旧方式:直接使用Execute方法 ======= + fmt.Println("=== 示例1: 直接使用Execute方法 ===") + + // 调用日志插件 logResult, err := pm.ExecutePlugin(ctx, "LoggerPlugin", "info", map[string]interface{}{ "message": "这是通过动态参数传递的日志消息", }) @@ -62,8 +63,60 @@ func DynamicParamsExample() { statusResult, err := pm.ExecutePlugin(ctx, "LoggerPlugin", "getLoggerStatus", nil) fmt.Printf("日志插件状态: %v, 错误: %v\n", statusResult, err) - // 示例2: 调用存储插件 - fmt.Println("\n=== 示例2: 调用存储插件 ===") + // ======= 新方式:使用自动发现的方法 ======= + fmt.Println("\n=== 示例2: 使用自动注册的方法 ===") + + // 获取StatsPlugin + _, ok := pm.GetPlugin("StatsPlugin") + if !ok { + fmt.Println("未找到StatsPlugin") + return + } + + // 获取StatsPlugin支持的所有操作 + operationsResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "getAvailableOperations", nil) + if err != nil { + fmt.Printf("获取操作失败: %v\n", err) + } else { + fmt.Println("StatsPlugin支持的操作:") + operations, ok := operationsResult.([]string) + if ok { + for i, op := range operations { + fmt.Printf(" %d. %s\n", i+1, op) + } + } + } + + // 使用结构体参数 + fmt.Println("\n使用结构体参数记录请求:") + recordResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "recordrequest", map[string]interface{}{ + "bytesReceived": 2048, + "bytesSent": 4096, + "isError": false, + }) + fmt.Printf("记录请求结果: %v, 错误: %v\n", recordResult, err) + + // 获取统计数据 + fmt.Println("\n获取统计数据:") + statsResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "getallstats", nil) + fmt.Printf("统计数据: %v, 错误: %v\n", statsResult, err) + + // 增加自定义统计项 + fmt.Println("\n增加自定义统计项:") + incrResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "incrementstat", map[string]interface{}{ + "name": "custom_counter", + "value": 42, + }) + fmt.Printf("增加统计项结果: %v, 错误: %v\n", incrResult, err) + + // 获取统计报告 + fmt.Println("\n获取统计报告:") + reportResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "generatestatsreport", nil) + fmt.Printf("统计报告: %v, 错误: %v\n", reportResult, err) + + // 动态调用StoragePlugin方法 + fmt.Println("\n=== 示例3: 动态调用StoragePlugin方法 ===") + // 保存文件 saveResult, err := pm.ExecutePlugin(ctx, "StoragePlugin", "saveFile", map[string]interface{}{ "filename": "test.txt", @@ -71,7 +124,7 @@ func DynamicParamsExample() { }) fmt.Printf("保存文件结果: %v, 错误: %v\n", saveResult, err) - // 列出文件 + // 列出所有文件 listResult, err := pm.ExecutePlugin(ctx, "StoragePlugin", "listFiles", nil) fmt.Printf("文件列表: %v, 错误: %v\n", listResult, err) @@ -81,47 +134,19 @@ func DynamicParamsExample() { }) fmt.Printf("读取文件内容: %v, 错误: %v\n", loadResult, err) - // 获取存储信息 - infoResult, err := pm.ExecutePlugin(ctx, "StoragePlugin", "getStorageInfo", nil) - fmt.Printf("存储插件信息: %v, 错误: %v\n", infoResult, err) - - // 示例3: 调用统计插件 - fmt.Println("\n=== 示例3: 调用统计插件 ===") - // 记录请求 - pm.ExecutePlugin(ctx, "StatsPlugin", "recordRequest", map[string]interface{}{ - "bytesReceived": 1024.0, - "bytesSent": 2048.0, - "isError": false, + // 展示操作名不区分大小写 + fmt.Println("\n=== 示例4: 操作名不区分大小写 ===") + caseSensitiveResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "IncrementStat", map[string]interface{}{ + "name": "case_test", + "value": 100, }) + fmt.Printf("使用大写操作名结果: %v, 错误: %v\n", caseSensitiveResult, err) - // 睡眠一段时间以便观察 - time.Sleep(1 * time.Second) - - // 再次记录请求 - pm.ExecutePlugin(ctx, "StatsPlugin", "recordRequest", map[string]interface{}{ - "bytesReceived": 512.0, - "bytesSent": 768.0, - "isError": true, + // 获取后验证 + checkResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "getstat", map[string]interface{}{ + "name": "case_test", }) - - // 获取统计报告 - reportResult, err := pm.ExecutePlugin(ctx, "StatsPlugin", "getStatsReport", nil) - fmt.Printf("统计报告: %v, 错误: %v\n", reportResult, err) - - // 示例4: 使用ExecutePluginsByType调用所有工具类插件 - fmt.Println("\n=== 示例4: 按类型调用插件 ===") - typeResults := pm.ExecutePluginsByType(ctx, plugin.PluginTypeUtils, "info", map[string]interface{}{ - "message": "这条消息将发送给所有工具类插件", - }) - fmt.Printf("工具类插件调用结果: %v\n", typeResults) - - // 示例5: 使用ExecuteAllPlugins调用所有插件 - fmt.Println("\n=== 示例5: 调用所有插件 ===") - allResults := pm.ExecuteAllPlugins(ctx, "getLoggerStatus", nil) - fmt.Println("所有插件调用结果:") - for name, result := range allResults { - fmt.Printf(" %s: %v\n", name, result) - } + fmt.Printf("验证结果: %v, 错误: %v\n", checkResult, err) // 停止所有插件 if err := pm.StopPlugins(ctx); err != nil { diff --git a/examples/plugin/plugin.go b/examples/plugin/plugin.go index 56402ce..c402855 100644 --- a/examples/plugin/plugin.go +++ b/examples/plugin/plugin.go @@ -9,7 +9,10 @@ import ( "os" "path/filepath" "plugin" + "reflect" "runtime" + "strconv" + "strings" "sync" ) @@ -42,24 +45,453 @@ type Plugin interface { Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) } +// PluginHelper 插件辅助器 +// 用于自动发现和注册插件方法,简化Execute方法的实现 +type PluginHelper struct { + instance interface{} // 插件实例 + methods map[string]reflect.Method // 注册的方法 + methodParams map[string][]reflect.Type // 方法的参数类型 + methodRetValues map[string][]reflect.Type // 方法的返回值类型 +} + +// NewPluginHelper 创建一个新的插件辅助器 +func NewPluginHelper(instance interface{}) *PluginHelper { + helper := &PluginHelper{ + instance: instance, + methods: make(map[string]reflect.Method), + methodParams: make(map[string][]reflect.Type), + methodRetValues: make(map[string][]reflect.Type), + } + helper.discoverMethods() + return helper +} + +// discoverMethods 发现插件的所有可用方法 +func (h *PluginHelper) discoverMethods() { + instanceType := reflect.TypeOf(h.instance) + + // 遍历实例的所有方法 + for i := 0; i < instanceType.NumMethod(); i++ { + method := instanceType.Method(i) + + // 跳过特定的内部方法和接口方法 + if shouldSkipMethod(method.Name) { + continue + } + + // 获取方法的参数和返回值类型 + var paramTypes []reflect.Type + var returnTypes []reflect.Type + + // 跳过接收者参数 (第一个参数) + for j := 1; j < method.Type.NumIn(); j++ { + paramTypes = append(paramTypes, method.Type.In(j)) + } + + for j := 0; j < method.Type.NumOut(); j++ { + returnTypes = append(returnTypes, method.Type.Out(j)) + } + + // 注册方法 + actionName := strings.ToLower(method.Name) + h.methods[actionName] = method + h.methodParams[actionName] = paramTypes + h.methodRetValues[actionName] = returnTypes + } +} + +// shouldSkipMethod 检查是否应该跳过某些方法 +func shouldSkipMethod(name string) bool { + // 跳过Plugin接口的方法和特定的内部方法 + skipMethods := map[string]bool{ + "Name": true, + "Version": true, + "Description": true, + "Author": true, + "Type": true, + "Init": true, + "Start": true, + "Stop": true, + "IsEnabled": true, + "SetEnabled": true, + "Execute": true, + } + + return skipMethods[name] +} + +// GetAvailableActions 获取所有可用的动作 +func (h *PluginHelper) GetAvailableActions() []string { + actions := make([]string, 0, len(h.methods)) + for action := range h.methods { + actions = append(actions, action) + } + return actions +} + +// ExecuteAction 执行指定的动作 +func (h *PluginHelper) ExecuteAction(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) { + // 转换为小写以实现不区分大小写的匹配 + action = strings.ToLower(action) + + method, exists := h.methods[action] + if !exists { + return nil, fmt.Errorf("未知的操作: %s", action) + } + + paramTypes := h.methodParams[action] + + // 准备参数 + var args []reflect.Value + + // 添加接收者参数 + args = append(args, reflect.ValueOf(h.instance)) + + // 处理context参数 + if len(paramTypes) > 0 && paramTypes[0].String() == "context.Context" { + args = append(args, reflect.ValueOf(ctx)) + paramTypes = paramTypes[1:] // 移除已处理的context参数 + } + + // 处理其他参数 + for i, paramType := range paramTypes { + paramName := fmt.Sprintf("arg%d", i) + + // 如果是结构体参数,尝试将整个params映射转换为该结构体 + if paramType.Kind() == reflect.Struct { + structValue := reflect.New(paramType).Elem() + if err := mapToStruct(params, structValue); err != nil { + return nil, fmt.Errorf("转换参数失败: %v", err) + } + args = append(args, structValue) + continue + } + + // 从params中获取参数 + paramValue, ok := params[paramName] + if !ok { + // 尝试使用参数类型名称作为键 + typeName := paramType.Name() + paramValue, ok = params[strings.ToLower(typeName)] + + if !ok { + return nil, fmt.Errorf("缺少必需参数: %s", paramName) + } + } + + // 转换参数类型 + convertedValue, err := convertParamValue(paramValue, paramType) + if err != nil { + return nil, err + } + + args = append(args, convertedValue) + } + + // 调用方法 + result := method.Func.Call(args) + + // 处理返回值 + if len(result) == 0 { + return nil, nil + } else if len(result) == 1 { + return result[0].Interface(), nil + } else { + // 处理多个返回值,通常最后一个是error + lastIndex := len(result) - 1 + if result[lastIndex].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { + if !result[lastIndex].IsNil() { + return nil, result[lastIndex].Interface().(error) + } + + if lastIndex == 0 { + return nil, nil + } + return result[lastIndex-1].Interface(), nil + } + + // 将所有返回值打包成一个映射 + resultMap := make(map[string]interface{}) + for i, v := range result { + resultMap[fmt.Sprintf("result%d", i)] = v.Interface() + } + return resultMap, nil + } +} + +// mapToStruct 将map转换为结构体 +func mapToStruct(m map[string]interface{}, structValue reflect.Value) error { + structType := structValue.Type() + + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + fieldValue := structValue.Field(i) + + if !fieldValue.CanSet() { + continue + } + + // 尝试不同的命名方式 + fieldName := field.Name + jsonTag := field.Tag.Get("json") + if jsonTag != "" && jsonTag != "-" { + parts := strings.Split(jsonTag, ",") + fieldName = parts[0] + } + + // 检查不同大小写 + value, ok := m[fieldName] + if !ok { + value, ok = m[strings.ToLower(fieldName)] + } + if !ok { + continue // 跳过未找到的字段 + } + + // 转换并设置字段值 + convertedValue, err := convertParamValue(value, field.Type) + if err != nil { + return err + } + + fieldValue.Set(convertedValue) + } + + return nil +} + +// convertParamValue 将接口值转换为指定类型 +func convertParamValue(value interface{}, targetType reflect.Type) (reflect.Value, error) { + // 处理nil值 + if value == nil { + return reflect.Zero(targetType), nil + } + + valueType := reflect.TypeOf(value) + + // 如果类型已经匹配,直接返回 + if valueType.AssignableTo(targetType) { + return reflect.ValueOf(value), nil + } + + // 类型转换 + switch targetType.Kind() { + case reflect.String: + return reflect.ValueOf(fmt.Sprintf("%v", value)), nil + + case reflect.Bool: + switch v := value.(type) { + case bool: + return reflect.ValueOf(v), nil + case string: + b, err := strconv.ParseBool(v) + if err != nil { + return reflect.Value{}, fmt.Errorf("无法将 %v 转换为布尔值", value) + } + return reflect.ValueOf(b), nil + case float64: + return reflect.ValueOf(v != 0), nil + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var intVal int64 + switch v := value.(type) { + case int: + intVal = int64(v) + case int8: + intVal = int64(v) + case int16: + intVal = int64(v) + case int32: + intVal = int64(v) + case int64: + intVal = v + case float32: + intVal = int64(v) + case float64: + intVal = int64(v) + case string: + var err error + intVal, err = strconv.ParseInt(v, 10, 64) + if err != nil { + return reflect.Value{}, fmt.Errorf("无法将 %v 转换为整数", value) + } + default: + return reflect.Value{}, fmt.Errorf("无法将 %v 转换为整数", value) + } + + switch targetType.Kind() { + case reflect.Int: + return reflect.ValueOf(int(intVal)), nil + case reflect.Int8: + return reflect.ValueOf(int8(intVal)), nil + case reflect.Int16: + return reflect.ValueOf(int16(intVal)), nil + case reflect.Int32: + return reflect.ValueOf(int32(intVal)), nil + case reflect.Int64: + return reflect.ValueOf(intVal), nil + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var uintVal uint64 + switch v := value.(type) { + case uint: + uintVal = uint64(v) + case uint8: + uintVal = uint64(v) + case uint16: + uintVal = uint64(v) + case uint32: + uintVal = uint64(v) + case uint64: + uintVal = v + case int: + if v < 0 { + return reflect.Value{}, fmt.Errorf("无法将负数 %v 转换为无符号整数", value) + } + uintVal = uint64(v) + case float64: + if v < 0 { + return reflect.Value{}, fmt.Errorf("无法将负数 %v 转换为无符号整数", value) + } + uintVal = uint64(v) + case string: + var err error + uintVal, err = strconv.ParseUint(v, 10, 64) + if err != nil { + return reflect.Value{}, fmt.Errorf("无法将 %v 转换为无符号整数", value) + } + default: + return reflect.Value{}, fmt.Errorf("无法将 %v 转换为无符号整数", value) + } + + switch targetType.Kind() { + case reflect.Uint: + return reflect.ValueOf(uint(uintVal)), nil + case reflect.Uint8: + return reflect.ValueOf(uint8(uintVal)), nil + case reflect.Uint16: + return reflect.ValueOf(uint16(uintVal)), nil + case reflect.Uint32: + return reflect.ValueOf(uint32(uintVal)), nil + case reflect.Uint64: + return reflect.ValueOf(uintVal), nil + } + + case reflect.Float32, reflect.Float64: + var floatVal float64 + switch v := value.(type) { + case float32: + floatVal = float64(v) + case float64: + floatVal = v + case int: + floatVal = float64(v) + case int64: + floatVal = float64(v) + case string: + var err error + floatVal, err = strconv.ParseFloat(v, 64) + if err != nil { + return reflect.Value{}, fmt.Errorf("无法将 %v 转换为浮点数", value) + } + default: + return reflect.Value{}, fmt.Errorf("无法将 %v 转换为浮点数", value) + } + + if targetType.Kind() == reflect.Float32 { + return reflect.ValueOf(float32(floatVal)), nil + } + return reflect.ValueOf(floatVal), nil + + case reflect.Slice: + // 处理切片类型 + srcVal := reflect.ValueOf(value) + if srcVal.Kind() == reflect.Slice { + // 创建目标类型的新切片 + elemType := targetType.Elem() + newSlice := reflect.MakeSlice(targetType, srcVal.Len(), srcVal.Cap()) + + // 转换每个元素 + for i := 0; i < srcVal.Len(); i++ { + elemValue, err := convertParamValue(srcVal.Index(i).Interface(), elemType) + if err != nil { + return reflect.Value{}, err + } + newSlice.Index(i).Set(elemValue) + } + return newSlice, nil + } + + case reflect.Map: + // 处理映射类型 + srcVal := reflect.ValueOf(value) + if srcVal.Kind() == reflect.Map { + keyType := targetType.Key() + elemType := targetType.Elem() + newMap := reflect.MakeMap(targetType) + + iter := srcVal.MapRange() + for iter.Next() { + k := iter.Key() + v := iter.Value() + + newKey, err := convertParamValue(k.Interface(), keyType) + if err != nil { + return reflect.Value{}, err + } + + newValue, err := convertParamValue(v.Interface(), elemType) + if err != nil { + return reflect.Value{}, err + } + + newMap.SetMapIndex(newKey, newValue) + } + + return newMap, nil + } + + case reflect.Ptr: + // 处理指针类型 + elemType := targetType.Elem() + elemValue, err := convertParamValue(value, elemType) + if err != nil { + return reflect.Value{}, err + } + + ptrValue := reflect.New(elemType) + ptrValue.Elem().Set(elemValue) + return ptrValue, nil + } + + return reflect.Value{}, fmt.Errorf("不支持将 %T 类型转换为 %s", value, targetType) +} + // BasePluginImpl 提供插件接口的基本实现,用于适配Plugin接口 // 这个结构体包装了BasePlugin,以便兼容context参数 type BasePluginImpl struct { *BasePlugin + helper *PluginHelper // 添加插件辅助器 } // NewPlugin 创建一个基本插件实现,带有插件类型 func NewPlugin(name, version, description, author string, pluginType PluginType) *BasePluginImpl { - return &BasePluginImpl{ + plugin := &BasePluginImpl{ BasePlugin: NewBasePlugin(name, version, description, author, pluginType), } + plugin.helper = NewPluginHelper(plugin) + return plugin } // NewPluginWithDefaultType 创建一个基本插件实现,使用默认的通用插件类型 func NewPluginWithDefaultType(name, version, description, author string) *BasePluginImpl { - return &BasePluginImpl{ + plugin := &BasePluginImpl{ BasePlugin: NewBasePluginWithDefaultType(name, version, description, author), } + plugin.helper = NewPluginHelper(plugin) + return plugin } // Init 适配Init方法以支持context参数 @@ -77,9 +509,26 @@ func (p *BasePluginImpl) Stop(ctx context.Context) error { return p.BasePlugin.Stop() } -// Execute 适配Execute方法以支持context参数 +// Execute 通过辅助器自动执行插件方法 func (p *BasePluginImpl) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) { - return p.BasePlugin.Execute(action, params) + // 首先尝试通过辅助器执行方法 + result, err := p.helper.ExecuteAction(ctx, action, params) + if err == nil { + return result, nil + } + + // 如果辅助器执行失败,检查错误类型 + if strings.Contains(err.Error(), "未知的操作") { + // 回退到基础实现 + return p.BasePlugin.Execute(action, params) + } + + return nil, err +} + +// GetAvailableActions 获取插件支持的所有操作 +func (p *BasePluginImpl) GetAvailableActions() []string { + return p.helper.GetAvailableActions() } // PluginInfo 插件信息 diff --git a/examples/plugin/plugins/stats/stats_plugin.go b/examples/plugin/plugins/stats/stats_plugin.go index 1015394..754e1e6 100644 --- a/examples/plugin/plugins/stats/stats_plugin.go +++ b/examples/plugin/plugins/stats/stats_plugin.go @@ -12,19 +12,29 @@ import ( // StatsPlugin 统计插件 // 用于收集和记录系统运行时统计数据 type StatsPlugin struct { - *plugin.BasePlugin - stats map[string]int64 - startTime time.Time - mu sync.RWMutex - tickerStop chan bool - ticker *time.Ticker - config map[string]interface{} + *plugin.BasePluginImpl // 改为使用BasePluginImpl + stats map[string]int64 + startTime time.Time + mu sync.RWMutex + tickerStop chan bool + ticker *time.Ticker + config map[string]interface{} +} + +// StatsParams 统计请求参数结构体 +// 允许通过结构体传递参数,简化调用 +type StatsParams struct { + Name string `json:"name"` // 统计项名称 + Value int64 `json:"value"` // 统计值 + BytesReceived int64 `json:"bytesReceived"` // 接收字节数 + BytesSent int64 `json:"bytesSent"` // 发送字节数 + IsError bool `json:"isError"` // 是否为错误请求 } // Plugin 导出的插件变量 var Plugin = &StatsPlugin{ // 使用默认构造函数,不指定插件类型,将默认为通用插件 - BasePlugin: plugin.NewBasePluginWithDefaultType( + BasePluginImpl: plugin.NewPluginWithDefaultType( "StatsPlugin", "1.0.0", "系统运行时统计插件", @@ -109,98 +119,68 @@ func (p *StatsPlugin) Stop(ctx context.Context) error { return nil } -// Execute 执行插件功能 -func (p *StatsPlugin) Execute(ctx context.Context, action string, params map[string]interface{}) (interface{}, error) { - switch action { - case "incrementStat": - // 需要参数: name, value - name, ok := params["name"].(string) - if !ok { - return nil, fmt.Errorf("缺少必需参数: name") - } +// 以下方法将被自动注册为可通过Execute调用的操作 - // 处理整数值参数 - var value int64 - if floatValue, ok := params["value"].(float64); ok { - value = int64(floatValue) - } else if strValue, ok := params["value"].(string); ok { - var err error - _, err = fmt.Sscanf(strValue, "%d", &value) - if err != nil { - return nil, fmt.Errorf("参数value必须是整数: %v", err) - } - } else { - return nil, fmt.Errorf("缺少必需参数: value") - } +// IncrementStat 增加统计值 +// 会被自动注册为"incrementstat"操作 +func (p *StatsPlugin) IncrementStat(name string, value int64) error { + p.mu.Lock() + defer p.mu.Unlock() - p.IncrementStat(name, value) - return true, nil - - case "getStat": - // 需要参数: name - name, ok := params["name"].(string) - if !ok { - return nil, fmt.Errorf("缺少必需参数: name") - } - - value := p.GetStat(name) - return value, nil - - case "getAllStats": - // 不需要参数 - return p.GetAllStats(), nil - - case "recordRequest": - // 需要参数: bytesReceived, bytesSent, isError - var bytesReceived, bytesSent int64 - var isError bool - - // 处理bytesReceived参数 - if floatValue, ok := params["bytesReceived"].(float64); ok { - bytesReceived = int64(floatValue) - } else { - return nil, fmt.Errorf("缺少必需参数: bytesReceived") - } - - // 处理bytesSent参数 - if floatValue, ok := params["bytesSent"].(float64); ok { - bytesSent = int64(floatValue) - } else { - return nil, fmt.Errorf("缺少必需参数: bytesSent") - } - - // 处理isError参数 - if value, ok := params["isError"].(bool); ok { - isError = value - } - - p.RecordRequest(bytesReceived, bytesSent, isError) - return true, nil - - case "resetStats": - // 不需要参数 - p.mu.Lock() - p.stats = make(map[string]int64) - p.stats["requests"] = 0 - p.stats["errors"] = 0 - p.stats["bytes_sent"] = 0 - p.stats["bytes_received"] = 0 - p.startTime = time.Now() - p.mu.Unlock() - return true, nil - - case "getStatsReport": - // 生成统计报告 - report := p.generateStatsReport() - return report, nil - - default: - return nil, fmt.Errorf("未知的操作: %s", action) + if _, exists := p.stats[name]; exists { + p.stats[name] += value + } else { + p.stats[name] = value } + + return nil } -// generateStatsReport 生成统计报告 -func (p *StatsPlugin) generateStatsReport() map[string]interface{} { +// GetStat 获取统计值 +// 会被自动注册为"getstat"操作 +func (p *StatsPlugin) GetStat(name string) (int64, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + if value, exists := p.stats[name]; exists { + return value, nil + } + return 0, fmt.Errorf("统计项 %s 不存在", name) +} + +// RecordRequest 记录请求 +// 会被自动注册为"recordrequest"操作 +func (p *StatsPlugin) RecordRequest(ctx context.Context, params StatsParams) error { + p.IncrementStat("requests", 1) + p.IncrementStat("bytes_received", params.BytesReceived) + p.IncrementStat("bytes_sent", params.BytesSent) + + if params.IsError { + p.IncrementStat("errors", 1) + } + + return nil +} + +// ResetStats 重置统计数据 +// 会被自动注册为"resetstats"操作 +func (p *StatsPlugin) ResetStats() error { + p.mu.Lock() + defer p.mu.Unlock() + + p.stats = make(map[string]int64) + p.stats["requests"] = 0 + p.stats["errors"] = 0 + p.stats["bytes_sent"] = 0 + p.stats["bytes_received"] = 0 + p.startTime = time.Now() + + return nil +} + +// GenerateStatsReport 生成统计报告 +// 会被自动注册为"generatestatsreport"操作 +func (p *StatsPlugin) GenerateStatsReport() (map[string]interface{}, error) { p.mu.RLock() defer p.mu.RUnlock() @@ -215,10 +195,29 @@ func (p *StatsPlugin) generateStatsReport() map[string]interface{} { report["error_rate"] = float64(p.stats["errors"]) * 100 / float64(p.stats["requests"]) } - return report + return report, nil +} + +// GetAllStats 获取所有统计数据 +// 会被自动注册为"getallstats"操作 +func (p *StatsPlugin) GetAllStats() (map[string]int64, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + // 创建一个副本 + statsCopy := make(map[string]int64, len(p.stats)) + for k, v := range p.stats { + statsCopy[k] = v + } + + // 添加运行时间 + statsCopy["uptime_seconds"] = int64(time.Since(p.startTime).Seconds()) + + return statsCopy, nil } // logStats 记录当前统计信息 +// 不会被注册为操作,因为它是内部方法 func (p *StatsPlugin) logStats() { p.mu.RLock() defer p.mu.RUnlock() @@ -240,55 +239,10 @@ func (p *StatsPlugin) logStats() { fmt.Printf("=======================\n") } -// IncrementStat 增加统计值 -func (p *StatsPlugin) IncrementStat(name string, value int64) { - p.mu.Lock() - defer p.mu.Unlock() - - if _, exists := p.stats[name]; exists { - p.stats[name] += value - } else { - p.stats[name] = value - } -} - -// GetStat 获取统计值 -func (p *StatsPlugin) GetStat(name string) int64 { - p.mu.RLock() - defer p.mu.RUnlock() - - if value, exists := p.stats[name]; exists { - return value - } - return 0 -} - -// RecordRequest 记录请求 -func (p *StatsPlugin) RecordRequest(bytesReceived, bytesSent int64, isError bool) { - p.IncrementStat("requests", 1) - p.IncrementStat("bytes_received", bytesReceived) - p.IncrementStat("bytes_sent", bytesSent) - - if isError { - p.IncrementStat("errors", 1) - } -} - -// GetAllStats 获取所有统计数据 -func (p *StatsPlugin) GetAllStats() map[string]int64 { - p.mu.RLock() - defer p.mu.RUnlock() - - // 创建一个副本 - statsCopy := make(map[string]int64, len(p.stats)) - for k, v := range p.stats { - statsCopy[k] = v - } - - // 添加运行时间 - statsCopy["uptime_seconds"] = int64(time.Since(p.startTime).Seconds()) - - return statsCopy +// GetAvailableOperations 获取可用操作列表 +// 这是一个帮助方法,列出所有可通过Execute调用的操作 +func (p *StatsPlugin) GetAvailableOperations() []string { + return p.GetAvailableActions() } // main 函数是必须的,但不会被调用 diff --git a/examples/plugin/test/run_autoreflect_test.sh b/examples/plugin/test/run_autoreflect_test.sh new file mode 100644 index 0000000..aefc7f8 --- /dev/null +++ b/examples/plugin/test/run_autoreflect_test.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# 编译并运行自动方法注册测试 + +echo "===== 编译自动方法注册测试 =====" +go build -o test_autoreflect test_autoreflect.go + +echo "" +echo "===== 运行自动方法注册测试 =====" +./test_autoreflect + +# 运行完成后清理 +echo "" +echo "===== 清理 =====" +rm -f test_autoreflect +echo "完成" \ No newline at end of file diff --git a/examples/plugin/test/test_autoreflect.go b/examples/plugin/test/test_autoreflect.go new file mode 100644 index 0000000..7b5d87e --- /dev/null +++ b/examples/plugin/test/test_autoreflect.go @@ -0,0 +1,207 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/darkit/goproxy/examples/plugin" +) + +// TestPlugin 用于测试自动方法注册的插件 +type TestPlugin struct { + *plugin.BasePluginImpl + counter int +} + +// TestParams 测试参数结构体 +type TestParams struct { + Name string `json:"name"` + Value int `json:"value"` + Enabled bool `json:"enabled"` + Tags []string `json:"tags"` + Meta map[string]string `json:"meta"` +} + +// TestResult 测试结果结构体 +type TestResult struct { + Status string `json:"status"` + Count int `json:"count"` + Time time.Time `json:"time"` + Results interface{} `json:"results"` +} + +// NewTestPlugin 创建测试插件 +func NewTestPlugin() *TestPlugin { + return &TestPlugin{ + BasePluginImpl: plugin.NewPlugin( + "TestPlugin", + "1.0.0", + "用于测试自动方法注册的插件", + "开发者", + plugin.PluginTypeUtils, + ), + counter: 0, + } +} + +// 以下方法都会被自动注册为可通过Execute调用的操作 + +// SimpleMethod 简单方法,没有参数 +func (p *TestPlugin) SimpleMethod() string { + p.counter++ + return "SimpleMethod执行成功" +} + +// StringMethod 简单字符串方法 +func (p *TestPlugin) StringMethod(text string) string { + p.counter++ + return fmt.Sprintf("收到的文本: %s", text) +} + +// AddNumbers 带返回值的简单方法 +func (p *TestPlugin) AddNumbers(a, b int) int { + p.counter++ + return a + b +} + +// WithContext 带Context参数的方法 +func (p *TestPlugin) WithContext(ctx context.Context, timeout int) string { + p.counter++ + + select { + case <-time.After(time.Duration(timeout) * time.Millisecond): + return "操作完成" + case <-ctx.Done(): + return "操作被取消" + } +} + +// StructParams 接收结构体参数的方法 +func (p *TestPlugin) StructParams(params TestParams) TestResult { + p.counter++ + + return TestResult{ + Status: "success", + Count: p.counter, + Time: time.Now(), + Results: params, + } +} + +// MultipleReturns 多返回值方法 +func (p *TestPlugin) MultipleReturns(value int) (int, string, error) { + p.counter++ + + if value < 0 { + return 0, "", fmt.Errorf("值不能为负数") + } + + return value * 2, fmt.Sprintf("处理了值: %d", value), nil +} + +// WithError 返回错误的方法 +func (p *TestPlugin) WithError(shouldError bool) (string, error) { + p.counter++ + + if shouldError { + return "", fmt.Errorf("请求的错误") + } + + return "没有错误", nil +} + +// GetCounter 获取计数器值 +func (p *TestPlugin) GetCounter() int { + return p.counter +} + +func main() { + // 创建上下文 + ctx := context.Background() + + // 创建测试插件 + plugin := NewTestPlugin() + + fmt.Println("===== 测试自动方法注册与调用机制 =====\n") + + // 获取可用的操作列表 + fmt.Println("可用操作列表:") + actions := plugin.GetAvailableActions() + for i, action := range actions { + fmt.Printf(" %d. %s\n", i+1, action) + } + fmt.Println() + + // 测试1: 调用简单方法 + fmt.Println("测试1: 调用简单方法") + result1, err := plugin.Execute(ctx, "simplemethod", nil) + fmt.Printf("结果: %v, 错误: %v\n\n", result1, err) + + // 测试2: 调用带字符串参数的方法 + fmt.Println("测试2: 调用带字符串参数的方法") + result2, err := plugin.Execute(ctx, "stringmethod", map[string]interface{}{ + "text": "Hello, 反射!", + }) + fmt.Printf("结果: %v, 错误: %v\n\n", result2, err) + + // 测试3: 调用带数字参数的方法 + fmt.Println("测试3: 调用带数字参数的方法") + result3, err := plugin.Execute(ctx, "addnumbers", map[string]interface{}{ + "a": 5, + "b": 7, + }) + fmt.Printf("结果: %v, 错误: %v\n\n", result3, err) + + // 测试4: 调用带Context参数的方法 + fmt.Println("测试4: 调用带Context参数的方法") + result4, err := plugin.Execute(ctx, "withcontext", map[string]interface{}{ + "timeout": 100, // 100毫秒 + }) + fmt.Printf("结果: %v, 错误: %v\n\n", result4, err) + + // 测试5: 使用结构体参数 + fmt.Println("测试5: 使用结构体参数") + result5, err := plugin.Execute(ctx, "structparams", map[string]interface{}{ + "name": "测试名称", + "value": 42, + "enabled": true, + "tags": []string{"tag1", "tag2", "tag3"}, + "meta": map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }) + fmt.Printf("结果: %+v, 错误: %v\n\n", result5, err) + + // 测试6: 多返回值方法 + fmt.Println("测试6: 多返回值方法") + result6, err := plugin.Execute(ctx, "multiplereturns", map[string]interface{}{ + "value": 10, + }) + fmt.Printf("结果: %+v, 错误: %v\n\n", result6, err) + + // 测试7: 返回错误的方法 + fmt.Println("测试7: 返回错误的方法") + result7a, err := plugin.Execute(ctx, "witherror", map[string]interface{}{ + "shouldError": false, + }) + fmt.Printf("无错误情况: %v, 错误: %v\n", result7a, err) + + result7b, err := plugin.Execute(ctx, "witherror", map[string]interface{}{ + "shouldError": true, + }) + fmt.Printf("有错误情况: %v, 错误: %v\n\n", result7b, err) + + // 测试8: 获取计数器值 + fmt.Println("测试8: 获取计数器值 (应该是7)") + result8, err := plugin.Execute(ctx, "getcounter", nil) + fmt.Printf("计数器值: %v, 错误: %v\n\n", result8, err) + + // 测试9: 调用不存在的方法 + fmt.Println("测试9: 调用不存在的方法") + result9, err := plugin.Execute(ctx, "nonexistentmethod", nil) + fmt.Printf("结果: %v, 错误: %v\n\n", result9, err) + + fmt.Println("===== 测试完成 =====") +}