diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 70ad7086..e7d8f6d6 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -40,8 +40,10 @@ jobs: VALIDATE_TERRAGRUNT: false VALIDATE_DOCKERFILE_HADOLINT: false VALIDATE_TRIVY: false - # Prettier and StandardJS are incompatible + # Prettier, Biome and StandardJS are incompatible VALIDATE_JAVASCRIPT_PRETTIER: false VALIDATE_TYPESCRIPT_PRETTIER: false + VALIDATE_BIOME_FORMAT: false + VALIDATE_BIOME_LINT: false # Conflicts with MARKDOWN VALIDATE_MARKDOWN_PRETTIER: false diff --git a/caddy/extinit.go b/caddy/extinit.go index 6a944be3..f0327f57 100644 --- a/caddy/extinit.go +++ b/caddy/extinit.go @@ -2,12 +2,13 @@ package caddy import ( "errors" - "github.com/dunglas/frankenphp/internal/extgen" "log" "os" "path/filepath" "strings" + "github.com/dunglas/frankenphp/internal/extgen" + caddycmd "github.com/caddyserver/caddy/v2/cmd" "github.com/spf13/cobra" ) @@ -27,27 +28,21 @@ Initializes a PHP extension from a Go file. This command generates the necessary }) } -func cmdInitExtension(fs caddycmd.Flags) (int, error) { +func cmdInitExtension(_ caddycmd.Flags) (int, error) { if len(os.Args) < 3 { return 1, errors.New("the path to the Go source is required") } sourceFile := os.Args[2] + baseName := extgen.SanitizePackageName(strings.TrimSuffix(filepath.Base(sourceFile), ".go")) - baseName := strings.TrimSuffix(filepath.Base(sourceFile), ".go") - - baseName = extgen.SanitizePackageName(baseName) - - sourceDir := filepath.Dir(sourceFile) - buildDir := filepath.Join(sourceDir, "build") - - generator := extgen.Generator{BaseName: baseName, SourceFile: sourceFile, BuildDir: buildDir} + generator := extgen.Generator{BaseName: baseName, SourceFile: sourceFile, BuildDir: filepath.Dir(sourceFile)} if err := generator.Generate(); err != nil { return 1, err } - log.Printf("PHP extension %q initialized successfully in %q", baseName, generator.BuildDir) + log.Printf("PHP extension %q initialized successfully in directory %q", baseName, generator.BuildDir) return 0, nil } diff --git a/cgi.go b/cgi.go index 8d98f75e..128c9bc9 100644 --- a/cgi.go +++ b/cgi.go @@ -214,8 +214,10 @@ func go_register_variables(threadIndex C.uintptr_t, trackVarsArray *C.zval) { thread := phpThreads[threadIndex] fc := thread.getRequestContext() - addKnownVariablesToServer(thread, fc, trackVarsArray) - addHeadersToServer(fc, trackVarsArray) + if fc.request != nil { + addKnownVariablesToServer(thread, fc, trackVarsArray) + addHeadersToServer(fc, trackVarsArray) + } // The Prepared Environment is registered last and can overwrite any previous values addPreparedEnvToServer(fc, trackVarsArray) @@ -280,6 +282,10 @@ func go_update_request_info(threadIndex C.uintptr_t, info *C.sapi_request_info) fc := thread.getRequestContext() request := fc.request + if request == nil { + return C.bool(fc.worker != nil) + } + authUser, authPassword, ok := request.BasicAuth() if ok { if authPassword != "" { diff --git a/context.go b/context.go index 2e897cd5..b039feba 100644 --- a/context.go +++ b/context.go @@ -42,13 +42,18 @@ func fromContext(ctx context.Context) (fctx *frankenPHPContext, ok bool) { return } -// NewRequestWithContext creates a new FrankenPHP request context. -func NewRequestWithContext(r *http.Request, opts ...RequestOption) (*http.Request, error) { - fc := &frankenPHPContext{ +func newFrankenPHPContext() *frankenPHPContext { + return &frankenPHPContext{ done: make(chan any), startedAt: time.Now(), - request: r, } +} + +// NewRequestWithContext creates a new FrankenPHP request context. +func NewRequestWithContext(r *http.Request, opts ...RequestOption) (*http.Request, error) { + fc := newFrankenPHPContext() + fc.request = r + for _, o := range opts { if err := o(fc); err != nil { return nil, err @@ -132,6 +137,10 @@ func (fc *frankenPHPContext) validate() bool { } func (fc *frankenPHPContext) clientHasClosed() bool { + if fc.request == nil { + return false + } + select { case <-fc.request.Context().Done(): return true diff --git a/docs/cn/extensions.md b/docs/cn/extensions.md index 4203dc5e..883c52f9 100644 --- a/docs/cn/extensions.md +++ b/docs/cn/extensions.md @@ -146,11 +146,11 @@ func process_data(arr *C.zval) unsafe.Pointer { **可用方法:** -- `SetInt(key int64, value interface{})` - 使用整数键设置值 -- `SetString(key string, value interface{})` - 使用字符串键设置值 -- `Append(value interface{})` - 使用下一个可用整数键添加值 +- `SetInt(key int64, value any)` - 使用整数键设置值 +- `SetString(key string, value any)` - 使用字符串键设置值 +- `Append(value any)` - 使用下一个可用整数键添加值 - `Len() uint32` - 获取元素数量 -- `At(index uint32) (PHPKey, interface{})` - 获取索引处的键值对 +- `At(index uint32) (PHPKey, any)` - 获取索引处的键值对 - `frankenphp.PHPArray(arr *frankenphp.Array) unsafe.Pointer` - 转换为 PHP 数组 ### 声明原生 PHP 类 diff --git a/docs/extensions.md b/docs/extensions.md index dff1844e..27865d4c 100644 --- a/docs/extensions.md +++ b/docs/extensions.md @@ -33,7 +33,7 @@ As covered in the manual implementation section below as well, you need to [get The first step to writing a PHP extension in Go is to create a new Go module. You can use the following command for this: ```console -go mod init github.com/my-account/my-module +go mod init example.com/example ``` The second step is to [get the PHP sources](https://www.php.net/downloads.php) for the next steps. Once you have them, decompress them into the directory of your choice, not inside your Go module: @@ -47,10 +47,14 @@ tar xf php-* Everything is now setup to write your native function in Go. Create a new file named `stringext.go`. Our first function will take a string as an argument, the number of times to repeat it, a boolean to indicate whether to reverse the string, and return the resulting string. This should look like this: ```go +package example + +// #include +import "C" import ( - "C" - "github.com/dunglas/frankenphp" "strings" + + "github.com/dunglas/frankenphp" ) //export_php:function repeat_this(string $str, int $count, bool $reverse): string @@ -98,6 +102,7 @@ This table summarizes what you need to know: | `object` | `struct` | ❌ | _Not yet implemented_ | _Not yet implemented_ | ❌ | > [!NOTE] +> > This table is not exhaustive yet and will be completed as the FrankenPHP types API gets more complete. > > For class methods specifically, primitive types and arrays are currently supported. Objects cannot be used as method parameters or return types yet. @@ -115,6 +120,16 @@ If order or association are not needed, it's also possible to directly convert t **Creating and manipulating arrays in Go:** ```go +package example + +// #include +import "C" +import ( + "unsafe" + + "github.com/dunglas/frankenphp" +) + // export_php:function process_data_ordered(array $input): array func process_data_ordered_map(arr *C.zval) unsafe.Pointer { // Convert PHP associative array to Go while keeping the order @@ -128,7 +143,7 @@ func process_data_ordered_map(arr *C.zval) unsafe.Pointer { // return an ordered array // if 'Order' is not empty, only the key-value pairs in 'Order' will be respected - return frankenphp.PHPAssociativeArray(AssociativeArray{ + return frankenphp.PHPAssociativeArray(frankenphp.AssociativeArray{ Map: map[string]any{ "key1": "value1", "key2": "value2", @@ -192,6 +207,8 @@ func process_data_packed(arr *C.zval) unsafe.Pointer { The generator supports declaring **opaque classes** as Go structs, which can be used to create PHP objects. You can use the `//export_php:class` directive comment to define a PHP class. For example: ```go +package example + //export_php:class User type UserStruct struct { Name string @@ -216,6 +233,16 @@ This approach provides better encapsulation and prevents PHP code from accidenta Since properties are not directly accessible, you **must define methods** to interact with your opaque classes. Use the `//export_php:method` directive to define behavior: ```go +package example + +// #include +import "C" +import ( + "unsafe" + + "github.com/dunglas/frankenphp" +) + //export_php:class User type UserStruct struct { Name string @@ -248,6 +275,16 @@ func (us *UserStruct) SetNamePrefix(prefix *C.zend_string) { The generator supports nullable parameters using the `?` prefix in PHP signatures. When a parameter is nullable, it becomes a pointer in your Go function, allowing you to check if the value was `null` in PHP: ```go +package example + +// #include +import "C" +import ( + "unsafe" + + "github.com/dunglas/frankenphp" +) + //export_php:method User::updateInfo(?string $name, ?int $age, ?bool $active): void func (us *UserStruct) UpdateInfo(name *C.zend_string, age *int64, active *bool) { // Check if name was provided (not null) @@ -275,6 +312,7 @@ func (us *UserStruct) UpdateInfo(name *C.zend_string, age *int64, active *bool) - **PHP `null` becomes Go `nil`** - when PHP passes `null`, your Go function receives a `nil` pointer > [!WARNING] +> > Currently, class methods have the following limitations. **Objects are not supported** as parameter types or return types. **Arrays are fully supported** for both parameters and return types. Supported types: `string`, `int`, `float`, `bool`, `array`, and `void` (for return type). **Nullable parameter types are fully supported** for all scalar types (`?string`, `?int`, `?float`, `?bool`). After generating the extension, you will be allowed to use the class and its methods in PHP. Note that you **cannot access properties directly**: @@ -311,6 +349,8 @@ The generator supports exporting Go constants to PHP using two directives: `//ex Use the `//export_php:const` directive to create global PHP constants: ```go +package example + //export_php:const const MAX_CONNECTIONS = 100 @@ -329,6 +369,8 @@ const STATUS_ERROR = iota Use the `//export_php:classconstant ClassName` directive to create constants that belong to a specific PHP class: ```go +package example + //export_php:classconstant User const STATUS_ACTIVE = 1 @@ -368,10 +410,15 @@ The directive supports various value types including strings, integers, booleans You can use constants just like you are used to in the Go code. For example, let's take the `repeat_this()` function we declared earlier and change the last argument to an integer: ```go +package example + +// #include +import "C" import ( - "C" - "github.com/dunglas/frankenphp" - "strings" + "strings" + "unsafe" + + "github.com/dunglas/frankenphp" ) //export_php:const @@ -388,37 +435,37 @@ const MODE_UPPERCASE = 2 //export_php:function repeat_this(string $str, int $count, int $mode): string func repeat_this(s *C.zend_string, count int64, mode int) unsafe.Pointer { - str := frankenphp.GoString(unsafe.Pointer(s)) + str := frankenphp.GoString(unsafe.Pointer(s)) - result := strings.Repeat(str, int(count)) - if mode == STR_REVERSE { - // reverse the string - } + result := strings.Repeat(str, int(count)) + if mode == STR_REVERSE { + // reverse the string + } - if mode == STR_NORMAL { - // no-op, just to showcase the constant - } + if mode == STR_NORMAL { + // no-op, just to showcase the constant + } - return frankenphp.PHPString(result, false) + return frankenphp.PHPString(result, false) } //export_php:class StringProcessor type StringProcessorStruct struct { - // internal fields + // internal fields } //export_php:method StringProcessor::process(string $input, int $mode): string func (sp *StringProcessorStruct) Process(input *C.zend_string, mode int64) unsafe.Pointer { - str := frankenphp.GoString(unsafe.Pointer(input)) + str := frankenphp.GoString(unsafe.Pointer(input)) - switch mode { - case MODE_LOWERCASE: - str = strings.ToLower(str) - case MODE_UPPERCASE: - str = strings.ToUpper(str) - } + switch mode { + case MODE_LOWERCASE: + str = strings.ToLower(str) + case MODE_UPPERCASE: + str = strings.ToUpper(str) + } - return frankenphp.PHPString(str, false) + return frankenphp.PHPString(str, false) } ``` @@ -432,9 +479,13 @@ Use the `//export_php:namespace` directive at the top of your Go file to place a ```go //export_php:namespace My\Extension -package main +package example -import "C" +import ( + "unsafe" + + "github.com/dunglas/frankenphp" +) //export_php:function hello(): string func hello() string { @@ -537,25 +588,26 @@ We'll see how to write a simple PHP extension in Go that defines a new native fu In your module, you need to define a new native function that will be called from PHP. To do this, create a file with the name you want, for example, `extension.go`, and add the following code: ```go -package ext_go +package example -//#include "extension.h" +// #include "extension.h" import "C" import ( - "unsafe" - "github.com/caddyserver/caddy/v2" - "github.com/dunglas/frankenphp" + "log/slog" + "unsafe" + + "github.com/dunglas/frankenphp" ) func init() { - frankenphp.RegisterExtension(unsafe.Pointer(&C.ext_module_entry)) + frankenphp.RegisterExtension(unsafe.Pointer(&C.ext_module_entry)) } //export go_print_something func go_print_something() { - go func() { - caddy.Log().Info("Hello from a goroutine!") - }() + go func() { + slog.Info("Hello from a goroutine!") + }() } ``` @@ -731,7 +783,16 @@ There's only one thing left to do: implement the `go_upper` function in Go. Our Go function will take a `*C.zend_string` as a parameter, convert it to a Go string using FrankenPHP's helper function, process it, and return the result as a new `*C.zend_string`. The helper functions handle all the memory management and conversion complexity for us. ```go -import "strings" +package example + +// #include +import "C" +import ( + "unsafe" + "strings" + + "github.com/dunglas/frankenphp" +) //export go_upper func go_upper(s *C.zend_string) *C.zend_string { @@ -743,9 +804,12 @@ func go_upper(s *C.zend_string) *C.zend_string { } ``` -This approach is much cleaner and safer than manual memory management. FrankenPHP's helper functions handle the conversion between PHP's `zend_string` format and Go strings automatically. The `false` parameter in `PHPString()` indicates that we want to create a new non-persistent string (freed at the end of the request). +This approach is much cleaner and safer than manual memory management. +FrankenPHP's helper functions handle the conversion between PHP's `zend_string` format and Go strings automatically. +The `false` parameter in `PHPString()` indicates that we want to create a new non-persistent string (freed at the end of the request). > [!TIP] +> > In this example, we don't perform any error handling, but you should always check that pointers are not `nil` and that the data is valid before using it in your Go functions. ### Integrating the Extension into FrankenPHP diff --git a/docs/fr/extensions.md b/docs/fr/extensions.md index 8464bd09..1dacbde5 100644 --- a/docs/fr/extensions.md +++ b/docs/fr/extensions.md @@ -146,11 +146,11 @@ func process_data(arr *C.zval) unsafe.Pointer { **Méthodes disponibles :** -- `SetInt(key int64, value interface{})` - Définir une valeur avec une clé entière -- `SetString(key string, value interface{})` - Définir une valeur avec une clé chaîne -- `Append(value interface{})` - Ajouter une valeur avec la prochaine clé entière disponible +- `SetInt(key int64, value any)` - Définir une valeur avec une clé entière +- `SetString(key string, value any)` - Définir une valeur avec une clé chaîne +- `Append(value any)` - Ajouter une valeur avec la prochaine clé entière disponible - `Len() uint32` - Obtenir le nombre d'éléments -- `At(index uint32) (PHPKey, interface{})` - Obtenir la paire clé-valeur à l'index +- `At(index uint32) (PHPKey, any)` - Obtenir la paire clé-valeur à l'index - `frankenphp.PHPArray(arr *frankenphp.Array) unsafe.Pointer` - Convertir vers un tableau PHP ### Déclarer une Classe PHP Native diff --git a/docs/worker.md b/docs/worker.md index 64aed5c1..169e0a63 100644 --- a/docs/worker.md +++ b/docs/worker.md @@ -78,9 +78,15 @@ $myApp->boot(); // Handler outside the loop for better performance (doing less work) $handler = static function () use ($myApp) { - // Called when a request is received, - // superglobals, php://input and the like are reset - echo $myApp->handle($_GET, $_POST, $_COOKIE, $_FILES, $_SERVER); + try { + // Called when a request is received, + // superglobals, php://input and the like are reset + echo $myApp->handle($_GET, $_POST, $_COOKIE, $_FILES, $_SERVER); + } catch (\Throwable $exception) { + // `set_exception_handler` is called only when the worker script ends, + // which may not be what you expect, so catch and handle exceptions here + (new \MyCustomExceptionHandler)->handleException($exception); + } }; $maxRequests = (int)($_SERVER['MAX_REQUESTS'] ?? 0); diff --git a/frankenphp.c b/frankenphp.c index 8a79ddab..75876fd6 100644 --- a/frankenphp.c +++ b/frankenphp.c @@ -1131,8 +1131,7 @@ static char **cli_argv; * Parts based on CGI SAPI Module by Rasmus Lerdorf, Stig * Bakken and Zeev Suraski */ -static void cli_register_file_handles(bool no_close) /* {{{ */ -{ +static void cli_register_file_handles(void) { php_stream *s_in, *s_out, *s_err; php_stream_context *sc_in = NULL, *sc_out = NULL, *sc_err = NULL; zend_constant ic, oc, ec; @@ -1141,6 +1140,17 @@ static void cli_register_file_handles(bool no_close) /* {{{ */ s_out = php_stream_open_wrapper_ex("php://stdout", "wb", 0, NULL, sc_out); s_err = php_stream_open_wrapper_ex("php://stderr", "wb", 0, NULL, sc_err); + /* Release stream resources, but don't free the underlying handles. Othewrise, + * extensions which write to stderr or company during mshutdown/gshutdown + * won't have the expected functionality. + */ + if (s_in) + s_in->flags |= PHP_STREAM_FLAG_NO_RSCR_DTOR_CLOSE; + if (s_out) + s_out->flags |= PHP_STREAM_FLAG_NO_RSCR_DTOR_CLOSE; + if (s_err) + s_err->flags |= PHP_STREAM_FLAG_NO_RSCR_DTOR_CLOSE; + if (s_in == NULL || s_out == NULL || s_err == NULL) { if (s_in) php_stream_close(s_in); @@ -1151,12 +1161,6 @@ static void cli_register_file_handles(bool no_close) /* {{{ */ return; } - if (no_close) { - s_in->flags |= PHP_STREAM_FLAG_NO_CLOSE; - s_out->flags |= PHP_STREAM_FLAG_NO_CLOSE; - s_err->flags |= PHP_STREAM_FLAG_NO_CLOSE; - } - /*s_in_process = s_in;*/ php_stream_to_zval(s_in, &ic.value); @@ -1175,7 +1179,6 @@ static void cli_register_file_handles(bool no_close) /* {{{ */ ec.name = zend_string_init_interned("STDERR", sizeof("STDERR") - 1, 0); zend_register_constant(&ec); } -/* }}} */ static void sapi_cli_register_variables(zval *track_vars_array) /* {{{ */ { @@ -1220,7 +1223,7 @@ static void *execute_script_cli(void *arg) { php_embed_init(cli_argc, cli_argv); - cli_register_file_handles(false); + cli_register_file_handles(); zend_first_try { if (eval) { /* evaluate the cli_script as literal PHP code (php-cli -r "...") */ diff --git a/frankenphp.go b/frankenphp.go index 5c2044f0..3d77acc5 100644 --- a/frankenphp.go +++ b/frankenphp.go @@ -222,7 +222,7 @@ func Init(options ...Option) error { registerExtensions() // add registered external workers - for _, ew := range externalWorkers { + for _, ew := range extensionWorkers { options = append(options, WithWorkers(ew.Name(), ew.FileName(), ew.GetMinThreads(), WithWorkerEnv(ew.Env()))) } @@ -405,7 +405,7 @@ func go_apache_request_headers(threadIndex C.uintptr_t) (*C.go_string, C.size_t) if fc.responseWriter == nil { // worker mode, not handling a request - logger.LogAttrs(context.Background(), slog.LevelDebug, "apache_request_headers() called in non-HTTP context", slog.String("worker", fc.scriptFilename)) + logger.LogAttrs(context.Background(), slog.LevelDebug, "apache_request_headers() called in non-HTTP context", slog.String("worker", fc.worker.name)) return nil, 0 } @@ -550,8 +550,12 @@ func go_read_post(threadIndex C.uintptr_t, cBuf *C.char, countBytes C.size_t) (r //export go_read_cookies func go_read_cookies(threadIndex C.uintptr_t) *C.char { - cookies := phpThreads[threadIndex].getRequestContext().request.Header.Values("Cookie") - cookie := strings.Join(cookies, "; ") + request := phpThreads[threadIndex].getRequestContext().request + if request == nil { + return nil + } + + cookie := strings.Join(request.Header.Values("Cookie"), "; ") if cookie == "" { return nil } diff --git a/internal/extgen/cfile_namespace_test.go b/internal/extgen/cfile_namespace_test.go index b5f7ee87..954bace4 100644 --- a/internal/extgen/cfile_namespace_test.go +++ b/internal/extgen/cfile_namespace_test.go @@ -1,10 +1,11 @@ package extgen import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "os" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNamespacedClassName(t *testing.T) { diff --git a/internal/extgen/cfile_phpmethod_test.go b/internal/extgen/cfile_phpmethod_test.go index 1c952578..b8633bd2 100644 --- a/internal/extgen/cfile_phpmethod_test.go +++ b/internal/extgen/cfile_phpmethod_test.go @@ -1,8 +1,9 @@ package extgen import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestCFile_NamespacedPHPMethods(t *testing.T) { diff --git a/internal/extgen/cfile_test.go b/internal/extgen/cfile_test.go index 63b696ff..f9d8c749 100644 --- a/internal/extgen/cfile_test.go +++ b/internal/extgen/cfile_test.go @@ -1,12 +1,12 @@ package extgen import ( - "github.com/stretchr/testify/require" "path/filepath" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCFileGenerator_Generate(t *testing.T) { diff --git a/internal/extgen/classparser.go b/internal/extgen/classparser.go index 273df6a1..caef0ea2 100644 --- a/internal/extgen/classparser.go +++ b/internal/extgen/classparser.go @@ -177,22 +177,23 @@ func (cp *classParser) typeToString(expr ast.Expr) string { case *ast.MapType: return "map[" + cp.typeToString(t.Key) + "]" + cp.typeToString(t.Value) default: - return "interface{}" + return "any" } } +var goToPhpTypeMap = map[string]phpType{ + "string": phpString, + "int": phpInt, "int64": phpInt, "int32": phpInt, "int16": phpInt, "int8": phpInt, + "uint": phpInt, "uint64": phpInt, "uint32": phpInt, "uint16": phpInt, "uint8": phpInt, + "float64": phpFloat, "float32": phpFloat, + "bool": phpBool, + "any": phpMixed, +} + func (cp *classParser) goTypeToPHPType(goType string) phpType { goType = strings.TrimPrefix(goType, "*") - typeMap := map[string]phpType{ - "string": phpString, - "int": phpInt, "int64": phpInt, "int32": phpInt, "int16": phpInt, "int8": phpInt, - "uint": phpInt, "uint64": phpInt, "uint32": phpInt, "uint16": phpInt, "uint8": phpInt, - "float64": phpFloat, "float32": phpFloat, - "bool": phpBool, - } - - if phpType, exists := typeMap[goType]; exists { + if phpType, exists := goToPhpTypeMap[goType]; exists { return phpType } @@ -244,7 +245,7 @@ func (cp *classParser) parseMethods(filename string) (methods []phpClassMethod, IsReturnNullable: method.isReturnNullable, } - if err := validator.validateScalarTypes(phpFunc); err != nil { + if err := validator.validateTypes(phpFunc); err != nil { fmt.Printf("Warning: Method \"%s::%s\" uses unsupported types: %v\n", className, method.Name, err) continue diff --git a/internal/extgen/classparser_test.go b/internal/extgen/classparser_test.go index a05e5cbc..11454c48 100644 --- a/internal/extgen/classparser_test.go +++ b/internal/extgen/classparser_test.go @@ -1,12 +1,12 @@ package extgen import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestClassParser(t *testing.T) { @@ -282,7 +282,7 @@ func TestGoTypeToPHPType(t *testing.T) { {"[]string", phpArray}, {"map[string]int", phpArray}, {"*[]int", phpArray}, - {"interface{}", phpMixed}, + {"any", phpMixed}, {"CustomType", phpMixed}, } @@ -335,7 +335,7 @@ type NullableStruct struct { type CollectionStruct struct { StringSlice []string IntMap map[string]int - MixedSlice []interface{} + MixedSlice []any }`, expected: []phpType{phpArray, phpArray, phpArray}, }, @@ -381,7 +381,7 @@ type TestClass struct { } //export_php:method TestClass::arrayMethod(array $data): string -func (tc *TestClass) arrayMethod(data interface{}) unsafe.Pointer { +func (tc *TestClass) arrayMethod(data any) unsafe.Pointer { return nil }`, expectedClasses: 1, @@ -398,7 +398,7 @@ type TestClass struct { } //export_php:method TestClass::objectMethod(object $obj): string -func (tc *TestClass) objectMethod(obj interface{}) unsafe.Pointer { +func (tc *TestClass) objectMethod(obj any) unsafe.Pointer { return nil }`, expectedClasses: 1, @@ -415,7 +415,7 @@ type TestClass struct { } //export_php:method TestClass::mixedMethod(mixed $value): string -func (tc *TestClass) mixedMethod(value interface{}) unsafe.Pointer { +func (tc *TestClass) mixedMethod(value any) unsafe.Pointer { return nil }`, expectedClasses: 1, @@ -432,7 +432,7 @@ type TestClass struct { } //export_php:method TestClass::arrayReturn(string $name): array -func (tc *TestClass) arrayReturn(name *C.zend_string) interface{} { +func (tc *TestClass) arrayReturn(name *C.zend_string) any { return []string{"result"} }`, expectedClasses: 1, @@ -449,8 +449,8 @@ type TestClass struct { } //export_php:method TestClass::objectReturn(string $name): object -func (tc *TestClass) objectReturn(name *C.zend_string) interface{} { - return map[string]interface{}{"key": "value"} +func (tc *TestClass) objectReturn(name *C.zend_string) any { + return map[string]any{"key": "value"} }`, expectedClasses: 1, expectedMethods: 0, diff --git a/internal/extgen/constants_test.go b/internal/extgen/constants_test.go index cf14fcac..b4f028b5 100644 --- a/internal/extgen/constants_test.go +++ b/internal/extgen/constants_test.go @@ -1,12 +1,12 @@ package extgen import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConstantsIntegration(t *testing.T) { diff --git a/internal/extgen/constparser_test.go b/internal/extgen/constparser_test.go index 63c594eb..549bc912 100644 --- a/internal/extgen/constparser_test.go +++ b/internal/extgen/constparser_test.go @@ -1,12 +1,12 @@ package extgen import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConstantParser(t *testing.T) { diff --git a/internal/extgen/docs_test.go b/internal/extgen/docs_test.go index da7479ff..27b2b6af 100644 --- a/internal/extgen/docs_test.go +++ b/internal/extgen/docs_test.go @@ -1,12 +1,12 @@ package extgen import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDocumentationGenerator_Generate(t *testing.T) { diff --git a/internal/extgen/funcparser.go b/internal/extgen/funcparser.go index b2c9f3ec..5e64de1b 100644 --- a/internal/extgen/funcparser.go +++ b/internal/extgen/funcparser.go @@ -50,7 +50,7 @@ func (fp *FuncParser) parse(filename string) (functions []phpFunction, err error continue } - if err := validator.validateScalarTypes(*phpFunc); err != nil { + if err := validator.validateTypes(*phpFunc); err != nil { fmt.Printf("Warning: Function '%s' uses unsupported types: %v\n", phpFunc.Name, err) continue diff --git a/internal/extgen/funcparser_test.go b/internal/extgen/funcparser_test.go index 51e5a3ee..282d23b4 100644 --- a/internal/extgen/funcparser_test.go +++ b/internal/extgen/funcparser_test.go @@ -1,12 +1,12 @@ package extgen import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestFunctionParser(t *testing.T) { @@ -306,7 +306,7 @@ func TestFunctionParserUnsupportedTypes(t *testing.T) { input: `package main //export_php:function arrayFunc(array $data): string -func arrayFunc(data interface{}) unsafe.Pointer { +func arrayFunc(data any) unsafe.Pointer { return String("processed") }`, expected: 0, @@ -317,7 +317,7 @@ func arrayFunc(data interface{}) unsafe.Pointer { input: `package main //export_php:function objectFunc(object $obj): string -func objectFunc(obj interface{}) unsafe.Pointer { +func objectFunc(obj any) unsafe.Pointer { return String("processed") }`, expected: 0, @@ -328,7 +328,7 @@ func objectFunc(obj interface{}) unsafe.Pointer { input: `package main //export_php:function mixedFunc(mixed $value): string -func mixedFunc(value interface{}) unsafe.Pointer { +func mixedFunc(value any) unsafe.Pointer { return String("processed") }`, expected: 0, @@ -339,7 +339,7 @@ func mixedFunc(value interface{}) unsafe.Pointer { input: `package main //export_php:function arrayReturnFunc(string $name): array -func arrayReturnFunc(name *C.zend_string) interface{} { +func arrayReturnFunc(name *C.zend_string) any { return []string{"result"} }`, expected: 0, @@ -350,8 +350,8 @@ func arrayReturnFunc(name *C.zend_string) interface{} { input: `package main //export_php:function objectReturnFunc(string $name): object -func objectReturnFunc(name *C.zend_string) interface{} { - return map[string]interface{}{"key": "value"} +func objectReturnFunc(name *C.zend_string) any { + return map[string]any{"key": "value"} }`, expected: 0, hasWarning: true, diff --git a/internal/extgen/generator.go b/internal/extgen/generator.go index f3c31e81..95879f63 100644 --- a/internal/extgen/generator.go +++ b/internal/extgen/generator.go @@ -5,8 +5,6 @@ import ( "os" ) -const BuildDir = "build" - type Generator struct { BaseName string SourceFile string diff --git a/internal/extgen/gofile.go b/internal/extgen/gofile.go index 43e16218..83bcfd46 100644 --- a/internal/extgen/gofile.go +++ b/internal/extgen/gofile.go @@ -47,7 +47,7 @@ func (gg *GoFileGenerator) buildContent() (string, error) { filteredImports := make([]string, 0, len(imports)) for _, imp := range imports { - if imp != `"C"` { + if imp != `"C"` && imp != `"unsafe"` && imp != `"github.com/dunglas/frankenphp"` { filteredImports = append(filteredImports, imp) } } @@ -104,20 +104,20 @@ type GoParameter struct { Type string } -func (gg *GoFileGenerator) phpTypeToGoType(phpT phpType) string { - typeMap := map[phpType]string{ - phpString: "string", - phpInt: "int64", - phpFloat: "float64", - phpBool: "bool", - phpArray: "*frankenphp.Array", - phpMixed: "interface{}", - phpVoid: "", - } +var phpToGoTypeMap = map[phpType]string{ + phpString: "string", + phpInt: "int64", + phpFloat: "float64", + phpBool: "bool", + phpArray: "*frankenphp.Array", + phpMixed: "any", + phpVoid: "", +} - if goType, exists := typeMap[phpT]; exists { +func (gg *GoFileGenerator) phpTypeToGoType(phpT phpType) string { + if goType, exists := phpToGoTypeMap[phpT]; exists { return goType } - return "interface{}" + return "any" } diff --git a/internal/extgen/gofile_test.go b/internal/extgen/gofile_test.go index ce7fe2c5..baca290a 100644 --- a/internal/extgen/gofile_test.go +++ b/internal/extgen/gofile_test.go @@ -1,13 +1,13 @@ package extgen import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGoFileGenerator_Generate(t *testing.T) { @@ -109,7 +109,7 @@ func test() { contains: []string{ "package simple", `#include "simple.h"`, - "import \"C\"", + `import "C"`, "func init()", "frankenphp.RegisterExtension(", "//export test", @@ -143,11 +143,11 @@ func process(data *go_string) *go_value { }, contains: []string{ "package complex", - `import "fmt"`, - `import "strings"`, - `import "encoding/json"`, + `"fmt"`, + `"strings"`, + `"encoding/json"`, "//export process", - `import "C"`, + `"C"`, }, }, { @@ -193,7 +193,7 @@ func internalFunc2(data string) { require.NoError(t, err) for _, expected := range tt.contains { - assert.Contains(t, content, expected, "Generated Go content should contain '%s'", expected) + assert.Contains(t, content, expected, "Generated Go content should contain %q", expected) } }) } @@ -305,9 +305,9 @@ func test() {}` require.NoError(t, err) expectedImports := []string{ - `import "fmt"`, - `import "strings"`, - `import "github.com/other/package"`, + `"fmt"`, + `"strings"`, + `"github.com/other/package"`, } for _, imp := range expectedImports { @@ -315,10 +315,10 @@ func test() {}` } forbiddenImports := []string{ - `import "C"`, + `"C"`, } - cImportCount := strings.Count(content, `import "C"`) + cImportCount := strings.Count(content, `"C"`) assert.Equal(t, 1, cImportCount, "Expected exactly 1 occurrence of 'import \"C\"'") for _, imp := range forbiddenImports[1:] { @@ -340,7 +340,7 @@ import ( func processData(input *go_string, options *go_nullable) *go_value { data := CStringToGoString(input) processed := internalProcess(data) - return types.Array([]interface{}{processed}) + return types.Array([]any{processed}) } //export_php: validateInput(data string): bool @@ -358,7 +358,7 @@ func validateFormat(input string) bool { return !strings.Contains(input, "invalid") } -func jsonHelper(data interface{}) ([]byte, error) { +func jsonHelper(data any) ([]byte, error) { return json.Marshal(data) } @@ -375,7 +375,7 @@ func debugPrint(msg string) { GoFunction: `func processData(input *go_string, options *go_nullable) *go_value { data := CStringToGoString(input) processed := internalProcess(data) - return Array([]interface{}{processed}) + return Array([]any{processed}) }`, }, { @@ -403,7 +403,7 @@ func debugPrint(msg string) { internalFuncs := []string{ "func internalProcess(data string) string", "func validateFormat(input string) bool", - "func jsonHelper(data interface{}) ([]byte, error)", + "func jsonHelper(data any) ([]byte, error)", "func debugPrint(msg string)", } @@ -510,7 +510,7 @@ import "fmt" //export_php:class ArrayClass type ArrayStruct struct { - data []interface{} + data []any } //export_php:method ArrayClass::processArray(array $items): array @@ -675,10 +675,8 @@ func createTempSourceFile(t *testing.T, content string) string { func testGoFileBasicStructure(t *testing.T, content, baseName string) { requiredElements := []string{ "package " + SanitizePackageName(baseName), - "/*", - "#include ", - `#include "` + baseName + `.h"`, - "*/", + "// #include ", + `// #include "` + baseName + `.h"`, `import "C"`, "func init() {", "frankenphp.RegisterExtension(", @@ -691,7 +689,7 @@ func testGoFileBasicStructure(t *testing.T, content, baseName string) { } func testGoFileImports(t *testing.T, content string) { - cImportCount := strings.Count(content, `import "C"`) + cImportCount := strings.Count(content, `"C"`) assert.Equal(t, 1, cImportCount, "Expected exactly 1 C import") } diff --git a/internal/extgen/hfile_test.go b/internal/extgen/hfile_test.go index 26a60f7f..0798719a 100644 --- a/internal/extgen/hfile_test.go +++ b/internal/extgen/hfile_test.go @@ -5,9 +5,8 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHeaderGenerator_Generate(t *testing.T) { diff --git a/internal/extgen/namespace_test.go b/internal/extgen/namespace_test.go index 7b728280..5f777d55 100644 --- a/internal/extgen/namespace_test.go +++ b/internal/extgen/namespace_test.go @@ -1,10 +1,11 @@ package extgen import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "os" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNamespaceParser(t *testing.T) { diff --git a/internal/extgen/paramparser.go b/internal/extgen/paramparser.go index 7c203a35..8da8895e 100644 --- a/internal/extgen/paramparser.go +++ b/internal/extgen/paramparser.go @@ -68,7 +68,7 @@ func (pp *ParameterParser) generateSingleParamDeclaration(param phpParameter) [] if param.IsNullable { decls = append(decls, fmt.Sprintf("zend_bool %s_is_null = 0;", param.Name)) } - case phpArray: + case phpArray, phpMixed: decls = append(decls, fmt.Sprintf("zval *%s = NULL;", param.Name)) } @@ -119,6 +119,8 @@ func (pp *ParameterParser) generateParamParsingMacro(param phpParameter) string return fmt.Sprintf("\n Z_PARAM_BOOL_OR_NULL(%s, %s_is_null)", param.Name, param.Name) case phpArray: return fmt.Sprintf("\n Z_PARAM_ARRAY_OR_NULL(%s)", param.Name) + case phpMixed: + return fmt.Sprintf("\n Z_PARAM_ZVAL_OR_NULL(%s)", param.Name) default: return "" } @@ -134,6 +136,8 @@ func (pp *ParameterParser) generateParamParsingMacro(param phpParameter) string return fmt.Sprintf("\n Z_PARAM_BOOL(%s)", param.Name) case phpArray: return fmt.Sprintf("\n Z_PARAM_ARRAY(%s)", param.Name) + case phpMixed: + return fmt.Sprintf("\n Z_PARAM_ZVAL(%s)", param.Name) default: return "" } @@ -164,25 +168,19 @@ func (pp *ParameterParser) generateSingleGoCallParam(param phpParameter) string return fmt.Sprintf("%s_is_null ? NULL : &%s", param.Name, param.Name) case phpBool: return fmt.Sprintf("%s_is_null ? NULL : &%s", param.Name, param.Name) - case phpArray: - return param.Name - default: - return param.Name - } - } else { - switch param.PhpType { - case phpString: - return param.Name - case phpInt: - return fmt.Sprintf("(long) %s", param.Name) - case phpFloat: - return fmt.Sprintf("(double) %s", param.Name) - case phpBool: - return fmt.Sprintf("(int) %s", param.Name) - case phpArray: - return param.Name default: return param.Name } } + + switch param.PhpType { + case phpInt: + return fmt.Sprintf("(long) %s", param.Name) + case phpFloat: + return fmt.Sprintf("(double) %s", param.Name) + case phpBool: + return fmt.Sprintf("(int) %s", param.Name) + default: + return param.Name + } } diff --git a/internal/extgen/paramparser_test.go b/internal/extgen/paramparser_test.go index 251719dc..5752c3a5 100644 --- a/internal/extgen/paramparser_test.go +++ b/internal/extgen/paramparser_test.go @@ -163,6 +163,20 @@ func TestParameterParser_GenerateParamDeclarations(t *testing.T) { }, expected: " zend_string *name = NULL;\n zval *items = NULL;\n zend_long count = 5;", }, + { + name: "mixed parameter", + params: []phpParameter{ + {Name: "m", PhpType: phpMixed, HasDefault: false}, + }, + expected: " zval *m = NULL;", + }, + { + name: "nullable mixed parameter", + params: []phpParameter{ + {Name: "m", PhpType: phpMixed, HasDefault: false, IsNullable: true}, + }, + expected: " zval *m = NULL;", + }, } for _, tt := range tests { @@ -346,6 +360,16 @@ func TestParameterParser_GenerateParamParsingMacro(t *testing.T) { param: phpParameter{Name: "items", PhpType: phpArray, IsNullable: true}, expected: "\n Z_PARAM_ARRAY_OR_NULL(items)", }, + { + name: "mixed parameter", + param: phpParameter{Name: "m", PhpType: phpMixed}, + expected: "\n Z_PARAM_ZVAL(m)", + }, + { + name: "nullable mixed parameter", + param: phpParameter{Name: "m", PhpType: phpMixed, IsNullable: true}, + expected: "\n Z_PARAM_ZVAL_OR_NULL(m)", + }, { name: "unknown type", param: phpParameter{Name: "unknown", PhpType: phpType("unknown")}, diff --git a/internal/extgen/phpfunc.go b/internal/extgen/phpfunc.go index 8298bcb5..13cad820 100644 --- a/internal/extgen/phpfunc.go +++ b/internal/extgen/phpfunc.go @@ -50,21 +50,21 @@ func (pfg *PHPFuncGenerator) generateGoCall(fn phpFunction) string { return fmt.Sprintf(" zend_array *result = %s(%s);", fn.Name, callParams) } + if fn.ReturnType == phpMixed { + return fmt.Sprintf(" zval *result = %s(%s);", fn.Name, callParams) + } + return fmt.Sprintf(" %s result = %s(%s);", pfg.getCReturnType(fn.ReturnType), fn.Name, callParams) } func (pfg *PHPFuncGenerator) getCReturnType(returnType phpType) string { switch returnType { - case phpString: - return "zend_string*" case phpInt: return "long" case phpFloat: return "double" case phpBool: return "int" - case phpArray: - return "zend_array*" default: return "void" } diff --git a/internal/extgen/phpfunc_namespace_test.go b/internal/extgen/phpfunc_namespace_test.go index 6ba8855e..04cae786 100644 --- a/internal/extgen/phpfunc_namespace_test.go +++ b/internal/extgen/phpfunc_namespace_test.go @@ -1,8 +1,9 @@ package extgen import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestPHPFuncGenerator_NamespacedFunctions(t *testing.T) { diff --git a/internal/extgen/srcanalyzer_test.go b/internal/extgen/srcanalyzer_test.go index 926591a1..717f99b5 100644 --- a/internal/extgen/srcanalyzer_test.go +++ b/internal/extgen/srcanalyzer_test.go @@ -1,11 +1,12 @@ package extgen import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) diff --git a/internal/extgen/templates/extension.c.tpl b/internal/extgen/templates/extension.c.tpl index 6ae0b347..0dd4608e 100644 --- a/internal/extgen/templates/extension.c.tpl +++ b/internal/extgen/templates/extension.c.tpl @@ -156,7 +156,7 @@ void register_all_classes() { PHP_MINIT_FUNCTION({{.BaseName}}) { {{ if .Classes}}register_all_classes();{{end}} - + {{- range .Constants}} {{- if eq .ClassName ""}} {{if .IsIota}}REGISTER_LONG_CONSTANT("{{.Name}}", {{.Name}}, CONST_CS | CONST_PERSISTENT); @@ -180,4 +180,3 @@ zend_module_entry {{.BaseName}}_module_entry = {STANDARD_MODULE_HEADER, NULL, /* MINFO */ "1.0.0", /* Version */ STANDARD_MODULE_PROPERTIES}; - diff --git a/internal/extgen/templates/extension.go.tpl b/internal/extgen/templates/extension.go.tpl index a3edaa82..ed0e0854 100644 --- a/internal/extgen/templates/extension.go.tpl +++ b/internal/extgen/templates/extension.go.tpl @@ -1,52 +1,55 @@ package {{.PackageName}} -/* -#include -#include "{{.BaseName}}.h" -*/ +// #include +// #include "{{.BaseName}}.h" import "C" +import ( + "unsafe" + + "github.com/dunglas/frankenphp" {{- range .Imports}} -import {{.}} + {{.}} {{- end}} +) func init() { frankenphp.RegisterExtension(unsafe.Pointer(&C.{{.BaseName}}_module_entry)) } -{{- range .Constants}} + +{{ range .Constants}} const {{.Name}} = {{.Value}} -{{- end}} -{{ range .Variables}} +{{- end}} +{{- range .Variables}} + {{.}} {{- end}} - -{{range .InternalFunctions}} +{{- range .InternalFunctions}} {{.}} -{{- end}} +{{- end}} {{- range .Functions}} //export {{.Name}} {{.GoFunction}} -{{- end}} +{{- end}} {{- range .Classes}} type {{.GoStruct}} struct { {{- range .Properties}} {{.Name}} {{.GoType}} {{- end}} } + {{- end}} - {{- if .Classes}} - //export registerGoObject -func registerGoObject(obj interface{}) C.uintptr_t { +func registerGoObject(obj any) C.uintptr_t { handle := cgo.NewHandle(obj) return C.uintptr_t(handle) } //export getGoObject -func getGoObject(handle C.uintptr_t) interface{} { +func getGoObject(handle C.uintptr_t) any { h := cgo.Handle(handle) return h.Value() } @@ -58,7 +61,6 @@ func removeGoObject(handle C.uintptr_t) { } {{- end}} - {{- range $class := .Classes}} //export create_{{.GoStruct}}_object func create_{{.GoStruct}}_object() C.uintptr_t { @@ -70,8 +72,8 @@ func create_{{.GoStruct}}_object() C.uintptr_t { {{- if .GoFunction}} {{.GoFunction}} {{- end}} -{{- end}} +{{- end}} {{- range .Methods}} //export {{.Name}}_wrapper func {{.Name}}_wrapper(handle C.uintptr_t{{range .Params}}{{if eq .PhpType "string"}}, {{.Name}} *C.zend_string{{else if eq .PhpType "array"}}, {{.Name}} *C.zval{{else}}, {{.Name}} {{if .IsNullable}}*{{end}}{{phpTypeToGoType .PhpType}}{{end}}{{end}}){{if not (isVoid .ReturnType)}}{{if isStringOrArray .ReturnType}} unsafe.Pointer{{else}} {{phpTypeToGoType .ReturnType}}{{end}}{{end}} { diff --git a/internal/extgen/utils_namespace_test.go b/internal/extgen/utils_namespace_test.go index 8ac806ba..5f7dd4c3 100644 --- a/internal/extgen/utils_namespace_test.go +++ b/internal/extgen/utils_namespace_test.go @@ -1,8 +1,9 @@ package extgen import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestNamespacedName(t *testing.T) { diff --git a/internal/extgen/validator.go b/internal/extgen/validator.go index 777dd0d6..44fd1380 100644 --- a/internal/extgen/validator.go +++ b/internal/extgen/validator.go @@ -10,26 +10,17 @@ import ( "strings" ) -func scalarTypes() []phpType { - return []phpType{phpString, phpInt, phpFloat, phpBool, phpArray} -} +var ( + paramTypes = []phpType{phpString, phpInt, phpFloat, phpBool, phpArray, phpObject, phpMixed} + returnTypes = []phpType{phpVoid, phpString, phpInt, phpFloat, phpBool, phpArray, phpObject, phpMixed, phpNull, phpTrue, phpFalse} + propTypes = []phpType{phpString, phpInt, phpFloat, phpBool, phpArray, phpObject, phpMixed} + supportedTypes = []phpType{phpString, phpInt, phpFloat, phpBool, phpArray, phpMixed} -func paramTypes() []phpType { - return []phpType{phpString, phpInt, phpFloat, phpBool, phpArray, phpObject, phpMixed} -} - -func returnTypes() []phpType { - return []phpType{phpVoid, phpString, phpInt, phpFloat, phpBool, phpArray, phpObject, phpMixed, phpNull, phpTrue, phpFalse} -} - -func propTypes() []phpType { - return []phpType{phpString, phpInt, phpFloat, phpBool, phpArray, phpObject, phpMixed} -} - -var functionNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) -var parameterNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) -var classNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) -var propNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + functionNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + parameterNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + classNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + propNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) +) type Validator struct{} @@ -64,8 +55,7 @@ func (v *Validator) validateParameter(param phpParameter) error { return fmt.Errorf("invalid parameter name: %s", param.Name) } - validTypes := paramTypes() - if !v.isValidPHPType(param.PhpType, validTypes) { + if !slices.Contains(paramTypes, param.PhpType) { return fmt.Errorf("invalid parameter type: %s", param.PhpType) } @@ -73,8 +63,7 @@ func (v *Validator) validateParameter(param phpParameter) error { } func (v *Validator) validateReturnType(returnType phpType) error { - validReturnTypes := returnTypes() - if !v.isValidPHPType(returnType, validReturnTypes) { + if !slices.Contains(returnTypes, returnType) { return fmt.Errorf("invalid return type: %s", returnType) } return nil @@ -107,43 +96,32 @@ func (v *Validator) validateClassProperty(prop phpClassProperty) error { return fmt.Errorf("invalid property name: %s", prop.Name) } - validTypes := propTypes() - if !v.isValidPHPType(prop.PhpType, validTypes) { + if !slices.Contains(propTypes, prop.PhpType) { return fmt.Errorf("invalid property type: %s", prop.PhpType) } return nil } -func (v *Validator) isValidPHPType(phpType phpType, validTypes []phpType) bool { - return slices.Contains(validTypes, phpType) -} - -// validateScalarTypes checks if PHP signature contains only supported scalar types -func (v *Validator) validateScalarTypes(fn phpFunction) error { - supportedTypes := scalarTypes() - +// validateTypes checks if PHP signature contains only supported types +func (v *Validator) validateTypes(fn phpFunction) error { for i, param := range fn.Params { - if !v.isScalarPHPType(param.PhpType, supportedTypes) { - return fmt.Errorf("parameter %d (%s) has unsupported type '%s'. Only scalar types (string, int, float, bool, array) and their nullable variants are supported", i+1, param.Name, param.PhpType) + if !slices.Contains(supportedTypes, param.PhpType) { + return fmt.Errorf("parameter %d %q has unsupported type %q, supported typed: string, int, float, bool, array and mixed, can be nullable", i+1, param.Name, param.PhpType) } } - if fn.ReturnType != phpVoid && !v.isScalarPHPType(fn.ReturnType, supportedTypes) { - return fmt.Errorf("return type '%s' is not supported. Only scalar types (string, int, float, bool, array), void, and their nullable variants are supported", fn.ReturnType) + if fn.ReturnType != phpVoid && !slices.Contains(supportedTypes, fn.ReturnType) { + return fmt.Errorf("return type %q is not supported, supported typed: string, int, float, bool, array and mixed, can be nullable", fn.ReturnType) } return nil } -func (v *Validator) isScalarPHPType(phpType phpType, supportedTypes []phpType) bool { - return slices.Contains(supportedTypes, phpType) -} - // validateGoFunctionSignatureWithOptions validates with option for method vs function func (v *Validator) validateGoFunctionSignatureWithOptions(phpFunc phpFunction, isMethod bool) error { if phpFunc.GoFunction == "" { - return fmt.Errorf("no Go function found for PHP function '%s'", phpFunc.Name) + return fmt.Errorf("no Go function found for PHP function %q", phpFunc.Name) } fset := token.NewFileSet() @@ -199,7 +177,7 @@ func (v *Validator) validateGoFunctionSignatureWithOptions(phpFunc phpFunction, actualGoType := v.goTypeToString(goParam.Type) if !v.isCompatibleGoType(expectedGoType, actualGoType) { - return fmt.Errorf("parameter %d type mismatch: PHP '%s' requires Go type '%s' but found '%s'", i+1, phpParam.PhpType, expectedGoType, actualGoType) + return fmt.Errorf("parameter %d type mismatch: PHP %q requires Go type %q but found %q", i+1, phpParam.PhpType, expectedGoType, actualGoType) } } } @@ -208,7 +186,7 @@ func (v *Validator) validateGoFunctionSignatureWithOptions(phpFunc phpFunction, actualGoReturnType := v.goReturnTypeToString(goFunc.Type.Results) if !v.isCompatibleGoType(expectedGoReturnType, actualGoReturnType) { - return fmt.Errorf("return type mismatch: PHP '%s' requires Go return type '%s' but found '%s'", phpFunc.ReturnType, expectedGoReturnType, actualGoReturnType) + return fmt.Errorf("return type mismatch: PHP %q requires Go return type %q but found %q", phpFunc.ReturnType, expectedGoReturnType, actualGoReturnType) } return nil @@ -225,10 +203,10 @@ func (v *Validator) phpTypeToGoType(t phpType, isNullable bool) string { baseType = "float64" case phpBool: baseType = "bool" - case phpArray: + case phpArray, phpMixed: baseType = "*C.zval" default: - baseType = "interface{}" + baseType = "any" } if isNullable && t != phpString && t != phpArray { @@ -271,7 +249,7 @@ func (v *Validator) phpReturnTypeToGoType(phpReturnType phpType) string { case phpArray: return "unsafe.Pointer" default: - return "interface{}" + return "any" } } diff --git a/internal/extgen/validator_test.go b/internal/extgen/validator_test.go index 53d941c1..bfd232a1 100644 --- a/internal/extgen/validator_test.go +++ b/internal/extgen/validator_test.go @@ -417,7 +417,7 @@ func TestValidateClass(t *testing.T) { } } -func TestValidateScalarTypes(t *testing.T) { +func TestValidateTypes(t *testing.T) { tests := []struct { name string function phpFunction @@ -494,19 +494,7 @@ func TestValidateScalarTypes(t *testing.T) { }, }, expectError: true, - errorMsg: "parameter 1 (objectParam) has unsupported type 'object'", - }, - { - name: "invalid mixed parameter", - function: phpFunction{ - Name: "mixedFunction", - ReturnType: phpString, - Params: []phpParameter{ - {Name: "mixedParam", PhpType: phpMixed}, - }, - }, - expectError: true, - errorMsg: "parameter 1 (mixedParam) has unsupported type 'mixed'", + errorMsg: `parameter 1 "objectParam" has unsupported type "object"`, }, { name: "invalid object return type", @@ -518,7 +506,7 @@ func TestValidateScalarTypes(t *testing.T) { }, }, expectError: true, - errorMsg: "return type 'object' is not supported", + errorMsg: `return type "object" is not supported`, }, { name: "mixed scalar and invalid parameters", @@ -532,20 +520,20 @@ func TestValidateScalarTypes(t *testing.T) { }, }, expectError: true, - errorMsg: "parameter 2 (invalidParam) has unsupported type 'object'", + errorMsg: `parameter 2 "invalidParam" has unsupported type "object"`, }, } validator := Validator{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validator.validateScalarTypes(tt.function) + err := validator.validateTypes(tt.function) if tt.expectError { - assert.Error(t, err, "validateScalarTypes() should return an error for function %s", tt.function.Name) + assert.Error(t, err, "validateTypes() should return an error for function %s", tt.function.Name) assert.Contains(t, err.Error(), tt.errorMsg, "Error message should contain expected text") } else { - assert.NoError(t, err, "validateScalarTypes() should not return an error for function %s", tt.function.Name) + assert.NoError(t, err, "validateTypes() should not return an error for function %s", tt.function.Name) } }) } @@ -628,7 +616,7 @@ func TestValidateGoFunctionSignature(t *testing.T) { }`, }, expectError: true, - errorMsg: "parameter 2 type mismatch: PHP 'int' requires Go type 'int64' but found 'string'", + errorMsg: `parameter 2 type mismatch: PHP "int" requires Go type "int64" but found "string"`, }, { name: "return type mismatch", @@ -643,7 +631,7 @@ func TestValidateGoFunctionSignature(t *testing.T) { }`, }, expectError: true, - errorMsg: "return type mismatch: PHP 'int' requires Go return type 'int64' but found 'string'", + errorMsg: `return type mismatch: PHP "int" requires Go return type "int64" but found "string"`, }, { name: "valid bool parameter and return", @@ -751,7 +739,7 @@ func TestPhpTypeToGoType(t *testing.T) { {"bool", true, "*bool"}, {"array", false, "*C.zval"}, {"array", true, "*C.zval"}, - {"unknown", false, "interface{}"}, + {"unknown", false, "any"}, } validator := Validator{} @@ -780,7 +768,7 @@ func TestPhpReturnTypeToGoType(t *testing.T) { {"bool", "bool"}, {"array", "unsafe.Pointer"}, {"array", "unsafe.Pointer"}, - {"unknown", "interface{}"}, + {"unknown", "any"}, } validator := Validator{} diff --git a/threadFramework.go b/threadFramework.go deleted file mode 100644 index a18996c7..00000000 --- a/threadFramework.go +++ /dev/null @@ -1,101 +0,0 @@ -package frankenphp - -import ( - "context" - "log/slog" - "net/http" - "sync" -) - -// EXPERIMENTAL: WorkerExtension allows you to register an external worker where instead of calling frankenphp handlers on -// frankenphp_handle_request(), the ProvideRequest method is called. You are responsible for providing a standard -// http.Request that will be conferred to the underlying worker script. -// -// A worker script with the provided Name and FileName will be registered, along with the provided -// configuration. You can also provide any environment variables that you want through Env. GetMinThreads allows you to -// reserve a minimum number of threads from the frankenphp thread pool. This number must be positive. -// These methods are only called once at startup, so register them in an init() function. -// -// When a thread is activated and nearly ready, ThreadActivatedNotification will be called with an opaque threadId; -// this is a time for setting up any per-thread resources. When a thread is about to be returned to the thread pool, -// you will receive a call to ThreadDrainNotification that will inform you of the threadId. -// After the thread is returned to the thread pool, ThreadDeactivatedNotification will be called. -// -// Once you have at least one thread activated, you will receive calls to ProvideRequest where you should respond with -// a request. FrankenPHP will automatically pipe these requests to the worker script and handle the response. -// The piping process is designed to run indefinitely and will be gracefully shut down when FrankenPHP shuts down. -// -// Note: External workers receive the lowest priority when determining thread allocations. If GetMinThreads cannot be -// allocated, then frankenphp will panic and provide this information to the user (who will need to allocate more -// total threads). Don't be greedy. -type WorkerExtension interface { - Name() string - FileName() string - Env() PreparedEnv - GetMinThreads() int - ThreadActivatedNotification(threadId int) - ThreadDrainNotification(threadId int) - ThreadDeactivatedNotification(threadId int) - ProvideRequest() *WorkerRequest[any, any] -} - -// EXPERIMENTAL -type WorkerRequest[P any, R any] struct { - // The request for your worker script to handle - Request *http.Request - // Response is a response writer that provides the output of the provided request, it must not be nil to access the request body - Response http.ResponseWriter - // CallbackParameters is an optional field that will be converted in PHP types and passed as parameter to the PHP callback - CallbackParameters P - // AfterFunc is an optional function that will be called after the request is processed with the original value, the return of the PHP callback, converted in Go types, is passed as parameter - AfterFunc func(callbackReturn R) -} - -var externalWorkers = make(map[string]WorkerExtension) -var externalWorkerMutex sync.Mutex - -// EXPERIMENTAL -func RegisterExternalWorker(worker WorkerExtension) { - externalWorkerMutex.Lock() - defer externalWorkerMutex.Unlock() - - externalWorkers[worker.Name()] = worker -} - -// startExternalWorkerPipe creates a pipe from an external worker to the main worker. -func startExternalWorkerPipe(w *worker, externalWorker WorkerExtension, thread *phpThread) { - for { - rq := externalWorker.ProvideRequest() - - if rq == nil || rq.Request == nil { - logger.LogAttrs(context.Background(), slog.LevelWarn, "external worker provided nil request", slog.String("worker", w.name), slog.Int("thread", thread.threadIndex)) - continue - } - - r := rq.Request - fr, err := NewRequestWithContext(r, WithOriginalRequest(r), WithWorkerName(w.name)) - if err != nil { - logger.LogAttrs(context.Background(), slog.LevelError, "error creating request for external worker", slog.String("worker", w.name), slog.Int("thread", thread.threadIndex), slog.Any("error", err)) - continue - } - - if fc, ok := fromContext(fr.Context()); ok { - fc.responseWriter = rq.Response - fc.handlerParameters = rq.CallbackParameters - - // Queue the request and wait for completion if Done channel was provided - logger.LogAttrs(context.Background(), slog.LevelInfo, "queue the external worker request", slog.String("worker", w.name), slog.Int("thread", thread.threadIndex)) - - w.requestChan <- fc - if rq.AfterFunc != nil { - go func() { - <-fc.done - - if rq.AfterFunc != nil { - rq.AfterFunc(fc.handlerReturn) - } - }() - } - } - } -} diff --git a/threadFramework_test.go b/threadFramework_test.go deleted file mode 100644 index 7519d9ab..00000000 --- a/threadFramework_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package frankenphp - -import ( - "io" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// mockWorkerExtension implements the WorkerExtension interface -type mockWorkerExtension struct { - name string - fileName string - env PreparedEnv - minThreads int - requestChan chan *WorkerRequest[any, any] - activatedCount int - drainCount int - deactivatedCount int - mu sync.Mutex -} - -func newMockWorkerExtension(name, fileName string, minThreads int) *mockWorkerExtension { - return &mockWorkerExtension{ - name: name, - fileName: fileName, - env: make(PreparedEnv), - minThreads: minThreads, - requestChan: make(chan *WorkerRequest[any, any], 10), // Buffer to avoid blocking - } -} - -func (m *mockWorkerExtension) Name() string { - return m.name -} - -func (m *mockWorkerExtension) FileName() string { - return m.fileName -} - -func (m *mockWorkerExtension) Env() PreparedEnv { - return m.env -} - -func (m *mockWorkerExtension) GetMinThreads() int { - return m.minThreads -} - -func (m *mockWorkerExtension) ThreadActivatedNotification(threadId int) { - m.mu.Lock() - defer m.mu.Unlock() - m.activatedCount++ -} - -func (m *mockWorkerExtension) ThreadDrainNotification(threadId int) { - m.mu.Lock() - defer m.mu.Unlock() - m.drainCount++ -} - -func (m *mockWorkerExtension) ThreadDeactivatedNotification(threadId int) { - m.mu.Lock() - defer m.mu.Unlock() - m.deactivatedCount++ -} - -func (m *mockWorkerExtension) ProvideRequest() *WorkerRequest[any, any] { - return <-m.requestChan -} - -func (m *mockWorkerExtension) InjectRequest(r *WorkerRequest[any, any]) { - m.requestChan <- r -} - -func (m *mockWorkerExtension) GetActivatedCount() int { - m.mu.Lock() - defer m.mu.Unlock() - return m.activatedCount -} - -func TestWorkerExtension(t *testing.T) { - // Create a mock extension - mockExt := newMockWorkerExtension("mockWorker", "testdata/worker.php", 1) - - // Register the mock extension - RegisterExternalWorker(mockExt) - - // Clean up external workers after test to avoid interfering with other tests - defer func() { - delete(externalWorkers, mockExt.Name()) - }() - - // Initialize FrankenPHP with a worker that has a different name than our extension - err := Init() - require.NoError(t, err) - defer Shutdown() - - // Wait a bit for the worker to be ready - time.Sleep(100 * time.Millisecond) - - // Verify that the extension's thread was activated - assert.GreaterOrEqual(t, mockExt.GetActivatedCount(), 1, "Thread should have been activated") - - // Create a test request - req := httptest.NewRequest("GET", "http://example.com/test/?foo=bar", nil) - req.Header.Set("X-Test-Header", "test-value") - - w := httptest.NewRecorder() - - // Create a channel to signal when the request is done - done := make(chan struct{}) - - // Inject the request into the worker through the extension - mockExt.InjectRequest(&WorkerRequest[any, any]{ - Request: req, - Response: w, - AfterFunc: func(callbackReturn any) { - close(done) - }, - }) - - // Wait for the request to be fully processed - <-done - - // Check the response - now safe from race conditions - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - // The worker.php script should output information about the request - // We're just checking that we got a response, not the specific content - assert.NotEmpty(t, body, "Response body should not be empty") -} diff --git a/threadworker.go b/threadworker.go index b0346680..20c88596 100644 --- a/threadworker.go +++ b/threadworker.go @@ -20,12 +20,12 @@ type workerThread struct { dummyContext *frankenPHPContext workerContext *frankenPHPContext backoff *exponentialBackoff - externalWorker WorkerExtension + externalWorker Worker isBootingScript bool // true if the worker has not reached frankenphp_handle_request yet } func convertToWorkerThread(thread *phpThread, worker *worker) { - externalWorker := externalWorkers[worker.name] + externalWorker := extensionWorkers[worker.name] thread.setHandler(&workerThread{ state: thread.state, @@ -205,7 +205,11 @@ func (handler *workerThread) waitForWorkerRequest() (bool, any) { handler.workerContext = fc handler.state.markAsWaiting(false) - logger.LogAttrs(ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex), slog.String("url", fc.request.RequestURI)) + if fc.request == nil { + logger.LogAttrs(ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) + } else { + logger.LogAttrs(ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex), slog.String("url", fc.request.RequestURI)) + } return true, fc.handlerParameters } @@ -218,10 +222,18 @@ func go_frankenphp_worker_handle_request_start(threadIndex C.uintptr_t) (C.bool, hasRequest, parameters := handler.waitForWorkerRequest() if parameters != nil { - p := PHPValue(parameters) - handler.thread.Pin(p) + var ptr unsafe.Pointer - return C.bool(hasRequest), p + switch p := parameters.(type) { + case unsafe.Pointer: + ptr = p + + default: + ptr = PHPValue(p) + } + handler.thread.Pin(ptr) + + return C.bool(hasRequest), ptr } return C.bool(hasRequest), nil @@ -240,7 +252,11 @@ func go_frankenphp_finish_worker_request(threadIndex C.uintptr_t, retval *C.zval fc.closeContext() thread.handler.(*workerThread).workerContext = nil - fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.String("worker", fc.scriptFilename), slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) + if fc.request == nil { + fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex)) + } else { + fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) + } } // when frankenphp_finish_request() is directly called from PHP diff --git a/types.go b/types.go index d02f627e..f38cdbca 100644 --- a/types.go +++ b/types.go @@ -5,6 +5,7 @@ package frankenphp */ import "C" import ( + "fmt" "strconv" "unsafe" ) @@ -287,7 +288,7 @@ func phpValue(value any) *C.zval { case []any: return (*C.zval)(PHPPackedArray(v)) default: - C.__zval_null__(&zval) + panic(fmt.Sprintf("unsupported Go type %T", v)) } return &zval diff --git a/types.h b/types.h index e667a6f7..72442cf3 100644 --- a/types.h +++ b/types.h @@ -1,11 +1,11 @@ #ifndef TYPES_H #define TYPES_H -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include zval *get_ht_packed_data(HashTable *, uint32_t index); Bucket *get_ht_bucket_data(HashTable *, uint32_t index); diff --git a/worker.go b/worker.go index 429d3539..ec66a918 100644 --- a/worker.go +++ b/worker.go @@ -51,7 +51,7 @@ func initWorkers(opt []workerOpt) error { // create a pipe from the external worker to the main worker // note: this is locked to the initial thread size the external worker requested if workerThread, ok := thread.handler.(*workerThread); ok && workerThread.externalWorker != nil { - go startExternalWorkerPipe(w, workerThread.externalWorker, thread) + go startWorker(w, workerThread.externalWorker, thread) } workersReady.Done() }() diff --git a/workerextension.go b/workerextension.go new file mode 100644 index 00000000..4e7c29d5 --- /dev/null +++ b/workerextension.go @@ -0,0 +1,167 @@ +package frankenphp + +import ( + "context" + "log/slog" + "net/http" + "sync" + "sync/atomic" +) + +// EXPERIMENTAL: Worker allows you to register a worker where instead of calling FrankenPHP handlers on +// frankenphp_handle_request(), the ProvideRequest method is called. You may provide a standard +// http.Request that will be conferred to the underlying worker script. +// +// A worker script with the provided Name and FileName will be registered, along with the provided +// configuration. You can also provide any environment variables that you want through Env. GetMinThreads allows you to +// reserve a minimum number of threads from the frankenphp thread pool. This number must be positive. +// These methods are only called once at startup, so register them in an init() function. +// +// When a thread is activated and nearly ready, ThreadActivatedNotification will be called with an opaque threadId; +// this is a time for setting up any per-thread resources. When a thread is about to be returned to the thread pool, +// you will receive a call to ThreadDrainNotification that will inform you of the threadId. +// After the thread is returned to the thread pool, ThreadDeactivatedNotification will be called. +// +// Once you have at least one thread activated, you will receive calls to ProvideRequest where you should respond with +// a request. FrankenPHP will automatically pipe these requests to the worker script and handle the response. +// The piping process is designed to run indefinitely and will be gracefully shut down when FrankenPHP shuts down. +// +// Note: External workers receive the lowest priority when determining thread allocations. If GetMinThreads cannot be +// allocated, then frankenphp will panic and provide this information to the user (who will need to allocate more +// total threads). Don't be greedy. +type Worker interface { + Name() string + FileName() string + Env() PreparedEnv + GetMinThreads() int + ThreadActivatedNotification(threadId int) + ThreadDrainNotification(threadId int) + ThreadDeactivatedNotification(threadId int) + ProvideRequest() *WorkerRequest + InjectRequest(r *WorkerRequest) +} + +// EXPERIMENTAL +type WorkerRequest struct { + // The request for your worker script to handle + Request *http.Request + // Response is a response writer that provides the output of the provided request, it must not be nil to access the request body + Response http.ResponseWriter + // CallbackParameters is an optional field that will be converted in PHP types and passed as parameter to the PHP callback + CallbackParameters any + // AfterFunc is an optional function that will be called after the request is processed with the original value, the return of the PHP callback, converted in Go types, is passed as parameter + AfterFunc func(callbackReturn any) +} + +var extensionWorkers = make(map[string]Worker) +var extensionWorkersMutex sync.Mutex + +// EXPERIMENTAL +func RegisterWorker(worker Worker) { + extensionWorkersMutex.Lock() + defer extensionWorkersMutex.Unlock() + + extensionWorkers[worker.Name()] = worker +} + +// startWorker creates a pipe from a worker to the main worker. +func startWorker(w *worker, extensionWorker Worker, thread *phpThread) { + for { + rq := extensionWorker.ProvideRequest() + + var fc *frankenPHPContext + if rq.Request == nil { + fc = newFrankenPHPContext() + fc.logger = logger + } else { + fr, err := NewRequestWithContext(rq.Request, WithOriginalRequest(rq.Request)) + if err != nil { + logger.LogAttrs(context.Background(), slog.LevelError, "error creating request for external worker", slog.String("worker", w.name), slog.Int("thread", thread.threadIndex), slog.Any("error", err)) + continue + } + + var ok bool + if fc, ok = fromContext(fr.Context()); !ok { + continue + } + } + + fc.worker = w + + fc.responseWriter = rq.Response + fc.handlerParameters = rq.CallbackParameters + + // Queue the request and wait for completion if Done channel was provided + logger.LogAttrs(context.Background(), slog.LevelInfo, "queue the external worker request", slog.String("worker", w.name), slog.Int("thread", thread.threadIndex)) + + w.requestChan <- fc + if rq.AfterFunc != nil { + go func() { + <-fc.done + + if rq.AfterFunc != nil { + rq.AfterFunc(fc.handlerReturn) + } + }() + } + } +} + +func NewWorker(name, fileName string, minThreads int, env PreparedEnv) Worker { + return &defaultWorker{ + name: name, + fileName: fileName, + env: env, + minThreads: minThreads, + requestChan: make(chan *WorkerRequest), + activatedCount: atomic.Int32{}, + drainCount: atomic.Int32{}, + } +} + +type defaultWorker struct { + name string + fileName string + env PreparedEnv + minThreads int + requestChan chan *WorkerRequest + activatedCount atomic.Int32 + drainCount atomic.Int32 +} + +func (w *defaultWorker) Name() string { + return w.name +} + +func (w *defaultWorker) FileName() string { + return w.fileName +} + +func (w *defaultWorker) Env() PreparedEnv { + return w.env +} + +func (w *defaultWorker) GetMinThreads() int { + return w.minThreads +} + +func (w *defaultWorker) ThreadActivatedNotification(_ int) { + w.activatedCount.Add(1) +} + +func (w *defaultWorker) ThreadDrainNotification(_ int) { + w.drainCount.Add(1) +} + +func (w *defaultWorker) ThreadDeactivatedNotification(_ int) { + w.drainCount.Add(-1) + w.activatedCount.Add(-1) +} + +func (w *defaultWorker) ProvideRequest() *WorkerRequest { + return <-w.requestChan +} + +func (w *defaultWorker) InjectRequest(r *WorkerRequest) { + w.requestChan <- r +} diff --git a/workerextension_test.go b/workerextension_test.go new file mode 100644 index 00000000..2a900f6e --- /dev/null +++ b/workerextension_test.go @@ -0,0 +1,71 @@ +package frankenphp + +import ( + "io" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockWorker implements the Worker interface +type mockWorker struct { + Worker +} + +func TestWorkerExtension(t *testing.T) { + // Create a mock worker extension + mockExt := &mockWorker{ + Worker: NewWorker("mockWorker", "testdata/worker.php", 1, nil), + } + + // Register the mock extension + RegisterWorker(mockExt) + + // Clean up external workers after test to avoid interfering with other tests + defer func() { + delete(extensionWorkers, mockExt.Name()) + }() + + // Initialize FrankenPHP with a worker that has a different name than our extension + err := Init() + require.NoError(t, err) + defer Shutdown() + + // Wait a bit for the worker to be ready + time.Sleep(100 * time.Millisecond) + + // Verify that the extension's thread was activated + assert.GreaterOrEqual(t, int(mockExt.Worker.(*defaultWorker).activatedCount.Load()), 1, "Thread should have been activated") + + // Create a test request + req := httptest.NewRequest("GET", "https://example.com/test/?foo=bar", nil) + req.Header.Set("X-Test-Header", "test-value") + + w := httptest.NewRecorder() + + // Create a channel to signal when the request is done + done := make(chan struct{}) + + // Inject the request into the worker through the extension + mockExt.InjectRequest(&WorkerRequest{ + Request: req, + Response: w, + AfterFunc: func(callbackReturn any) { + close(done) + }, + }) + + // Wait for the request to be fully processed + <-done + + // Check the response - now safe from race conditions + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + // The worker.php script should output information about the request + // We're just checking that we got a response, not the specific content + assert.NotEmpty(t, body, "Response body should not be empty") +}