diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index a46f1c55..827c0dd2 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -197,7 +197,7 @@ jobs: run: | docker run --platform=${{ matrix.platform }} --rm \ "$(jq -r '."builder-${{ matrix.variant }}"."containerimage.config.digest"' <<< "${METADATA}")" \ - sh -c 'go test -tags ${{ matrix.race }} -v ./... && cd caddy && go test -tags nobadger,nomysql,nopgx ${{ matrix.race }} -v ./...' + sh -c 'go test -tags ${{ matrix.race }} -v $(go list ./... | grep -v github.com/dunglas/frankenphp/internal/testext | grep -v github.com/dunglas/frankenphp/internal/extgen) && cd caddy && go test -tags nobadger,nomysql,nopgx ${{ matrix.race }} -v ./...' env: METADATA: ${{ steps.build.outputs.metadata }} # Adapted from https://docs.docker.com/build/ci/github-actions/multi-platform/ diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 85e5469a..5a2645ee 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -53,8 +53,10 @@ jobs: - name: Build testcli binary working-directory: internal/testcli/ run: go build + - name: Compile library tests + run: go test -race -v -x -c - name: Run library tests - run: go test -race -v ./... + run: ./frankenphp.test -test.v - name: Run Caddy module tests working-directory: caddy/ run: go test -tags nobadger,nomysql,nopgx -race -v ./... diff --git a/caddy/extinit.go b/caddy/extinit.go new file mode 100644 index 00000000..6a944be3 --- /dev/null +++ b/caddy/extinit.go @@ -0,0 +1,53 @@ +package caddy + +import ( + "errors" + "github.com/dunglas/frankenphp/internal/extgen" + "log" + "os" + "path/filepath" + "strings" + + caddycmd "github.com/caddyserver/caddy/v2/cmd" + "github.com/spf13/cobra" +) + +func init() { + caddycmd.RegisterCommand(caddycmd.Command{ + Name: "extension-init", + Usage: "go_extension.go [--verbose]", + Short: "(Experimental) Initializes a PHP extension from a Go file", + Long: ` +Initializes a PHP extension from a Go file. This command generates the necessary C files for the extension, including the header and source files, as well as the arginfo file.`, + CobraFunc: func(cmd *cobra.Command) { + cmd.Flags().BoolP("debug", "v", false, "Enable verbose debug logs") + + cmd.RunE = caddycmd.WrapCommandFuncForCobra(cmdInitExtension) + }, + }) +} + +func cmdInitExtension(fs caddycmd.Flags) (int, error) { + if len(os.Args) < 3 { + return 1, errors.New("the path to the Go source is required") + } + + sourceFile := os.Args[2] + + baseName := strings.TrimSuffix(filepath.Base(sourceFile), ".go") + + baseName = extgen.SanitizePackageName(baseName) + + sourceDir := filepath.Dir(sourceFile) + buildDir := filepath.Join(sourceDir, "build") + + generator := extgen.Generator{BaseName: baseName, SourceFile: sourceFile, BuildDir: buildDir} + + if err := generator.Generate(); err != nil { + return 1, err + } + + log.Printf("PHP extension %q initialized successfully in %q", baseName, generator.BuildDir) + + return 0, nil +} diff --git a/ext.go b/ext.go new file mode 100644 index 00000000..8d565d4b --- /dev/null +++ b/ext.go @@ -0,0 +1,29 @@ +package frankenphp + +//#include "frankenphp.h" +import "C" +import ( + "sync" + "unsafe" +) + +var ( + extensions []*C.zend_module_entry + registerOnce sync.Once +) + +// RegisterExtension registers a new PHP extension. +func RegisterExtension(me unsafe.Pointer) { + extensions = append(extensions, (*C.zend_module_entry)(me)) +} + +func registerExtensions() { + if len(extensions) == 0 { + return + } + + registerOnce.Do(func() { + C.register_extensions(extensions[0], C.int(len(extensions))) + extensions = nil + }) +} diff --git a/frankenphp.c b/frankenphp.c index a9cd534d..27dc103a 100644 --- a/frankenphp.c +++ b/frankenphp.c @@ -1182,3 +1182,34 @@ int frankenphp_reset_opcache(void) { } int frankenphp_get_current_memory_limit() { return PG(memory_limit); } + +static zend_module_entry *modules = NULL; +static int modules_len = 0; +static int (*original_php_register_internal_extensions_func)(void) = NULL; + +PHPAPI int register_internal_extensions(void) { + if (original_php_register_internal_extensions_func != NULL && + original_php_register_internal_extensions_func() != SUCCESS) { + return FAILURE; + } + + for (int i = 0; i < modules_len; i++) { + if (zend_register_internal_module(&modules[i]) == NULL) { + return FAILURE; + } + } + + modules = NULL; + modules_len = 0; + + return SUCCESS; +} + +void register_extensions(zend_module_entry *m, int len) { + modules = m; + modules_len = len; + + original_php_register_internal_extensions_func = + php_register_internal_extensions_func; + php_register_internal_extensions_func = register_internal_extensions; +} diff --git a/frankenphp.go b/frankenphp.go index afb4b77a..37fb2367 100644 --- a/frankenphp.go +++ b/frankenphp.go @@ -226,6 +226,8 @@ func Init(options ...Option) error { // Docker/Moby has a similar hack: https://github.com/moby/moby/blob/d828b032a87606ae34267e349bf7f7ccb1f6495a/cmd/dockerd/docker.go#L87-L90 signal.Ignore(syscall.SIGPIPE) + registerExtensions() + opt := &opt{} for _, o := range options { if err := o(opt); err != nil { diff --git a/frankenphp.h b/frankenphp.h index 6636bfbf..6d956290 100644 --- a/frankenphp.h +++ b/frankenphp.h @@ -1,6 +1,7 @@ #ifndef _FRANKENPPHP_H #define _FRANKENPPHP_H +#include #include #include #include @@ -92,4 +93,6 @@ void frankenphp_register_bulk( ht_key_value_pair auth_type, ht_key_value_pair remote_ident, ht_key_value_pair request_uri); +void register_extensions(zend_module_entry *m, int len); + #endif diff --git a/go.mod b/go.mod index 3095f6d0..e9f63252 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.0 retract v1.0.0-rc.1 // Human error require ( + github.com/Masterminds/sprig/v3 v3.3.0 github.com/maypok86/otter v1.2.4 github.com/prometheus/client_golang v1.22.0 github.com/stretchr/testify v1.10.0 @@ -14,19 +15,29 @@ require ( ) require ( + dario.cat/mergo v1.0.1 // indirect + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver/v3 v3.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dolthub/maphash v0.1.0 // indirect github.com/gammazero/deque v1.0.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/huandu/xstrings v1.5.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.64.0 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect + github.com/spf13/cast v1.7.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/crypto v0.39.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.26.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/go.sum b/go.sum index 9dcfb2e2..da9a30c8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,11 @@ +dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= +dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= +github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0= +github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs= +github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -6,10 +14,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/gammazero/deque v1.0.0 h1:LTmimT8H7bXkkCy6gZX7zNLtkbz4NdS2z8LZuor3j34= github.com/gammazero/deque v1.0.0/go.mod h1:iflpYvtGfM3U8S8j+sZEKIak3SAKYpA5/SQewgfXDKo= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= +github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -18,6 +32,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/maypok86/otter v1.2.4 h1:HhW1Pq6VdJkmWwcZZq19BlEQkHtI8xgsQzBVXJU0nfc= github.com/maypok86/otter v1.2.4/go.mod h1:mKLfoI7v1HOmQMwFgX4QkRk23mX6ge3RDvjdHOWG4R4= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -32,6 +50,10 @@ github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzM github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -42,6 +64,8 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U= go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= diff --git a/internal/extgen/arginfo.go b/internal/extgen/arginfo.go new file mode 100644 index 00000000..c827d1ce --- /dev/null +++ b/internal/extgen/arginfo.go @@ -0,0 +1,50 @@ +package extgen + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +type arginfoGenerator struct { + generator *Generator +} + +func (ag *arginfoGenerator) generate() error { + genStubPath := os.Getenv("GEN_STUB_SCRIPT") + if genStubPath == "" { + genStubPath = "/usr/local/src/php/build/gen_stub.php" + } + + if _, err := os.Stat(genStubPath); err != nil { + return fmt.Errorf(`the PHP "gen_stub.php" file couldn't be found under %q, you can set the "GEN_STUB_SCRIPT" environement variable to set a custom location`, genStubPath) + } + + stubFile := ag.generator.BaseName + ".stub.php" + cmd := exec.Command("php", genStubPath, filepath.Join(ag.generator.BuildDir, stubFile)) + + if err := cmd.Run(); err != nil { + return fmt.Errorf("running gen_stub script: %w", err) + } + + return ag.fixArginfoFile(stubFile) +} + +func (ag *arginfoGenerator) fixArginfoFile(stubFile string) error { + arginfoFile := strings.TrimSuffix(stubFile, ".stub.php") + "_arginfo.h" + arginfoPath := filepath.Join(ag.generator.BuildDir, arginfoFile) + + content, err := ReadFile(arginfoPath) + if err != nil { + return fmt.Errorf("reading arginfo file: %w", err) + } + + // FIXME: the script generate "zend_register_internal_class_with_flags" but it is not recognized by the compiler + fixedContent := strings.ReplaceAll(content, + "zend_register_internal_class_with_flags(&ce, NULL, 0)", + "zend_register_internal_class(&ce)") + + return WriteFile(arginfoPath, fixedContent) +} diff --git a/internal/extgen/cfile.go b/internal/extgen/cfile.go new file mode 100644 index 00000000..693e6995 --- /dev/null +++ b/internal/extgen/cfile.go @@ -0,0 +1,68 @@ +package extgen + +import ( + "github.com/Masterminds/sprig/v3" + + "bytes" + _ "embed" + "path/filepath" + "strings" + "text/template" +) + +//go:embed templates/extension.c.tpl +var cFileContent string + +type cFileGenerator struct { + generator *Generator +} + +type cTemplateData struct { + BaseName string + Functions []phpFunction + Classes []phpClass + Constants []phpConstant +} + +func (cg *cFileGenerator) generate() error { + filename := filepath.Join(cg.generator.BuildDir, cg.generator.BaseName+".c") + content, err := cg.buildContent() + if err != nil { + return err + } + + return WriteFile(filename, content) +} + +func (cg *cFileGenerator) buildContent() (string, error) { + var builder strings.Builder + + templateContent, err := cg.getTemplateContent() + if err != nil { + return "", err + } + builder.WriteString(templateContent) + + for _, fn := range cg.generator.Functions { + fnGen := PHPFuncGenerator{paramParser: &ParameterParser{}} + builder.WriteString(fnGen.generate(fn)) + } + + return builder.String(), nil +} + +func (cg *cFileGenerator) getTemplateContent() (string, error) { + tmpl := template.Must(template.New("cfile").Funcs(sprig.FuncMap()).Parse(cFileContent)) + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, cTemplateData{ + BaseName: cg.generator.BaseName, + Functions: cg.generator.Functions, + Classes: cg.generator.Classes, + Constants: cg.generator.Constants, + }); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/internal/extgen/cfile_test.go b/internal/extgen/cfile_test.go new file mode 100644 index 00000000..e8799565 --- /dev/null +++ b/internal/extgen/cfile_test.go @@ -0,0 +1,461 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCFileGenerator_Generate(t *testing.T) { + tmpDir := t.TempDir() + + generator := &Generator{ + BaseName: "test_extension", + BuildDir: tmpDir, + Functions: []phpFunction{ + { + Name: "simpleFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "input", PhpType: "string"}, + }, + }, + { + Name: "complexFunction", + ReturnType: "array", + Params: []phpParameter{ + {Name: "data", PhpType: "string"}, + {Name: "count", PhpType: "int", IsNullable: true}, + {Name: "options", PhpType: "array", HasDefault: true, DefaultValue: "[]"}, + }, + }, + }, + Classes: []phpClass{ + { + Name: "TestClass", + GoStruct: "TestStruct", + Properties: []phpClassProperty{ + {Name: "id", PhpType: "int"}, + {Name: "name", PhpType: "string"}, + }, + }, + }, + } + + cGen := cFileGenerator{generator} + require.NoError(t, cGen.generate()) + + expectedFile := filepath.Join(tmpDir, "test_extension.c") + require.FileExists(t, expectedFile, "Expected C file was not created: %s", expectedFile) + + content, err := ReadFile(expectedFile) + require.NoError(t, err) + + testCFileBasicStructure(t, content, "test_extension") + testCFileFunctions(t, content, generator.Functions) + testCFileClasses(t, content, generator.Classes) +} + +func TestCFileGenerator_BuildContent(t *testing.T) { + tests := []struct { + name string + baseName string + functions []phpFunction + classes []phpClass + contains []string + notContains []string + }{ + { + name: "empty extension", + baseName: "empty", + contains: []string{ + "#include ", + "#include ", + `#include "empty.h"`, + "PHP_MINIT_FUNCTION(empty)", + "empty_module_entry", + "return SUCCESS;", + }, + }, + { + name: "extension with functions only", + baseName: "func_only", + functions: []phpFunction{ + {Name: "testFunc", ReturnType: "string"}, + }, + contains: []string{ + "PHP_FUNCTION(testFunc)", + `#include "func_only.h"`, + "func_only_module_entry", + "PHP_MINIT_FUNCTION(func_only)", + }, + }, + { + name: "extension with classes only", + baseName: "class_only", + classes: []phpClass{ + {Name: "MyClass", GoStruct: "MyStruct"}, + }, + contains: []string{ + "register_all_classes()", + "register_class_MyClass();", + "PHP_METHOD(MyClass, __construct)", + `#include "class_only.h"`, + }, + }, + { + name: "extension with functions and classes", + baseName: "full", + functions: []phpFunction{ + {Name: "doSomething", ReturnType: "void"}, + }, + classes: []phpClass{ + {Name: "FullClass", GoStruct: "FullStruct"}, + }, + contains: []string{ + "PHP_FUNCTION(doSomething)", + "PHP_METHOD(FullClass, __construct)", + "register_all_classes()", + "register_class_FullClass();", + `#include "full.h"`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := &Generator{ + BaseName: tt.baseName, + Functions: tt.functions, + Classes: tt.classes, + } + + cGen := cFileGenerator{generator} + content, err := cGen.buildContent() + require.NoError(t, err) + + for _, expected := range tt.contains { + assert.Contains(t, content, expected, "Generated C content should contain '%s'", expected) + } + }) + } +} + +func TestCFileGenerator_GetTemplateContent(t *testing.T) { + tests := []struct { + name string + baseName string + classes []phpClass + contains []string + notContains []string + }{ + { + name: "extension without classes", + baseName: "myext", + contains: []string{ + `#include "myext.h"`, + `#include "myext_arginfo.h"`, + "PHP_MINIT_FUNCTION(myext)", + "myext_module_entry", + "return SUCCESS;", + }, + }, + { + name: "extension with classes", + baseName: "complex_name", + classes: []phpClass{ + {Name: "TestClass", GoStruct: "TestStruct"}, + {Name: "AnotherClass", GoStruct: "AnotherStruct"}, + }, + contains: []string{ + `#include "complex_name.h"`, + `#include "complex_name_arginfo.h"`, + "PHP_MINIT_FUNCTION(complex_name)", + "complex_name_module_entry", + "register_all_classes()", + "register_class_TestClass();", + "register_class_AnotherClass();", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := &Generator{ + BaseName: tt.baseName, + Classes: tt.classes, + } + cGen := cFileGenerator{generator} + content, err := cGen.getTemplateContent() + require.NoError(t, err) + + for _, expected := range tt.contains { + assert.Contains(t, content, expected, "Template content should contain '%s'", expected) + } + + for _, notExpected := range tt.notContains { + assert.NotContains(t, content, notExpected, "Template content should NOT contain '%s'", notExpected) + } + }) + } +} + +func TestCFileIntegrationWithGenerators(t *testing.T) { + tmpDir := t.TempDir() + + functions := []phpFunction{ + { + Name: "processData", + ReturnType: "array", + IsReturnNullable: true, + Params: []phpParameter{ + {Name: "input", PhpType: "string"}, + {Name: "options", PhpType: "array", HasDefault: true, DefaultValue: "[]"}, + {Name: "callback", PhpType: "object", IsNullable: true}, + }, + }, + { + Name: "validateInput", + ReturnType: "bool", + Params: []phpParameter{ + {Name: "data", PhpType: "string", IsNullable: true}, + {Name: "strict", PhpType: "bool", HasDefault: true, DefaultValue: "false"}, + }, + }, + } + + classes := []phpClass{ + { + Name: "DataProcessor", + GoStruct: "DataProcessorStruct", + Properties: []phpClassProperty{ + {Name: "mode", PhpType: "string"}, + {Name: "timeout", PhpType: "int", IsNullable: true}, + {Name: "options", PhpType: "array"}, + }, + }, + { + Name: "Result", + GoStruct: "ResultStruct", + Properties: []phpClassProperty{ + {Name: "success", PhpType: "bool"}, + {Name: "data", PhpType: "mixed", IsNullable: true}, + {Name: "errors", PhpType: "array"}, + }, + }, + } + + generator := &Generator{ + BaseName: "integration_test", + BuildDir: tmpDir, + Functions: functions, + Classes: classes, + } + + cGen := cFileGenerator{generator} + require.NoError(t, cGen.generate()) + + content, err := ReadFile(filepath.Join(tmpDir, "integration_test.c")) + require.NoError(t, err) + + for _, fn := range functions { + expectedFunc := "PHP_FUNCTION(" + fn.Name + ")" + assert.Contains(t, content, expectedFunc, "Generated C file should contain function: %s", expectedFunc) + } + + for _, class := range classes { + expectedMethod := "PHP_METHOD(" + class.Name + ", __construct)" + assert.Contains(t, content, expectedMethod, "Generated C file should contain class method: %s", expectedMethod) + } + + assert.Contains(t, content, "register_all_classes()", "Generated C file should contain class registration call") + assert.Contains(t, content, "integration_test_module_entry", "Generated C file should contain integration_test_module_entry") +} + +func TestCFileErrorHandling(t *testing.T) { + // Test with invalid build directory + generator := &Generator{ + BaseName: "test", + BuildDir: "/invalid/readonly/path", + Functions: []phpFunction{ + {Name: "test", ReturnType: "void"}, + }, + } + + cGen := cFileGenerator{generator} + err := cGen.generate() + assert.Error(t, err, "Expected error when writing to invalid directory") +} + +func TestCFileSpecialCharacters(t *testing.T) { + tests := []struct { + baseName string + expected string + }{ + {"simple", "simple"}, + {"my_extension", "my_extension"}, + {"ext-with-dashes", "ext-with-dashes"}, + } + + for _, tt := range tests { + t.Run(tt.baseName, func(t *testing.T) { + generator := &Generator{ + BaseName: tt.baseName, + Functions: []phpFunction{ + {Name: "test", ReturnType: "void"}, + }, + } + + cGen := cFileGenerator{generator} + content, err := cGen.buildContent() + require.NoError(t, err) + + expectedInclude := "#include \"" + tt.expected + ".h\"" + assert.Contains(t, content, expectedInclude, "Content should contain include: %s", expectedInclude) + }) + } +} + +func testCFileBasicStructure(t *testing.T, content, baseName string) { + requiredElements := []string{ + "#include ", + "#include ", + `#include "_cgo_export.h"`, + `#include "` + baseName + `.h"`, + `#include "` + baseName + `_arginfo.h"`, + "PHP_MINIT_FUNCTION(" + baseName + ")", + baseName + "_module_entry", + } + + for _, element := range requiredElements { + assert.Contains(t, content, element, "C file should contain: %s", element) + } +} + +func testCFileFunctions(t *testing.T, content string, functions []phpFunction) { + for _, fn := range functions { + phpFunc := "PHP_FUNCTION(" + fn.Name + ")" + assert.Contains(t, content, phpFunc, "C file should contain function declaration: %s", phpFunc) + } +} + +func testCFileClasses(t *testing.T, content string, classes []phpClass) { + if len(classes) == 0 { + // Si pas de classes, ne devrait pas contenir register_all_classes + assert.NotContains(t, content, "register_all_classes()", "C file should NOT contain register_all_classes call when no classes") + return + } + + assert.Contains(t, content, "void register_all_classes() {", "C file should contain register_all_classes function") + assert.Contains(t, content, "register_all_classes();", "C file should contain register_all_classes call in MINIT") + + for _, class := range classes { + expectedCall := "register_class_" + class.Name + "();" + assert.Contains(t, content, expectedCall, "C file should contain class registration call: %s", expectedCall) + + constructor := "PHP_METHOD(" + class.Name + ", __construct)" + assert.Contains(t, content, constructor, "C file should contain constructor: %s", constructor) + } +} + +func TestCFileContentValidation(t *testing.T) { + generator := &Generator{ + BaseName: "syntax_test", + Functions: []phpFunction{ + { + Name: "testFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "param", PhpType: "string"}, + }, + }, + }, + Classes: []phpClass{ + {Name: "TestClass", GoStruct: "TestStruct"}, + }, + } + + cGen := cFileGenerator{generator} + content, err := cGen.buildContent() + require.NoError(t, err) + + syntaxElements := []string{ + "{", "}", "(", ")", ";", + "static", "void", "int", + "#include", + } + + for _, element := range syntaxElements { + assert.Contains(t, content, element, "Generated C content should contain basic C syntax: %s", element) + } + + openBraces := strings.Count(content, "{") + closeBraces := strings.Count(content, "}") + + assert.Equal(t, openBraces, closeBraces, "Unbalanced braces in generated C code: %d open, %d close", openBraces, closeBraces) + assert.False(t, strings.Contains(content, ";;"), "Generated C code contains double semicolons") + assert.False(t, strings.Contains(content, "{{") || strings.Contains(content, "}}"), "Generated C code contains unresolved template syntax") +} + +func TestCFileConstants(t *testing.T) { + tests := []struct { + name string + baseName string + constants []phpConstant + classes []phpClass + contains []string + }{ + { + name: "global constants only", + baseName: "const_test", + constants: []phpConstant{ + { + Name: "GLOBAL_INT", + Value: "42", + PhpType: "int", + }, + { + Name: "GLOBAL_STRING", + Value: `"test"`, + PhpType: "string", + }, + }, + contains: []string{ + "REGISTER_LONG_CONSTANT(\"GLOBAL_INT\", 42, CONST_CS | CONST_PERSISTENT);", + "REGISTER_STRING_CONSTANT(\"GLOBAL_STRING\", \"test\", CONST_CS | CONST_PERSISTENT);", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := &Generator{ + BaseName: tt.baseName, + Constants: tt.constants, + Classes: tt.classes, + } + + cGen := cFileGenerator{generator} + content, err := cGen.buildContent() + require.NoError(t, err) + + for _, expected := range tt.contains { + assert.Contains(t, content, expected, "Generated C content should contain '%s'", expected) + } + }) + } +} + +func TestCFileTemplateErrorHandling(t *testing.T) { + generator := &Generator{ + BaseName: "error_test", + } + + cGen := cFileGenerator{generator} + + _, err := cGen.getTemplateContent() + assert.NoError(t, err, "getTemplateContent() should not fail with valid template") +} diff --git a/internal/extgen/classparser.go b/internal/extgen/classparser.go new file mode 100644 index 00000000..6ac39c75 --- /dev/null +++ b/internal/extgen/classparser.go @@ -0,0 +1,390 @@ +package extgen + +import ( + "bufio" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "regexp" + "strings" +) + +var phpClassRegex = regexp.MustCompile(`//\s*export_php:class\s+(\w+)`) +var phpMethodRegex = regexp.MustCompile(`//\s*export_php:method\s+(\w+)::([^{}\n]+)(?:\s*{\s*})?`) +var methodSignatureRegex = regexp.MustCompile(`(\w+)\s*\(([^)]*)\)\s*:\s*(\??[\w|]+)`) +var methodParamTypeNameRegex = regexp.MustCompile(`(\??[\w|]+)\s+\$?(\w+)`) + +type exportDirective struct { + line int + className string +} + +type classParser struct{} + +func (cp *classParser) Parse(filename string) ([]phpClass, error) { + return cp.parse(filename) +} + +func (cp *classParser) parse(filename string) (classes []phpClass, err error) { + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("parsing file: %w", err) + } + + validator := Validator{} + + exportDirectives := cp.collectExportDirectives(node, fset) + methods, err := cp.parseMethods(filename) + if err != nil { + return nil, fmt.Errorf("parsing methods: %w", err) + } + + // match structs to directives + matchedDirectives := make(map[int]bool) + + var genDecl *ast.GenDecl + var ok bool + for _, decl := range node.Decls { + if genDecl, ok = decl.(*ast.GenDecl); !ok || genDecl.Tok != token.TYPE { + continue + } + + for _, spec := range genDecl.Specs { + var typeSpec *ast.TypeSpec + if typeSpec, ok = spec.(*ast.TypeSpec); !ok { + continue + } + + var structType *ast.StructType + if structType, ok = typeSpec.Type.(*ast.StructType); !ok { + continue + } + + var phpCl string + var directiveLine int + if phpCl, directiveLine = cp.extractPHPClassCommentWithLine(genDecl.Doc, fset); phpCl == "" { + continue + } + + matchedDirectives[directiveLine] = true + + class := phpClass{ + Name: phpCl, + GoStruct: typeSpec.Name.Name, + } + + class.Properties = cp.parseStructFields(structType.Fields.List) + + // associate methods with this class + for _, method := range methods { + if method.ClassName == phpCl { + class.Methods = append(class.Methods, method) + } + } + + if err := validator.validateClass(class); err != nil { + fmt.Printf("Warning: Invalid class '%s': %v\n", class.Name, err) + continue + } + + classes = append(classes, class) + } + } + + for _, directive := range exportDirectives { + if !matchedDirectives[directive.line] { + return nil, fmt.Errorf("//export_php class directive at line %d is not followed by a struct declaration", directive.line) + } + } + + return classes, nil +} + +func (cp *classParser) collectExportDirectives(node *ast.File, fset *token.FileSet) []exportDirective { + var directives []exportDirective + + for _, commentGroup := range node.Comments { + for _, comment := range commentGroup.List { + if matches := phpClassRegex.FindStringSubmatch(comment.Text); matches != nil { + pos := fset.Position(comment.Pos()) + directives = append(directives, exportDirective{ + line: pos.Line, + className: matches[1], + }) + } + } + } + + return directives +} + +func (cp *classParser) extractPHPClassCommentWithLine(commentGroup *ast.CommentGroup, fset *token.FileSet) (string, int) { + if commentGroup == nil { + return "", 0 + } + + for _, comment := range commentGroup.List { + if matches := phpClassRegex.FindStringSubmatch(comment.Text); matches != nil { + pos := fset.Position(comment.Pos()) + return matches[1], pos.Line + } + } + + return "", 0 +} + +func (cp *classParser) parseStructFields(fields []*ast.Field) []phpClassProperty { + var properties []phpClassProperty + + for _, field := range fields { + for _, name := range field.Names { + prop := cp.parseStructField(name.Name, field) + properties = append(properties, prop) + } + } + + return properties +} + +func (cp *classParser) parseStructField(fieldName string, field *ast.Field) phpClassProperty { + prop := phpClassProperty{Name: fieldName} + + // check if field is a pointer (nullable) + if starExpr, isPointer := field.Type.(*ast.StarExpr); isPointer { + prop.IsNullable = true + prop.GoType = cp.typeToString(starExpr.X) + } else { + prop.IsNullable = false + prop.GoType = cp.typeToString(field.Type) + } + + prop.PhpType = cp.goTypeToPHPType(prop.GoType) + + return prop +} + +func (cp *classParser) typeToString(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.StarExpr: + return "*" + cp.typeToString(t.X) + case *ast.ArrayType: + return "[]" + cp.typeToString(t.Elt) + case *ast.MapType: + return "map[" + cp.typeToString(t.Key) + "]" + cp.typeToString(t.Value) + default: + return "interface{}" + } +} + +func (cp *classParser) goTypeToPHPType(goType string) string { + goType = strings.TrimPrefix(goType, "*") + + typeMap := map[string]string{ + "string": "string", + "int": "int", "int64": "int", "int32": "int", "int16": "int", "int8": "int", + "uint": "int", "uint64": "int", "uint32": "int", "uint16": "int", "uint8": "int", + "float64": "float", "float32": "float", + "bool": "bool", + } + + if phpType, exists := typeMap[goType]; exists { + return phpType + } + + if strings.HasPrefix(goType, "[]") || strings.HasPrefix(goType, "map[") { + return "array" + } + + return "mixed" +} + +func (cp *classParser) parseMethods(filename string) (methods []phpClassMethod, err error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + + defer func() { + e := file.Close() + if err != nil { + err = e + } + }() + + scanner := bufio.NewScanner(file) + var currentMethod *phpClassMethod + + lineNumber := 0 + for scanner.Scan() { + lineNumber++ + line := strings.TrimSpace(scanner.Text()) + + if matches := phpMethodRegex.FindStringSubmatch(line); matches != nil { + className := strings.TrimSpace(matches[1]) + signature := strings.TrimSpace(matches[2]) + + method, err := cp.parseMethodSignature(className, signature) + if err != nil { + fmt.Printf("Warning: Error parsing method signature %q: %v\n", signature, err) + + continue + } + + validator := Validator{} + phpFunc := phpFunction{ + Name: method.Name, + Signature: method.Signature, + Params: method.Params, + ReturnType: method.ReturnType, + IsReturnNullable: method.isReturnNullable, + } + + if err := validator.validateScalarTypes(phpFunc); err != nil { + fmt.Printf("Warning: Method \"%s::%s\" uses unsupported types: %v\n", className, method.Name, err) + + continue + } + + method.lineNumber = lineNumber + currentMethod = method + } + + if currentMethod != nil && strings.HasPrefix(line, "func ") { + goFunc, err := cp.extractGoMethodFunction(scanner, line) + if err != nil { + return nil, fmt.Errorf("extracting Go method function: %w", err) + } + + currentMethod.GoFunction = goFunc + + validator := Validator{} + phpFunc := phpFunction{ + Name: currentMethod.Name, + Signature: currentMethod.Signature, + GoFunction: currentMethod.GoFunction, + Params: currentMethod.Params, + ReturnType: currentMethod.ReturnType, + IsReturnNullable: currentMethod.isReturnNullable, + } + + if err := validator.validateGoFunctionSignatureWithOptions(phpFunc, true); err != nil { + fmt.Printf("Warning: Go method signature mismatch for '%s::%s': %v\n", currentMethod.ClassName, currentMethod.Name, err) + currentMethod = nil + continue + } + + methods = append(methods, *currentMethod) + currentMethod = nil + } + } + + if currentMethod != nil { + return nil, fmt.Errorf("//export_php:method directive at line %d is not followed by a function declaration", currentMethod.lineNumber) + } + + return methods, scanner.Err() +} + +func (cp *classParser) parseMethodSignature(className, signature string) (*phpClassMethod, error) { + matches := methodSignatureRegex.FindStringSubmatch(signature) + + if len(matches) != 4 { + return nil, fmt.Errorf("invalid method signature format") + } + + methodName := matches[1] + paramsStr := strings.TrimSpace(matches[2]) + returnTypeStr := strings.TrimSpace(matches[3]) + + isReturnNullable := strings.HasPrefix(returnTypeStr, "?") + returnType := strings.TrimPrefix(returnTypeStr, "?") + + var params []phpParameter + if paramsStr != "" { + paramParts := strings.Split(paramsStr, ",") + for _, part := range paramParts { + param, err := cp.parseMethodParameter(strings.TrimSpace(part)) + if err != nil { + return nil, fmt.Errorf("parsing parameter '%s': %w", part, err) + } + + params = append(params, param) + } + } + + return &phpClassMethod{ + Name: methodName, + PhpName: methodName, + ClassName: className, + Signature: signature, + Params: params, + ReturnType: returnType, + isReturnNullable: isReturnNullable, + }, nil +} + +func (cp *classParser) parseMethodParameter(paramStr string) (phpParameter, error) { + parts := strings.Split(paramStr, "=") + typePart := strings.TrimSpace(parts[0]) + + param := phpParameter{HasDefault: len(parts) > 1} + + if param.HasDefault { + param.DefaultValue = cp.sanitizeDefaultValue(strings.TrimSpace(parts[1])) + } + + matches := methodParamTypeNameRegex.FindStringSubmatch(typePart) + + if len(matches) < 3 { + return phpParameter{}, fmt.Errorf("invalid parameter format: %s", paramStr) + } + + typeStr := strings.TrimSpace(matches[1]) + param.Name = strings.TrimSpace(matches[2]) + param.IsNullable = strings.HasPrefix(typeStr, "?") + param.PhpType = strings.TrimPrefix(typeStr, "?") + + return param, nil +} + +func (cp *classParser) sanitizeDefaultValue(value string) string { + if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + return value + } + + if strings.ToLower(value) == "null" { + return "null" + } + + return strings.Trim(value, `'"`) +} + +func (cp *classParser) extractGoMethodFunction(scanner *bufio.Scanner, firstLine string) (string, error) { + goFunc := firstLine + "\n" + braceCount := 1 + + for scanner.Scan() { + line := scanner.Text() + goFunc += line + "\n" + + for _, char := range line { + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + } + } + + if braceCount == 0 { + break + } + } + + return goFunc, nil +} diff --git a/internal/extgen/classparser_test.go b/internal/extgen/classparser_test.go new file mode 100644 index 00000000..5acaff08 --- /dev/null +++ b/internal/extgen/classparser_test.go @@ -0,0 +1,641 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClassParser(t *testing.T) { + tests := []struct { + name string + input string + expected int + }{ + { + name: "single class", + input: `package main + +//export_php:class User +type UserStruct struct { + name string + Age int +}`, + expected: 1, + }, + { + name: "multiple classes", + input: `package main + +//export_php:class User +type UserStruct struct { + name string + Age int +} + +//export_php:class Product +type ProductStruct struct { + Title string + Price float64 +}`, + expected: 2, + }, + { + name: "no php classes", + input: `package main + +type RegularStruct struct { + Data string +}`, + expected: 0, + }, + { + name: "class with nullable fields", + input: `package main + +//export_php:class OptionalData +type OptionalStruct struct { + Required string + Optional *string + Count *int +}`, + expected: 1, + }, + { + name: "class with methods", + input: `package main + +//export_php:class User +type UserStruct struct { + name string + Age int +} + +//export_php:method User::getName(): string +func GetUserName(u UserStruct) string { + return u.name +} + +//export_php:method User::setAge(int $age): void +func SetUserAge(u *UserStruct, age int) { + u.Age = age +}`, + expected: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(fileName, []byte(tt.input), 0644)) + + parser := classParser{} + classes, err := parser.parse(fileName) + require.NoError(t, err) + + assert.Len(t, classes, tt.expected, "parse() got wrong number of classes") + + if tt.name == "single class" && len(classes) > 0 { + class := classes[0] + assert.Equal(t, "User", class.Name, "Expected class name 'User'") + assert.Equal(t, "UserStruct", class.GoStruct, "Expected Go struct 'UserStruct'") + assert.Len(t, class.Properties, 2, "Expected 2 properties") + } + + if tt.name == "class with nullable fields" && len(classes) > 0 { + class := classes[0] + if len(class.Properties) >= 3 { + assert.False(t, class.Properties[0].IsNullable, "Required field should not be nullable") + assert.True(t, class.Properties[1].IsNullable, "Optional field should be nullable") + assert.True(t, class.Properties[2].IsNullable, "Count field should be nullable") + } + } + }) + } +} + +func TestClassMethods(t *testing.T) { + var input = []byte(`package main + +//export_php:class User +type UserStruct struct { + name string + Age int +} + +//export_php:method User::getName(): string +func GetUserName(u UserStruct) unsafe.Pointer { + return nil +} + +//export_php:method User::setAge(int $age): void +func SetUserAge(u *UserStruct, age int64) { + u.Age = int(age) +} + +//export_php:method User::getInfo(string $prefix = "User"): string +func GetUserInfo(u UserStruct, prefix *C.zend_string) unsafe.Pointer { + return nil +}`) + + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, "test.go") + require.NoError(t, os.WriteFile(fileName, input, 0644)) + + parser := classParser{} + classes, err := parser.parse(fileName) + require.NoError(t, err) + + require.Len(t, classes, 1, "Expected 1 class") + + class := classes[0] + require.Len(t, class.Methods, 3, "Expected 3 methods") + + getName := class.Methods[0] + assert.Equal(t, "getName", getName.Name, "Expected method name 'getName'") + assert.Equal(t, "string", getName.ReturnType, "Expected return type 'string'") + assert.Empty(t, getName.Params, "Expected 0 params") + assert.Equal(t, "User", getName.ClassName, "Expected class name 'User'") + + setAge := class.Methods[1] + assert.Equal(t, "setAge", setAge.Name, "Expected method name 'setAge'") + assert.Equal(t, "void", setAge.ReturnType, "Expected return type 'void'") + require.Len(t, setAge.Params, 1, "Expected 1 param") + + param := setAge.Params[0] + assert.Equal(t, "age", param.Name, "Expected param name 'age'") + assert.Equal(t, "int", param.PhpType, "Expected param type 'int'") + assert.False(t, param.IsNullable, "Expected param to not be nullable") + assert.False(t, param.HasDefault, "Expected param to not have default value") + + getInfo := class.Methods[2] + assert.Equal(t, "getInfo", getInfo.Name, "Expected method name 'getInfo'") + assert.Equal(t, "string", getInfo.ReturnType, "Expected return type 'string'") + require.Len(t, getInfo.Params, 1, "Expected 1 param") + + param = getInfo.Params[0] + assert.Equal(t, "prefix", param.Name, "Expected param name 'prefix'") + assert.Equal(t, "string", param.PhpType, "Expected param type 'string'") + assert.True(t, param.HasDefault, "Expected param to have default value") + assert.Equal(t, "User", param.DefaultValue, "Expected default value 'User'") +} + +func TestMethodParameterParsing(t *testing.T) { + tests := []struct { + name string + paramStr string + expectedParam phpParameter + expectError bool + }{ + { + name: "simple int parameter", + paramStr: "int $age", + expectedParam: phpParameter{ + Name: "age", + PhpType: "int", + IsNullable: false, + HasDefault: false, + }, + expectError: false, + }, + { + name: "nullable string parameter", + paramStr: "?string $name", + expectedParam: phpParameter{ + Name: "name", + PhpType: "string", + IsNullable: true, + HasDefault: false, + }, + expectError: false, + }, + { + name: "parameter with default value", + paramStr: `string $prefix = "default"`, + expectedParam: phpParameter{ + Name: "prefix", + PhpType: "string", + IsNullable: false, + HasDefault: true, + DefaultValue: "default", + }, + expectError: false, + }, + { + name: "nullable parameter with default null", + paramStr: "?int $count = null", + expectedParam: phpParameter{ + Name: "count", + PhpType: "int", + IsNullable: true, + HasDefault: true, + DefaultValue: "null", + }, + expectError: false, + }, + { + name: "invalid parameter format", + paramStr: "invalid", + expectError: true, + }, + } + + parser := classParser{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + param, err := parser.parseMethodParameter(tt.paramStr) + + if tt.expectError { + assert.Error(t, err, "Expected error for parameter '%s', but got none", tt.paramStr) + return + } + + require.NoError(t, err, "parseMethodParameter(%s) error", tt.paramStr) + + assert.Equal(t, tt.expectedParam.Name, param.Name, "Expected name '%s'", tt.expectedParam.Name) + assert.Equal(t, tt.expectedParam.PhpType, param.PhpType, "Expected type '%s'", tt.expectedParam.PhpType) + assert.Equal(t, tt.expectedParam.IsNullable, param.IsNullable, "Expected isNullable %v", tt.expectedParam.IsNullable) + assert.Equal(t, tt.expectedParam.HasDefault, param.HasDefault, "Expected hasDefault %v", tt.expectedParam.HasDefault) + assert.Equal(t, tt.expectedParam.DefaultValue, param.DefaultValue, "Expected defaultValue '%s'", tt.expectedParam.DefaultValue) + }) + } +} + +func TestGoTypeToPHPType(t *testing.T) { + tests := []struct { + goType string + expected string + }{ + {"string", "string"}, + {"*string", "string"}, + {"int", "int"}, + {"int64", "int"}, + {"*int", "int"}, + {"float64", "float"}, + {"*float32", "float"}, + {"bool", "bool"}, + {"*bool", "bool"}, + {"[]string", "array"}, + {"map[string]int", "array"}, + {"*[]int", "array"}, + {"interface{}", "mixed"}, + {"CustomType", "mixed"}, + } + + parser := classParser{} + for _, tt := range tests { + t.Run(tt.goType, func(t *testing.T) { + result := parser.goTypeToPHPType(tt.goType) + assert.Equal(t, tt.expected, result, "goTypeToPHPType(%s) = %s, want %s", tt.goType, result, tt.expected) + }) + } +} + +func TestTypeToString(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "basic types", + input: `package main + +//export_php:class TestClass +type TestStruct struct { + StringField string + IntField int + FloatField float64 + BoolField bool +}`, + expected: []string{"string", "int", "float", "bool"}, + }, + { + name: "pointer types", + input: `package main + +//export_php:class NullableClass +type NullableStruct struct { + NullableString *string + NullableInt *int + NullableFloat *float64 + NullableBool *bool +}`, + expected: []string{"string", "int", "float", "bool"}, + }, + { + name: "collection types", + input: `package main + +//export_php:class CollectionClass +type CollectionStruct struct { + StringSlice []string + IntMap map[string]int + MixedSlice []interface{} +}`, + expected: []string{"array", "array", "array"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(fileName, []byte(tt.input), 0o644)) + + parser := classParser{} + classes, err := parser.parse(fileName) + require.NoError(t, err) + + require.Len(t, classes, 1, "Expected 1 class") + + class := classes[0] + require.Len(t, class.Properties, len(tt.expected), "Expected %d properties", len(tt.expected)) + + for i, expectedType := range tt.expected { + assert.Equal(t, expectedType, class.Properties[i].PhpType, "Property %d: expected type %s", i, expectedType) + } + }) + } +} + +func TestClassParserUnsupportedTypes(t *testing.T) { + tests := []struct { + name string + input string + expectedClasses int + expectedMethods int + hasWarning bool + }{ + { + name: "method with array parameter should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::arrayMethod(array $data): string +func (tc *TestClass) arrayMethod(data interface{}) unsafe.Pointer { + return nil +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "method with object parameter should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::objectMethod(object $obj): string +func (tc *TestClass) objectMethod(obj interface{}) unsafe.Pointer { + return nil +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "method with mixed parameter should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::mixedMethod(mixed $value): string +func (tc *TestClass) mixedMethod(value interface{}) unsafe.Pointer { + return nil +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "method with array return type should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::arrayReturn(string $name): array +func (tc *TestClass) arrayReturn(name *C.zend_string) interface{} { + return []string{"result"} +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "method with object return type should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::objectReturn(string $name): object +func (tc *TestClass) objectReturn(name *C.zend_string) interface{} { + return map[string]interface{}{"key": "value"} +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "valid scalar types should pass", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::validMethod(string $name, int $count, float $rate, bool $active): string +func validMethod(tc *TestClass, name *C.zend_string, count int64, rate float64, active bool) unsafe.Pointer { + return nil +}`, + expectedClasses: 1, + expectedMethods: 1, + hasWarning: false, + }, + { + name: "valid void return should pass", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::voidMethod(string $message): void +func voidMethod(tc *TestClass, message *C.zend_string) { + // Do something +}`, + expectedClasses: 1, + expectedMethods: 1, + hasWarning: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(fileName, []byte(tt.input), 0644)) + + parser := &classParser{} + classes, err := parser.parse(fileName) + require.NoError(t, err) + + assert.Len(t, classes, tt.expectedClasses, "parse() got wrong number of classes") + if len(classes) > 0 { + assert.Len(t, classes[0].Methods, tt.expectedMethods, "parse() got wrong number of methods") + } + }) + } +} + +func TestClassParserGoTypeMismatch(t *testing.T) { + tests := []struct { + name string + input string + expectedClasses int + expectedMethods int + hasWarning bool + }{ + { + name: "method parameter count mismatch should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::countMismatch(string $name, int $count): string +func (tc *TestClass) countMismatch(name *C.zend_string) unsafe.Pointer { + return nil +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "method parameter type mismatch should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::typeMismatch(string $name, int $count): string +func (tc *TestClass) typeMismatch(name *C.zend_string, count string) unsafe.Pointer { + return nil +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "method return type mismatch should be rejected", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::returnMismatch(string $name): int +func (tc *TestClass) returnMismatch(name *C.zend_string) string { + return "" +}`, + expectedClasses: 1, + expectedMethods: 0, + hasWarning: true, + }, + { + name: "valid matching types should pass", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::validMatch(string $name, int $count): string +func validMatch(tc *TestClass, name *C.zend_string, count int64) unsafe.Pointer { + return nil +}`, + expectedClasses: 1, + expectedMethods: 1, + hasWarning: false, + }, + { + name: "valid bool types should pass", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::validBool(bool $flag): bool +func validBool(tc *TestClass, flag bool) bool { + return flag +}`, + expectedClasses: 1, + expectedMethods: 1, + hasWarning: false, + }, + { + name: "valid float types should pass", + input: `package main + +//export_php:class TestClass +type TestClass struct { + Name string +} + +//export_php:method TestClass::validFloat(float $value): float +func validFloat(tc *TestClass, value float64) float64 { + return value +}`, + expectedClasses: 1, + expectedMethods: 1, + hasWarning: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(fileName, []byte(tt.input), 0644)) + + parser := &classParser{} + classes, err := parser.parse(fileName) + require.NoError(t, err) + + assert.Len(t, classes, tt.expectedClasses, "parse() got wrong number of classes") + if len(classes) > 0 { + assert.Len(t, classes[0].Methods, tt.expectedMethods, "parse() got wrong number of methods") + } + }) + } +} diff --git a/internal/extgen/constants_test.go b/internal/extgen/constants_test.go new file mode 100644 index 00000000..cf14fcac --- /dev/null +++ b/internal/extgen/constants_test.go @@ -0,0 +1,160 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConstantsIntegration(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + + content := `package main + +//export_php:const +const STATUS_OK = iota + +//export_php:const +const MAX_CONNECTIONS = 100 + +//export_php:const: function test(): void +func Test() { + // Implementation +} + +func main() {} +` + + require.NoError(t, os.WriteFile(testFile, []byte(content), 0644)) + + generator := &Generator{ + BaseName: "testext", + SourceFile: testFile, + BuildDir: filepath.Join(tmpDir, "build"), + } + + require.NoError(t, generator.parseSource()) + assert.Len(t, generator.Constants, 2, "Expected 2 constants") + + expectedConstants := map[string]struct { + Value string + IsIota bool + }{ + "STATUS_OK": {"0", true}, + "MAX_CONNECTIONS": {"100", false}, + } + + for _, constant := range generator.Constants { + expected, exists := expectedConstants[constant.Name] + assert.True(t, exists, "Unexpected constant: %s", constant.Name) + if !exists { + continue + } + + assert.Equal(t, expected.Value, constant.Value, "Constant %s: value mismatch", constant.Name) + assert.Equal(t, expected.IsIota, constant.IsIota, "Constant %s: isIota mismatch", constant.Name) + } + + require.NoError(t, generator.setupBuildDirectory()) + require.NoError(t, generator.generateStubFile()) + + stubPath := filepath.Join(generator.BuildDir, generator.BaseName+".stub.php") + stubContent, err := os.ReadFile(stubPath) + require.NoError(t, err) + + stubStr := string(stubContent) + + assert.Contains(t, stubStr, "* @cvalue", "Stub does not contain @cvalue annotation for iota constant") + assert.Contains(t, stubStr, "const STATUS_OK = UNKNOWN;", "Stub does not contain STATUS_OK constant with UNKNOWN value") + assert.Contains(t, stubStr, "const MAX_CONNECTIONS = 100;", "Stub does not contain MAX_CONNECTIONS constant with explicit value") + + require.NoError(t, generator.generateCFile()) + + cPath := filepath.Join(generator.BuildDir, generator.BaseName+".c") + cContent, err := os.ReadFile(cPath) + require.NoError(t, err) + + cStr := string(cContent) + + assert.Contains(t, cStr, `REGISTER_LONG_CONSTANT("STATUS_OK", STATUS_OK, CONST_CS | CONST_PERSISTENT);`, "C file does not contain STATUS_OK registration") + assert.Contains(t, cStr, `REGISTER_LONG_CONSTANT("MAX_CONNECTIONS", 100, CONST_CS | CONST_PERSISTENT);`, "C file does not contain MAX_CONNECTIONS registration") +} + +func TestConstantsIntegrationOctal(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + + content := `package main + +//export_php:const +const FILE_PERM = 0o755 + +//export_php:const +const OTHER_PERM = 0o644 + +//export_php:const +const REGULAR_INT = 42 + +func main() {} +` + + require.NoError(t, os.WriteFile(testFile, []byte(content), 0644)) + + generator := &Generator{ + BaseName: "octalstest", + SourceFile: testFile, + BuildDir: filepath.Join(tmpDir, "build"), + } + + require.NoError(t, generator.parseSource()) + assert.Len(t, generator.Constants, 3, "Expected 3 constants") + + // Verify CValue conversion + for _, constant := range generator.Constants { + switch constant.Name { + case "FILE_PERM": + assert.Equal(t, "0o755", constant.Value, "FILE_PERM value mismatch") + assert.Equal(t, "493", constant.CValue(), "FILE_PERM CValue mismatch") + case "OTHER_PERM": + assert.Equal(t, "0o644", constant.Value, "OTHER_PERM value mismatch") + assert.Equal(t, "420", constant.CValue(), "OTHER_PERM CValue mismatch") + case "REGULAR_INT": + assert.Equal(t, "42", constant.Value, "REGULAR_INT value mismatch") + assert.Equal(t, "42", constant.CValue(), "REGULAR_INT CValue mismatch") + } + } + + require.NoError(t, generator.setupBuildDirectory()) + + // Test C file generation + require.NoError(t, generator.generateCFile()) + + cPath := filepath.Join(generator.BuildDir, generator.BaseName+".c") + cContent, err := os.ReadFile(cPath) + require.NoError(t, err) + + cStr := string(cContent) + + // Verify C file uses decimal values for octal constants + assert.Contains(t, cStr, `REGISTER_LONG_CONSTANT("FILE_PERM", 493, CONST_CS | CONST_PERSISTENT);`, "C file does not contain FILE_PERM registration with decimal value 493") + assert.Contains(t, cStr, `REGISTER_LONG_CONSTANT("OTHER_PERM", 420, CONST_CS | CONST_PERSISTENT);`, "C file does not contain OTHER_PERM registration with decimal value 420") + assert.Contains(t, cStr, `REGISTER_LONG_CONSTANT("REGULAR_INT", 42, CONST_CS | CONST_PERSISTENT);`, "C file does not contain REGULAR_INT registration with value 42") + + // Test header file generation + require.NoError(t, generator.generateHeaderFile()) + + hPath := filepath.Join(generator.BuildDir, generator.BaseName+".h") + hContent, err := os.ReadFile(hPath) + require.NoError(t, err) + + hStr := string(hContent) + + // Verify header file uses decimal values for octal constants in #define + assert.Contains(t, hStr, "#define FILE_PERM 493", "Header file does not contain FILE_PERM #define with decimal value 493") + assert.Contains(t, hStr, "#define OTHER_PERM 420", "Header file does not contain OTHER_PERM #define with decimal value 420") + assert.Contains(t, hStr, "#define REGULAR_INT 42", "Header file does not contain REGULAR_INT #define with value 42") +} diff --git a/internal/extgen/constparser.go b/internal/extgen/constparser.go new file mode 100644 index 00000000..b7bb3cdb --- /dev/null +++ b/internal/extgen/constparser.go @@ -0,0 +1,133 @@ +package extgen + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strconv" + "strings" +) + +var constRegex = regexp.MustCompile(`//\s*export_php:const$`) +var classConstRegex = regexp.MustCompile(`//\s*export_php:classconst\s+(\w+)$`) +var constDeclRegex = regexp.MustCompile(`const\s+(\w+)\s*=\s*(.+)`) + +type ConstantParser struct { + constRegex *regexp.Regexp + classConstRegex *regexp.Regexp + constDeclRegex *regexp.Regexp +} + +func NewConstantParserWithDefRegex() *ConstantParser { + return &ConstantParser{ + constRegex: constRegex, + classConstRegex: classConstRegex, + constDeclRegex: constDeclRegex, + } +} + +func (cp *ConstantParser) parse(filename string) (constants []phpConstant, err error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer func() { + e := file.Close() + if err == nil { + err = e + } + }() + + scanner := bufio.NewScanner(file) + + lineNumber := 0 + expectConstDecl := false + expectClassConstDecl := false + currentClassName := "" + currentConstantValue := 0 + + for scanner.Scan() { + lineNumber++ + line := strings.TrimSpace(scanner.Text()) + + if cp.constRegex.MatchString(line) { + expectConstDecl = true + expectClassConstDecl = false + currentClassName = "" + + continue + } + + if matches := cp.classConstRegex.FindStringSubmatch(line); len(matches) == 2 { + expectClassConstDecl = true + expectConstDecl = false + currentClassName = matches[1] + + continue + } + + if (expectConstDecl || expectClassConstDecl) && strings.HasPrefix(line, "const ") { + matches := cp.constDeclRegex.FindStringSubmatch(line) + if len(matches) == 3 { + name := matches[1] + value := strings.TrimSpace(matches[2]) + + constant := phpConstant{ + Name: name, + Value: value, + IsIota: value == "iota", + lineNumber: lineNumber, + ClassName: currentClassName, + } + + constant.PhpType = determineConstantType(value) + + if constant.IsIota { + // affect a default value because user didn't give one + constant.Value = fmt.Sprintf("%d", currentConstantValue) + constant.PhpType = "int" + currentConstantValue++ + } + + constants = append(constants, constant) + } else { + return nil, fmt.Errorf("invalid constant declaration at line %d: %s", lineNumber, line) + } + expectConstDecl = false + expectClassConstDecl = false + } else if (expectConstDecl || expectClassConstDecl) && !strings.HasPrefix(line, "//") && line != "" { + // we expected a const declaration but found something else, reset + expectConstDecl = false + expectClassConstDecl = false + currentClassName = "" + } + } + + return constants, scanner.Err() +} + +// determineConstantType analyzes the value and determines its type +func determineConstantType(value string) string { + value = strings.TrimSpace(value) + + if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) || + (strings.HasPrefix(value, "`") && strings.HasSuffix(value, "`")) { + return "string" + } + + if value == "true" || value == "false" { + return "bool" + } + + // check for integer literals, including hex, octal, binary + if _, err := strconv.ParseInt(value, 0, 64); err == nil { + return "int" + } + + if _, err := strconv.ParseFloat(value, 64); err == nil { + return "float" + } + + return "int" +} diff --git a/internal/extgen/constparser_test.go b/internal/extgen/constparser_test.go new file mode 100644 index 00000000..34aa6d81 --- /dev/null +++ b/internal/extgen/constparser_test.go @@ -0,0 +1,558 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConstantParser(t *testing.T) { + tests := []struct { + name string + input string + expected int + }{ + { + name: "single constant", + input: `package main + +//export_php:const +const MyConstant = "test_value"`, + expected: 1, + }, + { + name: "multiple constants", + input: `package main + +//export_php:const +const FirstConstant = "first" + +//export_php:const +const SecondConstant = 42 + +//export_php:const +const ThirdConstant = true`, + expected: 3, + }, + { + name: "iota constant", + input: `package main + +//export_php:const +const IotaConstant = iota`, + expected: 1, + }, + { + name: "mixed constants and iota", + input: `package main + +//export_php:const +const StringConst = "hello" + +//export_php:const +const IotaConst = iota + +//export_php:const +const IntConst = 123`, + expected: 3, + }, + { + name: "no php constants", + input: `package main + +const RegularConstant = "not exported" + +func someFunction() { + // Just regular code +}`, + expected: 0, + }, + { + name: "constant with complex value", + input: `package main + +//export_php:const +const ComplexConstant = "string with spaces and symbols !@#$%"`, + expected: 1, + }, + { + name: "directive without constant", + input: `package main + +//export_php:const +var notAConstant = "this is a variable"`, + expected: 0, + }, + { + name: "mixed export and non-export constants", + input: `package main + +const RegularConst = "regular" + +//export_php:const +const ExportedConst = "exported" + +const AnotherRegular = 456 + +//export_php:const +const AnotherExported = 789`, + expected: 2, + }, + { + name: "numeric constants", + input: `package main + +//export_php:const +const IntConstant = 42 + +//export_php:const +const FloatConstant = 3.14 + +//export_php:const +const HexConstant = 0xFF`, + expected: 3, + }, + { + name: "boolean constants", + input: `package main + +//export_php:const +const TrueConstant = true + +//export_php:const +const FalseConstant = false`, + expected: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(tmpFile, []byte(tt.input), 0644)) + + parser := NewConstantParserWithDefRegex() + constants, err := parser.parse(tmpFile) + assert.NoError(t, err, "parse() error") + + assert.Len(t, constants, tt.expected, "parse() got wrong number of constants") + + if tt.name == "single constant" && len(constants) > 0 { + c := constants[0] + assert.Equal(t, "MyConstant", c.Name, "Expected constant name 'MyConstant'") + assert.Equal(t, `"test_value"`, c.Value, `Expected constant value '"test_value"'`) + assert.Equal(t, "string", c.PhpType, "Expected constant type 'string'") + assert.False(t, c.IsIota, "Expected isIota to be false for string constant") + } + + if tt.name == "iota constant" && len(constants) > 0 { + c := constants[0] + assert.Equal(t, "IotaConstant", c.Name, "Expected constant name 'IotaConstant'") + assert.True(t, c.IsIota, "Expected isIota to be true") + assert.Equal(t, "0", c.Value, "Expected iota constant value to be '0'") + } + + if tt.name == "multiple constants" && len(constants) == 3 { + expectedNames := []string{"FirstConstant", "SecondConstant", "ThirdConstant"} + expectedValues := []string{`"first"`, "42", "true"} + expectedTypes := []string{"string", "int", "bool"} + + for i, c := range constants { + assert.Equal(t, expectedNames[i], c.Name, "Expected constant name '%s'", expectedNames[i]) + assert.Equal(t, expectedValues[i], c.Value, "Expected constant value '%s'", expectedValues[i]) + assert.Equal(t, expectedTypes[i], c.PhpType, "Expected constant type '%s'", expectedTypes[i]) + } + } + }) + } +} + +func TestConstantParserErrors(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + }{ + { + name: "invalid constant declaration", + input: `package main + +//export_php:const +const = "missing name"`, + expectError: true, + }, + { + name: "malformed constant", + input: `package main + +//export_php:const +const InvalidSyntax`, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(tmpFile, []byte(tt.input), 0644)) + + parser := NewConstantParserWithDefRegex() + _, err := parser.parse(tmpFile) + require.NotNil(t, err) + + if tt.expectError { + assert.Error(t, err, "Expected error but got none") + + return + } + + assert.NoError(t, err) + }) + } +} + +func TestConstantParserIotaSequence(t *testing.T) { + input := `package main + +//export_php:const +const FirstIota = iota + +//export_php:const +const SecondIota = iota + +//export_php:const +const ThirdIota = iota` + + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, "test.go") + require.NoError(t, os.WriteFile(fileName, []byte(input), 0644)) + + parser := NewConstantParserWithDefRegex() + constants, err := parser.parse(fileName) + assert.NoError(t, err, "parse() error") + + assert.Len(t, constants, 3, "Expected 3 constants") + + expectedValues := []string{"0", "1", "2"} + for i, c := range constants { + assert.True(t, c.IsIota, "Expected constant %d to be iota", i) + assert.Equal(t, expectedValues[i], c.Value, "Expected constant %d value to be '%s'", i, expectedValues[i]) + } +} + +func TestConstantParserTypeDetection(t *testing.T) { + tests := []struct { + name string + value string + expectedType string + }{ + {"string with double quotes", "\"hello world\"", "string"}, + {"string with backticks", "`hello world`", "string"}, + {"boolean true", "true", "bool"}, + {"boolean false", "false", "bool"}, + {"integer", "42", "int"}, + {"negative integer", "-42", "int"}, + {"hex integer", "0xFF", "int"}, + {"octal integer", "0755", "int"}, + {"go octal integer", "0o755", "int"}, + {"binary integer", "0b1010", "int"}, + {"float", "3.14", "float"}, + {"negative float", "-3.14", "float"}, + {"scientific notation", "1e10", "float"}, + {"unknown type", "someFunction()", "int"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := determineConstantType(tt.value) + assert.Equal(t, tt.expectedType, result, "determineConstantType(%s) expected %s", tt.value, tt.expectedType) + }) + } +} + +func TestConstantParserClassConstants(t *testing.T) { + tests := []struct { + name string + input string + expected int + }{ + { + name: "single class constant", + input: `package main + +//export_php:classconst MyClass +const STATUS_ACTIVE = 1`, + expected: 1, + }, + { + name: "multiple class constants", + input: `package main + +//export_php:classconst User +const STATUS_ACTIVE = "active" + +//export_php:classconst User +const STATUS_INACTIVE = "inactive" + +//export_php:classconst Order +const STATE_PENDING = 0`, + expected: 3, + }, + { + name: "mixed global and class constants", + input: `package main + +//export_php:const +const GLOBAL_CONST = "global" + +//export_php:classconst MyClass +const CLASS_CONST = 42 + +//export_php:const +const ANOTHER_GLOBAL = true`, + expected: 3, + }, + { + name: "class constant with iota", + input: `package main + +//export_php:classconst Status +const FIRST = iota + +//export_php:classconst Status +const SECOND = iota`, + expected: 2, + }, + { + name: "invalid class constant directive", + input: `package main + +//export_php:classconst +const INVALID = "missing class name"`, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(tmpFile, []byte(tt.input), 0644)) + + parser := NewConstantParserWithDefRegex() + constants, err := parser.parse(tmpFile) + assert.NoError(t, err, "parse() error") + + assert.Len(t, constants, tt.expected, "parse() got wrong number of constants") + + if tt.name == "single class constant" && len(constants) > 0 { + c := constants[0] + assert.Equal(t, "STATUS_ACTIVE", c.Name, "Expected constant name 'STATUS_ACTIVE'") + assert.Equal(t, "MyClass", c.ClassName, "Expected class name 'MyClass'") + assert.Equal(t, "1", c.Value, "Expected constant value '1'") + assert.Equal(t, "int", c.PhpType, "Expected constant type 'int'") + } + + if tt.name == "multiple class constants" && len(constants) == 3 { + expectedClasses := []string{"User", "User", "Order"} + expectedNames := []string{"STATUS_ACTIVE", "STATUS_INACTIVE", "STATE_PENDING"} + expectedValues := []string{`"active"`, `"inactive"`, "0"} + + for i, c := range constants { + assert.Equal(t, expectedClasses[i], c.ClassName, "Expected class name '%s'", expectedClasses[i]) + assert.Equal(t, expectedNames[i], c.Name, "Expected constant name '%s'", expectedNames[i]) + assert.Equal(t, expectedValues[i], c.Value, "Expected constant value '%s'", expectedValues[i]) + } + } + + if tt.name == "mixed global and class constants" && len(constants) == 3 { + assert.Empty(t, constants[0].ClassName, "First constant should be global") + assert.Equal(t, "MyClass", constants[1].ClassName, "Second constant should belong to MyClass") + assert.Empty(t, constants[2].ClassName, "Third constant should be global") + } + }) + } +} + +func TestConstantParserRegexMatch(t *testing.T) { + parser := NewConstantParserWithDefRegex() + + testCases := []struct { + line string + expected bool + }{ + {"//export_php:const", true}, + {"// export_php:const", true}, + {"// export_php:const", true}, + {"//export_php:const ", false}, // should not match with trailing content + {"//export_php", false}, + {"//export_php:function", false}, + {"//export_php:class", false}, + {"// some other comment", false}, + } + + for _, tc := range testCases { + t.Run(tc.line, func(t *testing.T) { + matches := parser.constRegex.MatchString(tc.line) + assert.Equal(t, tc.expected, matches, "Expected regex match for line '%s'", tc.line) + }) + } +} + +func TestConstantParserClassConstRegex(t *testing.T) { + parser := NewConstantParserWithDefRegex() + + testCases := []struct { + line string + shouldMatch bool + className string + }{ + {"//export_php:classconst MyClass", true, "MyClass"}, + {"// export_php:classconst User", true, "User"}, + {"// export_php:classconst Status", true, "Status"}, + {"//export_php:classconst Order123", true, "Order123"}, + {"//export_php:classconst", false, ""}, + {"//export_php:classconst ", false, ""}, + {"//export_php:classconst MyClass extra", false, ""}, + {"//export_php:const", false, ""}, + {"//export_php:function", false, ""}, + {"// some other comment", false, ""}, + } + + for _, tc := range testCases { + t.Run(tc.line, func(t *testing.T) { + matches := parser.classConstRegex.FindStringSubmatch(tc.line) + + if tc.shouldMatch { + assert.Len(t, matches, 2, "Expected 2 matches for line '%s'", tc.line) + if len(matches) != 2 { + return + } + assert.Equal(t, tc.className, matches[1], "Expected class name '%s'", tc.className) + } else { + assert.Empty(t, matches, "Expected no matches for line '%s'", tc.line) + } + }) + } +} + +func TestConstantParserDeclRegex(t *testing.T) { + parser := NewConstantParserWithDefRegex() + + testCases := []struct { + line string + shouldMatch bool + name string + value string + }{ + {"const MyConst = \"value\"", true, "MyConst", "\"value\""}, + {"const IntConst = 42", true, "IntConst", "42"}, + {"const BoolConst = true", true, "BoolConst", "true"}, + {"const IotaConst = iota", true, "IotaConst", "iota"}, + {"const ComplexValue = someFunction()", true, "ComplexValue", "someFunction()"}, + {"const SpacedName = \"with spaces\"", true, "SpacedName", "\"with spaces\""}, + {"var notAConst = \"value\"", false, "", ""}, + {"const", false, "", ""}, + {"const =", false, "", ""}, + } + + for _, tc := range testCases { + t.Run(tc.line, func(t *testing.T) { + matches := parser.constDeclRegex.FindStringSubmatch(tc.line) + + if tc.shouldMatch { + assert.Len(t, matches, 3, "Expected 3 matches for line '%s'", tc.line) + if len(matches) != 3 { + return + } + assert.Equal(t, tc.name, matches[1], "Expected name '%s'", tc.name) + assert.Equal(t, tc.value, matches[2], "Expected value '%s'", tc.value) + } else { + assert.Empty(t, matches, "Expected no matches for line '%s'", tc.line) + } + }) + } +} + +func TestPHPConstantCValue(t *testing.T) { + tests := []struct { + name string + constant phpConstant + expected string + }{ + { + name: "octal notation 0o35", + constant: phpConstant{ + Name: "OctalConst", + Value: "0o35", + PhpType: "int", + }, + expected: "29", // 0o35 = 29 in decimal + }, + { + name: "octal notation 0o755", + constant: phpConstant{ + Name: "OctalPerm", + Value: "0o755", + PhpType: "int", + }, + expected: "493", // 0o755 = 493 in decimal + }, + { + name: "regular integer", + constant: phpConstant{ + Name: "RegularInt", + Value: "42", + PhpType: "int", + }, + expected: "42", + }, + { + name: "hex integer", + constant: phpConstant{ + Name: "HexInt", + Value: "0xFF", + PhpType: "int", + }, + expected: "0xFF", // hex should remain unchanged + }, + { + name: "string constant", + constant: phpConstant{ + Name: "StringConst", + Value: "\"hello\"", + PhpType: "string", + }, + expected: "\"hello\"", // strings should remain unchanged + }, + { + name: "boolean constant", + constant: phpConstant{ + Name: "BoolConst", + Value: "true", + PhpType: "bool", + }, + expected: "true", // booleans should remain unchanged + }, + { + name: "float constant", + constant: phpConstant{ + Name: "FloatConst", + Value: "3.14", + PhpType: "float", + }, + expected: "3.14", // floats should remain unchanged + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.constant.CValue() + assert.Equal(t, tt.expected, result, "CValue() expected %s", tt.expected) + }) + } +} diff --git a/internal/extgen/docs.go b/internal/extgen/docs.go new file mode 100644 index 00000000..24040957 --- /dev/null +++ b/internal/extgen/docs.go @@ -0,0 +1,46 @@ +package extgen + +import ( + "bytes" + _ "embed" + "path/filepath" + "text/template" +) + +//go:embed templates/README.md.tpl +var docFileContent string + +type DocumentationGenerator struct { + generator *Generator +} + +type DocTemplateData struct { + BaseName string + Functions []phpFunction + Classes []phpClass +} + +func (dg *DocumentationGenerator) generate() error { + filename := filepath.Join(dg.generator.BuildDir, "README.md") + content, err := dg.generateMarkdown() + if err != nil { + return err + } + + return WriteFile(filename, content) +} + +func (dg *DocumentationGenerator) generateMarkdown() (string, error) { + tmpl := template.Must(template.New("readme").Parse(docFileContent)) + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, DocTemplateData{ + BaseName: dg.generator.BaseName, + Functions: dg.generator.Functions, + Classes: dg.generator.Classes, + }); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/internal/extgen/docs_test.go b/internal/extgen/docs_test.go new file mode 100644 index 00000000..78a8abfd --- /dev/null +++ b/internal/extgen/docs_test.go @@ -0,0 +1,386 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDocumentationGenerator_Generate(t *testing.T) { + tests := []struct { + name string + generator *Generator + expectError bool + }{ + { + name: "simple extension with functions", + generator: &Generator{ + BaseName: "testextension", + BuildDir: "", + Functions: []phpFunction{ + { + Name: "greet", + ReturnType: "string", + Params: []phpParameter{ + {Name: "name", PhpType: "string"}, + }, + Signature: "greet(string $name): string", + }, + }, + Classes: []phpClass{}, + }, + expectError: false, + }, + { + name: "extension with classes", + generator: &Generator{ + BaseName: "classextension", + BuildDir: "", + Functions: []phpFunction{}, + Classes: []phpClass{ + { + Name: "TestClass", + Properties: []phpClassProperty{ + {Name: "name", PhpType: "string"}, + {Name: "count", PhpType: "int", IsNullable: true}, + }, + }, + }, + }, + expectError: false, + }, + { + name: "extension with both functions and classes", + generator: &Generator{ + BaseName: "fullextension", + BuildDir: "", + Functions: []phpFunction{ + { + Name: "calculate", + ReturnType: "int", + IsReturnNullable: true, + Params: []phpParameter{ + {Name: "base", PhpType: "int"}, + {Name: "multiplier", PhpType: "int", HasDefault: true, DefaultValue: "2", IsNullable: true}, + }, + Signature: "calculate(int $base, ?int $multiplier = 2): ?int", + }, + }, + Classes: []phpClass{ + { + Name: "Calculator", + Properties: []phpClassProperty{ + {Name: "precision", PhpType: "int"}, + }, + }, + }, + }, + expectError: false, + }, + { + name: "empty extension", + generator: &Generator{ + BaseName: "emptyextension", + BuildDir: "", + Functions: []phpFunction{}, + Classes: []phpClass{}, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + tt.generator.BuildDir = tempDir + + docGen := &DocumentationGenerator{ + generator: tt.generator, + } + + err := docGen.generate() + + if tt.expectError { + assert.Error(t, err, "generate() expected error but got none") + return + } + + assert.NoError(t, err, "generate() unexpected error") + + readmePath := filepath.Join(tempDir, "README.md") + require.FileExists(t, readmePath) + + content, err := os.ReadFile(readmePath) + require.NoError(t, err, "Failed to read generated README.md") + + contentStr := string(content) + + assert.Contains(t, contentStr, "# "+tt.generator.BaseName+" Extension", "README should contain extension title") + assert.Contains(t, contentStr, "Auto-generated PHP extension from Go code.", "README should contain description") + + if len(tt.generator.Functions) > 0 { + assert.Contains(t, contentStr, "## Functions", "README should contain functions section when functions exist") + + for _, fn := range tt.generator.Functions { + assert.Contains(t, contentStr, "### "+fn.Name, "README should contain function %s", fn.Name) + assert.Contains(t, contentStr, fn.Signature, "README should contain function signature for %s", fn.Name) + } + } + + if len(tt.generator.Classes) > 0 { + assert.Contains(t, contentStr, "## Classes", "README should contain classes section when classes exist") + + for _, class := range tt.generator.Classes { + assert.Contains(t, contentStr, "### "+class.Name, "README should contain class %s", class.Name) + } + } + }) + } +} + +func TestDocumentationGenerator_GenerateMarkdown(t *testing.T) { + tests := []struct { + name string + generator *Generator + contains []string + notContains []string + }{ + { + name: "function with parameters", + generator: &Generator{ + BaseName: "testextension", + Functions: []phpFunction{ + { + Name: "processData", + ReturnType: "array", + Params: []phpParameter{ + {Name: "data", PhpType: "string"}, + {Name: "options", PhpType: "array", IsNullable: true}, + {Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "10"}, + }, + Signature: "processData(string $data, ?array $options, int $count = 10): array", + }, + }, + Classes: []phpClass{}, + }, + contains: []string{ + "# testextension Extension", + "## Functions", + "### processData", + "**Parameters:**", + "- `data` (string)", + "- `options` (array) (nullable)", + "- `count` (int) (default: 10)", + "**Returns:** array", + }, + }, + { + name: "nullable return type", + generator: &Generator{ + BaseName: "nullableext", + Functions: []phpFunction{ + { + Name: "maybeGetValue", + ReturnType: "string", + IsReturnNullable: true, + Params: []phpParameter{}, + Signature: "maybeGetValue(): ?string", + }, + }, + Classes: []phpClass{}, + }, + contains: []string{ + "**Returns:** string (nullable)", + }, + }, + { + name: "class with properties", + generator: &Generator{ + BaseName: "classext", + Functions: []phpFunction{}, + Classes: []phpClass{ + { + Name: "DataProcessor", + Properties: []phpClassProperty{ + {Name: "name", PhpType: "string"}, + {Name: "config", PhpType: "array", IsNullable: true}, + {Name: "enabled", PhpType: "bool"}, + }, + }, + }, + }, + contains: []string{ + "## Classes", + "### DataProcessor", + "**Properties:**", + "- `name`: string", + "- `config`: array (nullable)", + "- `enabled`: bool", + }, + }, + { + name: "extension with no functions or classes", + generator: &Generator{ + BaseName: "emptyext", + Functions: []phpFunction{}, + Classes: []phpClass{}, + }, + contains: []string{ + "# emptyext Extension", + "Auto-generated PHP extension from Go code.", + }, + notContains: []string{ + "## Functions", + "## Classes", + }, + }, + { + name: "function with no parameters", + generator: &Generator{ + BaseName: "noparamext", + Functions: []phpFunction{ + { + Name: "getCurrentTime", + ReturnType: "int", + Params: []phpParameter{}, + Signature: "getCurrentTime(): int", + }, + }, + Classes: []phpClass{}, + }, + contains: []string{ + "### getCurrentTime", + "**Returns:** int", + }, + notContains: []string{ + "**Parameters:**", + }, + }, + { + name: "class with no properties", + generator: &Generator{ + BaseName: "nopropsext", + Functions: []phpFunction{}, + Classes: []phpClass{ + { + Name: "EmptyClass", + Properties: []phpClassProperty{}, + }, + }, + }, + contains: []string{ + "### EmptyClass", + }, + notContains: []string{ + "**Properties:**", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + docGen := &DocumentationGenerator{ + generator: tt.generator, + } + + result, err := docGen.generateMarkdown() + if !assert.NoError(t, err, "generateMarkdown() unexpected error") { + return + } + + for _, expected := range tt.contains { + assert.Contains(t, result, expected, "generateMarkdown() should contain '%s'", expected) + } + + for _, notExpected := range tt.notContains { + assert.NotContains(t, result, notExpected, "generateMarkdown() should NOT contain '%s'", notExpected) + } + }) + } +} + +func TestDocumentationGenerator_Generate_InvalidDirectory(t *testing.T) { + generator := &Generator{ + BaseName: "test", + BuildDir: "/nonexistent/directory", + Functions: []phpFunction{}, + Classes: []phpClass{}, + } + + docGen := &DocumentationGenerator{ + generator: generator, + } + + err := docGen.generate() + assert.Error(t, err, "generate() expected error for invalid directory but got none") +} + +func TestDocumentationGenerator_TemplateError(t *testing.T) { + generator := &Generator{ + BaseName: "test", + Functions: []phpFunction{ + { + Name: "test", + ReturnType: "string", + Signature: "test(): string", + }, + }, + Classes: []phpClass{}, + } + + docGen := &DocumentationGenerator{ + generator: generator, + } + + result, err := docGen.generateMarkdown() + assert.NoError(t, err, "generateMarkdown() unexpected error") + assert.NotEmpty(t, result, "generateMarkdown() returned empty result") +} + +func BenchmarkDocumentationGenerator_GenerateMarkdown(b *testing.B) { + generator := &Generator{ + BaseName: "benchext", + Functions: []phpFunction{ + { + Name: "function1", + ReturnType: "string", + Params: []phpParameter{ + {Name: "param1", PhpType: "string"}, + {Name: "param2", PhpType: "int", HasDefault: true, DefaultValue: "0"}, + }, + Signature: "function1(string $param1, int $param2 = 0): string", + }, + { + Name: "function2", + ReturnType: "array", + IsReturnNullable: true, + Params: []phpParameter{ + {Name: "data", PhpType: "array", IsNullable: true}, + }, + Signature: "function2(?array $data): ?array", + }, + }, + Classes: []phpClass{ + { + Name: "TestClass", + Properties: []phpClassProperty{ + {Name: "prop1", PhpType: "string"}, + {Name: "prop2", PhpType: "int", IsNullable: true}, + }, + }, + }, + } + + docGen := &DocumentationGenerator{ + generator: generator, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := docGen.generateMarkdown() + assert.NoError(b, err) + } +} diff --git a/internal/extgen/errors.go b/internal/extgen/errors.go new file mode 100644 index 00000000..d4d0f145 --- /dev/null +++ b/internal/extgen/errors.go @@ -0,0 +1,17 @@ +package extgen + +import "fmt" + +type GeneratorError struct { + Stage string + Message string + Err error +} + +func (e *GeneratorError) Error() string { + if e.Err == nil { + return fmt.Sprintf("generator error at %s: %s", e.Stage, e.Message) + } + + return fmt.Sprintf("generator error at %s: %s: %v", e.Stage, e.Message, e.Err) +} diff --git a/internal/extgen/funcparser.go b/internal/extgen/funcparser.go new file mode 100644 index 00000000..eb9275d6 --- /dev/null +++ b/internal/extgen/funcparser.go @@ -0,0 +1,191 @@ +package extgen + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" +) + +var phpFuncRegex = regexp.MustCompile(`//\s*export_php:function\s+([^{}\n]+)(?:\s*{\s*})?`) +var signatureRegex = regexp.MustCompile(`(\w+)\s*\(([^)]*)\)\s*:\s*(\??[\w|]+)`) +var typeNameRegex = regexp.MustCompile(`(\??[\w|]+)\s+\$?(\w+)`) + +type FuncParser struct { + phpFuncRegex *regexp.Regexp +} + +func NewFuncParserDefRegex() *FuncParser { + return &FuncParser{ + phpFuncRegex: phpFuncRegex, + } +} + +func (fp *FuncParser) parse(filename string) (functions []phpFunction, err error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer func() { + e := file.Close() + if err == nil { + err = e + } + }() + + scanner := bufio.NewScanner(file) + var currentPHPFunc *phpFunction + validator := Validator{} + + lineNumber := 0 + for scanner.Scan() { + lineNumber++ + line := strings.TrimSpace(scanner.Text()) + + if matches := fp.phpFuncRegex.FindStringSubmatch(line); matches != nil { + signature := strings.TrimSpace(matches[1]) + phpFunc, err := fp.parseSignature(signature) + if err != nil { + fmt.Printf("Warning: Error parsing signature '%s': %v\n", signature, err) + + continue + } + + if err := validator.validateFunction(*phpFunc); err != nil { + fmt.Printf("Warning: Invalid function '%s': %v\n", phpFunc.Name, err) + + continue + } + + if err := validator.validateScalarTypes(*phpFunc); err != nil { + fmt.Printf("Warning: Function '%s' uses unsupported types: %v\n", phpFunc.Name, err) + + continue + } + + phpFunc.lineNumber = lineNumber + currentPHPFunc = phpFunc + } + + if currentPHPFunc != nil && strings.HasPrefix(line, "func ") { + goFunc, err := fp.extractGoFunction(scanner, line) + if err != nil { + return nil, fmt.Errorf("extracting Go function: %w", err) + } + + currentPHPFunc.GoFunction = goFunc + + if err := validator.validateGoFunctionSignatureWithOptions(*currentPHPFunc, false); err != nil { + fmt.Printf("Warning: Go function signature mismatch for %q: %v\n", currentPHPFunc.Name, err) + currentPHPFunc = nil + + continue + } + + functions = append(functions, *currentPHPFunc) + currentPHPFunc = nil + } + } + + if currentPHPFunc != nil { + return nil, fmt.Errorf("//export_php function directive at line %d is not followed by a function declaration", currentPHPFunc.lineNumber) + } + + return functions, scanner.Err() +} + +func (fp *FuncParser) extractGoFunction(scanner *bufio.Scanner, firstLine string) (string, error) { + goFunc := firstLine + "\n" + braceCount := 1 + + for scanner.Scan() { + line := scanner.Text() + goFunc += line + "\n" + + for _, char := range line { + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + } + } + + if braceCount == 0 { + break + } + } + + return goFunc, nil +} + +func (fp *FuncParser) parseSignature(signature string) (*phpFunction, error) { + matches := signatureRegex.FindStringSubmatch(signature) + + if len(matches) != 4 { + return nil, fmt.Errorf("invalid signature format") + } + + name := matches[1] + paramsStr := strings.TrimSpace(matches[2]) + returnTypeStr := strings.TrimSpace(matches[3]) + + isReturnNullable := strings.HasPrefix(returnTypeStr, "?") + returnType := strings.TrimPrefix(returnTypeStr, "?") + + var params []phpParameter + if paramsStr != "" { + paramParts := strings.Split(paramsStr, ",") + for _, part := range paramParts { + param, err := fp.parseParameter(strings.TrimSpace(part)) + if err != nil { + return nil, fmt.Errorf("parsing parameter '%s': %w", part, err) + } + params = append(params, param) + } + } + + return &phpFunction{ + Name: name, + Signature: signature, + Params: params, + ReturnType: returnType, + IsReturnNullable: isReturnNullable, + }, nil +} + +func (fp *FuncParser) parseParameter(paramStr string) (phpParameter, error) { + parts := strings.Split(paramStr, "=") + typePart := strings.TrimSpace(parts[0]) + + param := phpParameter{HasDefault: len(parts) > 1} + + if param.HasDefault { + param.DefaultValue = fp.sanitizeDefaultValue(strings.TrimSpace(parts[1])) + } + + matches := typeNameRegex.FindStringSubmatch(typePart) + + if len(matches) < 3 { + return phpParameter{}, fmt.Errorf("invalid parameter format: %s", paramStr) + } + + typeStr := strings.TrimSpace(matches[1]) + param.Name = strings.TrimSpace(matches[2]) + param.IsNullable = strings.HasPrefix(typeStr, "?") + param.PhpType = strings.TrimPrefix(typeStr, "?") + + return param, nil +} + +func (fp *FuncParser) sanitizeDefaultValue(value string) string { + if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + return value + } + if strings.ToLower(value) == "null" { + return "null" + } + + return strings.Trim(value, `'"`) +} diff --git a/internal/extgen/funcparser_test.go b/internal/extgen/funcparser_test.go new file mode 100644 index 00000000..3af5088f --- /dev/null +++ b/internal/extgen/funcparser_test.go @@ -0,0 +1,486 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFunctionParser(t *testing.T) { + tests := []struct { + name string + input string + expected int + }{ + { + name: "single function", + input: `package main + +//export_php:function testFunc(string $name): string +func testFunc(name *C.zend_string) unsafe.Pointer { + return String("Hello " + CStringToGoString(name)) +}`, + expected: 1, + }, + { + name: "multiple functions", + input: `package main + +//export_php:function func1(int $a): int +func func1(a int64) int64 { + return a * 2 +} + +//export_php:function func2(string $b): string +func func2(b *C.zend_string) unsafe.Pointer { + return String("processed: " + CStringToGoString(b)) +}`, + expected: 2, + }, + { + name: "no php functions", + input: `package main + +func regularFunc() { + // Just a regular Go function +}`, + expected: 0, + }, + { + name: "mixed functions", + input: `package main + +//export_php:function phpFunc(string $data): string +func phpFunc(data *C.zend_string) unsafe.Pointer { + return String("PHP: " + CStringToGoString(data)) +} + +func internalFunc() { + // Internal function without export_php comment +} + +//export_php:function anotherPhpFunc(int $num): int +func anotherPhpFunc(num int64) int64 { + return num * 10 +}`, + expected: 2, + }, + { + name: "wrong args syntax", + input: `package main + +//export_php function phpFunc(data string): string +func phpFunc(data *C.zend_string) unsafe.Pointer { + return String("PHP: " + CStringToGoString(data)) +}`, + expected: 0, + }, + { + name: "decoupled function names", + input: `package main + +//export_php:function my_php_function(string $name): string +func myGoFunction(name *C.zend_string) unsafe.Pointer { + return String("Hello " + CStringToGoString(name)) +} + +//export_php:function another_php_func(int $num): int +func someOtherGoName(num int64) int64 { + return num * 5 +}`, + expected: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(fileName, []byte(tt.input), 0644)) + + parser := NewFuncParserDefRegex() + functions, err := parser.parse(fileName) + require.NoError(t, err) + assert.Len(t, functions, tt.expected, "parse() got wrong number of functions") + + if tt.name == "single function" && len(functions) > 0 { + fn := functions[0] + assert.Equal(t, "testFunc", fn.Name, "Expected function name 'testFunc'") + assert.Equal(t, "string", fn.ReturnType, "Expected return type 'string'") + assert.Len(t, fn.Params, 1, "Expected 1 parameter") + if len(fn.Params) > 0 { + assert.Equal(t, "name", fn.Params[0].Name, "Expected parameter name 'name'") + } + } + + if tt.name == "decoupled function names" && len(functions) >= 2 { + fn1 := functions[0] + assert.Equal(t, "my_php_function", fn1.Name, "Expected PHP function name 'my_php_function'") + fn2 := functions[1] + assert.Equal(t, "another_php_func", fn2.Name, "Expected PHP function name 'another_php_func'") + } + }) + } +} + +func TestSignatureParsing(t *testing.T) { + tests := []struct { + name string + signature string + expectError bool + funcName string + paramCount int + returnType string + nullable bool + }{ + { + name: "simple function", + signature: "test(name string): string", + funcName: "test", + paramCount: 1, + returnType: "string", + nullable: false, + }, + { + name: "nullable return", + signature: "test(id int): ?string", + funcName: "test", + paramCount: 1, + returnType: "string", + nullable: true, + }, + { + name: "multiple params", + signature: "calculate(a int, b float, name string): float", + funcName: "calculate", + paramCount: 3, + returnType: "float", + nullable: false, + }, + { + name: "no parameters", + signature: "getValue(): int", + funcName: "getValue", + paramCount: 0, + returnType: "int", + nullable: false, + }, + { + name: "nullable parameters", + signature: "process(?string data, ?int count): bool", + funcName: "process", + paramCount: 2, + returnType: "bool", + nullable: false, + }, + { + name: "invalid signature", + signature: "invalid syntax here", + expectError: true, + }, + { + name: "missing return type", + signature: "test(name string)", + expectError: true, + }, + } + + parser := NewFuncParserDefRegex() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, err := parser.parseSignature(tt.signature) + + if tt.expectError { + assert.Error(t, err, "parseSignature() expected error but got none") + return + } + + assert.NoError(t, err, "parseSignature() unexpected error") + assert.Equal(t, tt.funcName, fn.Name, "parseSignature() name mismatch") + assert.Len(t, fn.Params, tt.paramCount, "parseSignature() param count mismatch") + assert.Equal(t, tt.returnType, fn.ReturnType, "parseSignature() return type mismatch") + assert.Equal(t, tt.nullable, fn.IsReturnNullable, "parseSignature() nullable mismatch") + + if tt.name == "nullable parameters" { + if len(fn.Params) >= 2 { + assert.True(t, fn.Params[0].IsNullable, "First parameter should be nullable") + assert.True(t, fn.Params[1].IsNullable, "Second parameter should be nullable") + } + } + }) + } +} + +func TestParameterParsing(t *testing.T) { + tests := []struct { + name string + paramStr string + expectedName string + expectedType string + expectedNullable bool + expectedDefault string + hasDefault bool + expectError bool + }{ + { + name: "simple string param", + paramStr: "string name", + expectedName: "name", + expectedType: "string", + }, + { + name: "nullable int param", + paramStr: "?int count", + expectedName: "count", + expectedType: "int", + expectedNullable: true, + }, + { + name: "param with default", + paramStr: "string message = 'hello'", + expectedName: "message", + expectedType: "string", + expectedDefault: "hello", + hasDefault: true, + }, + { + name: "int with default", + paramStr: "int limit = 10", + expectedName: "limit", + expectedType: "int", + expectedDefault: "10", + hasDefault: true, + }, + { + name: "nullable with default", + paramStr: "?string data = null", + expectedName: "data", + expectedType: "string", + expectedNullable: true, + expectedDefault: "null", + hasDefault: true, + }, + { + name: "invalid format", + paramStr: "invalid", + expectError: true, + }, + } + + parser := NewFuncParserDefRegex() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + param, err := parser.parseParameter(tt.paramStr) + + if tt.expectError { + assert.Error(t, err, "parseParameter() expected error but got none") + + return + } + + assert.NoError(t, err, "parseParameter() unexpected error") + assert.Equal(t, tt.expectedName, param.Name, "parseParameter() name mismatch") + assert.Equal(t, tt.expectedType, param.PhpType, "parseParameter() type mismatch") + assert.Equal(t, tt.expectedNullable, param.IsNullable, "parseParameter() nullable mismatch") + assert.Equal(t, tt.hasDefault, param.HasDefault, "parseParameter() hasDefault mismatch") + + if tt.hasDefault { + assert.Equal(t, tt.expectedDefault, param.DefaultValue, "parseParameter() defaultValue mismatch") + } + }) + } +} + +func TestFunctionParserUnsupportedTypes(t *testing.T) { + tests := []struct { + name string + input string + expected int + hasWarning bool + }{ + { + name: "function with array parameter should be rejected", + input: `package main + +//export_php:function arrayFunc(array $data): string +func arrayFunc(data interface{}) unsafe.Pointer { + return String("processed") +}`, + expected: 0, + hasWarning: true, + }, + { + name: "function with object parameter should be rejected", + input: `package main + +//export_php:function objectFunc(object $obj): string +func objectFunc(obj interface{}) unsafe.Pointer { + return String("processed") +}`, + expected: 0, + hasWarning: true, + }, + { + name: "function with mixed parameter should be rejected", + input: `package main + +//export_php:function mixedFunc(mixed $value): string +func mixedFunc(value interface{}) unsafe.Pointer { + return String("processed") +}`, + expected: 0, + hasWarning: true, + }, + { + name: "function with array return type should be rejected", + input: `package main + +//export_php:function arrayReturnFunc(string $name): array +func arrayReturnFunc(name *C.zend_string) interface{} { + return []string{"result"} +}`, + expected: 0, + hasWarning: true, + }, + { + name: "function with object return type should be rejected", + input: `package main + +//export_php:function objectReturnFunc(string $name): object +func objectReturnFunc(name *C.zend_string) interface{} { + return map[string]interface{}{"key": "value"} +}`, + expected: 0, + hasWarning: true, + }, + { + name: "valid scalar types should pass", + input: `package main + +//export_php:function validFunc(string $name, int $count, float $rate, bool $active): string +func validFunc(name *C.zend_string, count int64, rate float64, active bool) unsafe.Pointer { + return nil +}`, + expected: 1, + hasWarning: false, + }, + { + name: "valid void return should pass", + input: `package main + +//export_php:function voidFunc(string $message): void +func voidFunc(message *C.zend_string) { + // Do something +}`, + expected: 1, + hasWarning: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(tmpFile, []byte(tt.input), 0644)) + + parser := NewFuncParserDefRegex() + functions, err := parser.parse(tmpFile) + require.NoError(t, err) + + assert.Len(t, functions, tt.expected, "parse() got wrong number of functions") + }) + } +} + +func TestFunctionParserGoTypeMismatch(t *testing.T) { + tests := []struct { + name string + input string + expected int + hasWarning bool + }{ + { + name: "parameter count mismatch should be rejected", + input: `package main + +//export_php:function countMismatch(string $name, int $count): string +func countMismatch(name *C.zend_string) unsafe.Pointer { + return nil +}`, + expected: 0, + hasWarning: true, + }, + { + name: "parameter type mismatch should be rejected", + input: `package main + +//export_php:function typeMismatch(string $name, int $count): string +func typeMismatch(name *C.zend_string, count string) unsafe.Pointer { + return nil +}`, + expected: 0, + hasWarning: true, + }, + { + name: "return type mismatch should be rejected", + input: `package main + +//export_php:function returnMismatch(string $name): int +func returnMismatch(name *C.zend_string) string { + return "" +}`, + expected: 0, + hasWarning: true, + }, + { + name: "valid matching types should pass", + input: `package main + +//export_php:function validMatch(string $name, int $count): string +func validMatch(name *C.zend_string, count int64) unsafe.Pointer { + return nil +}`, + expected: 1, + hasWarning: false, + }, + { + name: "valid bool types should pass", + input: `package main + +//export_php:function validBool(bool $flag): bool +func validBool(flag bool) bool { + return flag +}`, + expected: 1, + hasWarning: false, + }, + { + name: "valid float types should pass", + input: `package main + +//export_php:function validFloat(float $value): float +func validFloat(value float64) float64 { + return value +}`, + expected: 1, + hasWarning: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + fileName := filepath.Join(tmpDir, tt.name+".go") + require.NoError(t, os.WriteFile(fileName, []byte(tt.input), 0644)) + + parser := NewFuncParserDefRegex() + functions, err := parser.parse(fileName) + require.NoError(t, err) + + assert.Len(t, functions, tt.expected, "parse() got wrong number of functions") + }) + } +} diff --git a/internal/extgen/generator.go b/internal/extgen/generator.go new file mode 100644 index 00000000..c728e61e --- /dev/null +++ b/internal/extgen/generator.go @@ -0,0 +1,137 @@ +package extgen + +import ( + "fmt" + "os" +) + +const BuildDir = "build" + +type Generator struct { + BaseName string + SourceFile string + BuildDir string + Functions []phpFunction + Classes []phpClass + Constants []phpConstant +} + +// EXPERIMENTAL +func (g *Generator) Generate() error { + if err := g.setupBuildDirectory(); err != nil { + return fmt.Errorf("setup build directory: %w", err) + } + if err := g.parseSource(); err != nil { + return fmt.Errorf("parse source: %w", err) + } + + if len(g.Functions) == 0 && len(g.Classes) == 0 && len(g.Constants) == 0 { + return fmt.Errorf("no PHP functions, classes, or constants found in source file") + } + + generators := []struct { + name string + fn func() error + }{ + {"stub file", g.generateStubFile}, + {"arginfo", g.generateArginfo}, + {"header file", g.generateHeaderFile}, + {"C file", g.generateCFile}, + {"Go file", g.generateGoFile}, + {"documentation", g.generateDocumentation}, + } + + for _, gen := range generators { + if err := gen.fn(); err != nil { + return err + } + } + + return nil +} + +func (g *Generator) setupBuildDirectory() error { + if err := os.RemoveAll(g.BuildDir); err != nil { + return fmt.Errorf("removing build directory: %w", err) + } + + return os.MkdirAll(g.BuildDir, 0755) +} + +func (g *Generator) parseSource() error { + parser := SourceParser{} + + functions, err := parser.ParseFunctions(g.SourceFile) + if err != nil { + return fmt.Errorf("parsing functions: %w", err) + } + g.Functions = functions + + classes, err := parser.ParseClasses(g.SourceFile) + if err != nil { + return fmt.Errorf("parsing classes: %w", err) + } + g.Classes = classes + + constants, err := parser.ParseConstants(g.SourceFile) + if err != nil { + return fmt.Errorf("parsing constants: %w", err) + } + g.Constants = constants + + return nil +} + +func (g *Generator) generateStubFile() error { + generator := StubGenerator{g} + if err := generator.generate(); err != nil { + return &GeneratorError{"stub generation", "failed to generate stub file", err} + } + + return nil +} + +func (g *Generator) generateArginfo() error { + generator := arginfoGenerator{generator: g} + if err := generator.generate(); err != nil { + return &GeneratorError{"arginfo generation", "failed to generate arginfo", err} + } + + return nil +} + +func (g *Generator) generateHeaderFile() error { + generator := HeaderGenerator{g} + if err := generator.generate(); err != nil { + return &GeneratorError{"header generation", "failed to generate header file", err} + } + + return nil +} + +func (g *Generator) generateCFile() error { + generator := cFileGenerator{g} + if err := generator.generate(); err != nil { + return &GeneratorError{"C file generation", "failed to generate C file", err} + } + + return nil +} + +func (g *Generator) generateGoFile() error { + generator := GoFileGenerator{g} + if err := generator.generate(); err != nil { + return &GeneratorError{"Go file generation", "failed to generate Go file", err} + } + + return nil +} + +func (g *Generator) generateDocumentation() error { + docGen := DocumentationGenerator{g} + if err := docGen.generate(); err != nil { + return &GeneratorError{"documentation generation", "failed to generate documentation", err} + } + + return nil +} diff --git a/internal/extgen/gofile.go b/internal/extgen/gofile.go new file mode 100644 index 00000000..ed450552 --- /dev/null +++ b/internal/extgen/gofile.go @@ -0,0 +1,192 @@ +package extgen + +import ( + "bytes" + _ "embed" + "fmt" + "path/filepath" + "strings" + "text/template" + + "github.com/Masterminds/sprig/v3" +) + +//go:embed templates/extension.go.tpl +var goFileContent string + +type GoFileGenerator struct { + generator *Generator +} + +type goTemplateData struct { + PackageName string + BaseName string + Imports []string + Constants []phpConstant + InternalFunctions []string + Functions []phpFunction + Classes []phpClass +} + +func (gg *GoFileGenerator) generate() error { + filename := filepath.Join(gg.generator.BuildDir, gg.generator.BaseName+".go") + content, err := gg.buildContent() + if err != nil { + return fmt.Errorf("building Go file content: %w", err) + } + + return WriteFile(filename, content) +} + +func (gg *GoFileGenerator) buildContent() (string, error) { + sourceAnalyzer := SourceAnalyzer{} + imports, internalFunctions, err := sourceAnalyzer.analyze(gg.generator.SourceFile) + if err != nil { + return "", fmt.Errorf("analyzing source file: %w", err) + } + + filteredImports := make([]string, 0, len(imports)) + for _, imp := range imports { + if imp != `"C"` { + filteredImports = append(filteredImports, imp) + } + } + + classes := make([]phpClass, len(gg.generator.Classes)) + copy(classes, gg.generator.Classes) + for i, class := range classes { + for j, method := range class.Methods { + classes[i].Methods[j].Wrapper = gg.generateMethodWrapper(method, class) + } + } + + templateContent, err := gg.getTemplateContent(goTemplateData{ + PackageName: SanitizePackageName(gg.generator.BaseName), + BaseName: gg.generator.BaseName, + Imports: filteredImports, + Constants: gg.generator.Constants, + InternalFunctions: internalFunctions, + Functions: gg.generator.Functions, + Classes: classes, + }) + + if err != nil { + return "", fmt.Errorf("executing template: %w", err) + } + + return templateContent, nil +} + +func (gg *GoFileGenerator) getTemplateContent(data goTemplateData) (string, error) { + tmpl := template.Must(template.New("gofile").Funcs(sprig.FuncMap()).Parse(goFileContent)) + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return "", err + } + + return buf.String(), nil +} + +func (gg *GoFileGenerator) generateMethodWrapper(method phpClassMethod, class phpClass) string { + var builder strings.Builder + + builder.WriteString(fmt.Sprintf("func %s_wrapper(handle C.uintptr_t", method.Name)) + + for _, param := range method.Params { + if param.PhpType == "string" { + builder.WriteString(fmt.Sprintf(", %s *C.zend_string", param.Name)) + + continue + } + + goType := gg.phpTypeToGoType(param.PhpType) + if param.IsNullable { + goType = "*" + goType + } + builder.WriteString(fmt.Sprintf(", %s %s", param.Name, goType)) + } + + if method.ReturnType != "void" { + if method.ReturnType == "string" { + builder.WriteString(") unsafe.Pointer {\n") + } else { + goReturnType := gg.phpTypeToGoType(method.ReturnType) + builder.WriteString(fmt.Sprintf(") %s {\n", goReturnType)) + } + } else { + builder.WriteString(") {\n") + } + + builder.WriteString(" obj := getGoObject(handle)\n") + builder.WriteString(" if obj == nil {\n") + if method.ReturnType != "void" { + if method.ReturnType == "string" { + builder.WriteString(" return nil\n") + } else { + builder.WriteString(fmt.Sprintf(" var zero %s\n", gg.phpTypeToGoType(method.ReturnType))) + builder.WriteString(" return zero\n") + } + } else { + builder.WriteString(" return\n") + } + builder.WriteString(" }\n") + builder.WriteString(fmt.Sprintf(" structObj := obj.(*%s)\n", class.GoStruct)) + + builder.WriteString(" ") + if method.ReturnType != "void" { + builder.WriteString("return ") + } + + builder.WriteString(fmt.Sprintf("structObj.%s(", gg.goMethodName(method.Name))) + + for i, param := range method.Params { + if i > 0 { + builder.WriteString(", ") + } + + builder.WriteString(param.Name) + } + + builder.WriteString(")\n") + builder.WriteString("}") + + return builder.String() +} + +type GoMethodSignature struct { + MethodName string + Params []GoParameter + ReturnType string +} + +type GoParameter struct { + Name string + Type string +} + +func (gg *GoFileGenerator) phpTypeToGoType(phpType string) string { + typeMap := map[string]string{ + "string": "string", + "int": "int64", + "float": "float64", + "bool": "bool", + "array": "[]interface{}", + "mixed": "interface{}", + "void": "", + } + + if goType, exists := typeMap[phpType]; exists { + return goType + } + + return "interface{}" +} + +func (gg *GoFileGenerator) goMethodName(phpMethodName string) string { + if len(phpMethodName) == 0 { + return phpMethodName + } + + return strings.ToUpper(phpMethodName[:1]) + phpMethodName[1:] +} diff --git a/internal/extgen/gofile_test.go b/internal/extgen/gofile_test.go new file mode 100644 index 00000000..c1510655 --- /dev/null +++ b/internal/extgen/gofile_test.go @@ -0,0 +1,564 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGoFileGenerator_Generate(t *testing.T) { + tmpDir := t.TempDir() + + sourceContent := `package main + +import ( + "fmt" + "strings" + "github.com/dunglas/frankenphp/internal/extensions/types" +) + +//export_php: greet(name string): string +func greet(name *go_string) *go_value { + return types.String("Hello " + CStringToGoString(name)) +} + +//export_php: calculate(a int, b int): int +func calculate(a long, b long) *go_value { + result := a + b + return types.Int(result) +} + +func internalHelper(data string) string { + return strings.ToUpper(data) +} + +func anotherHelper() { + fmt.Println("Internal helper") +}` + + sourceFile := filepath.Join(tmpDir, "test.go") + require.NoError(t, os.WriteFile(sourceFile, []byte(sourceContent), 0644)) + + generator := &Generator{ + BaseName: "test", + SourceFile: sourceFile, + BuildDir: tmpDir, + Functions: []phpFunction{ + { + Name: "greet", + ReturnType: "string", + GoFunction: `func greet(name *go_string) *go_value { + return types.String("Hello " + CStringToGoString(name)) +}`, + }, + { + Name: "calculate", + ReturnType: "int", + GoFunction: `func calculate(a long, b long) *go_value { + result := a + b + return types.Int(result) +}`, + }, + }, + } + + goGen := GoFileGenerator{generator} + require.NoError(t, goGen.generate()) + + expectedFile := filepath.Join(tmpDir, "test.go") + require.FileExists(t, expectedFile) + + content, err := ReadFile(expectedFile) + require.NoError(t, err) + + testGoFileBasicStructure(t, content, "test") + testGoFileImports(t, content) + testGoFileExportedFunctions(t, content, generator.Functions) + testGoFileInternalFunctions(t, content) +} + +func TestGoFileGenerator_BuildContent(t *testing.T) { + tests := []struct { + name string + baseName string + sourceFile string + functions []phpFunction + contains []string + notContains []string + }{ + { + name: "simple extension", + baseName: "simple", + sourceFile: createTempSourceFile(t, `package main + +//export_php: test(): void +func test() { + // simple function +}`), + functions: []phpFunction{ + { + Name: "test", + ReturnType: "void", + GoFunction: "func test() {\n\t// simple function\n}", + }, + }, + contains: []string{ + "package simple", + `#include "simple.h"`, + "import \"C\"", + "func init()", + "frankenphp.RegisterExtension(", + "//export test", + "func test()", + }, + }, + { + name: "extension with complex imports", + baseName: "complex", + sourceFile: createTempSourceFile(t, `package main + +import ( + "fmt" + "strings" + "encoding/json" + "github.com/dunglas/frankenphp/internal/extensions/types" +) + +//export_php: process(data string): string +func process(data *go_string) *go_value { + return types.String(fmt.Sprintf("processed: %s", CStringToGoString(data))) +}`), + functions: []phpFunction{ + { + Name: "process", + ReturnType: "string", + GoFunction: `func process(data *go_string) *go_value { + return String(fmt.Sprintf("processed: %s", CStringToGoString(data))) +}`, + }, + }, + contains: []string{ + "package complex", + `import "fmt"`, + `import "strings"`, + `import "encoding/json"`, + "//export process", + `import "C"`, + }, + }, + { + name: "extension with internal functions", + baseName: "internal", + sourceFile: createTempSourceFile(t, `package main + +//export_php: publicFunc(): void +func publicFunc() {} + +func internalFunc1() string { + return "internal" +} + +func internalFunc2(data string) { + // process data internally +}`), + functions: []phpFunction{ + { + Name: "publicFunc", + ReturnType: "void", + GoFunction: "func publicFunc() {}", + }, + }, + contains: []string{ + "func internalFunc1() string", + "func internalFunc2(data string)", + "//export publicFunc", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := &Generator{ + BaseName: tt.baseName, + SourceFile: tt.sourceFile, + Functions: tt.functions, + } + + goGen := GoFileGenerator{generator} + content, err := goGen.buildContent() + require.NoError(t, err) + + for _, expected := range tt.contains { + assert.Contains(t, content, expected, "Generated Go content should contain '%s'", expected) + } + }) + } +} + +func TestGoFileGenerator_PackageNameSanitization(t *testing.T) { + tests := []struct { + baseName string + expectedPackage string + }{ + {"simple", "simple"}, + {"my-extension", "my_extension"}, + {"ext.with.dots", "ext_with_dots"}, + {"123invalid", "_123invalid"}, + {"valid_name", "valid_name"}, + } + + for _, tt := range tests { + t.Run(tt.baseName, func(t *testing.T) { + sourceFile := createTempSourceFile(t, "package main\n//export_php: test(): void\nfunc test() {}") + + generator := &Generator{ + BaseName: tt.baseName, + SourceFile: sourceFile, + Functions: []phpFunction{ + {Name: "test", ReturnType: "void", GoFunction: "func test() {}"}, + }, + } + + goGen := GoFileGenerator{generator} + content, err := goGen.buildContent() + require.NoError(t, err) + + expectedPackage := "package " + tt.expectedPackage + assert.Contains(t, content, expectedPackage, "Generated content should contain '%s'", expectedPackage) + }) + } +} + +func TestGoFileGenerator_ErrorHandling(t *testing.T) { + tests := []struct { + name string + sourceFile string + expectErr bool + }{ + { + name: "nonexistent file", + sourceFile: "/nonexistent/file.go", + expectErr: true, + }, + { + name: "invalid Go syntax", + sourceFile: createTempSourceFile(t, "invalid go syntax here"), + expectErr: true, + }, + { + name: "valid file", + sourceFile: createTempSourceFile(t, "package main\nfunc test() {}"), + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := &Generator{ + BaseName: "test", + SourceFile: tt.sourceFile, + } + + goGen := GoFileGenerator{generator} + _, err := goGen.buildContent() + + if tt.expectErr { + assert.Error(t, err, "Expected error but got none") + } else { + assert.NoError(t, err, "Unexpected error") + } + }) + } +} + +func TestGoFileGenerator_ImportFiltering(t *testing.T) { + sourceContent := `package main + +import ( + "C" + "fmt" + "strings" + "github.com/dunglas/frankenphp/internal/extensions/types" + "github.com/other/package" + originalPkg "github.com/test/original" +) + +//export_php: test(): void +func test() {}` + + sourceFile := createTempSourceFile(t, sourceContent) + + generator := &Generator{ + BaseName: "importtest", + SourceFile: sourceFile, + Functions: []phpFunction{ + {Name: "test", ReturnType: "void", GoFunction: "func test() {}"}, + }, + } + + goGen := GoFileGenerator{generator} + content, err := goGen.buildContent() + require.NoError(t, err) + + expectedImports := []string{ + `import "fmt"`, + `import "strings"`, + `import "github.com/other/package"`, + } + + for _, imp := range expectedImports { + assert.Contains(t, content, imp, "Generated content should contain import: %s", imp) + } + + forbiddenImports := []string{ + `import "C"`, + } + + cImportCount := strings.Count(content, `import "C"`) + assert.Equal(t, 1, cImportCount, "Expected exactly 1 occurrence of 'import \"C\"'") + + for _, imp := range forbiddenImports[1:] { + assert.NotContains(t, content, imp, "Generated content should NOT contain import: %s", imp) + } +} + +func TestGoFileGenerator_ComplexScenario(t *testing.T) { + sourceContent := `package example + +import ( + "fmt" + "strings" + "encoding/json" + "github.com/dunglas/frankenphp/internal/extensions/types" +) + +//export_php: processData(input string, options array): array +func processData(input *go_string, options *go_nullable) *go_value { + data := CStringToGoString(input) + processed := internalProcess(data) + return types.Array([]interface{}{processed}) +} + +//export_php: validateInput(data string): bool +func validateInput(data *go_string) *go_value { + input := CStringToGoString(data) + isValid := len(input) > 0 && validateFormat(input) + return types.Bool(isValid) +} + +func internalProcess(data string) string { + return strings.ToUpper(data) +} + +func validateFormat(input string) bool { + return !strings.Contains(input, "invalid") +} + +func jsonHelper(data interface{}) ([]byte, error) { + return json.Marshal(data) +} + +func debugPrint(msg string) { + fmt.Printf("DEBUG: %s\n", msg) +}` + + sourceFile := createTempSourceFile(t, sourceContent) + + functions := []phpFunction{ + { + Name: "processData", + ReturnType: "array", + GoFunction: `func processData(input *go_string, options *go_nullable) *go_value { + data := CStringToGoString(input) + processed := internalProcess(data) + return Array([]interface{}{processed}) +}`, + }, + { + Name: "validateInput", + ReturnType: "bool", + GoFunction: `func validateInput(data *go_string) *go_value { + input := CStringToGoString(data) + isValid := len(input) > 0 && validateFormat(input) + return Bool(isValid) +}`, + }, + } + + generator := &Generator{ + BaseName: "complex-example", + SourceFile: sourceFile, + Functions: functions, + } + + goGen := GoFileGenerator{generator} + content, err := goGen.buildContent() + require.NoError(t, err) + assert.Contains(t, content, "package complex_example", "Package name should be sanitized") + + internalFuncs := []string{ + "func internalProcess(data string) string", + "func validateFormat(input string) bool", + "func jsonHelper(data interface{}) ([]byte, error)", + "func debugPrint(msg string)", + } + + for _, fn := range internalFuncs { + assert.Contains(t, content, fn, "Generated content should contain internal function: %s", fn) + } + + for _, fn := range functions { + exportDirective := "//export " + fn.Name + assert.Contains(t, content, exportDirective, "Generated content should contain export directive: %s", exportDirective) + } + + assert.False(t, strings.Contains(content, "types.Array") || strings.Contains(content, "types.Bool"), "Types should be replaced (types.* should not appear)") + assert.True(t, strings.Contains(content, "return Array(") && strings.Contains(content, "return Bool("), "Replaced types should appear without types prefix") +} + +func TestGoFileGenerator_MethodWrapperWithNullableParams(t *testing.T) { + tmpDir := t.TempDir() + + sourceContent := `package main + +import "fmt" + +//export_php:class TestClass +type TestStruct struct { + name string +} + +//export_php:method TestClass::processData(string $name, ?int $count, ?bool $enabled): string +func (ts *TestStruct) ProcessData(name string, count *int64, enabled *bool) string { + result := fmt.Sprintf("name=%s", name) + if count != nil { + result += fmt.Sprintf(", count=%d", *count) + } + if enabled != nil { + result += fmt.Sprintf(", enabled=%t", *enabled) + } + return result +}` + + sourceFile := filepath.Join(tmpDir, "test.go") + require.NoError(t, os.WriteFile(sourceFile, []byte(sourceContent), 0644)) + + methods := []phpClassMethod{ + { + Name: "ProcessData", + PhpName: "processData", + ClassName: "TestClass", + Signature: "processData(string $name, ?int $count, ?bool $enabled): string", + ReturnType: "string", + Params: []phpParameter{ + {Name: "name", PhpType: "string", IsNullable: false}, + {Name: "count", PhpType: "int", IsNullable: true}, + {Name: "enabled", PhpType: "bool", IsNullable: true}, + }, + GoFunction: `func (ts *TestStruct) ProcessData(name string, count *int64, enabled *bool) string { + result := fmt.Sprintf("name=%s", name) + if count != nil { + result += fmt.Sprintf(", count=%d", *count) + } + if enabled != nil { + result += fmt.Sprintf(", enabled=%t", *enabled) + } + return result +}`, + }, + } + + classes := []phpClass{ + { + Name: "TestClass", + GoStruct: "TestStruct", + Methods: methods, + }, + } + + generator := &Generator{ + BaseName: "nullable_test", + SourceFile: sourceFile, + Classes: classes, + BuildDir: tmpDir, + } + + goGen := GoFileGenerator{generator} + content, err := goGen.buildContent() + require.NoError(t, err) + + expectedWrapperSignature := "func ProcessData_wrapper(handle C.uintptr_t, name *C.zend_string, count *int64, enabled *bool)" + assert.Contains(t, content, expectedWrapperSignature, "Generated content should contain wrapper with nullable pointer types: %s", expectedWrapperSignature) + + expectedCall := "structObj.ProcessData(name, count, enabled)" + assert.Contains(t, content, expectedCall, "Generated content should contain correct method call: %s", expectedCall) + + exportDirective := "//export ProcessData_wrapper" + assert.Contains(t, content, exportDirective, "Generated content should contain export directive: %s", exportDirective) +} + +func createTempSourceFile(t *testing.T, content string) string { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "source.go") + + require.NoError(t, os.WriteFile(tmpFile, []byte(content), 0644)) + + return tmpFile +} + +func testGoFileBasicStructure(t *testing.T, content, baseName string) { + requiredElements := []string{ + "package " + SanitizePackageName(baseName), + "/*", + "#include ", + `#include "` + baseName + `.h"`, + "*/", + `import "C"`, + "func init() {", + "frankenphp.RegisterExtension(", + "}", + } + + for _, element := range requiredElements { + assert.Contains(t, content, element, "Go file should contain: %s", element) + } +} + +func testGoFileImports(t *testing.T, content string) { + cImportCount := strings.Count(content, `import "C"`) + assert.Equal(t, 1, cImportCount, "Expected exactly 1 C import") +} + +func testGoFileExportedFunctions(t *testing.T, content string, functions []phpFunction) { + for _, fn := range functions { + exportDirective := "//export " + fn.Name + assert.Contains(t, content, exportDirective, "Go file should contain export directive: %s", exportDirective) + + funcStart := "func " + fn.Name + "(" + assert.Contains(t, content, funcStart, "Go file should contain function definition: %s", funcStart) + } +} + +func testGoFileInternalFunctions(t *testing.T, content string) { + internalIndicators := []string{ + "func internalHelper", + "func anotherHelper", + } + + foundInternal := false + for _, indicator := range internalIndicators { + if strings.Contains(content, indicator) { + foundInternal = true + + break + } + } + + if !foundInternal { + t.Log("No internal functions found (this may be expected)") + } +} diff --git a/internal/extgen/hfile.go b/internal/extgen/hfile.go new file mode 100644 index 00000000..85371b75 --- /dev/null +++ b/internal/extgen/hfile.go @@ -0,0 +1,63 @@ +// header.go +package extgen + +import ( + "bytes" + _ "embed" + "path/filepath" + "strings" + "text/template" +) + +//go:embed templates/extension.h.tpl +var hFileContent string + +type HeaderGenerator struct { + generator *Generator +} + +type TemplateData struct { + HeaderGuard string + Constants []phpConstant + Classes []phpClass +} + +func (hg *HeaderGenerator) generate() error { + filename := filepath.Join(hg.generator.BuildDir, hg.generator.BaseName+".h") + content, err := hg.buildContent() + if err != nil { + return err + } + + return WriteFile(filename, content) +} + +func (hg *HeaderGenerator) buildContent() (string, error) { + headerGuard := strings.Map(func(r rune) rune { + if r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' { + return r + } + + return '_' + }, hg.generator.BaseName) + + headerGuard = strings.ToUpper(headerGuard) + "_H" + + tmpl, err := template.New("header").Parse(hFileContent) + if err != nil { + return "", err + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, TemplateData{ + HeaderGuard: headerGuard, + Constants: hg.generator.Constants, + Classes: hg.generator.Classes, + }) + + if err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/internal/extgen/hfile_test.go b/internal/extgen/hfile_test.go new file mode 100644 index 00000000..c6d3edab --- /dev/null +++ b/internal/extgen/hfile_test.go @@ -0,0 +1,334 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHeaderGenerator_Generate(t *testing.T) { + tmpDir := t.TempDir() + + generator := &Generator{ + BaseName: "test_extension", + BuildDir: tmpDir, + } + + headerGen := HeaderGenerator{generator} + require.NoError(t, headerGen.generate()) + + expectedFile := filepath.Join(tmpDir, "test_extension.h") + require.FileExists(t, expectedFile) + + content, err := ReadFile(expectedFile) + require.NoError(t, err) + + testHeaderBasicStructure(t, content, "test_extension") + testHeaderIncludeGuards(t, content, "TEST_EXTENSION_H") +} + +func TestHeaderGenerator_BuildContent(t *testing.T) { + tests := []struct { + name string + baseName string + contains []string + }{ + { + name: "simple extension", + baseName: "simple", + contains: []string{ + "#ifndef _SIMPLE_H", + "#define _SIMPLE_H", + "#include ", + "extern zend_module_entry ext_module_entry;", + "typedef struct go_value go_value;", + "typedef struct go_string {", + "size_t len;", + "char *data;", + "} go_string;", + "#endif", + }, + }, + { + name: "extension with hyphens", + baseName: "my-extension", + contains: []string{ + "#ifndef _MY_EXTENSION_H", + "#define _MY_EXTENSION_H", + "#endif", + }, + }, + { + name: "extension with underscores", + baseName: "my_extension_name", + contains: []string{ + "#ifndef _MY_EXTENSION_NAME_H", + "#define _MY_EXTENSION_NAME_H", + "#endif", + }, + }, + { + name: "complex extension name", + baseName: "complex.name-with_symbols", + contains: []string{ + "#ifndef _COMPLEX_NAME_WITH_SYMBOLS_H", + "#define _COMPLEX_NAME_WITH_SYMBOLS_H", + "#endif", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := &Generator{BaseName: tt.baseName} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + for _, expected := range tt.contains { + assert.Contains(t, content, expected, "Generated header content should contain '%s'", expected) + } + }) + } +} + +func TestHeaderGenerator_HeaderGuardGeneration(t *testing.T) { + tests := []struct { + baseName string + expectedGuard string + }{ + {"simple", "_SIMPLE_H"}, + {"my-extension", "_MY_EXTENSION_H"}, + {"complex.name", "_COMPLEX_NAME_H"}, + {"under_score", "_UNDER_SCORE_H"}, + {"MixedCase", "_MIXEDCASE_H"}, + {"123numeric", "_123NUMERIC_H"}, + {"special!@#chars", "_SPECIAL___CHARS_H"}, + } + + for _, tt := range tests { + t.Run(tt.baseName, func(t *testing.T) { + generator := &Generator{BaseName: tt.baseName} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + expectedIfndef := "#ifndef " + tt.expectedGuard + expectedDefine := "#define " + tt.expectedGuard + + assert.Contains(t, content, expectedIfndef, "Expected #ifndef %s, but not found in content", tt.expectedGuard) + assert.Contains(t, content, expectedDefine, "Expected #define %s, but not found in content", tt.expectedGuard) + }) + } +} + +func TestHeaderGenerator_BasicStructure(t *testing.T) { + generator := &Generator{BaseName: "structtest"} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + expectedElements := []string{ + "#include ", + "extern zend_module_entry ext_module_entry;", + "typedef struct go_value go_value;", + "typedef struct go_string {", + "size_t len;", + "char *data;", + "} go_string;", + } + + for _, element := range expectedElements { + assert.Contains(t, content, element, "Header should contain: %s", element) + } +} + +func TestHeaderGenerator_CompleteStructure(t *testing.T) { + generator := &Generator{BaseName: "complete_test"} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + lines := strings.Split(content, "\n") + + assert.GreaterOrEqual(t, len(lines), 5, "Header file should have multiple lines") + + var foundIfndef, foundDefine, foundEndif bool + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + if strings.HasPrefix(line, "#ifndef") && !foundIfndef { + foundIfndef = true + } else if strings.HasPrefix(line, "#define") && foundIfndef && !foundDefine { + foundDefine = true + } else if line == "#endif" { + foundEndif = true + } + } + + assert.True(t, foundIfndef, "Header should start with #ifndef guard") + assert.True(t, foundDefine, "Header should have #define after #ifndef") + assert.True(t, foundEndif, "Header should end with #endif") +} + +func TestHeaderGenerator_ErrorHandling(t *testing.T) { + generator := &Generator{ + BaseName: "test", + BuildDir: "/invalid/readonly/path", + } + + headerGen := HeaderGenerator{generator} + err := headerGen.generate() + assert.Error(t, err, "Expected error when writing to invalid directory") +} + +func TestHeaderGenerator_EmptyBaseName(t *testing.T) { + generator := &Generator{BaseName: ""} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + assert.Contains(t, content, "#ifndef __H", "Header with empty basename should have __H guard") + assert.Contains(t, content, "#define __H", "Header with empty basename should have __H define") +} + +func TestHeaderGenerator_ContentValidation(t *testing.T) { + generator := &Generator{BaseName: "validation_test"} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + assert.Equal(t, 1, strings.Count(content, "#ifndef"), "Header should have exactly one #ifndef") + assert.Equal(t, 1, strings.Count(content, "#define"), "Header should have exactly one #define") + assert.Equal(t, 1, strings.Count(content, "#endif"), "Header should have exactly one #endif") + assert.False(t, strings.Contains(content, "{{") || strings.Contains(content, "}}"), "Generated header contains unresolved template syntax") + assert.Contains(t, content, "typedef struct go_string {", "Header should contain go_string typedef") + assert.Contains(t, content, "size_t len;", "Header should contain len field in go_string") + assert.Contains(t, content, "char *data;", "Header should contain data field in go_string") +} + +func TestHeaderGenerator_SpecialCharacterHandling(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"normal", "NORMAL"}, + {"with-hyphens", "WITH_HYPHENS"}, + {"with.dots", "WITH_DOTS"}, + {"with_underscores", "WITH_UNDERSCORES"}, + {"MixedCASE", "MIXEDCASE"}, + {"123numbers", "123NUMBERS"}, + {"special!@#$%", "SPECIAL_____"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + generator := &Generator{BaseName: tt.input} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + expectedGuard := "_" + tt.expected + "_H" + expectedIfndef := "#ifndef " + expectedGuard + expectedDefine := "#define " + expectedGuard + + assert.Contains(t, content, expectedIfndef, "Expected #ifndef %s for input %s", expectedGuard, tt.input) + assert.Contains(t, content, expectedDefine, "Expected #define %s for input %s", expectedGuard, tt.input) + }) + } +} + +func TestHeaderGenerator_TemplateErrorHandling(t *testing.T) { + generator := &Generator{BaseName: "error_test"} + headerGen := HeaderGenerator{generator} + + _, err := headerGen.buildContent() + assert.NoError(t, err, "buildContent() should not fail with valid template") +} + +func TestHeaderGenerator_GuardConsistency(t *testing.T) { + baseName := "test_consistency" + generator := &Generator{BaseName: baseName} + headerGen := HeaderGenerator{generator} + + content1, err := headerGen.buildContent() + require.NoError(t, err, "First buildContent() failed: %v", err) + + content2, err := headerGen.buildContent() + require.NoError(t, err, "Second buildContent() failed: %v", err) + + assert.Equal(t, content1, content2, "Multiple calls to buildContent() should produce identical results") +} + +func TestHeaderGenerator_MinimalContent(t *testing.T) { + generator := &Generator{BaseName: "minimal"} + headerGen := HeaderGenerator{generator} + content, err := headerGen.buildContent() + require.NoError(t, err) + + essentialElements := []string{ + "#ifndef _MINIMAL_H", + "#define _MINIMAL_H", + "#include ", + "extern zend_module_entry ext_module_entry;", + "typedef struct go_value go_value;", + "#endif", + } + + for _, element := range essentialElements { + assert.Contains(t, content, element, "Minimal header should contain: %s", element) + } +} + +func testHeaderBasicStructure(t *testing.T, content, baseName string) { + headerGuard := strings.Map(func(r rune) rune { + if r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' { + return r + } + + return '_' + }, baseName) + headerGuard = strings.ToUpper(headerGuard) + "_H" + + requiredElements := []string{ + "#ifndef _" + headerGuard, + "#define _" + headerGuard, + "#include ", + "extern zend_module_entry ext_module_entry;", + "typedef struct go_value go_value;", + "typedef struct go_string {", + "size_t len;", + "char *data;", + "} go_string;", + "#endif", + } + + for _, element := range requiredElements { + assert.Contains(t, content, element, "Header file should contain: %s", element) + } +} + +func testHeaderIncludeGuards(t *testing.T, content, expectedGuard string) { + expectedIfndef := "#ifndef _" + expectedGuard + expectedDefine := "#define _" + expectedGuard + + assert.Contains(t, content, expectedIfndef, "Header should contain: %s", expectedIfndef) + assert.Contains(t, content, expectedDefine, "Header should contain: %s", expectedDefine) + assert.Contains(t, content, "#endif", "Header should end with #endif") + + ifndefPos := strings.Index(content, expectedIfndef) + definePos := strings.Index(content, expectedDefine) + + assert.Less(t, ifndefPos, definePos, "#ifndef should come before #define") + + endifPos := strings.LastIndex(content, "#endif") + assert.NotEqual(t, -1, endifPos, "Header should end with #endif") + assert.Greater(t, endifPos, definePos, "#endif should come after #define") +} diff --git a/internal/extgen/nodes.go b/internal/extgen/nodes.go new file mode 100644 index 00000000..b585089d --- /dev/null +++ b/internal/extgen/nodes.go @@ -0,0 +1,75 @@ +package extgen + +import ( + "strconv" + "strings" +) + +type phpFunction struct { + Name string + Signature string + GoFunction string + Params []phpParameter + ReturnType string + IsReturnNullable bool + lineNumber int +} + +type phpParameter struct { + Name string + PhpType string + IsNullable bool + DefaultValue string + HasDefault bool +} + +type phpClass struct { + Name string + GoStruct string + Properties []phpClassProperty + Methods []phpClassMethod +} + +type phpClassMethod struct { + Name string + PhpName string + Signature string + GoFunction string + Wrapper string + Params []phpParameter + ReturnType string + isReturnNullable bool + lineNumber int + ClassName string // used by the "//export_php:method" directive +} + +type phpClassProperty struct { + Name string + PhpType string + GoType string + IsNullable bool +} + +type phpConstant struct { + Name string + Value string + PhpType string // "int", "string", "bool", "float" + IsIota bool + lineNumber int + ClassName string // empty for global constants, set for class constants +} + +// CValue returns the constant value in C-compatible format +func (c phpConstant) CValue() string { + if c.PhpType != "int" { + return c.Value + } + + if strings.HasPrefix(c.Value, "0o") { + if val, err := strconv.ParseInt(c.Value, 0, 64); err == nil { + return strconv.FormatInt(val, 10) + } + } + + return c.Value +} diff --git a/internal/extgen/paramparser.go b/internal/extgen/paramparser.go new file mode 100644 index 00000000..9fa42119 --- /dev/null +++ b/internal/extgen/paramparser.go @@ -0,0 +1,178 @@ +package extgen + +import ( + "fmt" + "strings" +) + +type ParameterParser struct{} + +type ParameterInfo struct { + RequiredCount int + TotalCount int +} + +func (pp *ParameterParser) analyzeParameters(params []phpParameter) ParameterInfo { + info := ParameterInfo{TotalCount: len(params)} + + for _, param := range params { + if !param.HasDefault { + info.RequiredCount++ + } + } + + return info +} + +func (pp *ParameterParser) generateParamDeclarations(params []phpParameter) string { + if len(params) == 0 { + return "" + } + + var declarations []string + + for _, param := range params { + declarations = append(declarations, pp.generateSingleParamDeclaration(param)...) + } + + return " " + strings.Join(declarations, "\n ") +} + +func (pp *ParameterParser) generateSingleParamDeclaration(param phpParameter) []string { + var decls []string + + switch param.PhpType { + case "string": + decls = append(decls, fmt.Sprintf("zend_string *%s = NULL;", param.Name)) + if param.IsNullable { + decls = append(decls, fmt.Sprintf("zend_bool %s_is_null = 0;", param.Name)) + } + case "int": + defaultVal := pp.getDefaultValue(param, "0") + decls = append(decls, fmt.Sprintf("zend_long %s = %s;", param.Name, defaultVal)) + if param.IsNullable { + decls = append(decls, fmt.Sprintf("zend_bool %s_is_null = 0;", param.Name)) + } + case "float": + defaultVal := pp.getDefaultValue(param, "0.0") + decls = append(decls, fmt.Sprintf("double %s = %s;", param.Name, defaultVal)) + if param.IsNullable { + decls = append(decls, fmt.Sprintf("zend_bool %s_is_null = 0;", param.Name)) + } + case "bool": + defaultVal := pp.getDefaultValue(param, "0") + if param.HasDefault && param.DefaultValue == "true" { + defaultVal = "1" + } + decls = append(decls, fmt.Sprintf("zend_bool %s = %s;", param.Name, defaultVal)) + if param.IsNullable { + decls = append(decls, fmt.Sprintf("zend_bool %s_is_null = 0;", param.Name)) + } + } + + return decls +} + +func (pp *ParameterParser) getDefaultValue(param phpParameter, fallback string) string { + if !param.HasDefault || param.DefaultValue == "" { + return fallback + } + return param.DefaultValue +} + +func (pp *ParameterParser) generateParamParsing(params []phpParameter, requiredCount int) string { + if len(params) == 0 { + return ` if (zend_parse_parameters_none() == FAILURE) { + RETURN_THROWS(); + }` + } + + var builder strings.Builder + builder.WriteString(fmt.Sprintf(" ZEND_PARSE_PARAMETERS_START(%d, %d)", requiredCount, len(params))) + + optionalStarted := false + for _, param := range params { + if param.HasDefault && !optionalStarted { + builder.WriteString("\n Z_PARAM_OPTIONAL") + optionalStarted = true + } + + builder.WriteString(pp.generateParamParsingMacro(param)) + } + + builder.WriteString("\n ZEND_PARSE_PARAMETERS_END();") + return builder.String() +} + +func (pp *ParameterParser) generateParamParsingMacro(param phpParameter) string { + if param.IsNullable { + switch param.PhpType { + case "string": + return fmt.Sprintf("\n Z_PARAM_STR_OR_NULL(%s, %s_is_null)", param.Name, param.Name) + case "int": + return fmt.Sprintf("\n Z_PARAM_LONG_OR_NULL(%s, %s_is_null)", param.Name, param.Name) + case "float": + return fmt.Sprintf("\n Z_PARAM_DOUBLE_OR_NULL(%s, %s_is_null)", param.Name, param.Name) + case "bool": + return fmt.Sprintf("\n Z_PARAM_BOOL_OR_NULL(%s, %s_is_null)", param.Name, param.Name) + default: + return "" + } + } else { + switch param.PhpType { + case "string": + return fmt.Sprintf("\n Z_PARAM_STR(%s)", param.Name) + case "int": + return fmt.Sprintf("\n Z_PARAM_LONG(%s)", param.Name) + case "float": + return fmt.Sprintf("\n Z_PARAM_DOUBLE(%s)", param.Name) + case "bool": + return fmt.Sprintf("\n Z_PARAM_BOOL(%s)", param.Name) + default: + return "" + } + } +} + +func (pp *ParameterParser) generateGoCallParams(params []phpParameter) string { + if len(params) == 0 { + return "" + } + + var goParams []string + for _, param := range params { + goParams = append(goParams, pp.generateSingleGoCallParam(param)) + } + + return strings.Join(goParams, ", ") +} + +func (pp *ParameterParser) generateSingleGoCallParam(param phpParameter) string { + if param.IsNullable { + switch param.PhpType { + case "string": + return fmt.Sprintf("%s_is_null ? NULL : %s", param.Name, param.Name) + case "int": + return fmt.Sprintf("%s_is_null ? NULL : &%s", param.Name, param.Name) + case "float": + return fmt.Sprintf("%s_is_null ? NULL : &%s", param.Name, param.Name) + case "bool": + return fmt.Sprintf("%s_is_null ? NULL : &%s", param.Name, param.Name) + default: + return param.Name + } + } else { + switch param.PhpType { + case "string": + return param.Name + case "int": + return fmt.Sprintf("(long) %s", param.Name) + case "float": + return fmt.Sprintf("(double) %s", param.Name) + case "bool": + return fmt.Sprintf("(int) %s", param.Name) + default: + return param.Name + } + } +} diff --git a/internal/extgen/paramparser_test.go b/internal/extgen/paramparser_test.go new file mode 100644 index 00000000..254b9646 --- /dev/null +++ b/internal/extgen/paramparser_test.go @@ -0,0 +1,500 @@ +package extgen + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParameterParser_AnalyzeParameters(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + params []phpParameter + expected ParameterInfo + }{ + { + name: "no parameters", + params: []phpParameter{}, + expected: ParameterInfo{ + RequiredCount: 0, + TotalCount: 0, + }, + }, + { + name: "all required parameters", + params: []phpParameter{ + {Name: "name", PhpType: "string", HasDefault: false}, + {Name: "count", PhpType: "int", HasDefault: false}, + }, + expected: ParameterInfo{ + RequiredCount: 2, + TotalCount: 2, + }, + }, + { + name: "mixed required and optional parameters", + params: []phpParameter{ + {Name: "name", PhpType: "string", HasDefault: false}, + {Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "10"}, + {Name: "enabled", PhpType: "bool", HasDefault: true, DefaultValue: "true"}, + }, + expected: ParameterInfo{ + RequiredCount: 1, + TotalCount: 3, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.analyzeParameters(tt.params) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_GenerateParamDeclarations(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + params []phpParameter + expected string + }{ + { + name: "no parameters", + params: []phpParameter{}, + expected: "", + }, + { + name: "string parameter", + params: []phpParameter{ + {Name: "message", PhpType: "string", HasDefault: false}, + }, + expected: " zend_string *message = NULL;", + }, + { + name: "nullable string parameter", + params: []phpParameter{ + {Name: "message", PhpType: "string", HasDefault: false, IsNullable: true}, + }, + expected: " zend_string *message = NULL;\n zend_bool message_is_null = 0;", + }, + { + name: "int parameter with default", + params: []phpParameter{ + {Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "42"}, + }, + expected: " zend_long count = 42;", + }, + { + name: "nullable int parameter", + params: []phpParameter{ + {Name: "count", PhpType: "int", HasDefault: false, IsNullable: true}, + }, + expected: " zend_long count = 0;\n zend_bool count_is_null = 0;", + }, + { + name: "bool parameter with true default", + params: []phpParameter{ + {Name: "enabled", PhpType: "bool", HasDefault: true, DefaultValue: "true"}, + }, + expected: " zend_bool enabled = 1;", + }, + { + name: "nullable bool parameter", + params: []phpParameter{ + {Name: "enabled", PhpType: "bool", HasDefault: false, IsNullable: true}, + }, + expected: " zend_bool enabled = 0;\n zend_bool enabled_is_null = 0;", + }, + { + name: "float parameter", + params: []phpParameter{ + {Name: "ratio", PhpType: "float", HasDefault: false}, + }, + expected: " double ratio = 0.0;", + }, + { + name: "nullable float parameter", + params: []phpParameter{ + {Name: "ratio", PhpType: "float", HasDefault: false, IsNullable: true}, + }, + expected: " double ratio = 0.0;\n zend_bool ratio_is_null = 0;", + }, + { + name: "multiple parameters", + params: []phpParameter{ + {Name: "name", PhpType: "string", HasDefault: false}, + {Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "10"}, + }, + expected: " zend_string *name = NULL;\n zend_long count = 10;", + }, + { + name: "mixed nullable and non-nullable parameters", + params: []phpParameter{ + {Name: "name", PhpType: "string", HasDefault: false, IsNullable: false}, + {Name: "count", PhpType: "int", HasDefault: false, IsNullable: true}, + }, + expected: " zend_string *name = NULL;\n zend_long count = 0;\n zend_bool count_is_null = 0;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.generateParamDeclarations(tt.params) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_GenerateParamParsing(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + params []phpParameter + requiredCount int + expected string + }{ + { + name: "no parameters", + params: []phpParameter{}, + requiredCount: 0, + expected: ` if (zend_parse_parameters_none() == FAILURE) { + RETURN_THROWS(); + }`, + }, + { + name: "single required string parameter", + params: []phpParameter{ + {Name: "message", PhpType: "string", HasDefault: false}, + }, + requiredCount: 1, + expected: ` ZEND_PARSE_PARAMETERS_START(1, 1) + Z_PARAM_STR(message) + ZEND_PARSE_PARAMETERS_END();`, + }, + { + name: "mixed required and optional parameters", + params: []phpParameter{ + {Name: "name", PhpType: "string", HasDefault: false}, + {Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "10"}, + {Name: "enabled", PhpType: "bool", HasDefault: true, DefaultValue: "true"}, + }, + requiredCount: 1, + expected: ` ZEND_PARSE_PARAMETERS_START(1, 3) + Z_PARAM_STR(name) + Z_PARAM_OPTIONAL + Z_PARAM_LONG(count) + Z_PARAM_BOOL(enabled) + ZEND_PARSE_PARAMETERS_END();`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.generateParamParsing(tt.params, tt.requiredCount) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_GenerateGoCallParams(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + params []phpParameter + expected string + }{ + { + name: "no parameters", + params: []phpParameter{}, + expected: "", + }, + { + name: "single string parameter", + params: []phpParameter{ + {Name: "message", PhpType: "string"}, + }, + expected: "message", + }, + { + name: "multiple parameters of different types", + params: []phpParameter{ + {Name: "name", PhpType: "string"}, + {Name: "count", PhpType: "int"}, + {Name: "ratio", PhpType: "float"}, + {Name: "enabled", PhpType: "bool"}, + }, + expected: "name, (long) count, (double) ratio, (int) enabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.generateGoCallParams(tt.params) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_GenerateParamParsingMacro(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + param phpParameter + expected string + }{ + { + name: "string parameter", + param: phpParameter{Name: "message", PhpType: "string"}, + expected: "\n Z_PARAM_STR(message)", + }, + { + name: "nullable string parameter", + param: phpParameter{Name: "message", PhpType: "string", IsNullable: true}, + expected: "\n Z_PARAM_STR_OR_NULL(message, message_is_null)", + }, + { + name: "int parameter", + param: phpParameter{Name: "count", PhpType: "int"}, + expected: "\n Z_PARAM_LONG(count)", + }, + { + name: "nullable int parameter", + param: phpParameter{Name: "count", PhpType: "int", IsNullable: true}, + expected: "\n Z_PARAM_LONG_OR_NULL(count, count_is_null)", + }, + { + name: "float parameter", + param: phpParameter{Name: "ratio", PhpType: "float"}, + expected: "\n Z_PARAM_DOUBLE(ratio)", + }, + { + name: "nullable float parameter", + param: phpParameter{Name: "ratio", PhpType: "float", IsNullable: true}, + expected: "\n Z_PARAM_DOUBLE_OR_NULL(ratio, ratio_is_null)", + }, + { + name: "bool parameter", + param: phpParameter{Name: "enabled", PhpType: "bool"}, + expected: "\n Z_PARAM_BOOL(enabled)", + }, + { + name: "nullable bool parameter", + param: phpParameter{Name: "enabled", PhpType: "bool", IsNullable: true}, + expected: "\n Z_PARAM_BOOL_OR_NULL(enabled, enabled_is_null)", + }, + { + name: "unknown type", + param: phpParameter{Name: "unknown", PhpType: "unknown"}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.generateParamParsingMacro(tt.param) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_GetDefaultValue(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + param phpParameter + fallback string + expected string + }{ + { + name: "parameter without default", + param: phpParameter{Name: "count", PhpType: "int", HasDefault: false}, + fallback: "0", + expected: "0", + }, + { + name: "parameter with default value", + param: phpParameter{Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "42"}, + fallback: "0", + expected: "42", + }, + { + name: "parameter with empty default value", + param: phpParameter{Name: "count", PhpType: "int", HasDefault: true, DefaultValue: ""}, + fallback: "0", + expected: "0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.getDefaultValue(tt.param, tt.fallback) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_GenerateSingleGoCallParam(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + param phpParameter + expected string + }{ + { + name: "string parameter", + param: phpParameter{Name: "message", PhpType: "string"}, + expected: "message", + }, + { + name: "nullable string parameter", + param: phpParameter{Name: "message", PhpType: "string", IsNullable: true}, + expected: "message_is_null ? NULL : message", + }, + { + name: "int parameter", + param: phpParameter{Name: "count", PhpType: "int"}, + expected: "(long) count", + }, + { + name: "nullable int parameter", + param: phpParameter{Name: "count", PhpType: "int", IsNullable: true}, + expected: "count_is_null ? NULL : &count", + }, + { + name: "float parameter", + param: phpParameter{Name: "ratio", PhpType: "float"}, + expected: "(double) ratio", + }, + { + name: "nullable float parameter", + param: phpParameter{Name: "ratio", PhpType: "float", IsNullable: true}, + expected: "ratio_is_null ? NULL : &ratio", + }, + { + name: "bool parameter", + param: phpParameter{Name: "enabled", PhpType: "bool"}, + expected: "(int) enabled", + }, + { + name: "nullable bool parameter", + param: phpParameter{Name: "enabled", PhpType: "bool", IsNullable: true}, + expected: "enabled_is_null ? NULL : &enabled", + }, + { + name: "unknown type", + param: phpParameter{Name: "unknown", PhpType: "unknown"}, + expected: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.generateSingleGoCallParam(tt.param) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_GenerateSingleParamDeclaration(t *testing.T) { + pp := &ParameterParser{} + + tests := []struct { + name string + param phpParameter + expected []string + }{ + { + name: "string parameter", + param: phpParameter{Name: "message", PhpType: "string", HasDefault: false}, + expected: []string{"zend_string *message = NULL;"}, + }, + { + name: "nullable string parameter", + param: phpParameter{Name: "message", PhpType: "string", HasDefault: false, IsNullable: true}, + expected: []string{"zend_string *message = NULL;", "zend_bool message_is_null = 0;"}, + }, + { + name: "int parameter with default", + param: phpParameter{Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "42"}, + expected: []string{"zend_long count = 42;"}, + }, + { + name: "nullable int parameter", + param: phpParameter{Name: "count", PhpType: "int", HasDefault: false, IsNullable: true}, + expected: []string{"zend_long count = 0;", "zend_bool count_is_null = 0;"}, + }, + { + name: "bool parameter with true default", + param: phpParameter{Name: "enabled", PhpType: "bool", HasDefault: true, DefaultValue: "true"}, + expected: []string{"zend_bool enabled = 1;"}, + }, + { + name: "nullable bool parameter", + param: phpParameter{Name: "enabled", PhpType: "bool", HasDefault: false, IsNullable: true}, + expected: []string{"zend_bool enabled = 0;", "zend_bool enabled_is_null = 0;"}, + }, + { + name: "bool parameter with false default", + param: phpParameter{Name: "disabled", PhpType: "bool", HasDefault: true, DefaultValue: "false"}, + expected: []string{"zend_bool disabled = false;"}, + }, + { + name: "float parameter", + param: phpParameter{Name: "ratio", PhpType: "float", HasDefault: false}, + expected: []string{"double ratio = 0.0;"}, + }, + { + name: "nullable float parameter", + param: phpParameter{Name: "ratio", PhpType: "float", HasDefault: false, IsNullable: true}, + expected: []string{"double ratio = 0.0;", "zend_bool ratio_is_null = 0;"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pp.generateSingleParamDeclaration(tt.param) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterParser_Integration(t *testing.T) { + pp := &ParameterParser{} + + params := []phpParameter{ + {Name: "name", PhpType: "string", HasDefault: false}, + {Name: "count", PhpType: "int", HasDefault: true, DefaultValue: "10"}, + {Name: "enabled", PhpType: "bool", HasDefault: true, DefaultValue: "true"}, + } + + info := pp.analyzeParameters(params) + assert.Equal(t, 1, info.RequiredCount) + assert.Equal(t, 3, info.TotalCount) + + declarations := pp.generateParamDeclarations(params) + expectedDeclarations := []string{ + "zend_string *name = NULL;", + "zend_long count = 10;", + "zend_bool enabled = 1;", + } + for _, expected := range expectedDeclarations { + assert.Contains(t, declarations, expected) + } + + parsing := pp.generateParamParsing(params, info.RequiredCount) + assert.Contains(t, parsing, "ZEND_PARSE_PARAMETERS_START(1, 3)") + assert.Contains(t, parsing, "Z_PARAM_OPTIONAL") + + goCallParams := pp.generateGoCallParams(params) + assert.Equal(t, "name, (long) count, (int) enabled", goCallParams) +} diff --git a/internal/extgen/parser.go b/internal/extgen/parser.go new file mode 100644 index 00000000..f6cb70a4 --- /dev/null +++ b/internal/extgen/parser.go @@ -0,0 +1,21 @@ +package extgen + +type SourceParser struct{} + +// EXPERIMENTAL +func (p *SourceParser) ParseFunctions(filename string) ([]phpFunction, error) { + functionParser := NewFuncParserDefRegex() + return functionParser.parse(filename) +} + +// EXPERIMENTAL +func (p *SourceParser) ParseClasses(filename string) ([]phpClass, error) { + classParser := classParser{} + return classParser.parse(filename) +} + +// EXPERIMENTAL +func (p *SourceParser) ParseConstants(filename string) ([]phpConstant, error) { + constantParser := NewConstantParserWithDefRegex() + return constantParser.parse(filename) +} diff --git a/internal/extgen/phpfunc.go b/internal/extgen/phpfunc.go new file mode 100644 index 00000000..f369eacf --- /dev/null +++ b/internal/extgen/phpfunc.go @@ -0,0 +1,82 @@ +package extgen + +import ( + "fmt" + "strings" +) + +type PHPFuncGenerator struct { + paramParser *ParameterParser +} + +func (pfg *PHPFuncGenerator) generate(fn phpFunction) string { + var builder strings.Builder + + paramInfo := pfg.paramParser.analyzeParameters(fn.Params) + + builder.WriteString(fmt.Sprintf("PHP_FUNCTION(%s)\n{\n", fn.Name)) + + if decl := pfg.paramParser.generateParamDeclarations(fn.Params); decl != "" { + builder.WriteString(decl + "\n") + } + + builder.WriteString(pfg.paramParser.generateParamParsing(fn.Params, paramInfo.RequiredCount) + "\n") + + builder.WriteString(pfg.generateGoCall(fn) + "\n") + + if returnCode := pfg.generateReturnCode(fn.ReturnType); returnCode != "" { + builder.WriteString(returnCode + "\n") + } + + builder.WriteString("}\n\n") + + return builder.String() +} + +func (pfg *PHPFuncGenerator) generateGoCall(fn phpFunction) string { + callParams := pfg.paramParser.generateGoCallParams(fn.Params) + + if fn.ReturnType == "void" { + return fmt.Sprintf(" %s(%s);", fn.Name, callParams) + } + + if fn.ReturnType == "string" { + return fmt.Sprintf(" zend_string *result = %s(%s);", fn.Name, callParams) + } + + return fmt.Sprintf(" %s result = %s(%s);", pfg.getCReturnType(fn.ReturnType), fn.Name, callParams) +} + +func (pfg *PHPFuncGenerator) getCReturnType(returnType string) string { + switch returnType { + case "string": + return "zend_string*" + case "int": + return "long" + case "float": + return "double" + case "bool": + return "int" + default: + return "void" + } +} + +func (pfg *PHPFuncGenerator) generateReturnCode(returnType string) string { + switch returnType { + case "string": + return ` if (result) { + RETURN_STR(result); + } else { + RETURN_EMPTY_STRING(); + }` + case "int": + return ` RETURN_LONG(result);` + case "float": + return ` RETURN_DOUBLE(result);` + case "bool": + return ` RETURN_BOOL(result);` + default: + return "" + } +} diff --git a/internal/extgen/phpfunc_test.go b/internal/extgen/phpfunc_test.go new file mode 100644 index 00000000..03281eee --- /dev/null +++ b/internal/extgen/phpfunc_test.go @@ -0,0 +1,335 @@ +package extgen + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPHPFunctionGenerator_Generate(t *testing.T) { + tests := []struct { + name string + function phpFunction + contains []string // Strings that should be present in the output + }{ + { + name: "simple string function", + function: phpFunction{ + Name: "greet", + ReturnType: "string", + Params: []phpParameter{ + {Name: "name", PhpType: "string"}, + }, + }, + contains: []string{ + "PHP_FUNCTION(greet)", + "zend_string *name = NULL;", + "Z_PARAM_STR(name)", + "zend_string *result = greet(name);", + "RETURN_STR(result)", + }, + }, + { + name: "function with default parameter", + function: phpFunction{ + Name: "calculate", + ReturnType: "int", + Params: []phpParameter{ + {Name: "base", PhpType: "int"}, + {Name: "multiplier", PhpType: "int", HasDefault: true, DefaultValue: "2"}, + }, + }, + contains: []string{ + "PHP_FUNCTION(calculate)", + "zend_long base = 0;", + "zend_long multiplier = 2;", + "ZEND_PARSE_PARAMETERS_START(1, 2)", + "Z_PARAM_OPTIONAL", + "Z_PARAM_LONG(base)", + "Z_PARAM_LONG(multiplier)", + }, + }, + { + name: "void function", + function: phpFunction{ + Name: "doSomething", + ReturnType: "void", + Params: []phpParameter{ + {Name: "action", PhpType: "string"}, + }, + }, + contains: []string{ + "PHP_FUNCTION(doSomething)", + "doSomething(action);", + }, + }, + { + name: "bool function with default", + function: phpFunction{ + Name: "isEnabled", + ReturnType: "bool", + Params: []phpParameter{ + {Name: "flag", PhpType: "bool", HasDefault: true, DefaultValue: "true"}, + }, + }, + contains: []string{ + "PHP_FUNCTION(isEnabled)", + "zend_bool flag = 1;", + "Z_PARAM_BOOL(flag)", + "RETURN_BOOL(result)", + }, + }, + { + name: "float function", + function: phpFunction{ + Name: "calculate", + ReturnType: "float", + Params: []phpParameter{ + {Name: "value", PhpType: "float"}, + }, + }, + contains: []string{ + "PHP_FUNCTION(calculate)", + "double value = 0.0;", + "Z_PARAM_DOUBLE(value)", + "RETURN_DOUBLE(result)", + }, + }, + } + + generator := PHPFuncGenerator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := generator.generate(tt.function) + + for _, expected := range tt.contains { + assert.Contains(t, result, expected, "Generated code should contain '%s'", expected) + } + + assert.True(t, strings.HasPrefix(result, "PHP_FUNCTION("), "Generated code should start with PHP_FUNCTION") + assert.True(t, strings.HasSuffix(strings.TrimSpace(result), "}"), "Generated code should end with closing brace") + }) + } +} + +func TestPHPFunctionGenerator_GenerateParamDeclarations(t *testing.T) { + tests := []struct { + name string + params []phpParameter + contains []string + }{ + { + name: "string parameter", + params: []phpParameter{ + {Name: "message", PhpType: "string"}, + }, + contains: []string{ + "zend_string *message = NULL;", + }, + }, + { + name: "int parameter", + params: []phpParameter{ + {Name: "count", PhpType: "int"}, + }, + contains: []string{ + "zend_long count = 0;", + }, + }, + { + name: "bool with default", + params: []phpParameter{ + {Name: "enabled", PhpType: "bool", HasDefault: true, DefaultValue: "true"}, + }, + contains: []string{ + "zend_bool enabled = 1;", + }, + }, + { + name: "float parameter with default", + params: []phpParameter{ + {Name: "rate", PhpType: "float", HasDefault: true, DefaultValue: "1.5"}, + }, + contains: []string{ + "double rate = 1.5;", + }, + }, + } + + parser := ParameterParser{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parser.generateParamDeclarations(tt.params) + + for _, expected := range tt.contains { + assert.Contains(t, result, expected, "phpParameter declarations should contain '%s'", expected) + } + }) + } +} + +func TestPHPFunctionGenerator_GenerateReturnCode(t *testing.T) { + tests := []struct { + name string + returnType string + contains []string + }{ + { + name: "string return", + returnType: "string", + contains: []string{ + "RETURN_STR(result)", + "RETURN_EMPTY_STRING()", + }, + }, + { + name: "int return", + returnType: "int", + contains: []string{ + "RETURN_LONG(result)", + }, + }, + { + name: "bool return", + returnType: "bool", + contains: []string{ + "RETURN_BOOL(result)", + }, + }, + { + name: "float return", + returnType: "float", + contains: []string{ + "RETURN_DOUBLE(result)", + }, + }, + { + name: "void return", + returnType: "void", + contains: []string{}, + }, + } + + generator := PHPFuncGenerator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := generator.generateReturnCode(tt.returnType) + + if len(tt.contains) == 0 { + assert.Empty(t, result, "Return code should be empty for void") + return + } + + for _, expected := range tt.contains { + assert.Contains(t, result, expected, "Return code should contain '%s'", expected) + } + }) + } +} + +func TestPHPFunctionGenerator_GenerateGoCallParams(t *testing.T) { + tests := []struct { + name string + params []phpParameter + expected string + }{ + { + name: "no parameters", + params: []phpParameter{}, + expected: "", + }, + { + name: "simple string parameter", + params: []phpParameter{ + {Name: "message", PhpType: "string"}, + }, + expected: "message", + }, + { + name: "int parameter", + params: []phpParameter{ + {Name: "count", PhpType: "int"}, + }, + expected: "(long) count", + }, + { + name: "multiple parameters", + params: []phpParameter{ + {Name: "name", PhpType: "string"}, + {Name: "age", PhpType: "int"}, + }, + expected: "name, (long) age", + }, + { + name: "bool and float parameters", + params: []phpParameter{ + {Name: "enabled", PhpType: "bool"}, + {Name: "rate", PhpType: "float"}, + }, + expected: "(int) enabled, (double) rate", + }, + } + + parser := ParameterParser{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parser.generateGoCallParams(tt.params) + + assert.Equal(t, tt.expected, result, "generateGoCallParams() mismatch") + }) + } +} + +func TestPHPFunctionGenerator_AnalyzeParameters(t *testing.T) { + tests := []struct { + name string + params []phpParameter + expectedReq int + expectedTotal int + }{ + { + name: "no parameters", + params: []phpParameter{}, + expectedReq: 0, + expectedTotal: 0, + }, + { + name: "all required", + params: []phpParameter{ + {Name: "a", PhpType: "string"}, + {Name: "b", PhpType: "int"}, + }, + expectedReq: 2, + expectedTotal: 2, + }, + { + name: "mixed required and optional", + params: []phpParameter{ + {Name: "required", PhpType: "string"}, + {Name: "optional", PhpType: "int", HasDefault: true, DefaultValue: "10"}, + }, + expectedReq: 1, + expectedTotal: 2, + }, + { + name: "all optional", + params: []phpParameter{ + {Name: "opt1", PhpType: "string", HasDefault: true, DefaultValue: "hello"}, + {Name: "opt2", PhpType: "int", HasDefault: true, DefaultValue: "0"}, + }, + expectedReq: 0, + expectedTotal: 2, + }, + } + + parser := ParameterParser{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := parser.analyzeParameters(tt.params) + + assert.Equal(t, tt.expectedReq, info.RequiredCount, "analyzeParameters() RequiredCount mismatch") + assert.Equal(t, tt.expectedTotal, info.TotalCount, "analyzeParameters() TotalCount mismatch") + }) + } +} diff --git a/internal/extgen/srcanalyzer.go b/internal/extgen/srcanalyzer.go new file mode 100644 index 00000000..2177e64a --- /dev/null +++ b/internal/extgen/srcanalyzer.go @@ -0,0 +1,104 @@ +package extgen + +import ( + "fmt" + "go/parser" + "go/token" + "os" + "strings" +) + +type SourceAnalyzer struct{} + +func (sa *SourceAnalyzer) analyze(filename string) (imports []string, internalFunctions []string, err error) { + fset := token.NewFileSet() + node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + return nil, nil, fmt.Errorf("parsing file: %w", err) + } + + for _, imp := range node.Imports { + if imp.Path != nil { + importPath := imp.Path.Value + if imp.Name != nil { + imports = append(imports, fmt.Sprintf("%s %s", imp.Name.Name, importPath)) + } else { + imports = append(imports, importPath) + } + } + } + + sourceContent, err := os.ReadFile(filename) + if err != nil { + return nil, nil, fmt.Errorf("reading source file: %w", err) + } + + internalFunctions = sa.extractInternalFunctions(string(sourceContent)) + + return imports, internalFunctions, nil +} + +func (sa *SourceAnalyzer) extractInternalFunctions(content string) []string { + lines := strings.Split(content, "\n") + var ( + functions []string + currentFunc strings.Builder + inFunction, hasPHPFunc bool + braceCount int + ) + + for i, line := range lines { + trimmedLine := strings.TrimSpace(line) + + if strings.HasPrefix(trimmedLine, "func ") && !inFunction { + inFunction = true + braceCount = 0 + hasPHPFunc = false + currentFunc.Reset() + + // look backwards for export_php comment + for j := i - 1; j >= 0 && j >= i-5; j-- { + prevLine := strings.TrimSpace(lines[j]) + if prevLine == "" { + continue + } + + if strings.Contains(prevLine, "export_php:") { + hasPHPFunc = true + + break + } + + if !strings.HasPrefix(prevLine, "//") { + break + } + } + } + + if inFunction { + currentFunc.WriteString(line + "\n") + + for _, char := range line { + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + } + } + + if braceCount == 0 && strings.Contains(line, "}") { + funcContent := currentFunc.String() + + if !hasPHPFunc { + functions = append(functions, strings.TrimSpace(funcContent)) + } + + inFunction = false + currentFunc.Reset() + } + } + } + + return functions +} diff --git a/internal/extgen/srcanalyzer_test.go b/internal/extgen/srcanalyzer_test.go new file mode 100644 index 00000000..fc649c04 --- /dev/null +++ b/internal/extgen/srcanalyzer_test.go @@ -0,0 +1,398 @@ +package extgen + +import ( + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSourceAnalyzer_Analyze(t *testing.T) { + tests := []struct { + name string + sourceContent string + expectedImports []string + expectedFunctions []string + expectError bool + }{ + { + name: "simple file with imports and functions", + sourceContent: `package main + +import ( + "fmt" + "strings" +) + +func regularFunction() { + fmt.Println("hello") +} + +//export_php:function +func exportedFunction() string { + return "exported" +}`, + expectedImports: []string{`"fmt"`, `"strings"`}, + expectedFunctions: []string{ + `func regularFunction() { + fmt.Println("hello") +}`, + }, + expectError: false, + }, + { + name: "file with named imports", + sourceContent: `package main + +import ( + custom "fmt" + . "strings" + _ "os" +) + +func test() {}`, + expectedImports: []string{`custom "fmt"`, `. "strings"`, `_ "os"`}, + expectedFunctions: []string{ + `func test() {}`, + }, + expectError: false, + }, + { + name: "file with multiple functions and export comments", + sourceContent: `package main + +func internalOne() { + // some code +} + +// This function is exported to PHP +//export_php:function +func exportedOne() int { + return 42 +} + +func internalTwo() string { + return "internal" +} + +// Another exported function +//export_php:function +func exportedTwo() bool { + return true +}`, + expectedImports: []string{}, + expectedFunctions: []string{ + `func internalOne() { + // some code +}`, + `func internalTwo() string { + return "internal" +}`, + }, + expectError: false, + }, + { + name: "file with nested braces", + sourceContent: `package main + +func complexFunction() { + if true { + for i := 0; i < 10; i++ { + if i%2 == 0 { + fmt.Println(i) + } + } + } +} + +//export_php:function +func exportedComplex() { + obj := struct{ + field string + }{ + field: "value", + } + fmt.Println(obj) +}`, + expectedImports: []string{}, + expectedFunctions: []string{ + `func complexFunction() { + if true { + for i := 0; i < 10; i++ { + if i%2 == 0 { + fmt.Println(i) + } + } + } +}`, + }, + expectError: false, + }, + { + name: "empty file", + sourceContent: `package main`, + expectedImports: []string{}, + expectedFunctions: []string{}, + expectError: false, + }, + { + name: "file with only exported functions", + sourceContent: `package main + +//export_php:function +func onlyExported() {} + +//export_php:function +func anotherExported() string { + return "test" +}`, + expectedImports: []string{}, + expectedFunctions: []string{}, + expectError: false, + }, + { + name: "file with export comment not immediately before function", + sourceContent: `package main + +//export_php:function +// Some other comment +func shouldNotBeExported() {} + +func normalFunction() { + //export_php:function inside function should not count +}`, + expectedImports: []string{}, + expectedFunctions: []string{ + `func normalFunction() { + //export_php:function inside function should not count +}`, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + filename := filepath.Join(tempDir, "test.go") + + require.NoError(t, os.WriteFile(filename, []byte(tt.sourceContent), 0644)) + + analyzer := &SourceAnalyzer{} + imports, functions, err := analyzer.analyze(filename) + + if tt.expectError { + assert.Error(t, err, "expected error") + return + } + + assert.NoError(t, err, "unexpected error") + + if len(imports) != 0 && len(tt.expectedImports) != 0 { + assert.Equal(t, tt.expectedImports, imports, "imports mismatch") + } + + assert.Len(t, functions, len(tt.expectedFunctions), "function count mismatch") + + for i, expected := range tt.expectedFunctions { + assert.Equal(t, expected, functions[i], "function %d mismatch", i) + } + }) + } +} + +func TestSourceAnalyzer_Analyze_InvalidFile(t *testing.T) { + analyzer := &SourceAnalyzer{} + + t.Run("nonexistent file", func(t *testing.T) { + _, _, err := analyzer.analyze("/nonexistent/file.go") + assert.Error(t, err, "expected error for nonexistent file") + }) + + t.Run("invalid Go syntax", func(t *testing.T) { + tempDir := t.TempDir() + filename := filepath.Join(tempDir, "invalid.go") + + invalidContent := `package main + func incomplete( { + // invalid syntax + ` + + require.NoError(t, os.WriteFile(filename, []byte(invalidContent), 0644)) + + _, _, err := analyzer.analyze(filename) + assert.Error(t, err, "expected error for invalid syntax") + }) +} + +func TestSourceAnalyzer_ExtractInternalFunctions(t *testing.T) { + tests := []struct { + name string + content string + expected []string + }{ + { + name: "single function without export", + content: `func test() { + fmt.Println("test") +}`, + expected: []string{ + `func test() { + fmt.Println("test") +}`, + }, + }, + { + name: "function with export comment", + content: `//export_php:function +func exported() {}`, + expected: []string{}, + }, + { + name: "mixed functions", + content: `func internal() {} + +//export_php:function +func exported() {} + +func anotherInternal() { + return "test" +}`, + expected: []string{ + "func internal() {}", + `func anotherInternal() { + return "test" +}`, + }, + }, + { + name: "export comment with spacing", + content: `//export_php:function +func exported1() {} + +//export_php:function +func exported2() {} + +// export_php:function +func exported3() {}`, + expected: []string{}, + }, + { + name: "complex function with nested braces", + content: `func complex() { + if true { + for { + switch x { + case 1: + { + // nested block + } + } + } + } +}`, + expected: []string{ + `func complex() { + if true { + for { + switch x { + case 1: + { + // nested block + } + } + } + } +}`, + }, + }, + { + name: "empty content", + content: "", + expected: []string{}, + }, + { + name: "no functions", + content: `package main + +import "fmt" + +var x = 10`, + expected: []string{}, + }, + } + + analyzer := &SourceAnalyzer{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := analyzer.extractInternalFunctions(tt.content) + + assert.Len(t, result, len(tt.expected), "function count mismatch") + + for i, expected := range tt.expected { + assert.Equal(t, expected, result[i], "function %d mismatch", i) + } + }) + } +} + +func BenchmarkSourceAnalyzer_Analyze(b *testing.B) { + content := `package main + +import ( + "fmt" + "strings" + "os" +) + +func internalOne() { + fmt.Println("test") +} + +//export_php:function +func exported() string { + return "exported" +} + +func internalTwo() { + for i := 0; i < 100; i++ { + if i%2 == 0 { + fmt.Println(i) + } + } +}` + + tempDir := b.TempDir() + filename := filepath.Join(tempDir, "bench.go") + + require.NoError(b, os.WriteFile(filename, []byte(content), 0644)) + + analyzer := &SourceAnalyzer{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := analyzer.analyze(filename) + require.NoError(b, err) + } +} + +func BenchmarkSourceAnalyzer_ExtractInternalFunctions(b *testing.B) { + content := `func test1() { fmt.Println("1") } +func test2() { fmt.Println("2") } +//export_php:function +func exported() {} +func test3() { + for i := 0; i < 10; i++ { + fmt.Println(i) + } +}` + + analyzer := &SourceAnalyzer{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + analyzer.extractInternalFunctions(content) + } +} diff --git a/internal/extgen/stub.go b/internal/extgen/stub.go new file mode 100644 index 00000000..3a34dad6 --- /dev/null +++ b/internal/extgen/stub.go @@ -0,0 +1,51 @@ +package extgen + +import ( + _ "embed" + "path/filepath" + "strings" + "text/template" +) + +//go:embed templates/stub.php.tpl +var templateContent string + +type StubGenerator struct { + Generator *Generator +} + +func (sg *StubGenerator) generate() error { + filename := filepath.Join(sg.Generator.BuildDir, sg.Generator.BaseName+".stub.php") + content, err := sg.buildContent() + if err != nil { + return err + } + + return WriteFile(filename, content) +} + +func (sg *StubGenerator) buildContent() (string, error) { + tmpl, err := template.New("stub.php.tpl").Funcs(template.FuncMap{ + "phpType": getPhpTypeAnnotation, + }).Parse(templateContent) + if err != nil { + return "", err + } + + var buf strings.Builder + if err := tmpl.Execute(&buf, sg.Generator); err != nil { + return "", err + } + + return buf.String(), nil +} + +// getPhpTypeAnnotation converts Go constant type to PHP type annotation +func getPhpTypeAnnotation(goType string) string { + switch goType { + case "string", "bool", "float", "int": + return goType + default: + return "int" + } +} diff --git a/internal/extgen/stub_test.go b/internal/extgen/stub_test.go new file mode 100644 index 00000000..4ec52885 --- /dev/null +++ b/internal/extgen/stub_test.go @@ -0,0 +1,612 @@ +package extgen + +import ( + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStubGenerator_Generate(t *testing.T) { + tmpDir := t.TempDir() + + generator := &Generator{ + BaseName: "test_extension", + BuildDir: tmpDir, + Functions: []phpFunction{ + { + Name: "greet", + Signature: "greet(string $name): string", + Params: []phpParameter{ + {Name: "name", PhpType: "string"}, + }, + ReturnType: "string", + }, + { + Name: "calculate", + Signature: "calculate(int $a, int $b): int", + Params: []phpParameter{ + {Name: "a", PhpType: "int"}, + {Name: "b", PhpType: "int"}, + }, + ReturnType: "int", + }, + }, + Classes: []phpClass{ + { + Name: "User", + GoStruct: "UserStruct", + }, + }, + Constants: []phpConstant{ + { + Name: "GLOBAL_CONST", + Value: "42", + PhpType: "int", + }, + { + Name: "USER_STATUS_ACTIVE", + Value: "1", + PhpType: "int", + ClassName: "User", + }, + }, + } + + stubGen := StubGenerator{generator} + assert.NoError(t, stubGen.generate(), "generate() failed") + + expectedFile := filepath.Join(tmpDir, "test_extension.stub.php") + assert.FileExists(t, expectedFile, "Expected stub file was not created: %s", expectedFile) + + content, err := ReadFile(expectedFile) + assert.NoError(t, err, "Failed to read generated stub file") + + testStubBasicStructure(t, content) + testStubFunctions(t, content, generator.Functions) + testStubClasses(t, content, generator.Classes) + testStubConstants(t, content, generator.Constants) +} + +func TestStubGenerator_BuildContent(t *testing.T) { + tests := []struct { + name string + functions []phpFunction + classes []phpClass + constants []phpConstant + contains []string + }{ + { + name: "empty extension", + functions: []phpFunction{}, + classes: []phpClass{}, + constants: []phpConstant{}, + contains: []string{ + " 0 { + assert.Equal(t, " +#include +#include + +#include "{{.BaseName}}.h" +#include "{{.BaseName}}_arginfo.h" +#include "_cgo_export.h" + +{{- if .Classes}} + +static zend_object_handlers object_handlers_{{.BaseName}}; + +typedef struct { + uintptr_t go_handle; + char* class_name; + zend_object std; /* This MUST be the last struct field to memory alignement problems */ +} {{.BaseName}}_object; + +static inline {{.BaseName}}_object *{{.BaseName}}_object_from_obj(zend_object *obj) { + return ({{.BaseName}}_object*)((char*)(obj) - offsetof({{.BaseName}}_object, std)); +} + +static zend_object *{{.BaseName}}_create_object(zend_class_entry *ce) { + {{.BaseName}}_object *intern = ecalloc(1, sizeof({{.BaseName}}_object) + zend_object_properties_size(ce)); + + zend_object_std_init(&intern->std, ce); + object_properties_init(&intern->std, ce); + + intern->std.handlers = &object_handlers_{{.BaseName}}; + intern->go_handle = 0; /* will be set in __construct */ + intern->class_name = estrdup(ZSTR_VAL(ce->name)); + + return &intern->std; +} + +static void {{.BaseName}}_free_object(zend_object *object) { + {{.BaseName}}_object *intern = {{.BaseName}}_object_from_obj(object); + + if (intern->class_name) { + efree(intern->class_name); + } + + if (intern->go_handle != 0) { + removeGoObject(intern->go_handle); + } + + zend_object_std_dtor(&intern->std); +} + +static zend_function *{{.BaseName}}_get_method(zend_object **object, zend_string *method, const zval *key) { + return zend_std_get_method(object, method, key); +} + +void init_object_handlers() { + memcpy(&object_handlers_{{.BaseName}}, &std_object_handlers, sizeof(zend_object_handlers)); + object_handlers_{{.BaseName}}.get_method = {{.BaseName}}_get_method; + object_handlers_{{.BaseName}}.free_obj = {{.BaseName}}_free_object; + object_handlers_{{.BaseName}}.offset = offsetof({{.BaseName}}_object, std); +} +{{- end}} +{{ range .Classes}} +static zend_class_entry *{{.Name}}_ce = NULL; + +PHP_METHOD({{.Name}}, __construct) { + if (zend_parse_parameters_none() == FAILURE) { + RETURN_THROWS(); + } + + {{$.BaseName}}_object *intern = {{$.BaseName}}_object_from_obj(Z_OBJ_P(ZEND_THIS)); + + intern->go_handle = create_{{.GoStruct}}_object(); +} + +{{ range .Methods}} +PHP_METHOD({{.ClassName}}, {{.PhpName}}) { + {{$.BaseName}}_object *intern = {{$.BaseName}}_object_from_obj(Z_OBJ_P(ZEND_THIS)); + + if (intern->go_handle == 0) { + zend_throw_error(NULL, "Go object not found in registry"); + RETURN_THROWS(); + } + + {{- if .Params -}} + {{range $i, $param := .Params -}} + {{- if eq $param.PhpType "string"}} + zend_string *{{$param.Name}} = NULL;{{if $param.IsNullable}} + zend_bool {{$param.Name}}_is_null = 0;{{end}} + {{- else if eq $param.PhpType "int"}} + zend_long {{$param.Name}} = {{if $param.HasDefault}}{{$param.DefaultValue}}{{else}}0{{end}};{{if $param.IsNullable}} + zend_bool {{$param.Name}}_is_null = 0;{{end}} + {{- else if eq $param.PhpType "float"}} + double {{$param.Name}} = {{if $param.HasDefault}}{{$param.DefaultValue}}{{else}}0.0{{end}};{{if $param.IsNullable}} + zend_bool {{$param.Name}}_is_null = 0;{{end}} + {{- else if eq $param.PhpType "bool"}} + zend_bool {{$param.Name}} = {{if $param.HasDefault}}{{if eq $param.DefaultValue "true"}}1{{else}}0{{end}}{{else}}0{{end}};{{if $param.IsNullable}} + zend_bool {{$param.Name}}_is_null = 0;{{end}} + {{- end}} + {{- end}} + + {{$requiredCount := 0}}{{range .Params}}{{if not .HasDefault}}{{$requiredCount = add1 $requiredCount}}{{end}}{{end -}} + ZEND_PARSE_PARAMETERS_START({{$requiredCount}}, {{len .Params}}); + {{$optionalStarted := false}}{{range .Params}}{{if .HasDefault}}{{if not $optionalStarted -}} + Z_PARAM_OPTIONAL + {{$optionalStarted = true}}{{end}}{{end -}} + {{if .IsNullable}}{{if eq .PhpType "string"}}Z_PARAM_STR_OR_NULL({{.Name}}, {{.Name}}_is_null){{else if eq .PhpType "int"}}Z_PARAM_LONG_OR_NULL({{.Name}}, {{.Name}}_is_null){{else if eq .PhpType "float"}}Z_PARAM_DOUBLE_OR_NULL({{.Name}}, {{.Name}}_is_null){{else if eq .PhpType "bool"}}Z_PARAM_BOOL_OR_NULL({{.Name}}, {{.Name}}_is_null){{end}}{{else}}{{if eq .PhpType "string"}}Z_PARAM_STR({{.Name}}){{else if eq .PhpType "int"}}Z_PARAM_LONG({{.Name}}){{else if eq .PhpType "float"}}Z_PARAM_DOUBLE({{.Name}}){{else if eq .PhpType "bool"}}Z_PARAM_BOOL({{.Name}}){{end}}{{end}} + {{end -}} + ZEND_PARSE_PARAMETERS_END(); + {{else}} + if (zend_parse_parameters_none() == FAILURE) { + RETURN_THROWS(); + } + {{end}} + + {{- if ne .ReturnType "void"}} + {{- if eq .ReturnType "string"}} + zend_string* result = {{.Name}}_wrapper(intern->go_handle{{if .Params}}{{range .Params}}, {{if .IsNullable}}{{if eq .PhpType "string"}}{{.Name}}_is_null ? NULL : {{.Name}}{{else if eq .PhpType "int"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "float"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "bool"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{end}}{{else}}{{.Name}}{{end}}{{end}}{{end}}); + RETURN_STR(result); + {{- else if eq .ReturnType "int"}} + zend_long result = {{.Name}}_wrapper(intern->go_handle{{if .Params}}{{range .Params}}, {{if .IsNullable}}{{if eq .PhpType "string"}}{{.Name}}_is_null ? NULL : {{.Name}}{{else if eq .PhpType "int"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "float"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "bool"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{end}}{{else}}(long){{.Name}}{{end}}{{end}}{{end}}); + RETURN_LONG(result); + {{- else if eq .ReturnType "float"}} + double result = {{.Name}}_wrapper(intern->go_handle{{if .Params}}{{range .Params}}, {{if .IsNullable}}{{if eq .PhpType "string"}}{{.Name}}_is_null ? NULL : {{.Name}}{{else if eq .PhpType "int"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "float"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "bool"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{end}}{{else}}(double){{.Name}}{{end}}{{end}}{{end}}); + RETURN_DOUBLE(result); + {{- else if eq .ReturnType "bool"}} + int result = {{.Name}}_wrapper(intern->go_handle{{if .Params}}{{range .Params}}, {{if .IsNullable}}{{if eq .PhpType "string"}}{{.Name}}_is_null ? NULL : {{.Name}}{{else if eq .PhpType "int"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "float"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "bool"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{end}}{{else}}(int){{.Name}}{{end}}{{end}}{{end}}); + RETURN_BOOL(result); + {{- end}} + {{- else}} + {{.Name}}_wrapper(intern->go_handle{{if .Params}}{{range .Params}}, {{if .IsNullable}}{{if eq .PhpType "string"}}{{.Name}}_is_null ? NULL : {{.Name}}{{else if eq .PhpType "int"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "float"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{else if eq .PhpType "bool"}}{{.Name}}_is_null ? NULL : &{{.Name}}{{end}}{{else}}{{if eq .PhpType "string"}}{{.Name}}{{else if eq .PhpType "int"}}(long){{.Name}}{{else if eq .PhpType "float"}}(double){{.Name}}{{else if eq .PhpType "bool"}}(int){{.Name}}{{end}}{{end}}{{end}}{{end}}); + {{- end}} +} +{{end}}{{end}} + +{{- if .Classes}} +void register_all_classes() { + init_object_handlers(); + + {{- range .Classes}} + {{.Name}}_ce = register_class_{{.Name}}(); + if (!{{.Name}}_ce) { + php_error_docref(NULL, E_ERROR, "Failed to register class {{.Name}}"); + return; + } + {{.Name}}_ce->create_object = {{$.BaseName}}_create_object; + {{- end}} +} +{{- end}} + +PHP_MINIT_FUNCTION({{.BaseName}}) { + {{ if .Classes}}register_all_classes();{{end}} + + {{- range .Constants}} + {{- if eq .ClassName ""}} + {{if .IsIota}}REGISTER_LONG_CONSTANT("{{.Name}}", {{.Name}}, CONST_CS | CONST_PERSISTENT); + {{else if eq .PhpType "string"}}REGISTER_STRING_CONSTANT("{{.Name}}", {{.CValue}}, CONST_CS | CONST_PERSISTENT); + {{else if eq .PhpType "bool"}}REGISTER_LONG_CONSTANT("{{.Name}}", {{if eq .Value "true"}}1{{else}}0{{end}}, CONST_CS | CONST_PERSISTENT); + {{else if eq .PhpType "float"}}REGISTER_DOUBLE_CONSTANT("{{.Name}}", {{.CValue}}, CONST_CS | CONST_PERSISTENT); + {{else}}REGISTER_LONG_CONSTANT("{{.Name}}", {{.CValue}}, CONST_CS | CONST_PERSISTENT); + {{- end}} + {{- end}} + {{- end}} + return SUCCESS; +} + +zend_module_entry {{.BaseName}}_module_entry = {STANDARD_MODULE_HEADER, + "{{.BaseName}}", + ext_functions, /* Functions */ + PHP_MINIT({{.BaseName}}), /* MINIT */ + NULL, /* MSHUTDOWN */ + NULL, /* RINIT */ + NULL, /* RSHUTDOWN */ + NULL, /* MINFO */ + "1.0.0", /* Version */ + STANDARD_MODULE_PROPERTIES}; + diff --git a/internal/extgen/templates/extension.go.tpl b/internal/extgen/templates/extension.go.tpl new file mode 100644 index 00000000..f1f00555 --- /dev/null +++ b/internal/extgen/templates/extension.go.tpl @@ -0,0 +1,75 @@ +package {{.PackageName}} + +/* +#include +#include "{{.BaseName}}.h" +*/ +import "C" +import "runtime/cgo" +{{- range .Imports}} +import {{.}} +{{- end}} + +func init() { + frankenphp.RegisterExtension(unsafe.Pointer(&C.ext_module_entry)) +} +{{range .Constants}} +const {{.Name}} = {{.Value}} +{{- end}} +{{range .InternalFunctions}} +{{.}} +{{- end}} + +{{- range .Functions}} +//export {{.Name}} +{{.GoFunction}} +{{- end}} + +{{- range .Classes}} +type {{.GoStruct}} struct { +{{- range .Properties}} + {{.Name}} {{.GoType}} +{{- end}} +} +{{- end}} + +{{- if .Classes}} + +//export registerGoObject +func registerGoObject(obj interface{}) C.uintptr_t { + handle := cgo.NewHandle(obj) + return C.uintptr_t(handle) +} + +//export getGoObject +func getGoObject(handle C.uintptr_t) interface{} { + h := cgo.Handle(handle) + return h.value() +} + +//export removeGoObject +func removeGoObject(handle C.uintptr_t) { + h := cgo.Handle(handle) + h.Delete() +} + +{{- end}} + +{{- range .Classes}} +//export create_{{.GoStruct}}_object +func create_{{.GoStruct}}_object() C.uintptr_t { + obj := &{{.GoStruct}}{} + return registerGoObject(obj) +} + +{{- range .Methods}} +{{- if .GoFunction}} +{{.GoFunction}} +{{- end}} +{{- end}} + +{{- range .Methods}} +//export {{.Name}}_wrapper +{{.Wrapper}} +{{end}} +{{- end}} diff --git a/internal/extgen/templates/extension.h.tpl b/internal/extgen/templates/extension.h.tpl new file mode 100644 index 00000000..49a55e9f --- /dev/null +++ b/internal/extgen/templates/extension.h.tpl @@ -0,0 +1,20 @@ +#ifndef _{{.HeaderGuard}} +#define _{{.HeaderGuard}} + +#include +#include + +extern zend_module_entry ext_module_entry; + +typedef struct go_value go_value; + +typedef struct go_string { + size_t len; + char *data; +} go_string; + +{{if .Constants}} +/* User defined constants */{{end}} +{{range .Constants}}#define {{.Name}} {{.CValue}} +{{end}} +#endif diff --git a/internal/extgen/templates/stub.php.tpl b/internal/extgen/templates/stub.php.tpl new file mode 100644 index 00000000..9c50d177 --- /dev/null +++ b/internal/extgen/templates/stub.php.tpl @@ -0,0 +1,37 @@ + 0 && !unicode.IsLetter(rune(sanitized[0])) && sanitized[0] != '_' { + sanitized = "_" + sanitized + } + + return sanitized +} diff --git a/internal/extgen/utils_test.go b/internal/extgen/utils_test.go new file mode 100644 index 00000000..756d9290 --- /dev/null +++ b/internal/extgen/utils_test.go @@ -0,0 +1,242 @@ +package extgen + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteFile(t *testing.T) { + tests := []struct { + name string + filename string + content string + expectError bool + }{ + { + name: "write simple file", + filename: "test.txt", + content: "hello world", + expectError: false, + }, + { + name: "write empty file", + filename: "empty.txt", + content: "", + expectError: false, + }, + { + name: "write file with special characters", + filename: "special.txt", + content: "hello\nworld\t!@#$%^&*()", + expectError: false, + }, + { + name: "write to invalid directory", + filename: "/nonexistent/directory/file.txt", + content: "test", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var filename string + if !tt.expectError { + tempDir := t.TempDir() + filename = filepath.Join(tempDir, tt.filename) + } else { + filename = tt.filename + } + + err := WriteFile(filename, tt.content) + + if tt.expectError { + assert.Error(t, err, "WriteFile() should return an error") + return + } + + assert.NoError(t, err, "WriteFile() should not return an error") + + content, err := os.ReadFile(filename) + assert.NoError(t, err, "Failed to read written file") + assert.Equal(t, tt.content, string(content), "WriteFile() content mismatch") + + info, err := os.Stat(filename) + assert.NoError(t, err, "Failed to stat file") + + expectedMode := os.FileMode(0644) + assert.Equal(t, expectedMode, info.Mode().Perm(), "WriteFile() wrong permissions") + }) + } +} + +func TestReadFile(t *testing.T) { + tests := []struct { + name string + content string + expectError bool + }{ + { + name: "read simple file", + content: "hello world", + expectError: false, + }, + { + name: "read empty file", + content: "", + expectError: false, + }, + { + name: "read file with special characters", + content: "hello\nworld\t!@#$%^&*()", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + filename := filepath.Join(tempDir, "test.txt") + + err := os.WriteFile(filename, []byte(tt.content), 0644) + assert.NoError(t, err, "Failed to create test file") + + content, err := ReadFile(filename) + + if tt.expectError { + assert.Error(t, err, "ReadFile() should return an error") + return + } + + assert.NoError(t, err, "ReadFile() should not return an error") + assert.Equal(t, tt.content, content, "ReadFile() content mismatch") + }) + } + + t.Run("read nonexistent file", func(t *testing.T) { + _, err := ReadFile("/nonexistent/file.txt") + assert.Error(t, err, "ReadFile() should return an error for nonexistent file") + }) +} + +func TestSanitizePackageName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple valid name", + input: "mypackage", + expected: "mypackage", + }, + { + name: "name with hyphens", + input: "my-package", + expected: "my_package", + }, + { + name: "name with dots", + input: "my.package", + expected: "my_package", + }, + { + name: "name with both hyphens and dots", + input: "my-package.name", + expected: "my_package_name", + }, + { + name: "name starting with number", + input: "123package", + expected: "_123package", + }, + { + name: "name starting with underscore", + input: "_package", + expected: "_package", + }, + { + name: "name starting with letter", + input: "Package", + expected: "Package", + }, + { + name: "name starting with special character", + input: "@package", + expected: "_@package", + }, + { + name: "complex name", + input: "123my-complex.package@name", + expected: "_123my_complex_package@name", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "single character letter", + input: "a", + expected: "a", + }, + { + name: "single character number", + input: "1", + expected: "_1", + }, + { + name: "single character underscore", + input: "_", + expected: "_", + }, + { + name: "single character special", + input: "@", + expected: "_@", + }, + { + name: "multiple consecutive hyphens", + input: "my--package", + expected: "my__package", + }, + { + name: "multiple consecutive dots", + input: "my..package", + expected: "my__package", + }, + { + name: "mixed case with special chars", + input: "MyPackage-name.version", + expected: "MyPackage_name_version", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizePackageName(tt.input) + assert.Equal(t, tt.expected, result, "SanitizePackageName(%q)", tt.input) + }) + } +} + +func BenchmarkSanitizePackageName(b *testing.B) { + testCases := []string{ + "simple", + "my-package", + "my.package.name", + "123complex-package.name@version", + "very-long-package-name-with-many-special-characters.and.dots", + } + + for _, tc := range testCases { + b.Run(tc, func(b *testing.B) { + for i := 0; i < b.N; i++ { + SanitizePackageName(tc) + } + }) + } +} diff --git a/internal/extgen/validator.go b/internal/extgen/validator.go new file mode 100644 index 00000000..b4e89727 --- /dev/null +++ b/internal/extgen/validator.go @@ -0,0 +1,294 @@ +package extgen + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "regexp" + "strings" +) + +var functionNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) +var parameterNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) +var classNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) +var propNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +type Validator struct{} + +func (v *Validator) validateFunction(fn phpFunction) error { + if fn.Name == "" { + return fmt.Errorf("function name cannot be empty") + } + + if !functionNameRegex.MatchString(fn.Name) { + return fmt.Errorf("invalid function name: %s", fn.Name) + } + + for i, param := range fn.Params { + if err := v.validateParameter(param); err != nil { + return fmt.Errorf("parameter %d (%s): %w", i, param.Name, err) + } + } + + if err := v.validateReturnType(fn.ReturnType); err != nil { + return fmt.Errorf("return type: %w", err) + } + + return nil +} + +func (v *Validator) validateParameter(param phpParameter) error { + if param.Name == "" { + return fmt.Errorf("parameter name cannot be empty") + } + + if !parameterNameRegex.MatchString(param.Name) { + return fmt.Errorf("invalid parameter name: %s", param.Name) + } + + validTypes := []string{"string", "int", "float", "bool", "array", "object", "mixed"} + if !v.isValidType(param.PhpType, validTypes) { + return fmt.Errorf("invalid parameter type: %s", param.PhpType) + } + + return nil +} + +func (v *Validator) validateReturnType(returnType string) error { + validReturnTypes := []string{"void", "string", "int", "float", "bool", "array", "object", "mixed", "null", "true", "false"} + if !v.isValidType(returnType, validReturnTypes) { + return fmt.Errorf("invalid return type: %s", returnType) + } + return nil +} + +func (v *Validator) validateClass(class phpClass) error { + if class.Name == "" { + return fmt.Errorf("class name cannot be empty") + } + + if !classNameRegex.MatchString(class.Name) { + return fmt.Errorf("invalid class name: %s", class.Name) + } + + for i, prop := range class.Properties { + if err := v.validateClassProperty(prop); err != nil { + return fmt.Errorf("property %d (%s): %w", i, prop.Name, err) + } + } + + return nil +} + +func (v *Validator) validateClassProperty(prop phpClassProperty) error { + if prop.Name == "" { + return fmt.Errorf("property name cannot be empty") + } + + if !propNameRegex.MatchString(prop.Name) { + return fmt.Errorf("invalid property name: %s", prop.Name) + } + + validTypes := []string{"string", "int", "float", "bool", "array", "object", "mixed"} + if !v.isValidType(prop.PhpType, validTypes) { + return fmt.Errorf("invalid property type: %s", prop.PhpType) + } + + return nil +} + +func (v *Validator) isValidType(typeStr string, validTypes []string) bool { + for _, valid := range validTypes { + if typeStr == valid { + return true + } + } + return false +} + +// validateScalarTypes checks if PHP signature contains only supported scalar types +func (v *Validator) validateScalarTypes(fn phpFunction) error { + supportedTypes := []string{"string", "int", "float", "bool"} + + for i, param := range fn.Params { + if !v.isScalarType(param.PhpType, supportedTypes) { + return fmt.Errorf("parameter %d (%s) has unsupported type '%s'. Only scalar types (string, int, float, bool) and their nullable variants are supported", i+1, param.Name, param.PhpType) + } + } + + if fn.ReturnType != "void" && !v.isScalarType(fn.ReturnType, supportedTypes) { + return fmt.Errorf("return type '%s' is not supported. Only scalar types (string, int, float, bool), void, and their nullable variants are supported", fn.ReturnType) + } + + return nil +} + +func (v *Validator) isScalarType(phpType string, supportedTypes []string) bool { + for _, supported := range supportedTypes { + if phpType == supported { + return true + } + } + return false +} + +// validateGoFunctionSignatureWithOptions validates with option for method vs function +func (v *Validator) validateGoFunctionSignatureWithOptions(phpFunc phpFunction, isMethod bool) error { + if phpFunc.GoFunction == "" { + return fmt.Errorf("no Go function found for PHP function '%s'", phpFunc.Name) + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "", "package main\n"+phpFunc.GoFunction, 0) + if err != nil { + return fmt.Errorf("failed to parse Go function: %w", err) + } + + var goFunc *ast.FuncDecl + for _, decl := range file.Decls { + if funcDecl, ok := decl.(*ast.FuncDecl); ok { + goFunc = funcDecl + break + } + } + + if goFunc == nil { + return fmt.Errorf("no function declaration found in Go function") + } + + goParamCount := 0 + if goFunc.Type.Params != nil { + goParamCount = len(goFunc.Type.Params.List) + } + + hasReceiver := goFunc.Recv != nil && len(goFunc.Recv.List) > 0 + paramOffset := 0 + effectiveGoParamCount := goParamCount + + if hasReceiver { + paramOffset = 0 + effectiveGoParamCount = goParamCount + } else if isMethod && goParamCount > 0 { + // this is a method-like function, first parameter should be the struct + paramOffset = 1 + effectiveGoParamCount = goParamCount - 1 + } + + if len(phpFunc.Params) != effectiveGoParamCount { + return fmt.Errorf("parameter count mismatch: PHP function has %d parameters but Go function has %d", len(phpFunc.Params), effectiveGoParamCount) + } + + if goFunc.Type.Params != nil && len(phpFunc.Params) > 0 { + for i, phpParam := range phpFunc.Params { + goParamIndex := i + paramOffset + + if goParamIndex >= len(goFunc.Type.Params.List) { + break + } + + goParam := goFunc.Type.Params.List[goParamIndex] + expectedGoType := v.phpTypeToGoType(phpParam.PhpType, phpParam.IsNullable) + actualGoType := v.goTypeToString(goParam.Type) + + if !v.isCompatibleGoType(expectedGoType, actualGoType) { + return fmt.Errorf("parameter %d type mismatch: PHP '%s' requires Go type '%s' but found '%s'", i+1, phpParam.PhpType, expectedGoType, actualGoType) + } + } + } + + expectedGoReturnType := v.phpReturnTypeToGoType(phpFunc.ReturnType, phpFunc.IsReturnNullable) + actualGoReturnType := v.goReturnTypeToString(goFunc.Type.Results) + + if !v.isCompatibleGoType(expectedGoReturnType, actualGoReturnType) { + return fmt.Errorf("return type mismatch: PHP '%s' requires Go return type '%s' but found '%s'", phpFunc.ReturnType, expectedGoReturnType, actualGoReturnType) + } + + return nil +} + +func (v *Validator) phpTypeToGoType(phpType string, isNullable bool) string { + var baseType string + switch phpType { + case "string": + baseType = "*C.zend_string" + case "int": + baseType = "int64" + case "float": + baseType = "float64" + case "bool": + baseType = "bool" + default: + baseType = "interface{}" + } + + if isNullable && phpType != "string" { + return "*" + baseType + } + + return baseType +} + +// isCompatibleGoType checks if the actual Go type is compatible with the expected type. +func (v *Validator) isCompatibleGoType(expectedType, actualType string) bool { + if expectedType == actualType { + return true + } + + switch expectedType { + case "int64": + return actualType == "int" + case "*int64": + return actualType == "*int" + case "*float64": + return actualType == "*float32" + } + + return false +} + +func (v *Validator) phpReturnTypeToGoType(phpReturnType string, isNullable bool) string { + switch phpReturnType { + case "void": + return "" + case "string": + return "unsafe.Pointer" + case "int": + return "int64" + case "float": + return "float64" + case "bool": + return "bool" + default: + return "interface{}" + } +} + +func (v *Validator) goTypeToString(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.StarExpr: + return "*" + v.goTypeToString(t.X) + case *ast.SelectorExpr: + return v.goTypeToString(t.X) + "." + t.Sel.Name + default: + return "unknown" + } +} + +func (v *Validator) goReturnTypeToString(results *ast.FieldList) string { + if results == nil || len(results.List) == 0 { + return "" + } + + if len(results.List) == 1 { + return v.goTypeToString(results.List[0].Type) + } + + var types []string + for _, field := range results.List { + types = append(types, v.goTypeToString(field.Type)) + } + return "(" + strings.Join(types, ", ") + ")" +} diff --git a/internal/extgen/validator_test.go b/internal/extgen/validator_test.go new file mode 100644 index 00000000..3e1b54c0 --- /dev/null +++ b/internal/extgen/validator_test.go @@ -0,0 +1,705 @@ +package extgen + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateFunction(t *testing.T) { + tests := []struct { + name string + function phpFunction + expectError bool + }{ + { + name: "valid function", + function: phpFunction{ + Name: "validFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "param1", PhpType: "string"}, + {Name: "param2", PhpType: "int"}, + }, + }, + expectError: false, + }, + { + name: "valid function with nullable return", + function: phpFunction{ + Name: "nullableReturn", + ReturnType: "string", + IsReturnNullable: true, + Params: []phpParameter{ + {Name: "data", PhpType: "array"}, + }, + }, + expectError: false, + }, + { + name: "empty function name", + function: phpFunction{ + Name: "", + ReturnType: "string", + }, + expectError: true, + }, + { + name: "invalid function name - starts with number", + function: phpFunction{ + Name: "123invalid", + ReturnType: "string", + }, + expectError: true, + }, + { + name: "invalid function name - contains special chars", + function: phpFunction{ + Name: "invalid-name", + ReturnType: "string", + }, + expectError: true, + }, + { + name: "invalid parameter name", + function: phpFunction{ + Name: "validName", + ReturnType: "string", + Params: []phpParameter{ + {Name: "123invalid", PhpType: "string"}, + }, + }, + expectError: true, + }, + { + name: "empty parameter name", + function: phpFunction{ + Name: "validName", + ReturnType: "string", + Params: []phpParameter{ + {Name: "", PhpType: "string"}, + }, + }, + expectError: true, + }, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.validateFunction(tt.function) + + if tt.expectError { + assert.Error(t, err, "validateFunction() should return an error for function %s", tt.function.Name) + } else { + assert.NoError(t, err, "validateFunction() should not return an error for function %s", tt.function.Name) + } + }) + } +} + +func TestValidateReturnType(t *testing.T) { + tests := []struct { + name string + returnType string + expectError bool + }{ + { + name: "valid string type", + returnType: "string", + expectError: false, + }, + { + name: "valid int type", + returnType: "int", + expectError: false, + }, + { + name: "valid array type", + returnType: "array", + expectError: false, + }, + { + name: "valid bool type", + returnType: "bool", + expectError: false, + }, + { + name: "valid float type", + returnType: "float", + expectError: false, + }, + { + name: "valid void type", + returnType: "void", + expectError: false, + }, + { + name: "invalid return type", + returnType: "invalidType", + expectError: true, + }, + { + name: "empty return type", + returnType: "", + expectError: true, + }, + { + name: "case sensitive - String should be invalid", + returnType: "String", + expectError: true, + }, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.validateReturnType(tt.returnType) + + if tt.expectError { + assert.Error(t, err, "validateReturnType(%s) should return an error", tt.returnType) + } else { + assert.NoError(t, err, "validateReturnType(%s) should not return an error", tt.returnType) + } + }) + } +} + +func TestValidateClassProperty(t *testing.T) { + tests := []struct { + name string + prop phpClassProperty + expectError bool + }{ + { + name: "valid property", + prop: phpClassProperty{ + Name: "validProperty", + PhpType: "string", + GoType: "string", + }, + expectError: false, + }, + { + name: "valid nullable property", + prop: phpClassProperty{ + Name: "nullableProperty", + PhpType: "int", + GoType: "*int", + IsNullable: true, + }, + expectError: false, + }, + { + name: "empty property name", + prop: phpClassProperty{ + Name: "", + PhpType: "string", + }, + expectError: true, + }, + { + name: "invalid property name", + prop: phpClassProperty{ + Name: "123invalid", + PhpType: "string", + }, + expectError: true, + }, + { + name: "invalid property type", + prop: phpClassProperty{ + Name: "validName", + PhpType: "invalidType", + }, + expectError: true, + }, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.validateClassProperty(tt.prop) + + if tt.expectError { + assert.Error(t, err, "validateClassProperty() should return an error") + } else { + assert.NoError(t, err, "validateClassProperty() should not return an error") + } + }) + } +} + +func TestValidateParameter(t *testing.T) { + tests := []struct { + name string + param phpParameter + expectError bool + }{ + { + name: "valid string parameter", + param: phpParameter{ + Name: "validParam", + PhpType: "string", + }, + expectError: false, + }, + { + name: "valid nullable parameter", + param: phpParameter{ + Name: "nullableParam", + PhpType: "int", + IsNullable: true, + }, + expectError: false, + }, + { + name: "valid parameter with default", + param: phpParameter{ + Name: "defaultParam", + PhpType: "string", + HasDefault: true, + DefaultValue: "hello", + }, + expectError: false, + }, + { + name: "empty parameter name", + param: phpParameter{ + Name: "", + PhpType: "string", + }, + expectError: true, + }, + { + name: "invalid parameter name", + param: phpParameter{ + Name: "123invalid", + PhpType: "string", + }, + expectError: true, + }, + { + name: "invalid parameter type", + param: phpParameter{ + Name: "validName", + PhpType: "invalidType", + }, + expectError: true, + }, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.validateParameter(tt.param) + + if tt.expectError { + assert.Error(t, err, "validateParameter() should return an error") + } else { + assert.NoError(t, err, "validateParameter() should not return an error") + } + }) + } +} + +func TestValidateClass(t *testing.T) { + tests := []struct { + name string + class phpClass + expectError bool + }{ + { + name: "valid class", + class: phpClass{ + Name: "ValidClass", + GoStruct: "ValidStruct", + Properties: []phpClassProperty{ + {Name: "name", PhpType: "string"}, + {Name: "age", PhpType: "int"}, + }, + }, + expectError: false, + }, + { + name: "valid class with nullable properties", + class: phpClass{ + Name: "NullableClass", + GoStruct: "NullableStruct", + Properties: []phpClassProperty{ + {Name: "required", PhpType: "string", IsNullable: false}, + {Name: "optional", PhpType: "string", IsNullable: true}, + }, + }, + expectError: false, + }, + { + name: "empty class name", + class: phpClass{ + Name: "", + GoStruct: "ValidStruct", + }, + expectError: true, + }, + { + name: "invalid class name", + class: phpClass{ + Name: "123InvalidClass", + GoStruct: "ValidStruct", + }, + expectError: true, + }, + { + name: "invalid property", + class: phpClass{ + Name: "ValidClass", + GoStruct: "ValidStruct", + Properties: []phpClassProperty{ + {Name: "123invalid", PhpType: "string"}, + }, + }, + expectError: true, + }, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.validateClass(tt.class) + + if tt.expectError { + assert.Error(t, err, "validateClass() should return an error") + } else { + assert.NoError(t, err, "validateClass() should not return an error") + } + }) + } +} + +func TestValidateScalarTypes(t *testing.T) { + tests := []struct { + name string + function phpFunction + expectError bool + errorMsg string + }{ + { + name: "valid scalar parameters only", + function: phpFunction{ + Name: "validFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "stringParam", PhpType: "string"}, + {Name: "intParam", PhpType: "int"}, + {Name: "floatParam", PhpType: "float"}, + {Name: "boolParam", PhpType: "bool"}, + }, + }, + expectError: false, + }, + { + name: "valid nullable scalar parameters", + function: phpFunction{ + Name: "nullableFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "stringParam", PhpType: "string", IsNullable: true}, + {Name: "intParam", PhpType: "int", IsNullable: true}, + }, + }, + expectError: false, + }, + { + name: "valid void return type", + function: phpFunction{ + Name: "voidFunction", + ReturnType: "void", + Params: []phpParameter{ + {Name: "stringParam", PhpType: "string"}, + }, + }, + expectError: false, + }, + { + name: "invalid array parameter", + function: phpFunction{ + Name: "arrayFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "arrayParam", PhpType: "array"}, + }, + }, + expectError: true, + errorMsg: "parameter 1 (arrayParam) has unsupported type 'array'", + }, + { + name: "invalid object parameter", + function: phpFunction{ + Name: "objectFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "objectParam", PhpType: "object"}, + }, + }, + expectError: true, + errorMsg: "parameter 1 (objectParam) has unsupported type 'object'", + }, + { + name: "invalid mixed parameter", + function: phpFunction{ + Name: "mixedFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "mixedParam", PhpType: "mixed"}, + }, + }, + expectError: true, + errorMsg: "parameter 1 (mixedParam) has unsupported type 'mixed'", + }, + { + name: "invalid array return type", + function: phpFunction{ + Name: "arrayReturnFunction", + ReturnType: "array", + Params: []phpParameter{ + {Name: "stringParam", PhpType: "string"}, + }, + }, + expectError: true, + errorMsg: "return type 'array' is not supported", + }, + { + name: "invalid object return type", + function: phpFunction{ + Name: "objectReturnFunction", + ReturnType: "object", + Params: []phpParameter{ + {Name: "stringParam", PhpType: "string"}, + }, + }, + expectError: true, + errorMsg: "return type 'object' is not supported", + }, + { + name: "mixed scalar and invalid parameters", + function: phpFunction{ + Name: "mixedFunction", + ReturnType: "string", + Params: []phpParameter{ + {Name: "validParam", PhpType: "string"}, + {Name: "invalidParam", PhpType: "array"}, + {Name: "anotherValidParam", PhpType: "int"}, + }, + }, + expectError: true, + errorMsg: "parameter 2 (invalidParam) has unsupported type 'array'", + }, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.validateScalarTypes(tt.function) + + if tt.expectError { + assert.Error(t, err, "validateScalarTypes() should return an error for function %s", tt.function.Name) + assert.Contains(t, err.Error(), tt.errorMsg, "Error message should contain expected text") + } else { + assert.NoError(t, err, "validateScalarTypes() should not return an error for function %s", tt.function.Name) + } + }) + } +} + +func TestValidateGoFunctionSignature(t *testing.T) { + tests := []struct { + name string + phpFunc phpFunction + expectError bool + errorMsg string + }{ + { + name: "valid Go function signature", + phpFunc: phpFunction{ + Name: "testFunc", + ReturnType: "string", + Params: []phpParameter{ + {Name: "name", PhpType: "string"}, + {Name: "count", PhpType: "int"}, + }, + GoFunction: `func testFunc(name *C.zend_string, count int64) unsafe.Pointer { + return nil +}`, + }, + expectError: false, + }, + { + name: "valid void return type", + phpFunc: phpFunction{ + Name: "voidFunc", + ReturnType: "void", + Params: []phpParameter{ + {Name: "message", PhpType: "string"}, + }, + GoFunction: `func voidFunc(message *C.zend_string) { + // Do something +}`, + }, + expectError: false, + }, + { + name: "no Go function provided", + phpFunc: phpFunction{ + Name: "noGoFunc", + ReturnType: "string", + Params: []phpParameter{}, + GoFunction: "", + }, + expectError: true, + errorMsg: "no Go function found", + }, + { + name: "parameter count mismatch", + phpFunc: phpFunction{ + Name: "countMismatch", + ReturnType: "string", + Params: []phpParameter{ + {Name: "param1", PhpType: "string"}, + {Name: "param2", PhpType: "int"}, + }, + GoFunction: `func countMismatch(param1 *C.zend_string) unsafe.Pointer { + return nil +}`, + }, + expectError: true, + errorMsg: "parameter count mismatch: PHP function has 2 parameters but Go function has 1", + }, + { + name: "parameter type mismatch", + phpFunc: phpFunction{ + Name: "typeMismatch", + ReturnType: "string", + Params: []phpParameter{ + {Name: "name", PhpType: "string"}, + {Name: "count", PhpType: "int"}, + }, + GoFunction: `func typeMismatch(name *C.zend_string, count string) unsafe.Pointer { + return nil +}`, + }, + expectError: true, + errorMsg: "parameter 2 type mismatch: PHP 'int' requires Go type 'int64' but found 'string'", + }, + { + name: "return type mismatch", + phpFunc: phpFunction{ + Name: "returnMismatch", + ReturnType: "int", + Params: []phpParameter{ + {Name: "value", PhpType: "string"}, + }, + GoFunction: `func returnMismatch(value *C.zend_string) string { + return "" +}`, + }, + expectError: true, + errorMsg: "return type mismatch: PHP 'int' requires Go return type 'int64' but found 'string'", + }, + { + name: "valid bool parameter and return", + phpFunc: phpFunction{ + Name: "boolFunc", + ReturnType: "bool", + Params: []phpParameter{ + {Name: "flag", PhpType: "bool"}, + }, + GoFunction: `func boolFunc(flag bool) bool { + return flag +}`, + }, + expectError: false, + }, + { + name: "valid float parameter and return", + phpFunc: phpFunction{ + Name: "floatFunc", + ReturnType: "float", + Params: []phpParameter{ + {Name: "value", PhpType: "float"}, + }, + GoFunction: `func floatFunc(value float64) float64 { + return value * 2.0 +}`, + }, + expectError: false, + }, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.validateGoFunctionSignatureWithOptions(tt.phpFunc, false) + + if tt.expectError { + assert.Error(t, err, "validateGoFunctionSignature() should return an error for function %s", tt.phpFunc.Name) + assert.Contains(t, err.Error(), tt.errorMsg, "Error message should contain expected text") + } else { + assert.NoError(t, err, "validateGoFunctionSignature() should not return an error for function %s", tt.phpFunc.Name) + } + }) + } +} + +func TestPhpTypeToGoType(t *testing.T) { + tests := []struct { + phpType string + isNullable bool + expected string + }{ + {"string", false, "*C.zend_string"}, + {"string", true, "*C.zend_string"}, // String is already a pointer, no change for nullable + {"int", false, "int64"}, + {"int", true, "*int64"}, // Nullable int becomes pointer to int64 + {"float", false, "float64"}, + {"float", true, "*float64"}, // Nullable float becomes pointer to float64 + {"bool", false, "bool"}, + {"bool", true, "*bool"}, // Nullable bool becomes pointer to bool + {"unknown", false, "interface{}"}, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.phpType, func(t *testing.T) { + result := validator.phpTypeToGoType(tt.phpType, tt.isNullable) + assert.Equal(t, tt.expected, result, "phpTypeToGoType(%s, %v) should return %s", tt.phpType, tt.isNullable, tt.expected) + }) + } +} + +func TestPhpReturnTypeToGoType(t *testing.T) { + tests := []struct { + phpReturnType string + isNullable bool + expected string + }{ + {"void", false, ""}, + {"void", true, ""}, + {"string", false, "unsafe.Pointer"}, + {"string", true, "unsafe.Pointer"}, + {"int", false, "int64"}, + {"int", true, "int64"}, + {"float", false, "float64"}, + {"float", true, "float64"}, + {"bool", false, "bool"}, + {"bool", true, "bool"}, + {"unknown", false, "interface{}"}, + } + + validator := Validator{} + for _, tt := range tests { + t.Run(tt.phpReturnType, func(t *testing.T) { + result := validator.phpReturnTypeToGoType(tt.phpReturnType, tt.isNullable) + assert.Equal(t, tt.expected, result, "phpReturnTypeToGoType(%s, %v) should return %s", tt.phpReturnType, tt.isNullable, tt.expected) + }) + } +} diff --git a/internal/testext/ext_test.go b/internal/testext/ext_test.go new file mode 100644 index 00000000..3e9cfa14 --- /dev/null +++ b/internal/testext/ext_test.go @@ -0,0 +1,7 @@ +package testext + +import "testing" + +func TestRegisterExtension(t *testing.T) { + testRegisterExtension(t) +} diff --git a/internal/testext/extension.h b/internal/testext/extension.h new file mode 100644 index 00000000..57fa60d6 --- /dev/null +++ b/internal/testext/extension.h @@ -0,0 +1,9 @@ +#ifndef _EXTENSIONS_H +#define _EXTENSIONS_H + +#include + +extern zend_module_entry module1_entry; +extern zend_module_entry module2_entry; + +#endif diff --git a/internal/testext/extensions.c b/internal/testext/extensions.c new file mode 100644 index 00000000..721955f6 --- /dev/null +++ b/internal/testext/extensions.c @@ -0,0 +1,26 @@ +#include +#include + +#include "_cgo_export.h" + +zend_module_entry module1_entry = {STANDARD_MODULE_HEADER, + "ext1", + NULL, /* Functions */ + NULL, /* MINIT */ + NULL, /* MSHUTDOWN */ + NULL, /* RINIT */ + NULL, /* RSHUTDOWN */ + NULL, /* MINFO */ + "0.1.0", + STANDARD_MODULE_PROPERTIES}; + +zend_module_entry module2_entry = {STANDARD_MODULE_HEADER, + "ext2", + NULL, /* Functions */ + NULL, /* MINIT */ + NULL, /* MSHUTDOWN */ + NULL, /* RINIT */ + NULL, /* RSHUTDOWN */ + NULL, /* MINFO */ + "0.1.0", + STANDARD_MODULE_PROPERTIES}; diff --git a/internal/testext/exttest.go b/internal/testext/exttest.go new file mode 100644 index 00000000..abebee4c --- /dev/null +++ b/internal/testext/exttest.go @@ -0,0 +1,44 @@ +package testext + +// #cgo darwin pkg-config: libxml-2.0 +// #cgo CFLAGS: -Wall -Werror +// #cgo CFLAGS: -I/usr/local/include -I/usr/local/include/php -I/usr/local/include/php/main -I/usr/local/include/php/TSRM -I/usr/local/include/php/Zend -I/usr/local/include/php/ext -I/usr/local/include/php/ext/date/lib +// #cgo linux CFLAGS: -D_GNU_SOURCE +// #cgo darwin CFLAGS: -I/opt/homebrew/include +// #cgo LDFLAGS: -L/usr/local/lib -L/usr/lib -lphp -lm -lutil +// #cgo linux LDFLAGS: -ldl -lresolv +// #cgo darwin LDFLAGS: -Wl,-rpath,/usr/local/lib -L/opt/homebrew/lib -L/opt/homebrew/opt/libiconv/lib -liconv -ldl +// #include "extension.h" +import "C" +import ( + "github.com/dunglas/frankenphp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "net/http/httptest" + "testing" + "unsafe" +) + +func testRegisterExtension(t *testing.T) { + frankenphp.RegisterExtension(unsafe.Pointer(&C.module1_entry)) + frankenphp.RegisterExtension(unsafe.Pointer(&C.module2_entry)) + + err := frankenphp.Init() + require.Nil(t, err) + defer frankenphp.Shutdown() + + req := httptest.NewRequest("GET", "http://example.com/index.php", nil) + w := httptest.NewRecorder() + + req, err = frankenphp.NewRequestWithContext(req, frankenphp.WithRequestDocumentRoot("./testdata", false)) + assert.NoError(t, err) + + err = frankenphp.ServeHTTP(w, req) + assert.NoError(t, err) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + assert.Contains(t, string(body), "ext1") + assert.Contains(t, string(body), "ext2") +} diff --git a/internal/testext/testdata/index.php b/internal/testext/testdata/index.php new file mode 100644 index 00000000..96c7dc65 --- /dev/null +++ b/internal/testext/testdata/index.php @@ -0,0 +1,3 @@ + +import "C" +import "unsafe" + +// EXPERIMENTAL: GoString copies a zend_string to a Go string. +func GoString(s unsafe.Pointer) string { + if s == nil { + return "" + } + + zendStr := (*C.zend_string)(s) + + return C.GoStringN((*C.char)(unsafe.Pointer(&zendStr.val)), C.int(zendStr.len)) +} + +// EXPERIMENTAL: PHPString converts a Go string to a zend_string with copy. The string can be +// non-persistent (automatically freed after the request by the ZMM) or persistent. If you choose +// the second mode, it is your repsonsability to free the allocated memory. +func PHPString(s string, persistent bool) unsafe.Pointer { + if s == "" { + return nil + } + + zendStr := C.zend_string_init( + (*C.char)(unsafe.Pointer(unsafe.StringData(s))), + C.size_t(len(s)), + C._Bool(persistent), + ) + + return unsafe.Pointer(zendStr) +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 00000000..be4559a4 --- /dev/null +++ b/types_test.go @@ -0,0 +1,7 @@ +package frankenphp + +import "testing" + +func TestGoString(t *testing.T) { + testGoString(t) +} diff --git a/typestest.go b/typestest.go new file mode 100644 index 00000000..178dae22 --- /dev/null +++ b/typestest.go @@ -0,0 +1,18 @@ +package frankenphp + +//#include +// +//zend_string *hello_string() { +// return zend_string_init("Hello", 5, 1); +//} +import "C" +import ( + "github.com/stretchr/testify/assert" + "testing" + "unsafe" +) + +func testGoString(t *testing.T) { + assert.Equal(t, "", GoString(nil)) + assert.Equal(t, "Hello", GoString(unsafe.Pointer(C.hello_string()))) +}