Add GoSRT & improvements (repo-merge)

Commits (Ingo Oppermann):
- Add experimental SRT connection stats and logs
- Hide /config/reload endpoint in reade-only mode
- Add SRT server
- Create v16 in go.mod
- Fix data races, tests, lint, and update dependencies
- Add trailing slash for routed directories (datarhei/restreamer#340)
- Allow relative URLs in content in static routes

Co-Authored-By: Ingo Oppermann <57445+ioppermann@users.noreply.github.com>
This commit is contained in:
Jan Stabenow
2022-06-23 22:13:58 +02:00
parent d7db9e4efe
commit eb1cc37456
323 changed files with 17524 additions and 10050 deletions

View File

@@ -5,10 +5,86 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
<a name="unreleased"></a>
## [Unreleased](https://github.com/99designs/gqlgen/compare/v0.17.7...HEAD)
## [Unreleased](https://github.com/99designs/gqlgen/compare/v0.17.9...HEAD)
<!-- end of if -->
<!-- end of CommitGroups -->
<a name="v0.17.9"></a>
## [v0.17.9](https://github.com/99designs/gqlgen/compare/v0.17.8...v0.17.9) - 2022-05-26
- <a href="https://github.com/99designs/gqlgen/commit/7f0611b2d19833a740afcfaf5708febff942da2d"><tt>7f0611b2</tt></a> release v0.17.9
- <a href="https://github.com/99designs/gqlgen/commit/738209b26337bc1116be7b0afacc83eae6bb93b0"><tt>738209b2</tt></a> Update gqlparser (<a href="https://github.com/99designs/gqlgen/pull/16">#2216</a>)
<dl><dd><details><summary><a href="https://github.com/99designs/gqlgen/commit/6855b7290cab62a1fc6a26a2b633e0b5bbf248da"><tt>6855b729</tt></a> fix: prevent goroutine leak and CPU spinning at websocket transport (<a href="https://github.com/99designs/gqlgen/pull/09">#2209</a>) (closes <a href="https://github.com/99designs/gqlgen/issues/2168"> #2168</a>)</summary>
* Added goroutine leak test for chat example
* Improved chat example with proper concurrency
This reverts commit eef7bfaad1b524f9e2fc0c1150fdb321c276069e.
* Improved subscription channel usage
* Regenerated examples and codegen
* Add support for subscription keepalives in websocket client
* Update chat example test
* if else chain to switch
* Revert "Add support for subscription keepalives in websocket client"
This reverts commits 64b882c3c9901f25edc0684ce2a1f9b63443416b and 670cf22272b490005d46dc2bee1634de1cd06d68.
* Fixed chat example race condition
* Fixed chatroom#Messages type
</details></dd></dl>
<dl><dd><details><summary><a href="https://github.com/99designs/gqlgen/commit/5f5bfcb97fdb01026cf35a5dc46f1246a30f9b26"><tt>5f5bfcb9</tt></a> fix <a href="https://github.com/99designs/gqlgen/pull/14">#2204](https://github.com/99designs/gqlgen/issues/2204) - don't try to embed builtin sources ([#2214</a>)</summary>
* dont't try to embed builtins
* add test
* generated code
* fix error message string
</details></dd></dl>
- <a href="https://github.com/99designs/gqlgen/commit/8d9d3f125f13dcd19f59072d3c38366dc520758b"><tt>8d9d3f12</tt></a> Check only direct dependencies (<a href="https://github.com/99designs/gqlgen/pull/05">#2205</a>)
- <a href="https://github.com/99designs/gqlgen/commit/b262e40a485f67d2659e239a156418938d0fe2e9"><tt>b262e40a</tt></a> v0.17.8 postrelease bump
<!-- end of Commits -->
<!-- end of Else -->
<!-- end of If NoteGroups -->
<a name="v0.17.8"></a>
## [v0.17.8](https://github.com/99designs/gqlgen/compare/v0.17.7...v0.17.8) - 2022-05-25
- <a href="https://github.com/99designs/gqlgen/commit/25367e0a24998aea40f09218f60d1d0e6d1cce4a"><tt>25367e0a</tt></a> release v0.17.8
- <a href="https://github.com/99designs/gqlgen/commit/5a56b69d89c7414e21b2f01e0e5042a26b69c5cb"><tt>5a56b69d</tt></a> Add security workflow with nancy (<a href="https://github.com/99designs/gqlgen/pull/02">#2202</a>)
- <a href="https://github.com/99designs/gqlgen/commit/482f4ce08e65458cec2dbfaf7d184f1c8fccb129"><tt>482f4ce0</tt></a> Run CI tests on windows (<a href="https://github.com/99designs/gqlgen/pull/99">#2199</a>)
- <a href="https://github.com/99designs/gqlgen/commit/656045d3fa643b898932c3f5332544b0baed1af4"><tt>656045d3</tt></a> This works on Windows too! (<a href="https://github.com/99designs/gqlgen/pull/97">#2197</a>)
- <a href="https://github.com/99designs/gqlgen/commit/f6aeed60a508dae102b2b821d3a947e24e5e0826"><tt>f6aeed60</tt></a> Merge branch 'master' of github.com:99designs/gqlgen
- <a href="https://github.com/99designs/gqlgen/commit/d91080be396af96266941499d369d0f8279761b0"><tt>d91080be</tt></a> Update changelog
- <a href="https://github.com/99designs/gqlgen/commit/752d2d7e9fff08c82a6d3ffc1c8c7ffe2a2e9fe2"><tt>752d2d7e</tt></a> v0.17.7 postrelease bump
<!-- end of Commits -->
<!-- end of Else -->
<!-- end of If NoteGroups -->
<a name="v0.17.7"></a>
## [v0.17.7](https://github.com/99designs/gqlgen/compare/v0.17.6...v0.17.7) - 2022-05-24
- <a href="https://github.com/99designs/gqlgen/commit/2b1dff1b71f89c95e946bbe5948b7061f9c47aa8"><tt>2b1dff1b</tt></a> release v0.17.7

View File

@@ -16,21 +16,23 @@ import (
)
type Config struct {
SchemaFilename StringList `yaml:"schema,omitempty"`
Exec ExecConfig `yaml:"exec"`
Model PackageConfig `yaml:"model,omitempty"`
Federation PackageConfig `yaml:"federation,omitempty"`
Resolver ResolverConfig `yaml:"resolver,omitempty"`
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
Packages *code.Packages `yaml:"-"`
Schema *ast.Schema `yaml:"-"`
SchemaFilename StringList `yaml:"schema,omitempty"`
Exec ExecConfig `yaml:"exec"`
Model PackageConfig `yaml:"model,omitempty"`
Federation PackageConfig `yaml:"federation,omitempty"`
Resolver ResolverConfig `yaml:"resolver,omitempty"`
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"`
ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
Packages *code.Packages `yaml:"-"`
Schema *ast.Schema `yaml:"-"`
// Deprecated: use Federation instead. Will be removed next release
Federated bool `yaml:"federated,omitempty"`
@@ -41,11 +43,13 @@ var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
// DefaultConfig creates a copy of the default config
func DefaultConfig() *Config {
return &Config{
SchemaFilename: StringList{"schema.graphql"},
Model: PackageConfig{Filename: "models_gen.go"},
Exec: ExecConfig{Filename: "generated.go"},
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
SchemaFilename: StringList{"schema.graphql"},
Model: PackageConfig{Filename: "models_gen.go"},
Exec: ExecConfig{Filename: "generated.go"},
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
StructFieldsAlwaysPointers: true,
ResolversAlwaysReturnPointers: true,
}
}

View File

@@ -71,7 +71,7 @@ func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, e
log.Println(err.Error())
}
if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
f.TypeReference = b.Binder.PointerTo(f.TypeReference)
}
@@ -557,7 +557,20 @@ func (f *Field) CallArgs() string {
}
for _, arg := range f.Args {
args = append(args, "fc.Args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
if types.IsInterface(arg.TypeReference.GO) {
tmp = fmt.Sprintf(`
func () interface{} {
if fc.Args["%s"] == nil {
return nil
}
return fc.Args["%s"].(interface{})
}()`, arg.Name, arg.Name,
)
}
args = append(args, tmp)
}
return strings.Join(args, ", ")

View File

@@ -13,7 +13,7 @@ require (
github.com/mitchellh/mapstructure v1.3.1
github.com/stretchr/testify v1.7.1
github.com/urfave/cli/v2 v2.8.1
github.com/vektah/gqlparser/v2 v2.4.4
github.com/vektah/gqlparser/v2 v2.4.5
golang.org/x/tools v0.1.10
google.golang.org/protobuf v1.28.0
gopkg.in/yaml.v2 v2.4.0

View File

@@ -49,8 +49,8 @@ github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMT
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/urfave/cli/v2 v2.8.1 h1:CGuYNZF9IKZY/rfBe3lJpccSoIY1ytfvmgQT90cNOl4=
github.com/urfave/cli/v2 v2.8.1/go.mod h1:Z41J9TPoffeoqP0Iza0YbAhGvymRdZAd2uPmZ5JxRdY=
github.com/vektah/gqlparser/v2 v2.4.4 h1:rh9hwZ5Jx9cCq88zXz2YHKmuQBuwY1JErHU8GywFdwE=
github.com/vektah/gqlparser/v2 v2.4.4/go.mod h1:flJWIR04IMQPGz+BXLrORkrARBxv/rtyIAFvd/MceW0=
github.com/vektah/gqlparser/v2 v2.4.5 h1:C02NsyEsL4TXJB7ndonqTfuQOL4XPIu0aAWugdmTgmc=
github.com/vektah/gqlparser/v2 v2.4.5/go.mod h1:flJWIR04IMQPGz+BXLrORkrARBxv/rtyIAFvd/MceW0=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

View File

@@ -132,7 +132,7 @@ func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graph
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := &graphql.Response{
Errors: list,
Errors: graphql.GetErrors(ctx),
}
resp.Extensions = graphql.GetExtensions(ctx)
return resp

View File

@@ -1,3 +1,3 @@
package graphql
const Version = "v0.17.9"
const Version = "v0.17.10"

View File

@@ -29,6 +29,13 @@ resolver:
# Optional: turn on to use []Thing instead of []*Thing
# omit_slice_element_pointers: false
# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true
# Optional: turn off to make resolvers return values instead of pointers for structs
# resolvers_always_return_pointers: true
# Optional: set to speed up generation time by not performing a final validation pass.
# skip_validation: true

View File

@@ -48,9 +48,12 @@ type Object struct {
type Field struct {
Description string
Name string
Type types.Type
Tag string
// Name is the field's name as it appears in the schema
Name string
// GoName is the field's name as it appears in the generated Go code
GoName string
Type types.Type
Tag string
}
type Enum struct {
@@ -178,19 +181,22 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
}
}
name := field.Name
name := templates.ToGo(field.Name)
if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
name = nameOveride
}
typ = binder.CopyModifiersFromAst(field.Type, typ)
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
if cfg.StructFieldsAlwaysPointers {
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
}
}
f := &Field{
Name: name,
Name: field.Name,
GoName: name,
Type: typ,
Description: field.Description,
Tag: `json:"` + field.Name + `"`,
@@ -230,6 +236,12 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
// check for cyclical relationships and recursive structs
if !cfg.StructFieldsAlwaysPointers {
findAndHandleCyclicalRelationships(b)
}
for _, it := range b.Enums {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
@@ -303,3 +315,55 @@ func isStruct(t types.Type) bool {
_, is := t.Underlying().(*types.Struct)
return is
}
// findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them
// with pointers. These relationships will produce compilation errors if they are not pointers.
// Also handles recursive structs.
func findAndHandleCyclicalRelationships(b *ModelBuild) {
for ii, structA := range b.Models {
for _, fieldA := range structA.Fields {
if strings.Contains(fieldA.Type.String(), "NotCyclicalA") {
fmt.Print()
}
if !isStruct(fieldA.Type) {
continue
}
// the field Type string will be in the form "github.com/99designs/gqlgen/codegen/testserver/followschema.LoopA"
// we only want the part after the last dot: "LoopA"
// this could lead to false positives, as we are only checking the name of the struct type, but these
// should be extremely rare, if it is even possible at all.
fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".")
fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1]
// find this struct type amongst the generated structs
for jj, structB := range b.Models {
if structB.Name != fieldAStructName {
continue
}
// check if structB contains a cyclical reference back to structA
var cyclicalReferenceFound bool
for _, fieldB := range structB.Fields {
if !isStruct(fieldB.Type) {
continue
}
fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".")
fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1]
if fieldBStructName == structA.Name {
cyclicalReferenceFound = true
fieldB.Type = types.NewPointer(fieldB.Type)
// keep looping in case this struct has additional fields of this type
}
}
// if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once
if cyclicalReferenceFound && ii != jj {
fieldA.Type = types.NewPointer(fieldA.Type)
break
}
}
}
}
}

View File

@@ -29,7 +29,7 @@
{{- with .Description }}
{{.|prefixLines "// "}}
{{- end}}
{{ $field.Name|go }} {{$field.Type | ref}} `{{$field.Tag}}`
{{ $field.GoName }} {{$field.Type | ref}} `{{$field.Tag}}`
{{- end }}
}

View File

@@ -1,311 +0,0 @@
# Go JSON Schema Reflection
[![CI](https://github.com/alecthomas/jsonschema/actions/workflows/ci.yml/badge.svg)](https://github.com/alecthomas/jsonschema/actions/workflows/ci.yml)
[![Go Report Card](https://goreportcard.com/badge/github.com/alecthomas/jsonschema)](https://goreportcard.com/report/github.com/alecthomas/jsonschema)
[![GoDoc](https://godoc.org/github.com/alecthomas/jsonschema?status.svg)](https://godoc.org/github.com/alecthomas/jsonschema)
This package can be used to generate [JSON Schemas](http://json-schema.org/latest/json-schema-validation.html) from Go types through reflection.
- Supports arbitrarily complex types, including `interface{}`, maps, slices, etc.
- Supports json-schema features such as minLength, maxLength, pattern, format, etc.
- Supports simple string and numeric enums.
- Supports custom property fields via the `jsonschema_extras` struct tag.
## Example
The following Go type:
```go
type TestUser struct {
ID int `json:"id"`
Name string `json:"name" jsonschema:"title=the name,description=The name of a friend,example=joe,example=lucy,default=alex"`
Friends []int `json:"friends,omitempty" jsonschema_description:"The list of IDs, omitted when empty"`
Tags map[string]interface{} `json:"tags,omitempty" jsonschema_extras:"a=b,foo=bar,foo=bar1"`
BirthDate time.Time `json:"birth_date,omitempty" jsonschema:"oneof_required=date"`
YearOfBirth string `json:"year_of_birth,omitempty" jsonschema:"oneof_required=year"`
Metadata interface{} `json:"metadata,omitempty" jsonschema:"oneof_type=string;array"`
FavColor string `json:"fav_color,omitempty" jsonschema:"enum=red,enum=green,enum=blue"`
}
```
Results in following JSON Schema:
```go
jsonschema.Reflect(&TestUser{})
```
```json
{
"$schema": "http://json-schema.org/draft-04/schema#",
"$ref": "#/definitions/TestUser",
"definitions": {
"TestUser": {
"type": "object",
"properties": {
"metadata": {
"oneOf": [
{
"type": "string"
},
{
"type": "array"
}
]
},
"birth_date": {
"type": "string",
"format": "date-time"
},
"friends": {
"type": "array",
"items": {
"type": "integer"
},
"description": "The list of IDs, omitted when empty"
},
"id": {
"type": "integer"
},
"name": {
"type": "string",
"title": "the name",
"description": "The name of a friend",
"default": "alex",
"examples": [
"joe",
"lucy"
]
},
"tags": {
"type": "object",
"patternProperties": {
".*": {
"additionalProperties": true
}
},
"a": "b",
"foo": [
"bar",
"bar1"
]
},
"fav_color": {
"type": "string",
"enum": [
"red",
"green",
"blue"
]
}
},
"additionalProperties": false,
"required": ["id", "name"],
"oneOf": [
{
"required": [
"birth_date"
],
"title": "date"
},
{
"required": [
"year_of_birth"
],
"title": "year"
}
]
}
}
}
```
## Configurable behaviour
The behaviour of the schema generator can be altered with parameters when a `jsonschema.Reflector`
instance is created.
### ExpandedStruct
If set to ```true```, makes the top level struct not to reference itself in the definitions. But type passed should be a struct type.
eg.
```go
type GrandfatherType struct {
FamilyName string `json:"family_name" jsonschema:"required"`
}
type SomeBaseType struct {
SomeBaseProperty int `json:"some_base_property"`
// The jsonschema required tag is nonsensical for private and ignored properties.
// Their presence here tests that the fields *will not* be required in the output
// schema, even if they are tagged required.
somePrivateBaseProperty string `json:"i_am_private" jsonschema:"required"`
SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"`
SomeSchemaIgnoredProperty string `jsonschema:"-,required"`
SomeUntaggedBaseProperty bool `jsonschema:"required"`
someUnexportedUntaggedBaseProperty bool
Grandfather GrandfatherType `json:"grand"`
}
```
will output:
```json
{
"$schema": "http://json-schema.org/draft-04/schema#",
"required": [
"some_base_property",
"grand",
"SomeUntaggedBaseProperty"
],
"properties": {
"SomeUntaggedBaseProperty": {
"type": "boolean"
},
"grand": {
"$schema": "http://json-schema.org/draft-04/schema#",
"$ref": "#/definitions/GrandfatherType"
},
"some_base_property": {
"type": "integer"
}
},
"type": "object",
"definitions": {
"GrandfatherType": {
"required": [
"family_name"
],
"properties": {
"family_name": {
"type": "string"
}
},
"additionalProperties": false,
"type": "object"
}
}
}
```
### PreferYAMLSchema
JSON schemas can also be used to validate YAML, however YAML frequently uses
different identifiers to JSON indicated by the `yaml:` tag. The `Reflector` will
by default prefer `json:` tags over `yaml:` tags (and only use the latter if the
former are not present). This behavior can be changed via the `PreferYAMLSchema`
flag, that will switch this behavior: `yaml:` tags will be preferred over
`json:` tags.
With `PreferYAMLSchema: true`, the following struct:
```go
type Person struct {
FirstName string `json:"FirstName" yaml:"first_name"`
}
```
would result in this schema:
```json
{
"$schema": "http://json-schema.org/draft-04/schema#",
"$ref": "#/definitions/TestYamlAndJson",
"definitions": {
"Person": {
"required": ["first_name"],
"properties": {
"first_name": {
"type": "string"
}
},
"additionalProperties": false,
"type": "object"
}
}
}
```
whereas without the flag one obtains:
```json
{
"$schema": "http://json-schema.org/draft-04/schema#",
"$ref": "#/definitions/TestYamlAndJson",
"definitions": {
"Person": {
"required": ["FirstName"],
"properties": {
"first_name": {
"type": "string"
}
},
"additionalProperties": false,
"type": "object"
}
}
}
```
### Custom Type Definitions
Sometimes it can be useful to have custom JSON Marshal and Unmarshal methods in your structs that automatically convert for example a string into an object.
To override auto-generating an object type for your type, implement the `JSONSchemaType() *Type` method and whatever is defined will be provided in the schema definitions.
Take the following simplified example of a `CompactDate` that only includes the Year and Month:
```go
type CompactDate struct {
Year int
Month int
}
func (d *CompactDate) UnmarshalJSON(data []byte) error {
if len(data) != 9 {
return errors.New("invalid compact date length")
}
var err error
d.Year, err = strconv.Atoi(string(data[1:5]))
if err != nil {
return err
}
d.Month, err = strconv.Atoi(string(data[7:8]))
if err != nil {
return err
}
return nil
}
func (d *CompactDate) MarshalJSON() ([]byte, error) {
buf := new(bytes.Buffer)
buf.WriteByte('"')
buf.WriteString(fmt.Sprintf("%d-%02d", d.Year, d.Month))
buf.WriteByte('"')
return buf.Bytes(), nil
}
func (CompactDate) JSONSchemaType() *Type {
return &Type{
Type: "string",
Title: "Compact Date",
Description: "Short date that only includes year and month",
Pattern: "^[0-9]{4}-[0-1][0-9]$",
}
}
```
The resulting schema generated for this struct would look like:
```json
{
"$schema": "http://json-schema.org/draft-04/schema#",
"$ref": "#/definitions/CompactDate",
"definitions": {
"CompactDate": {
"pattern": "^[0-9]{4}-[0-1][0-9]$",
"type": "string",
"title": "Compact Date",
"description": "Short date that only includes year and month"
}
}
}
```

View File

@@ -1,849 +0,0 @@
// Package jsonschema uses reflection to generate JSON Schemas from Go types [1].
//
// If json tags are present on struct fields, they will be used to infer
// property names and if a property is required (omitempty is present).
//
// [1] http://json-schema.org/latest/json-schema-validation.html
package jsonschema
import (
"encoding/json"
"net"
"net/url"
"reflect"
"strconv"
"strings"
"time"
"github.com/iancoleman/orderedmap"
)
// Version is the JSON Schema version.
// If extending JSON Schema with custom values use a custom URI.
// RFC draft-wright-json-schema-00, section 6
var Version = "http://json-schema.org/draft-04/schema#"
// Schema is the root schema.
// RFC draft-wright-json-schema-00, section 4.5
type Schema struct {
*Type
Definitions Definitions
}
// customSchemaType is used to detect if the type provides it's own
// custom Schema Type definition to use instead. Very useful for situations
// where there are custom JSON Marshal and Unmarshal methods.
type customSchemaType interface {
JSONSchemaType() *Type
}
var customType = reflect.TypeOf((*customSchemaType)(nil)).Elem()
// customSchemaGetFieldDocString
type customSchemaGetFieldDocString interface {
GetFieldDocString(fieldName string) string
}
type customGetFieldDocString func(fieldName string) string
var customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)).Elem()
// Type represents a JSON Schema object type.
type Type struct {
// RFC draft-wright-json-schema-00
Version string `json:"$schema,omitempty"` // section 6.1
Ref string `json:"$ref,omitempty"` // section 7
// RFC draft-wright-json-schema-validation-00, section 5
MultipleOf int `json:"multipleOf,omitempty"` // section 5.1
Maximum int `json:"maximum,omitempty"` // section 5.2
ExclusiveMaximum bool `json:"exclusiveMaximum,omitempty"` // section 5.3
Minimum int `json:"minimum,omitempty"` // section 5.4
ExclusiveMinimum bool `json:"exclusiveMinimum,omitempty"` // section 5.5
MaxLength int `json:"maxLength,omitempty"` // section 5.6
MinLength int `json:"minLength,omitempty"` // section 5.7
Pattern string `json:"pattern,omitempty"` // section 5.8
AdditionalItems *Type `json:"additionalItems,omitempty"` // section 5.9
Items *Type `json:"items,omitempty"` // section 5.9
MaxItems int `json:"maxItems,omitempty"` // section 5.10
MinItems int `json:"minItems,omitempty"` // section 5.11
UniqueItems bool `json:"uniqueItems,omitempty"` // section 5.12
MaxProperties int `json:"maxProperties,omitempty"` // section 5.13
MinProperties int `json:"minProperties,omitempty"` // section 5.14
Required []string `json:"required,omitempty"` // section 5.15
Properties *orderedmap.OrderedMap `json:"properties,omitempty"` // section 5.16
PatternProperties map[string]*Type `json:"patternProperties,omitempty"` // section 5.17
AdditionalProperties json.RawMessage `json:"additionalProperties,omitempty"` // section 5.18
Dependencies map[string]*Type `json:"dependencies,omitempty"` // section 5.19
Enum []interface{} `json:"enum,omitempty"` // section 5.20
Type string `json:"type,omitempty"` // section 5.21
AllOf []*Type `json:"allOf,omitempty"` // section 5.22
AnyOf []*Type `json:"anyOf,omitempty"` // section 5.23
OneOf []*Type `json:"oneOf,omitempty"` // section 5.24
Not *Type `json:"not,omitempty"` // section 5.25
Definitions Definitions `json:"definitions,omitempty"` // section 5.26
// RFC draft-wright-json-schema-validation-00, section 6, 7
Title string `json:"title,omitempty"` // section 6.1
Description string `json:"description,omitempty"` // section 6.1
Default interface{} `json:"default,omitempty"` // section 6.2
Format string `json:"format,omitempty"` // section 7
Examples []interface{} `json:"examples,omitempty"` // section 7.4
// RFC draft-wright-json-schema-hyperschema-00, section 4
Media *Type `json:"media,omitempty"` // section 4.3
BinaryEncoding string `json:"binaryEncoding,omitempty"` // section 4.3
Extras map[string]interface{} `json:"-"`
}
// Reflect reflects to Schema from a value using the default Reflector
func Reflect(v interface{}) *Schema {
return ReflectFromType(reflect.TypeOf(v))
}
// ReflectFromType generates root schema using the default Reflector
func ReflectFromType(t reflect.Type) *Schema {
r := &Reflector{}
return r.ReflectFromType(t)
}
// A Reflector reflects values into a Schema.
type Reflector struct {
// AllowAdditionalProperties will cause the Reflector to generate a schema
// with additionalProperties to 'true' for all struct types. This means
// the presence of additional keys in JSON objects will not cause validation
// to fail. Note said additional keys will simply be dropped when the
// validated JSON is unmarshaled.
AllowAdditionalProperties bool
// RequiredFromJSONSchemaTags will cause the Reflector to generate a schema
// that requires any key tagged with `jsonschema:required`, overriding the
// default of requiring any key *not* tagged with `json:,omitempty`.
RequiredFromJSONSchemaTags bool
// YAMLEmbeddedStructs will cause the Reflector to generate a schema that does
// not inline embedded structs. This should be enabled if the JSON schemas are
// used with yaml.Marshal/Unmarshal.
YAMLEmbeddedStructs bool
// Prefer yaml: tags over json: tags to generate the schema even if json: tags
// are present
PreferYAMLSchema bool
// ExpandedStruct will cause the toplevel definitions of the schema not
// be referenced itself to a definition.
ExpandedStruct bool
// Do not reference definitions.
// All types are still registered under the "definitions" top-level object,
// but instead of $ref fields in containing types, the entire definition
// of the contained type is inserted.
// This will cause the entire structure of types to be output in one tree.
DoNotReference bool
// Use package paths as well as type names, to avoid conflicts.
// Without this setting, if two packages contain a type with the same name,
// and both are present in a schema, they will conflict and overwrite in
// the definition map and produce bad output. This is particularly
// noticeable when using DoNotReference.
FullyQualifyTypeNames bool
// IgnoredTypes defines a slice of types that should be ignored in the schema,
// switching to just allowing additional properties instead.
IgnoredTypes []interface{}
// TypeMapper is a function that can be used to map custom Go types to jsonschema types.
TypeMapper func(reflect.Type) *Type
// TypeNamer allows customizing of type names
TypeNamer func(reflect.Type) string
// AdditionalFields allows adding structfields for a given type
AdditionalFields func(reflect.Type) []reflect.StructField
}
// Reflect reflects to Schema from a value.
func (r *Reflector) Reflect(v interface{}) *Schema {
return r.ReflectFromType(reflect.TypeOf(v))
}
// ReflectFromType generates root schema
func (r *Reflector) ReflectFromType(t reflect.Type) *Schema {
definitions := Definitions{}
if r.ExpandedStruct {
st := &Type{
Version: Version,
Type: "object",
Properties: orderedmap.New(),
AdditionalProperties: []byte("false"),
}
if r.AllowAdditionalProperties {
st.AdditionalProperties = []byte("true")
}
r.reflectStructFields(st, definitions, t)
r.reflectStruct(definitions, t)
delete(definitions, r.typeName(t))
return &Schema{Type: st, Definitions: definitions}
}
s := &Schema{
Type: r.reflectTypeToSchema(definitions, t),
Definitions: definitions,
}
return s
}
// Definitions hold schema definitions.
// http://json-schema.org/latest/json-schema-validation.html#rfc.section.5.26
// RFC draft-wright-json-schema-validation-00, section 5.26
type Definitions map[string]*Type
// Available Go defined types for JSON Schema Validation.
// RFC draft-wright-json-schema-validation-00, section 7.3
var (
timeType = reflect.TypeOf(time.Time{}) // date-time RFC section 7.3.1
ipType = reflect.TypeOf(net.IP{}) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5
uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6
)
// Byte slices will be encoded as base64
var byteSliceType = reflect.TypeOf([]byte(nil))
// Except for json.RawMessage
var rawMessageType = reflect.TypeOf(json.RawMessage{})
// Go code generated from protobuf enum types should fulfil this interface.
type protoEnum interface {
EnumDescriptor() ([]byte, []int)
}
var protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem()
func (r *Reflector) reflectTypeToSchema(definitions Definitions, t reflect.Type) *Type {
// Already added to definitions?
if _, ok := definitions[r.typeName(t)]; ok && !r.DoNotReference {
return &Type{Ref: "#/definitions/" + r.typeName(t)}
}
if r.TypeMapper != nil {
if t := r.TypeMapper(t); t != nil {
return t
}
}
if rt := r.reflectCustomType(definitions, t); rt != nil {
return rt
}
// jsonpb will marshal protobuf enum options as either strings or integers.
// It will unmarshal either.
if t.Implements(protoEnumType) {
return &Type{OneOf: []*Type{
{Type: "string"},
{Type: "integer"},
}}
}
// Defined format types for JSON Schema Validation
// RFC draft-wright-json-schema-validation-00, section 7.3
// TODO email RFC section 7.3.2, hostname RFC section 7.3.3, uriref RFC section 7.3.7
if t == ipType {
// TODO differentiate ipv4 and ipv6 RFC section 7.3.4, 7.3.5
return &Type{Type: "string", Format: "ipv4"} // ipv4 RFC section 7.3.4
}
switch t.Kind() {
case reflect.Struct:
switch t {
case timeType: // date-time RFC section 7.3.1
return &Type{Type: "string", Format: "date-time"}
case uriType: // uri RFC section 7.3.6
return &Type{Type: "string", Format: "uri"}
default:
return r.reflectStruct(definitions, t)
}
case reflect.Map:
switch t.Key().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
rt := &Type{
Type: "object",
PatternProperties: map[string]*Type{
"^[0-9]+$": r.reflectTypeToSchema(definitions, t.Elem()),
},
AdditionalProperties: []byte("false"),
}
return rt
}
rt := &Type{
Type: "object",
PatternProperties: map[string]*Type{
".*": r.reflectTypeToSchema(definitions, t.Elem()),
},
}
delete(rt.PatternProperties, "additionalProperties")
return rt
case reflect.Slice, reflect.Array:
returnType := &Type{}
if t == rawMessageType {
return &Type{
AdditionalProperties: []byte("true"),
}
}
if t.Kind() == reflect.Array {
returnType.MinItems = t.Len()
returnType.MaxItems = returnType.MinItems
}
if t.Kind() == reflect.Slice && t.Elem() == byteSliceType.Elem() {
returnType.Type = "string"
returnType.Media = &Type{BinaryEncoding: "base64"}
return returnType
}
returnType.Type = "array"
returnType.Items = r.reflectTypeToSchema(definitions, t.Elem())
return returnType
case reflect.Interface:
return &Type{
AdditionalProperties: []byte("true"),
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return &Type{Type: "integer"}
case reflect.Float32, reflect.Float64:
return &Type{Type: "number"}
case reflect.Bool:
return &Type{Type: "boolean"}
case reflect.String:
return &Type{Type: "string"}
case reflect.Ptr:
return r.reflectTypeToSchema(definitions, t.Elem())
}
panic("unsupported type " + t.String())
}
func (r *Reflector) reflectCustomType(definitions Definitions, t reflect.Type) *Type {
if t.Kind() == reflect.Ptr {
return r.reflectCustomType(definitions, t.Elem())
}
if t.Implements(customType) {
v := reflect.New(t)
o := v.Interface().(customSchemaType)
st := o.JSONSchemaType()
definitions[r.typeName(t)] = st
if r.DoNotReference {
return st
} else {
return &Type{
Version: Version,
Ref: "#/definitions/" + r.typeName(t),
}
}
}
return nil
}
// Reflects a struct to a JSON Schema type.
func (r *Reflector) reflectStruct(definitions Definitions, t reflect.Type) *Type {
if st := r.reflectCustomType(definitions, t); st != nil {
return st
}
for _, ignored := range r.IgnoredTypes {
if reflect.TypeOf(ignored) == t {
st := &Type{
Type: "object",
Properties: orderedmap.New(),
AdditionalProperties: []byte("true"),
}
definitions[r.typeName(t)] = st
if r.DoNotReference {
return st
} else {
return &Type{
Version: Version,
Ref: "#/definitions/" + r.typeName(t),
}
}
}
}
st := &Type{
Type: "object",
Properties: orderedmap.New(),
AdditionalProperties: []byte("false"),
}
if r.AllowAdditionalProperties {
st.AdditionalProperties = []byte("true")
}
definitions[r.typeName(t)] = st
r.reflectStructFields(st, definitions, t)
if r.DoNotReference {
return st
} else {
return &Type{
Version: Version,
Ref: "#/definitions/" + r.typeName(t),
}
}
}
func (r *Reflector) reflectStructFields(st *Type, definitions Definitions, t reflect.Type) {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return
}
var getFieldDocString customGetFieldDocString
if t.Implements(customStructGetFieldDocString) {
v := reflect.New(t)
o := v.Interface().(customSchemaGetFieldDocString)
getFieldDocString = o.GetFieldDocString
}
handleField := func(f reflect.StructField) {
name, shouldEmbed, required, nullable := r.reflectFieldName(f)
// if anonymous and exported type should be processed recursively
// current type should inherit properties of anonymous one
if name == "" {
if shouldEmbed {
r.reflectStructFields(st, definitions, f.Type)
}
return
}
property := r.reflectTypeToSchema(definitions, f.Type)
property.structKeywordsFromTags(f, st, name)
if getFieldDocString != nil {
property.Description = getFieldDocString(f.Name)
}
if nullable {
property = &Type{
OneOf: []*Type{
property,
{
Type: "null",
},
},
}
}
st.Properties.Set(name, property)
if required {
st.Required = append(st.Required, name)
}
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
handleField(f)
}
if r.AdditionalFields != nil {
if af := r.AdditionalFields(t); af != nil {
for _, sf := range af {
handleField(sf)
}
}
}
}
func (t *Type) structKeywordsFromTags(f reflect.StructField, parentType *Type, propertyName string) {
t.Description = f.Tag.Get("jsonschema_description")
tags := strings.Split(f.Tag.Get("jsonschema"), ",")
t.genericKeywords(tags, parentType, propertyName)
switch t.Type {
case "string":
t.stringKeywords(tags)
case "number":
t.numbericKeywords(tags)
case "integer":
t.numbericKeywords(tags)
case "array":
t.arrayKeywords(tags)
}
extras := strings.Split(f.Tag.Get("jsonschema_extras"), ",")
t.extraKeywords(extras)
}
// read struct tags for generic keyworks
func (t *Type) genericKeywords(tags []string, parentType *Type, propertyName string) {
for _, tag := range tags {
nameValue := strings.Split(tag, "=")
if len(nameValue) == 2 {
name, val := nameValue[0], nameValue[1]
switch name {
case "title":
t.Title = val
case "description":
t.Description = val
case "type":
t.Type = val
case "oneof_required":
var typeFound *Type
for i := range parentType.OneOf {
if parentType.OneOf[i].Title == nameValue[1] {
typeFound = parentType.OneOf[i]
}
}
if typeFound == nil {
typeFound = &Type{
Title: nameValue[1],
Required: []string{},
}
parentType.OneOf = append(parentType.OneOf, typeFound)
}
typeFound.Required = append(typeFound.Required, propertyName)
case "oneof_type":
if t.OneOf == nil {
t.OneOf = make([]*Type, 0, 1)
}
t.Type = ""
types := strings.Split(nameValue[1], ";")
for _, ty := range types {
t.OneOf = append(t.OneOf, &Type{
Type: ty,
})
}
case "enum":
switch t.Type {
case "string":
t.Enum = append(t.Enum, val)
case "integer":
i, _ := strconv.Atoi(val)
t.Enum = append(t.Enum, i)
case "number":
f, _ := strconv.ParseFloat(val, 64)
t.Enum = append(t.Enum, f)
}
}
}
}
}
// read struct tags for string type keyworks
func (t *Type) stringKeywords(tags []string) {
for _, tag := range tags {
nameValue := strings.Split(tag, "=")
if len(nameValue) == 2 {
name, val := nameValue[0], nameValue[1]
switch name {
case "minLength":
i, _ := strconv.Atoi(val)
t.MinLength = i
case "maxLength":
i, _ := strconv.Atoi(val)
t.MaxLength = i
case "pattern":
t.Pattern = val
case "format":
switch val {
case "date-time", "email", "hostname", "ipv4", "ipv6", "uri":
t.Format = val
break
}
case "default":
t.Default = val
case "example":
t.Examples = append(t.Examples, val)
}
}
}
}
// read struct tags for numberic type keyworks
func (t *Type) numbericKeywords(tags []string) {
for _, tag := range tags {
nameValue := strings.Split(tag, "=")
if len(nameValue) == 2 {
name, val := nameValue[0], nameValue[1]
switch name {
case "multipleOf":
i, _ := strconv.Atoi(val)
t.MultipleOf = i
case "minimum":
i, _ := strconv.Atoi(val)
t.Minimum = i
case "maximum":
i, _ := strconv.Atoi(val)
t.Maximum = i
case "exclusiveMaximum":
b, _ := strconv.ParseBool(val)
t.ExclusiveMaximum = b
case "exclusiveMinimum":
b, _ := strconv.ParseBool(val)
t.ExclusiveMinimum = b
case "default":
i, _ := strconv.Atoi(val)
t.Default = i
case "example":
if i, err := strconv.Atoi(val); err == nil {
t.Examples = append(t.Examples, i)
}
}
}
}
}
// read struct tags for object type keyworks
// func (t *Type) objectKeywords(tags []string) {
// for _, tag := range tags{
// nameValue := strings.Split(tag, "=")
// name, val := nameValue[0], nameValue[1]
// switch name{
// case "dependencies":
// t.Dependencies = val
// break;
// case "patternProperties":
// t.PatternProperties = val
// break;
// }
// }
// }
// read struct tags for array type keyworks
func (t *Type) arrayKeywords(tags []string) {
var defaultValues []interface{}
for _, tag := range tags {
nameValue := strings.Split(tag, "=")
if len(nameValue) == 2 {
name, val := nameValue[0], nameValue[1]
switch name {
case "minItems":
i, _ := strconv.Atoi(val)
t.MinItems = i
case "maxItems":
i, _ := strconv.Atoi(val)
t.MaxItems = i
case "uniqueItems":
t.UniqueItems = true
case "default":
defaultValues = append(defaultValues, val)
case "enum":
switch t.Items.Type {
case "string":
t.Items.Enum = append(t.Items.Enum, val)
case "integer":
i, _ := strconv.Atoi(val)
t.Items.Enum = append(t.Items.Enum, i)
case "number":
f, _ := strconv.ParseFloat(val, 64)
t.Items.Enum = append(t.Items.Enum, f)
}
}
}
}
if len(defaultValues) > 0 {
t.Default = defaultValues
}
}
func (t *Type) extraKeywords(tags []string) {
for _, tag := range tags {
nameValue := strings.Split(tag, "=")
if len(nameValue) == 2 {
t.setExtra(nameValue[0], nameValue[1])
}
}
}
func (t *Type) setExtra(key, val string) {
if t.Extras == nil {
t.Extras = map[string]interface{}{}
}
if existingVal, ok := t.Extras[key]; ok {
switch existingVal := existingVal.(type) {
case string:
t.Extras[key] = []string{existingVal, val}
case []string:
t.Extras[key] = append(existingVal, val)
case int:
t.Extras[key], _ = strconv.Atoi(val)
}
} else {
switch key {
case "minimum":
t.Extras[key], _ = strconv.Atoi(val)
default:
t.Extras[key] = val
}
}
}
func requiredFromJSONTags(tags []string) bool {
if ignoredByJSONTags(tags) {
return false
}
for _, tag := range tags[1:] {
if tag == "omitempty" {
return false
}
}
return true
}
func requiredFromJSONSchemaTags(tags []string) bool {
if ignoredByJSONSchemaTags(tags) {
return false
}
for _, tag := range tags {
if tag == "required" {
return true
}
}
return false
}
func nullableFromJSONSchemaTags(tags []string) bool {
if ignoredByJSONSchemaTags(tags) {
return false
}
for _, tag := range tags {
if tag == "nullable" {
return true
}
}
return false
}
func inlineYAMLTags(tags []string) bool {
for _, tag := range tags {
if tag == "inline" {
return true
}
}
return false
}
func ignoredByJSONTags(tags []string) bool {
return tags[0] == "-"
}
func ignoredByJSONSchemaTags(tags []string) bool {
return tags[0] == "-"
}
func (r *Reflector) reflectFieldName(f reflect.StructField) (string, bool, bool, bool) {
jsonTags, exist := f.Tag.Lookup("json")
yamlTags, yamlExist := f.Tag.Lookup("yaml")
if !exist || r.PreferYAMLSchema {
jsonTags = yamlTags
exist = yamlExist
}
jsonTagsList := strings.Split(jsonTags, ",")
yamlTagsList := strings.Split(yamlTags, ",")
if ignoredByJSONTags(jsonTagsList) {
return "", false, false, false
}
jsonSchemaTags := strings.Split(f.Tag.Get("jsonschema"), ",")
if ignoredByJSONSchemaTags(jsonSchemaTags) {
return "", false, false, false
}
name := f.Name
required := requiredFromJSONTags(jsonTagsList)
if r.RequiredFromJSONSchemaTags {
required = requiredFromJSONSchemaTags(jsonSchemaTags)
}
nullable := nullableFromJSONSchemaTags(jsonSchemaTags)
if jsonTagsList[0] != "" {
name = jsonTagsList[0]
}
// field not anonymous and not export has no export name
if !f.Anonymous && f.PkgPath != "" {
name = ""
}
embed := false
// field anonymous but without json tag should be inherited by current type
if f.Anonymous && !exist {
if !r.YAMLEmbeddedStructs {
name = ""
embed = true
} else {
name = strings.ToLower(name)
}
}
if yamlExist && inlineYAMLTags(yamlTagsList) {
name = ""
embed = true
}
return name, embed, required, nullable
}
func (s *Schema) MarshalJSON() ([]byte, error) {
b, err := json.Marshal(s.Type)
if err != nil {
return nil, err
}
if s.Definitions == nil || len(s.Definitions) == 0 {
return b, nil
}
d, err := json.Marshal(struct {
Definitions Definitions `json:"definitions,omitempty"`
}{s.Definitions})
if err != nil {
return nil, err
}
if len(b) == 2 {
return d, nil
} else {
b[len(b)-1] = ','
return append(b, d[1:]...), nil
}
}
func (t *Type) MarshalJSON() ([]byte, error) {
type Type_ Type
b, err := json.Marshal((*Type_)(t))
if err != nil {
return nil, err
}
if t.Extras == nil || len(t.Extras) == 0 {
return b, nil
}
m, err := json.Marshal(t.Extras)
if err != nil {
return nil, err
}
if len(b) == 2 {
return m, nil
} else {
b[len(b)-1] = ','
return append(b, m[1:]...), nil
}
}
func (r *Reflector) typeName(t reflect.Type) string {
if r.TypeNamer != nil {
if name := r.TypeNamer(t); name != "" {
return name
}
}
if r.FullyQualifyTypeNames {
return t.PkgPath() + "." + t.Name()
}
return t.Name()
}

View File

@@ -1,27 +0,0 @@
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,25 +0,0 @@
# Go's `text/template` package with newline elision
This is a fork of Go 1.4's [text/template](http://golang.org/pkg/text/template/) package with one addition: a backslash immediately after a closing delimiter will delete all subsequent newlines until a non-newline.
eg.
```
{{if true}}\
hello
{{end}}\
```
Will result in:
```
hello\n
```
Rather than:
```
\n
hello\n
\n
```

View File

@@ -1,406 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package template implements data-driven templates for generating textual output.
To generate HTML output, see package html/template, which has the same interface
as this package but automatically secures HTML output against certain attacks.
Templates are executed by applying them to a data structure. Annotations in the
template refer to elements of the data structure (typically a field of a struct
or a key in a map) to control execution and derive values to be displayed.
Execution of the template walks the structure and sets the cursor, represented
by a period '.' and called "dot", to the value at the current location in the
structure as execution proceeds.
The input text for a template is UTF-8-encoded text in any format.
"Actions"--data evaluations or control structures--are delimited by
"{{" and "}}"; all text outside actions is copied to the output unchanged.
Actions may not span newlines, although comments can.
Once parsed, a template may be executed safely in parallel.
Here is a trivial example that prints "17 items are made of wool".
type Inventory struct {
Material string
Count uint
}
sweaters := Inventory{"wool", 17}
tmpl, err := template.New("test").Parse("{{.Count}} items are made of {{.Material}}")
if err != nil { panic(err) }
err = tmpl.Execute(os.Stdout, sweaters)
if err != nil { panic(err) }
More intricate examples appear below.
Actions
Here is the list of actions. "Arguments" and "pipelines" are evaluations of
data, defined in detail below.
*/
// {{/* a comment */}}
// A comment; discarded. May contain newlines.
// Comments do not nest and must start and end at the
// delimiters, as shown here.
/*
{{pipeline}}
The default textual representation of the value of the pipeline
is copied to the output.
{{if pipeline}} T1 {{end}}
If the value of the pipeline is empty, no output is generated;
otherwise, T1 is executed. The empty values are false, 0, any
nil pointer or interface value, and any array, slice, map, or
string of length zero.
Dot is unaffected.
{{if pipeline}} T1 {{else}} T0 {{end}}
If the value of the pipeline is empty, T0 is executed;
otherwise, T1 is executed. Dot is unaffected.
{{if pipeline}} T1 {{else if pipeline}} T0 {{end}}
To simplify the appearance of if-else chains, the else action
of an if may include another if directly; the effect is exactly
the same as writing
{{if pipeline}} T1 {{else}}{{if pipeline}} T0 {{end}}{{end}}
{{range pipeline}} T1 {{end}}
The value of the pipeline must be an array, slice, map, or channel.
If the value of the pipeline has length zero, nothing is output;
otherwise, dot is set to the successive elements of the array,
slice, or map and T1 is executed. If the value is a map and the
keys are of basic type with a defined order ("comparable"), the
elements will be visited in sorted key order.
{{range pipeline}} T1 {{else}} T0 {{end}}
The value of the pipeline must be an array, slice, map, or channel.
If the value of the pipeline has length zero, dot is unaffected and
T0 is executed; otherwise, dot is set to the successive elements
of the array, slice, or map and T1 is executed.
{{template "name"}}
The template with the specified name is executed with nil data.
{{template "name" pipeline}}
The template with the specified name is executed with dot set
to the value of the pipeline.
{{with pipeline}} T1 {{end}}
If the value of the pipeline is empty, no output is generated;
otherwise, dot is set to the value of the pipeline and T1 is
executed.
{{with pipeline}} T1 {{else}} T0 {{end}}
If the value of the pipeline is empty, dot is unaffected and T0
is executed; otherwise, dot is set to the value of the pipeline
and T1 is executed.
Arguments
An argument is a simple value, denoted by one of the following.
- A boolean, string, character, integer, floating-point, imaginary
or complex constant in Go syntax. These behave like Go's untyped
constants, although raw strings may not span newlines.
- The keyword nil, representing an untyped Go nil.
- The character '.' (period):
.
The result is the value of dot.
- A variable name, which is a (possibly empty) alphanumeric string
preceded by a dollar sign, such as
$piOver2
or
$
The result is the value of the variable.
Variables are described below.
- The name of a field of the data, which must be a struct, preceded
by a period, such as
.Field
The result is the value of the field. Field invocations may be
chained:
.Field1.Field2
Fields can also be evaluated on variables, including chaining:
$x.Field1.Field2
- The name of a key of the data, which must be a map, preceded
by a period, such as
.Key
The result is the map element value indexed by the key.
Key invocations may be chained and combined with fields to any
depth:
.Field1.Key1.Field2.Key2
Although the key must be an alphanumeric identifier, unlike with
field names they do not need to start with an upper case letter.
Keys can also be evaluated on variables, including chaining:
$x.key1.key2
- The name of a niladic method of the data, preceded by a period,
such as
.Method
The result is the value of invoking the method with dot as the
receiver, dot.Method(). Such a method must have one return value (of
any type) or two return values, the second of which is an error.
If it has two and the returned error is non-nil, execution terminates
and an error is returned to the caller as the value of Execute.
Method invocations may be chained and combined with fields and keys
to any depth:
.Field1.Key1.Method1.Field2.Key2.Method2
Methods can also be evaluated on variables, including chaining:
$x.Method1.Field
- The name of a niladic function, such as
fun
The result is the value of invoking the function, fun(). The return
types and values behave as in methods. Functions and function
names are described below.
- A parenthesized instance of one the above, for grouping. The result
may be accessed by a field or map key invocation.
print (.F1 arg1) (.F2 arg2)
(.StructValuedMethod "arg").Field
Arguments may evaluate to any type; if they are pointers the implementation
automatically indirects to the base type when required.
If an evaluation yields a function value, such as a function-valued
field of a struct, the function is not invoked automatically, but it
can be used as a truth value for an if action and the like. To invoke
it, use the call function, defined below.
A pipeline is a possibly chained sequence of "commands". A command is a simple
value (argument) or a function or method call, possibly with multiple arguments:
Argument
The result is the value of evaluating the argument.
.Method [Argument...]
The method can be alone or the last element of a chain but,
unlike methods in the middle of a chain, it can take arguments.
The result is the value of calling the method with the
arguments:
dot.Method(Argument1, etc.)
functionName [Argument...]
The result is the value of calling the function associated
with the name:
function(Argument1, etc.)
Functions and function names are described below.
Pipelines
A pipeline may be "chained" by separating a sequence of commands with pipeline
characters '|'. In a chained pipeline, the result of the each command is
passed as the last argument of the following command. The output of the final
command in the pipeline is the value of the pipeline.
The output of a command will be either one value or two values, the second of
which has type error. If that second value is present and evaluates to
non-nil, execution terminates and the error is returned to the caller of
Execute.
Variables
A pipeline inside an action may initialize a variable to capture the result.
The initialization has syntax
$variable := pipeline
where $variable is the name of the variable. An action that declares a
variable produces no output.
If a "range" action initializes a variable, the variable is set to the
successive elements of the iteration. Also, a "range" may declare two
variables, separated by a comma:
range $index, $element := pipeline
in which case $index and $element are set to the successive values of the
array/slice index or map key and element, respectively. Note that if there is
only one variable, it is assigned the element; this is opposite to the
convention in Go range clauses.
A variable's scope extends to the "end" action of the control structure ("if",
"with", or "range") in which it is declared, or to the end of the template if
there is no such control structure. A template invocation does not inherit
variables from the point of its invocation.
When execution begins, $ is set to the data argument passed to Execute, that is,
to the starting value of dot.
Examples
Here are some example one-line templates demonstrating pipelines and variables.
All produce the quoted word "output":
{{"\"output\""}}
A string constant.
{{`"output"`}}
A raw string constant.
{{printf "%q" "output"}}
A function call.
{{"output" | printf "%q"}}
A function call whose final argument comes from the previous
command.
{{printf "%q" (print "out" "put")}}
A parenthesized argument.
{{"put" | printf "%s%s" "out" | printf "%q"}}
A more elaborate call.
{{"output" | printf "%s" | printf "%q"}}
A longer chain.
{{with "output"}}{{printf "%q" .}}{{end}}
A with action using dot.
{{with $x := "output" | printf "%q"}}{{$x}}{{end}}
A with action that creates and uses a variable.
{{with $x := "output"}}{{printf "%q" $x}}{{end}}
A with action that uses the variable in another action.
{{with $x := "output"}}{{$x | printf "%q"}}{{end}}
The same, but pipelined.
Functions
During execution functions are found in two function maps: first in the
template, then in the global function map. By default, no functions are defined
in the template but the Funcs method can be used to add them.
Predefined global functions are named as follows.
and
Returns the boolean AND of its arguments by returning the
first empty argument or the last argument, that is,
"and x y" behaves as "if x then y else x". All the
arguments are evaluated.
call
Returns the result of calling the first argument, which
must be a function, with the remaining arguments as parameters.
Thus "call .X.Y 1 2" is, in Go notation, dot.X.Y(1, 2) where
Y is a func-valued field, map entry, or the like.
The first argument must be the result of an evaluation
that yields a value of function type (as distinct from
a predefined function such as print). The function must
return either one or two result values, the second of which
is of type error. If the arguments don't match the function
or the returned error value is non-nil, execution stops.
html
Returns the escaped HTML equivalent of the textual
representation of its arguments.
index
Returns the result of indexing its first argument by the
following arguments. Thus "index x 1 2 3" is, in Go syntax,
x[1][2][3]. Each indexed item must be a map, slice, or array.
js
Returns the escaped JavaScript equivalent of the textual
representation of its arguments.
len
Returns the integer length of its argument.
not
Returns the boolean negation of its single argument.
or
Returns the boolean OR of its arguments by returning the
first non-empty argument or the last argument, that is,
"or x y" behaves as "if x then x else y". All the
arguments are evaluated.
print
An alias for fmt.Sprint
printf
An alias for fmt.Sprintf
println
An alias for fmt.Sprintln
urlquery
Returns the escaped value of the textual representation of
its arguments in a form suitable for embedding in a URL query.
The boolean functions take any zero value to be false and a non-zero
value to be true.
There is also a set of binary comparison operators defined as
functions:
eq
Returns the boolean truth of arg1 == arg2
ne
Returns the boolean truth of arg1 != arg2
lt
Returns the boolean truth of arg1 < arg2
le
Returns the boolean truth of arg1 <= arg2
gt
Returns the boolean truth of arg1 > arg2
ge
Returns the boolean truth of arg1 >= arg2
For simpler multi-way equality tests, eq (only) accepts two or more
arguments and compares the second and subsequent to the first,
returning in effect
arg1==arg2 || arg1==arg3 || arg1==arg4 ...
(Unlike with || in Go, however, eq is a function call and all the
arguments will be evaluated.)
The comparison functions work on basic types only (or named basic
types, such as "type Celsius float32"). They implement the Go rules
for comparison of values, except that size and exact type are
ignored, so any integer value, signed or unsigned, may be compared
with any other integer value. (The arithmetic value is compared,
not the bit pattern, so all negative integers are less than all
unsigned integers.) However, as usual, one may not compare an int
with a float32 and so on.
Associated templates
Each template is named by a string specified when it is created. Also, each
template is associated with zero or more other templates that it may invoke by
name; such associations are transitive and form a name space of templates.
A template may use a template invocation to instantiate another associated
template; see the explanation of the "template" action above. The name must be
that of a template associated with the template that contains the invocation.
Nested template definitions
When parsing a template, another template may be defined and associated with the
template being parsed. Template definitions must appear at the top level of the
template, much like global variables in a Go program.
The syntax of such definitions is to surround each template declaration with a
"define" and "end" action.
The define action names the template being created by providing a string
constant. Here is a simple example:
`{{define "T1"}}ONE{{end}}
{{define "T2"}}TWO{{end}}
{{define "T3"}}{{template "T1"}} {{template "T2"}}{{end}}
{{template "T3"}}`
This defines two templates, T1 and T2, and a third T3 that invokes the other two
when it is executed. Finally it invokes T3. If executed this template will
produce the text
ONE TWO
By construction, a template may reside in only one association. If it's
necessary to have a template addressable from multiple associations, the
template definition must be parsed multiple times to create distinct *Template
values, or must be copied with the Clone or AddParseTree method.
Parse may be called multiple times to assemble the various associated templates;
see the ParseFiles and ParseGlob functions and methods for simple ways to parse
related templates stored in files.
A template may be executed directly or through ExecuteTemplate, which executes
an associated template identified by name. To invoke our example above, we
might write,
err := tmpl.Execute(os.Stdout, "no data needed")
if err != nil {
log.Fatalf("execution failed: %s", err)
}
or to invoke a particular template explicitly by name,
err := tmpl.ExecuteTemplate(os.Stdout, "T2", "no data needed")
if err != nil {
log.Fatalf("execution failed: %s", err)
}
*/
package template

View File

@@ -1,845 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"fmt"
"io"
"reflect"
"runtime"
"sort"
"strings"
"github.com/alecthomas/template/parse"
)
// state represents the state of an execution. It's not part of the
// template so that multiple executions of the same template
// can execute in parallel.
type state struct {
tmpl *Template
wr io.Writer
node parse.Node // current node, for errors
vars []variable // push-down stack of variable values.
}
// variable holds the dynamic value of a variable such as $, $x etc.
type variable struct {
name string
value reflect.Value
}
// push pushes a new variable on the stack.
func (s *state) push(name string, value reflect.Value) {
s.vars = append(s.vars, variable{name, value})
}
// mark returns the length of the variable stack.
func (s *state) mark() int {
return len(s.vars)
}
// pop pops the variable stack up to the mark.
func (s *state) pop(mark int) {
s.vars = s.vars[0:mark]
}
// setVar overwrites the top-nth variable on the stack. Used by range iterations.
func (s *state) setVar(n int, value reflect.Value) {
s.vars[len(s.vars)-n].value = value
}
// varValue returns the value of the named variable.
func (s *state) varValue(name string) reflect.Value {
for i := s.mark() - 1; i >= 0; i-- {
if s.vars[i].name == name {
return s.vars[i].value
}
}
s.errorf("undefined variable: %s", name)
return zero
}
var zero reflect.Value
// at marks the state to be on node n, for error reporting.
func (s *state) at(node parse.Node) {
s.node = node
}
// doublePercent returns the string with %'s replaced by %%, if necessary,
// so it can be used safely inside a Printf format string.
func doublePercent(str string) string {
if strings.Contains(str, "%") {
str = strings.Replace(str, "%", "%%", -1)
}
return str
}
// errorf formats the error and terminates processing.
func (s *state) errorf(format string, args ...interface{}) {
name := doublePercent(s.tmpl.Name())
if s.node == nil {
format = fmt.Sprintf("template: %s: %s", name, format)
} else {
location, context := s.tmpl.ErrorContext(s.node)
format = fmt.Sprintf("template: %s: executing %q at <%s>: %s", location, name, doublePercent(context), format)
}
panic(fmt.Errorf(format, args...))
}
// errRecover is the handler that turns panics into returns from the top
// level of Parse.
func errRecover(errp *error) {
e := recover()
if e != nil {
switch err := e.(type) {
case runtime.Error:
panic(e)
case error:
*errp = err
default:
panic(e)
}
}
}
// ExecuteTemplate applies the template associated with t that has the given name
// to the specified data object and writes the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel.
func (t *Template) ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
tmpl := t.tmpl[name]
if tmpl == nil {
return fmt.Errorf("template: no template %q associated with template %q", name, t.name)
}
return tmpl.Execute(wr, data)
}
// Execute applies a parsed template to the specified data object,
// and writes the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel.
func (t *Template) Execute(wr io.Writer, data interface{}) (err error) {
defer errRecover(&err)
value := reflect.ValueOf(data)
state := &state{
tmpl: t,
wr: wr,
vars: []variable{{"$", value}},
}
t.init()
if t.Tree == nil || t.Root == nil {
var b bytes.Buffer
for name, tmpl := range t.tmpl {
if tmpl.Tree == nil || tmpl.Root == nil {
continue
}
if b.Len() > 0 {
b.WriteString(", ")
}
fmt.Fprintf(&b, "%q", name)
}
var s string
if b.Len() > 0 {
s = "; defined templates are: " + b.String()
}
state.errorf("%q is an incomplete or empty template%s", t.Name(), s)
}
state.walk(value, t.Root)
return
}
// Walk functions step through the major pieces of the template structure,
// generating output as they go.
func (s *state) walk(dot reflect.Value, node parse.Node) {
s.at(node)
switch node := node.(type) {
case *parse.ActionNode:
// Do not pop variables so they persist until next end.
// Also, if the action declares variables, don't print the result.
val := s.evalPipeline(dot, node.Pipe)
if len(node.Pipe.Decl) == 0 {
s.printValue(node, val)
}
case *parse.IfNode:
s.walkIfOrWith(parse.NodeIf, dot, node.Pipe, node.List, node.ElseList)
case *parse.ListNode:
for _, node := range node.Nodes {
s.walk(dot, node)
}
case *parse.RangeNode:
s.walkRange(dot, node)
case *parse.TemplateNode:
s.walkTemplate(dot, node)
case *parse.TextNode:
if _, err := s.wr.Write(node.Text); err != nil {
s.errorf("%s", err)
}
case *parse.WithNode:
s.walkIfOrWith(parse.NodeWith, dot, node.Pipe, node.List, node.ElseList)
default:
s.errorf("unknown node: %s", node)
}
}
// walkIfOrWith walks an 'if' or 'with' node. The two control structures
// are identical in behavior except that 'with' sets dot.
func (s *state) walkIfOrWith(typ parse.NodeType, dot reflect.Value, pipe *parse.PipeNode, list, elseList *parse.ListNode) {
defer s.pop(s.mark())
val := s.evalPipeline(dot, pipe)
truth, ok := isTrue(val)
if !ok {
s.errorf("if/with can't use %v", val)
}
if truth {
if typ == parse.NodeWith {
s.walk(val, list)
} else {
s.walk(dot, list)
}
} else if elseList != nil {
s.walk(dot, elseList)
}
}
// isTrue reports whether the value is 'true', in the sense of not the zero of its type,
// and whether the value has a meaningful truth value.
func isTrue(val reflect.Value) (truth, ok bool) {
if !val.IsValid() {
// Something like var x interface{}, never set. It's a form of nil.
return false, true
}
switch val.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
truth = val.Len() > 0
case reflect.Bool:
truth = val.Bool()
case reflect.Complex64, reflect.Complex128:
truth = val.Complex() != 0
case reflect.Chan, reflect.Func, reflect.Ptr, reflect.Interface:
truth = !val.IsNil()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
truth = val.Int() != 0
case reflect.Float32, reflect.Float64:
truth = val.Float() != 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
truth = val.Uint() != 0
case reflect.Struct:
truth = true // Struct values are always true.
default:
return
}
return truth, true
}
func (s *state) walkRange(dot reflect.Value, r *parse.RangeNode) {
s.at(r)
defer s.pop(s.mark())
val, _ := indirect(s.evalPipeline(dot, r.Pipe))
// mark top of stack before any variables in the body are pushed.
mark := s.mark()
oneIteration := func(index, elem reflect.Value) {
// Set top var (lexically the second if there are two) to the element.
if len(r.Pipe.Decl) > 0 {
s.setVar(1, elem)
}
// Set next var (lexically the first if there are two) to the index.
if len(r.Pipe.Decl) > 1 {
s.setVar(2, index)
}
s.walk(elem, r.List)
s.pop(mark)
}
switch val.Kind() {
case reflect.Array, reflect.Slice:
if val.Len() == 0 {
break
}
for i := 0; i < val.Len(); i++ {
oneIteration(reflect.ValueOf(i), val.Index(i))
}
return
case reflect.Map:
if val.Len() == 0 {
break
}
for _, key := range sortKeys(val.MapKeys()) {
oneIteration(key, val.MapIndex(key))
}
return
case reflect.Chan:
if val.IsNil() {
break
}
i := 0
for ; ; i++ {
elem, ok := val.Recv()
if !ok {
break
}
oneIteration(reflect.ValueOf(i), elem)
}
if i == 0 {
break
}
return
case reflect.Invalid:
break // An invalid value is likely a nil map, etc. and acts like an empty map.
default:
s.errorf("range can't iterate over %v", val)
}
if r.ElseList != nil {
s.walk(dot, r.ElseList)
}
}
func (s *state) walkTemplate(dot reflect.Value, t *parse.TemplateNode) {
s.at(t)
tmpl := s.tmpl.tmpl[t.Name]
if tmpl == nil {
s.errorf("template %q not defined", t.Name)
}
// Variables declared by the pipeline persist.
dot = s.evalPipeline(dot, t.Pipe)
newState := *s
newState.tmpl = tmpl
// No dynamic scoping: template invocations inherit no variables.
newState.vars = []variable{{"$", dot}}
newState.walk(dot, tmpl.Root)
}
// Eval functions evaluate pipelines, commands, and their elements and extract
// values from the data structure by examining fields, calling methods, and so on.
// The printing of those values happens only through walk functions.
// evalPipeline returns the value acquired by evaluating a pipeline. If the
// pipeline has a variable declaration, the variable will be pushed on the
// stack. Callers should therefore pop the stack after they are finished
// executing commands depending on the pipeline value.
func (s *state) evalPipeline(dot reflect.Value, pipe *parse.PipeNode) (value reflect.Value) {
if pipe == nil {
return
}
s.at(pipe)
for _, cmd := range pipe.Cmds {
value = s.evalCommand(dot, cmd, value) // previous value is this one's final arg.
// If the object has type interface{}, dig down one level to the thing inside.
if value.Kind() == reflect.Interface && value.Type().NumMethod() == 0 {
value = reflect.ValueOf(value.Interface()) // lovely!
}
}
for _, variable := range pipe.Decl {
s.push(variable.Ident[0], value)
}
return value
}
func (s *state) notAFunction(args []parse.Node, final reflect.Value) {
if len(args) > 1 || final.IsValid() {
s.errorf("can't give argument to non-function %s", args[0])
}
}
func (s *state) evalCommand(dot reflect.Value, cmd *parse.CommandNode, final reflect.Value) reflect.Value {
firstWord := cmd.Args[0]
switch n := firstWord.(type) {
case *parse.FieldNode:
return s.evalFieldNode(dot, n, cmd.Args, final)
case *parse.ChainNode:
return s.evalChainNode(dot, n, cmd.Args, final)
case *parse.IdentifierNode:
// Must be a function.
return s.evalFunction(dot, n, cmd, cmd.Args, final)
case *parse.PipeNode:
// Parenthesized pipeline. The arguments are all inside the pipeline; final is ignored.
return s.evalPipeline(dot, n)
case *parse.VariableNode:
return s.evalVariableNode(dot, n, cmd.Args, final)
}
s.at(firstWord)
s.notAFunction(cmd.Args, final)
switch word := firstWord.(type) {
case *parse.BoolNode:
return reflect.ValueOf(word.True)
case *parse.DotNode:
return dot
case *parse.NilNode:
s.errorf("nil is not a command")
case *parse.NumberNode:
return s.idealConstant(word)
case *parse.StringNode:
return reflect.ValueOf(word.Text)
}
s.errorf("can't evaluate command %q", firstWord)
panic("not reached")
}
// idealConstant is called to return the value of a number in a context where
// we don't know the type. In that case, the syntax of the number tells us
// its type, and we use Go rules to resolve. Note there is no such thing as
// a uint ideal constant in this situation - the value must be of int type.
func (s *state) idealConstant(constant *parse.NumberNode) reflect.Value {
// These are ideal constants but we don't know the type
// and we have no context. (If it was a method argument,
// we'd know what we need.) The syntax guides us to some extent.
s.at(constant)
switch {
case constant.IsComplex:
return reflect.ValueOf(constant.Complex128) // incontrovertible.
case constant.IsFloat && !isHexConstant(constant.Text) && strings.IndexAny(constant.Text, ".eE") >= 0:
return reflect.ValueOf(constant.Float64)
case constant.IsInt:
n := int(constant.Int64)
if int64(n) != constant.Int64 {
s.errorf("%s overflows int", constant.Text)
}
return reflect.ValueOf(n)
case constant.IsUint:
s.errorf("%s overflows int", constant.Text)
}
return zero
}
func isHexConstant(s string) bool {
return len(s) > 2 && s[0] == '0' && (s[1] == 'x' || s[1] == 'X')
}
func (s *state) evalFieldNode(dot reflect.Value, field *parse.FieldNode, args []parse.Node, final reflect.Value) reflect.Value {
s.at(field)
return s.evalFieldChain(dot, dot, field, field.Ident, args, final)
}
func (s *state) evalChainNode(dot reflect.Value, chain *parse.ChainNode, args []parse.Node, final reflect.Value) reflect.Value {
s.at(chain)
// (pipe).Field1.Field2 has pipe as .Node, fields as .Field. Eval the pipeline, then the fields.
pipe := s.evalArg(dot, nil, chain.Node)
if len(chain.Field) == 0 {
s.errorf("internal error: no fields in evalChainNode")
}
return s.evalFieldChain(dot, pipe, chain, chain.Field, args, final)
}
func (s *state) evalVariableNode(dot reflect.Value, variable *parse.VariableNode, args []parse.Node, final reflect.Value) reflect.Value {
// $x.Field has $x as the first ident, Field as the second. Eval the var, then the fields.
s.at(variable)
value := s.varValue(variable.Ident[0])
if len(variable.Ident) == 1 {
s.notAFunction(args, final)
return value
}
return s.evalFieldChain(dot, value, variable, variable.Ident[1:], args, final)
}
// evalFieldChain evaluates .X.Y.Z possibly followed by arguments.
// dot is the environment in which to evaluate arguments, while
// receiver is the value being walked along the chain.
func (s *state) evalFieldChain(dot, receiver reflect.Value, node parse.Node, ident []string, args []parse.Node, final reflect.Value) reflect.Value {
n := len(ident)
for i := 0; i < n-1; i++ {
receiver = s.evalField(dot, ident[i], node, nil, zero, receiver)
}
// Now if it's a method, it gets the arguments.
return s.evalField(dot, ident[n-1], node, args, final, receiver)
}
func (s *state) evalFunction(dot reflect.Value, node *parse.IdentifierNode, cmd parse.Node, args []parse.Node, final reflect.Value) reflect.Value {
s.at(node)
name := node.Ident
function, ok := findFunction(name, s.tmpl)
if !ok {
s.errorf("%q is not a defined function", name)
}
return s.evalCall(dot, function, cmd, name, args, final)
}
// evalField evaluates an expression like (.Field) or (.Field arg1 arg2).
// The 'final' argument represents the return value from the preceding
// value of the pipeline, if any.
func (s *state) evalField(dot reflect.Value, fieldName string, node parse.Node, args []parse.Node, final, receiver reflect.Value) reflect.Value {
if !receiver.IsValid() {
return zero
}
typ := receiver.Type()
receiver, _ = indirect(receiver)
// Unless it's an interface, need to get to a value of type *T to guarantee
// we see all methods of T and *T.
ptr := receiver
if ptr.Kind() != reflect.Interface && ptr.CanAddr() {
ptr = ptr.Addr()
}
if method := ptr.MethodByName(fieldName); method.IsValid() {
return s.evalCall(dot, method, node, fieldName, args, final)
}
hasArgs := len(args) > 1 || final.IsValid()
// It's not a method; must be a field of a struct or an element of a map. The receiver must not be nil.
receiver, isNil := indirect(receiver)
if isNil {
s.errorf("nil pointer evaluating %s.%s", typ, fieldName)
}
switch receiver.Kind() {
case reflect.Struct:
tField, ok := receiver.Type().FieldByName(fieldName)
if ok {
field := receiver.FieldByIndex(tField.Index)
if tField.PkgPath != "" { // field is unexported
s.errorf("%s is an unexported field of struct type %s", fieldName, typ)
}
// If it's a function, we must call it.
if hasArgs {
s.errorf("%s has arguments but cannot be invoked as function", fieldName)
}
return field
}
s.errorf("%s is not a field of struct type %s", fieldName, typ)
case reflect.Map:
// If it's a map, attempt to use the field name as a key.
nameVal := reflect.ValueOf(fieldName)
if nameVal.Type().AssignableTo(receiver.Type().Key()) {
if hasArgs {
s.errorf("%s is not a method but has arguments", fieldName)
}
return receiver.MapIndex(nameVal)
}
}
s.errorf("can't evaluate field %s in type %s", fieldName, typ)
panic("not reached")
}
var (
errorType = reflect.TypeOf((*error)(nil)).Elem()
fmtStringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
)
// evalCall executes a function or method call. If it's a method, fun already has the receiver bound, so
// it looks just like a function call. The arg list, if non-nil, includes (in the manner of the shell), arg[0]
// as the function itself.
func (s *state) evalCall(dot, fun reflect.Value, node parse.Node, name string, args []parse.Node, final reflect.Value) reflect.Value {
if args != nil {
args = args[1:] // Zeroth arg is function name/node; not passed to function.
}
typ := fun.Type()
numIn := len(args)
if final.IsValid() {
numIn++
}
numFixed := len(args)
if typ.IsVariadic() {
numFixed = typ.NumIn() - 1 // last arg is the variadic one.
if numIn < numFixed {
s.errorf("wrong number of args for %s: want at least %d got %d", name, typ.NumIn()-1, len(args))
}
} else if numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() {
s.errorf("wrong number of args for %s: want %d got %d", name, typ.NumIn(), len(args))
}
if !goodFunc(typ) {
// TODO: This could still be a confusing error; maybe goodFunc should provide info.
s.errorf("can't call method/function %q with %d results", name, typ.NumOut())
}
// Build the arg list.
argv := make([]reflect.Value, numIn)
// Args must be evaluated. Fixed args first.
i := 0
for ; i < numFixed && i < len(args); i++ {
argv[i] = s.evalArg(dot, typ.In(i), args[i])
}
// Now the ... args.
if typ.IsVariadic() {
argType := typ.In(typ.NumIn() - 1).Elem() // Argument is a slice.
for ; i < len(args); i++ {
argv[i] = s.evalArg(dot, argType, args[i])
}
}
// Add final value if necessary.
if final.IsValid() {
t := typ.In(typ.NumIn() - 1)
if typ.IsVariadic() {
t = t.Elem()
}
argv[i] = s.validateType(final, t)
}
result := fun.Call(argv)
// If we have an error that is not nil, stop execution and return that error to the caller.
if len(result) == 2 && !result[1].IsNil() {
s.at(node)
s.errorf("error calling %s: %s", name, result[1].Interface().(error))
}
return result[0]
}
// canBeNil reports whether an untyped nil can be assigned to the type. See reflect.Zero.
func canBeNil(typ reflect.Type) bool {
switch typ.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
}
return false
}
// validateType guarantees that the value is valid and assignable to the type.
func (s *state) validateType(value reflect.Value, typ reflect.Type) reflect.Value {
if !value.IsValid() {
if typ == nil || canBeNil(typ) {
// An untyped nil interface{}. Accept as a proper nil value.
return reflect.Zero(typ)
}
s.errorf("invalid value; expected %s", typ)
}
if typ != nil && !value.Type().AssignableTo(typ) {
if value.Kind() == reflect.Interface && !value.IsNil() {
value = value.Elem()
if value.Type().AssignableTo(typ) {
return value
}
// fallthrough
}
// Does one dereference or indirection work? We could do more, as we
// do with method receivers, but that gets messy and method receivers
// are much more constrained, so it makes more sense there than here.
// Besides, one is almost always all you need.
switch {
case value.Kind() == reflect.Ptr && value.Type().Elem().AssignableTo(typ):
value = value.Elem()
if !value.IsValid() {
s.errorf("dereference of nil pointer of type %s", typ)
}
case reflect.PtrTo(value.Type()).AssignableTo(typ) && value.CanAddr():
value = value.Addr()
default:
s.errorf("wrong type for value; expected %s; got %s", typ, value.Type())
}
}
return value
}
func (s *state) evalArg(dot reflect.Value, typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
switch arg := n.(type) {
case *parse.DotNode:
return s.validateType(dot, typ)
case *parse.NilNode:
if canBeNil(typ) {
return reflect.Zero(typ)
}
s.errorf("cannot assign nil to %s", typ)
case *parse.FieldNode:
return s.validateType(s.evalFieldNode(dot, arg, []parse.Node{n}, zero), typ)
case *parse.VariableNode:
return s.validateType(s.evalVariableNode(dot, arg, nil, zero), typ)
case *parse.PipeNode:
return s.validateType(s.evalPipeline(dot, arg), typ)
case *parse.IdentifierNode:
return s.evalFunction(dot, arg, arg, nil, zero)
case *parse.ChainNode:
return s.validateType(s.evalChainNode(dot, arg, nil, zero), typ)
}
switch typ.Kind() {
case reflect.Bool:
return s.evalBool(typ, n)
case reflect.Complex64, reflect.Complex128:
return s.evalComplex(typ, n)
case reflect.Float32, reflect.Float64:
return s.evalFloat(typ, n)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return s.evalInteger(typ, n)
case reflect.Interface:
if typ.NumMethod() == 0 {
return s.evalEmptyInterface(dot, n)
}
case reflect.String:
return s.evalString(typ, n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return s.evalUnsignedInteger(typ, n)
}
s.errorf("can't handle %s for arg of type %s", n, typ)
panic("not reached")
}
func (s *state) evalBool(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.BoolNode); ok {
value := reflect.New(typ).Elem()
value.SetBool(n.True)
return value
}
s.errorf("expected bool; found %s", n)
panic("not reached")
}
func (s *state) evalString(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.StringNode); ok {
value := reflect.New(typ).Elem()
value.SetString(n.Text)
return value
}
s.errorf("expected string; found %s", n)
panic("not reached")
}
func (s *state) evalInteger(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.NumberNode); ok && n.IsInt {
value := reflect.New(typ).Elem()
value.SetInt(n.Int64)
return value
}
s.errorf("expected integer; found %s", n)
panic("not reached")
}
func (s *state) evalUnsignedInteger(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.NumberNode); ok && n.IsUint {
value := reflect.New(typ).Elem()
value.SetUint(n.Uint64)
return value
}
s.errorf("expected unsigned integer; found %s", n)
panic("not reached")
}
func (s *state) evalFloat(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.NumberNode); ok && n.IsFloat {
value := reflect.New(typ).Elem()
value.SetFloat(n.Float64)
return value
}
s.errorf("expected float; found %s", n)
panic("not reached")
}
func (s *state) evalComplex(typ reflect.Type, n parse.Node) reflect.Value {
if n, ok := n.(*parse.NumberNode); ok && n.IsComplex {
value := reflect.New(typ).Elem()
value.SetComplex(n.Complex128)
return value
}
s.errorf("expected complex; found %s", n)
panic("not reached")
}
func (s *state) evalEmptyInterface(dot reflect.Value, n parse.Node) reflect.Value {
s.at(n)
switch n := n.(type) {
case *parse.BoolNode:
return reflect.ValueOf(n.True)
case *parse.DotNode:
return dot
case *parse.FieldNode:
return s.evalFieldNode(dot, n, nil, zero)
case *parse.IdentifierNode:
return s.evalFunction(dot, n, n, nil, zero)
case *parse.NilNode:
// NilNode is handled in evalArg, the only place that calls here.
s.errorf("evalEmptyInterface: nil (can't happen)")
case *parse.NumberNode:
return s.idealConstant(n)
case *parse.StringNode:
return reflect.ValueOf(n.Text)
case *parse.VariableNode:
return s.evalVariableNode(dot, n, nil, zero)
case *parse.PipeNode:
return s.evalPipeline(dot, n)
}
s.errorf("can't handle assignment of %s to empty interface argument", n)
panic("not reached")
}
// indirect returns the item at the end of indirection, and a bool to indicate if it's nil.
// We indirect through pointers and empty interfaces (only) because
// non-empty interfaces have methods we might need.
func indirect(v reflect.Value) (rv reflect.Value, isNil bool) {
for ; v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface; v = v.Elem() {
if v.IsNil() {
return v, true
}
if v.Kind() == reflect.Interface && v.NumMethod() > 0 {
break
}
}
return v, false
}
// printValue writes the textual representation of the value to the output of
// the template.
func (s *state) printValue(n parse.Node, v reflect.Value) {
s.at(n)
iface, ok := printableValue(v)
if !ok {
s.errorf("can't print %s of type %s", n, v.Type())
}
fmt.Fprint(s.wr, iface)
}
// printableValue returns the, possibly indirected, interface value inside v that
// is best for a call to formatted printer.
func printableValue(v reflect.Value) (interface{}, bool) {
if v.Kind() == reflect.Ptr {
v, _ = indirect(v) // fmt.Fprint handles nil.
}
if !v.IsValid() {
return "<no value>", true
}
if !v.Type().Implements(errorType) && !v.Type().Implements(fmtStringerType) {
if v.CanAddr() && (reflect.PtrTo(v.Type()).Implements(errorType) || reflect.PtrTo(v.Type()).Implements(fmtStringerType)) {
v = v.Addr()
} else {
switch v.Kind() {
case reflect.Chan, reflect.Func:
return nil, false
}
}
}
return v.Interface(), true
}
// Types to help sort the keys in a map for reproducible output.
type rvs []reflect.Value
func (x rvs) Len() int { return len(x) }
func (x rvs) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
type rvInts struct{ rvs }
func (x rvInts) Less(i, j int) bool { return x.rvs[i].Int() < x.rvs[j].Int() }
type rvUints struct{ rvs }
func (x rvUints) Less(i, j int) bool { return x.rvs[i].Uint() < x.rvs[j].Uint() }
type rvFloats struct{ rvs }
func (x rvFloats) Less(i, j int) bool { return x.rvs[i].Float() < x.rvs[j].Float() }
type rvStrings struct{ rvs }
func (x rvStrings) Less(i, j int) bool { return x.rvs[i].String() < x.rvs[j].String() }
// sortKeys sorts (if it can) the slice of reflect.Values, which is a slice of map keys.
func sortKeys(v []reflect.Value) []reflect.Value {
if len(v) <= 1 {
return v
}
switch v[0].Kind() {
case reflect.Float32, reflect.Float64:
sort.Sort(rvFloats{v})
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
sort.Sort(rvInts{v})
case reflect.String:
sort.Sort(rvStrings{v})
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
sort.Sort(rvUints{v})
}
return v
}

View File

@@ -1,598 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"errors"
"fmt"
"io"
"net/url"
"reflect"
"strings"
"unicode"
"unicode/utf8"
)
// FuncMap is the type of the map defining the mapping from names to functions.
// Each function must have either a single return value, or two return values of
// which the second has type error. In that case, if the second (error)
// return value evaluates to non-nil during execution, execution terminates and
// Execute returns that error.
type FuncMap map[string]interface{}
var builtins = FuncMap{
"and": and,
"call": call,
"html": HTMLEscaper,
"index": index,
"js": JSEscaper,
"len": length,
"not": not,
"or": or,
"print": fmt.Sprint,
"printf": fmt.Sprintf,
"println": fmt.Sprintln,
"urlquery": URLQueryEscaper,
// Comparisons
"eq": eq, // ==
"ge": ge, // >=
"gt": gt, // >
"le": le, // <=
"lt": lt, // <
"ne": ne, // !=
}
var builtinFuncs = createValueFuncs(builtins)
// createValueFuncs turns a FuncMap into a map[string]reflect.Value
func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
m := make(map[string]reflect.Value)
addValueFuncs(m, funcMap)
return m
}
// addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
for name, fn := range in {
v := reflect.ValueOf(fn)
if v.Kind() != reflect.Func {
panic("value for " + name + " not a function")
}
if !goodFunc(v.Type()) {
panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
}
out[name] = v
}
}
// addFuncs adds to values the functions in funcs. It does no checking of the input -
// call addValueFuncs first.
func addFuncs(out, in FuncMap) {
for name, fn := range in {
out[name] = fn
}
}
// goodFunc checks that the function or method has the right result signature.
func goodFunc(typ reflect.Type) bool {
// We allow functions with 1 result or 2 results where the second is an error.
switch {
case typ.NumOut() == 1:
return true
case typ.NumOut() == 2 && typ.Out(1) == errorType:
return true
}
return false
}
// findFunction looks for a function in the template, and global map.
func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
if tmpl != nil && tmpl.common != nil {
if fn := tmpl.execFuncs[name]; fn.IsValid() {
return fn, true
}
}
if fn := builtinFuncs[name]; fn.IsValid() {
return fn, true
}
return reflect.Value{}, false
}
// Indexing.
// index returns the result of indexing its first argument by the following
// arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
// indexed item must be a map, slice, or array.
func index(item interface{}, indices ...interface{}) (interface{}, error) {
v := reflect.ValueOf(item)
for _, i := range indices {
index := reflect.ValueOf(i)
var isNil bool
if v, isNil = indirect(v); isNil {
return nil, fmt.Errorf("index of nil pointer")
}
switch v.Kind() {
case reflect.Array, reflect.Slice, reflect.String:
var x int64
switch index.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x = index.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
x = int64(index.Uint())
default:
return nil, fmt.Errorf("cannot index slice/array with type %s", index.Type())
}
if x < 0 || x >= int64(v.Len()) {
return nil, fmt.Errorf("index out of range: %d", x)
}
v = v.Index(int(x))
case reflect.Map:
if !index.IsValid() {
index = reflect.Zero(v.Type().Key())
}
if !index.Type().AssignableTo(v.Type().Key()) {
return nil, fmt.Errorf("%s is not index type for %s", index.Type(), v.Type())
}
if x := v.MapIndex(index); x.IsValid() {
v = x
} else {
v = reflect.Zero(v.Type().Elem())
}
default:
return nil, fmt.Errorf("can't index item of type %s", v.Type())
}
}
return v.Interface(), nil
}
// Length
// length returns the length of the item, with an error if it has no defined length.
func length(item interface{}) (int, error) {
v, isNil := indirect(reflect.ValueOf(item))
if isNil {
return 0, fmt.Errorf("len of nil pointer")
}
switch v.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
return v.Len(), nil
}
return 0, fmt.Errorf("len of type %s", v.Type())
}
// Function invocation
// call returns the result of evaluating the first argument as a function.
// The function must return 1 result, or 2 results, the second of which is an error.
func call(fn interface{}, args ...interface{}) (interface{}, error) {
v := reflect.ValueOf(fn)
typ := v.Type()
if typ.Kind() != reflect.Func {
return nil, fmt.Errorf("non-function of type %s", typ)
}
if !goodFunc(typ) {
return nil, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
}
numIn := typ.NumIn()
var dddType reflect.Type
if typ.IsVariadic() {
if len(args) < numIn-1 {
return nil, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
}
dddType = typ.In(numIn - 1).Elem()
} else {
if len(args) != numIn {
return nil, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
}
}
argv := make([]reflect.Value, len(args))
for i, arg := range args {
value := reflect.ValueOf(arg)
// Compute the expected type. Clumsy because of variadics.
var argType reflect.Type
if !typ.IsVariadic() || i < numIn-1 {
argType = typ.In(i)
} else {
argType = dddType
}
if !value.IsValid() && canBeNil(argType) {
value = reflect.Zero(argType)
}
if !value.Type().AssignableTo(argType) {
return nil, fmt.Errorf("arg %d has type %s; should be %s", i, value.Type(), argType)
}
argv[i] = value
}
result := v.Call(argv)
if len(result) == 2 && !result[1].IsNil() {
return result[0].Interface(), result[1].Interface().(error)
}
return result[0].Interface(), nil
}
// Boolean logic.
func truth(a interface{}) bool {
t, _ := isTrue(reflect.ValueOf(a))
return t
}
// and computes the Boolean AND of its arguments, returning
// the first false argument it encounters, or the last argument.
func and(arg0 interface{}, args ...interface{}) interface{} {
if !truth(arg0) {
return arg0
}
for i := range args {
arg0 = args[i]
if !truth(arg0) {
break
}
}
return arg0
}
// or computes the Boolean OR of its arguments, returning
// the first true argument it encounters, or the last argument.
func or(arg0 interface{}, args ...interface{}) interface{} {
if truth(arg0) {
return arg0
}
for i := range args {
arg0 = args[i]
if truth(arg0) {
break
}
}
return arg0
}
// not returns the Boolean negation of its argument.
func not(arg interface{}) (truth bool) {
truth, _ = isTrue(reflect.ValueOf(arg))
return !truth
}
// Comparison.
// TODO: Perhaps allow comparison between signed and unsigned integers.
var (
errBadComparisonType = errors.New("invalid type for comparison")
errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison")
)
type kind int
const (
invalidKind kind = iota
boolKind
complexKind
intKind
floatKind
integerKind
stringKind
uintKind
)
func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() {
case reflect.Bool:
return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil
case reflect.Float32, reflect.Float64:
return floatKind, nil
case reflect.Complex64, reflect.Complex128:
return complexKind, nil
case reflect.String:
return stringKind, nil
}
return invalidKind, errBadComparisonType
}
// eq evaluates the comparison a == b || a == c || ...
func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
if len(arg2) == 0 {
return false, errNoComparison
}
for _, arg := range arg2 {
v2 := reflect.ValueOf(arg)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
truth := false
if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign.
switch {
case k1 == intKind && k2 == uintKind:
truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
case k1 == uintKind && k2 == intKind:
truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
default:
return false, errBadComparison
}
} else {
switch k1 {
case boolKind:
truth = v1.Bool() == v2.Bool()
case complexKind:
truth = v1.Complex() == v2.Complex()
case floatKind:
truth = v1.Float() == v2.Float()
case intKind:
truth = v1.Int() == v2.Int()
case stringKind:
truth = v1.String() == v2.String()
case uintKind:
truth = v1.Uint() == v2.Uint()
default:
panic("invalid kind")
}
}
if truth {
return true, nil
}
}
return false, nil
}
// ne evaluates the comparison a != b.
func ne(arg1, arg2 interface{}) (bool, error) {
// != is the inverse of ==.
equal, err := eq(arg1, arg2)
return !equal, err
}
// lt evaluates the comparison a < b.
func lt(arg1, arg2 interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
v2 := reflect.ValueOf(arg2)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
truth := false
if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign.
switch {
case k1 == intKind && k2 == uintKind:
truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
case k1 == uintKind && k2 == intKind:
truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
default:
return false, errBadComparison
}
} else {
switch k1 {
case boolKind, complexKind:
return false, errBadComparisonType
case floatKind:
truth = v1.Float() < v2.Float()
case intKind:
truth = v1.Int() < v2.Int()
case stringKind:
truth = v1.String() < v2.String()
case uintKind:
truth = v1.Uint() < v2.Uint()
default:
panic("invalid kind")
}
}
return truth, nil
}
// le evaluates the comparison <= b.
func le(arg1, arg2 interface{}) (bool, error) {
// <= is < or ==.
lessThan, err := lt(arg1, arg2)
if lessThan || err != nil {
return lessThan, err
}
return eq(arg1, arg2)
}
// gt evaluates the comparison a > b.
func gt(arg1, arg2 interface{}) (bool, error) {
// > is the inverse of <=.
lessOrEqual, err := le(arg1, arg2)
if err != nil {
return false, err
}
return !lessOrEqual, nil
}
// ge evaluates the comparison a >= b.
func ge(arg1, arg2 interface{}) (bool, error) {
// >= is the inverse of <.
lessThan, err := lt(arg1, arg2)
if err != nil {
return false, err
}
return !lessThan, nil
}
// HTML escaping.
var (
htmlQuot = []byte("&#34;") // shorter than "&quot;"
htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
htmlAmp = []byte("&amp;")
htmlLt = []byte("&lt;")
htmlGt = []byte("&gt;")
)
// HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
func HTMLEscape(w io.Writer, b []byte) {
last := 0
for i, c := range b {
var html []byte
switch c {
case '"':
html = htmlQuot
case '\'':
html = htmlApos
case '&':
html = htmlAmp
case '<':
html = htmlLt
case '>':
html = htmlGt
default:
continue
}
w.Write(b[last:i])
w.Write(html)
last = i + 1
}
w.Write(b[last:])
}
// HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
func HTMLEscapeString(s string) string {
// Avoid allocation if we can.
if strings.IndexAny(s, `'"&<>`) < 0 {
return s
}
var b bytes.Buffer
HTMLEscape(&b, []byte(s))
return b.String()
}
// HTMLEscaper returns the escaped HTML equivalent of the textual
// representation of its arguments.
func HTMLEscaper(args ...interface{}) string {
return HTMLEscapeString(evalArgs(args))
}
// JavaScript escaping.
var (
jsLowUni = []byte(`\u00`)
hex = []byte("0123456789ABCDEF")
jsBackslash = []byte(`\\`)
jsApos = []byte(`\'`)
jsQuot = []byte(`\"`)
jsLt = []byte(`\x3C`)
jsGt = []byte(`\x3E`)
)
// JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
func JSEscape(w io.Writer, b []byte) {
last := 0
for i := 0; i < len(b); i++ {
c := b[i]
if !jsIsSpecial(rune(c)) {
// fast path: nothing to do
continue
}
w.Write(b[last:i])
if c < utf8.RuneSelf {
// Quotes, slashes and angle brackets get quoted.
// Control characters get written as \u00XX.
switch c {
case '\\':
w.Write(jsBackslash)
case '\'':
w.Write(jsApos)
case '"':
w.Write(jsQuot)
case '<':
w.Write(jsLt)
case '>':
w.Write(jsGt)
default:
w.Write(jsLowUni)
t, b := c>>4, c&0x0f
w.Write(hex[t : t+1])
w.Write(hex[b : b+1])
}
} else {
// Unicode rune.
r, size := utf8.DecodeRune(b[i:])
if unicode.IsPrint(r) {
w.Write(b[i : i+size])
} else {
fmt.Fprintf(w, "\\u%04X", r)
}
i += size - 1
}
last = i + 1
}
w.Write(b[last:])
}
// JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
func JSEscapeString(s string) string {
// Avoid allocation if we can.
if strings.IndexFunc(s, jsIsSpecial) < 0 {
return s
}
var b bytes.Buffer
JSEscape(&b, []byte(s))
return b.String()
}
func jsIsSpecial(r rune) bool {
switch r {
case '\\', '\'', '"', '<', '>':
return true
}
return r < ' ' || utf8.RuneSelf <= r
}
// JSEscaper returns the escaped JavaScript equivalent of the textual
// representation of its arguments.
func JSEscaper(args ...interface{}) string {
return JSEscapeString(evalArgs(args))
}
// URLQueryEscaper returns the escaped value of the textual representation of
// its arguments in a form suitable for embedding in a URL query.
func URLQueryEscaper(args ...interface{}) string {
return url.QueryEscape(evalArgs(args))
}
// evalArgs formats the list of arguments into a string. It is therefore equivalent to
// fmt.Sprint(args...)
// except that each argument is indirected (if a pointer), as required,
// using the same rules as the default string evaluation during template
// execution.
func evalArgs(args []interface{}) string {
ok := false
var s string
// Fast path for simple common case.
if len(args) == 1 {
s, ok = args[0].(string)
}
if !ok {
for i, arg := range args {
a, ok := printableValue(reflect.ValueOf(arg))
if ok {
args[i] = a
} // else left fmt do its thing
}
s = fmt.Sprint(args...)
}
return s
}

View File

@@ -1 +0,0 @@
module github.com/alecthomas/template

View File

@@ -1,108 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Helper functions to make constructing templates easier.
package template
import (
"fmt"
"io/ioutil"
"path/filepath"
)
// Functions and methods to parse templates.
// Must is a helper that wraps a call to a function returning (*Template, error)
// and panics if the error is non-nil. It is intended for use in variable
// initializations such as
// var t = template.Must(template.New("name").Parse("text"))
func Must(t *Template, err error) *Template {
if err != nil {
panic(err)
}
return t
}
// ParseFiles creates a new Template and parses the template definitions from
// the named files. The returned template's name will have the (base) name and
// (parsed) contents of the first file. There must be at least one file.
// If an error occurs, parsing stops and the returned *Template is nil.
func ParseFiles(filenames ...string) (*Template, error) {
return parseFiles(nil, filenames...)
}
// ParseFiles parses the named files and associates the resulting templates with
// t. If an error occurs, parsing stops and the returned template is nil;
// otherwise it is t. There must be at least one file.
func (t *Template) ParseFiles(filenames ...string) (*Template, error) {
return parseFiles(t, filenames...)
}
// parseFiles is the helper for the method and function. If the argument
// template is nil, it is created from the first file.
func parseFiles(t *Template, filenames ...string) (*Template, error) {
if len(filenames) == 0 {
// Not really a problem, but be consistent.
return nil, fmt.Errorf("template: no files named in call to ParseFiles")
}
for _, filename := range filenames {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
s := string(b)
name := filepath.Base(filename)
// First template becomes return value if not already defined,
// and we use that one for subsequent New calls to associate
// all the templates together. Also, if this file has the same name
// as t, this file becomes the contents of t, so
// t, err := New(name).Funcs(xxx).ParseFiles(name)
// works. Otherwise we create a new template associated with t.
var tmpl *Template
if t == nil {
t = New(name)
}
if name == t.Name() {
tmpl = t
} else {
tmpl = t.New(name)
}
_, err = tmpl.Parse(s)
if err != nil {
return nil, err
}
}
return t, nil
}
// ParseGlob creates a new Template and parses the template definitions from the
// files identified by the pattern, which must match at least one file. The
// returned template will have the (base) name and (parsed) contents of the
// first file matched by the pattern. ParseGlob is equivalent to calling
// ParseFiles with the list of files matched by the pattern.
func ParseGlob(pattern string) (*Template, error) {
return parseGlob(nil, pattern)
}
// ParseGlob parses the template definitions in the files identified by the
// pattern and associates the resulting templates with t. The pattern is
// processed by filepath.Glob and must match at least one file. ParseGlob is
// equivalent to calling t.ParseFiles with the list of files matched by the
// pattern.
func (t *Template) ParseGlob(pattern string) (*Template, error) {
return parseGlob(t, pattern)
}
// parseGlob is the implementation of the function and method ParseGlob.
func parseGlob(t *Template, pattern string) (*Template, error) {
filenames, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
if len(filenames) == 0 {
return nil, fmt.Errorf("template: pattern matches no files: %#q", pattern)
}
return parseFiles(t, filenames...)
}

View File

@@ -1,556 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package parse
import (
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
// item represents a token or text string returned from the scanner.
type item struct {
typ itemType // The type of this item.
pos Pos // The starting position, in bytes, of this item in the input string.
val string // The value of this item.
}
func (i item) String() string {
switch {
case i.typ == itemEOF:
return "EOF"
case i.typ == itemError:
return i.val
case i.typ > itemKeyword:
return fmt.Sprintf("<%s>", i.val)
case len(i.val) > 10:
return fmt.Sprintf("%.10q...", i.val)
}
return fmt.Sprintf("%q", i.val)
}
// itemType identifies the type of lex items.
type itemType int
const (
itemError itemType = iota // error occurred; value is text of error
itemBool // boolean constant
itemChar // printable ASCII character; grab bag for comma etc.
itemCharConstant // character constant
itemComplex // complex constant (1+2i); imaginary is just a number
itemColonEquals // colon-equals (':=') introducing a declaration
itemEOF
itemField // alphanumeric identifier starting with '.'
itemIdentifier // alphanumeric identifier not starting with '.'
itemLeftDelim // left action delimiter
itemLeftParen // '(' inside action
itemNumber // simple number, including imaginary
itemPipe // pipe symbol
itemRawString // raw quoted string (includes quotes)
itemRightDelim // right action delimiter
itemElideNewline // elide newline after right delim
itemRightParen // ')' inside action
itemSpace // run of spaces separating arguments
itemString // quoted string (includes quotes)
itemText // plain text
itemVariable // variable starting with '$', such as '$' or '$1' or '$hello'
// Keywords appear after all the rest.
itemKeyword // used only to delimit the keywords
itemDot // the cursor, spelled '.'
itemDefine // define keyword
itemElse // else keyword
itemEnd // end keyword
itemIf // if keyword
itemNil // the untyped nil constant, easiest to treat as a keyword
itemRange // range keyword
itemTemplate // template keyword
itemWith // with keyword
)
var key = map[string]itemType{
".": itemDot,
"define": itemDefine,
"else": itemElse,
"end": itemEnd,
"if": itemIf,
"range": itemRange,
"nil": itemNil,
"template": itemTemplate,
"with": itemWith,
}
const eof = -1
// stateFn represents the state of the scanner as a function that returns the next state.
type stateFn func(*lexer) stateFn
// lexer holds the state of the scanner.
type lexer struct {
name string // the name of the input; used only for error reports
input string // the string being scanned
leftDelim string // start of action
rightDelim string // end of action
state stateFn // the next lexing function to enter
pos Pos // current position in the input
start Pos // start position of this item
width Pos // width of last rune read from input
lastPos Pos // position of most recent item returned by nextItem
items chan item // channel of scanned items
parenDepth int // nesting depth of ( ) exprs
}
// next returns the next rune in the input.
func (l *lexer) next() rune {
if int(l.pos) >= len(l.input) {
l.width = 0
return eof
}
r, w := utf8.DecodeRuneInString(l.input[l.pos:])
l.width = Pos(w)
l.pos += l.width
return r
}
// peek returns but does not consume the next rune in the input.
func (l *lexer) peek() rune {
r := l.next()
l.backup()
return r
}
// backup steps back one rune. Can only be called once per call of next.
func (l *lexer) backup() {
l.pos -= l.width
}
// emit passes an item back to the client.
func (l *lexer) emit(t itemType) {
l.items <- item{t, l.start, l.input[l.start:l.pos]}
l.start = l.pos
}
// ignore skips over the pending input before this point.
func (l *lexer) ignore() {
l.start = l.pos
}
// accept consumes the next rune if it's from the valid set.
func (l *lexer) accept(valid string) bool {
if strings.IndexRune(valid, l.next()) >= 0 {
return true
}
l.backup()
return false
}
// acceptRun consumes a run of runes from the valid set.
func (l *lexer) acceptRun(valid string) {
for strings.IndexRune(valid, l.next()) >= 0 {
}
l.backup()
}
// lineNumber reports which line we're on, based on the position of
// the previous item returned by nextItem. Doing it this way
// means we don't have to worry about peek double counting.
func (l *lexer) lineNumber() int {
return 1 + strings.Count(l.input[:l.lastPos], "\n")
}
// errorf returns an error token and terminates the scan by passing
// back a nil pointer that will be the next state, terminating l.nextItem.
func (l *lexer) errorf(format string, args ...interface{}) stateFn {
l.items <- item{itemError, l.start, fmt.Sprintf(format, args...)}
return nil
}
// nextItem returns the next item from the input.
func (l *lexer) nextItem() item {
item := <-l.items
l.lastPos = item.pos
return item
}
// lex creates a new scanner for the input string.
func lex(name, input, left, right string) *lexer {
if left == "" {
left = leftDelim
}
if right == "" {
right = rightDelim
}
l := &lexer{
name: name,
input: input,
leftDelim: left,
rightDelim: right,
items: make(chan item),
}
go l.run()
return l
}
// run runs the state machine for the lexer.
func (l *lexer) run() {
for l.state = lexText; l.state != nil; {
l.state = l.state(l)
}
}
// state functions
const (
leftDelim = "{{"
rightDelim = "}}"
leftComment = "/*"
rightComment = "*/"
)
// lexText scans until an opening action delimiter, "{{".
func lexText(l *lexer) stateFn {
for {
if strings.HasPrefix(l.input[l.pos:], l.leftDelim) {
if l.pos > l.start {
l.emit(itemText)
}
return lexLeftDelim
}
if l.next() == eof {
break
}
}
// Correctly reached EOF.
if l.pos > l.start {
l.emit(itemText)
}
l.emit(itemEOF)
return nil
}
// lexLeftDelim scans the left delimiter, which is known to be present.
func lexLeftDelim(l *lexer) stateFn {
l.pos += Pos(len(l.leftDelim))
if strings.HasPrefix(l.input[l.pos:], leftComment) {
return lexComment
}
l.emit(itemLeftDelim)
l.parenDepth = 0
return lexInsideAction
}
// lexComment scans a comment. The left comment marker is known to be present.
func lexComment(l *lexer) stateFn {
l.pos += Pos(len(leftComment))
i := strings.Index(l.input[l.pos:], rightComment)
if i < 0 {
return l.errorf("unclosed comment")
}
l.pos += Pos(i + len(rightComment))
if !strings.HasPrefix(l.input[l.pos:], l.rightDelim) {
return l.errorf("comment ends before closing delimiter")
}
l.pos += Pos(len(l.rightDelim))
l.ignore()
return lexText
}
// lexRightDelim scans the right delimiter, which is known to be present.
func lexRightDelim(l *lexer) stateFn {
l.pos += Pos(len(l.rightDelim))
l.emit(itemRightDelim)
if l.peek() == '\\' {
l.pos++
l.emit(itemElideNewline)
}
return lexText
}
// lexInsideAction scans the elements inside action delimiters.
func lexInsideAction(l *lexer) stateFn {
// Either number, quoted string, or identifier.
// Spaces separate arguments; runs of spaces turn into itemSpace.
// Pipe symbols separate and are emitted.
if strings.HasPrefix(l.input[l.pos:], l.rightDelim+"\\") || strings.HasPrefix(l.input[l.pos:], l.rightDelim) {
if l.parenDepth == 0 {
return lexRightDelim
}
return l.errorf("unclosed left paren")
}
switch r := l.next(); {
case r == eof || isEndOfLine(r):
return l.errorf("unclosed action")
case isSpace(r):
return lexSpace
case r == ':':
if l.next() != '=' {
return l.errorf("expected :=")
}
l.emit(itemColonEquals)
case r == '|':
l.emit(itemPipe)
case r == '"':
return lexQuote
case r == '`':
return lexRawQuote
case r == '$':
return lexVariable
case r == '\'':
return lexChar
case r == '.':
// special look-ahead for ".field" so we don't break l.backup().
if l.pos < Pos(len(l.input)) {
r := l.input[l.pos]
if r < '0' || '9' < r {
return lexField
}
}
fallthrough // '.' can start a number.
case r == '+' || r == '-' || ('0' <= r && r <= '9'):
l.backup()
return lexNumber
case isAlphaNumeric(r):
l.backup()
return lexIdentifier
case r == '(':
l.emit(itemLeftParen)
l.parenDepth++
return lexInsideAction
case r == ')':
l.emit(itemRightParen)
l.parenDepth--
if l.parenDepth < 0 {
return l.errorf("unexpected right paren %#U", r)
}
return lexInsideAction
case r <= unicode.MaxASCII && unicode.IsPrint(r):
l.emit(itemChar)
return lexInsideAction
default:
return l.errorf("unrecognized character in action: %#U", r)
}
return lexInsideAction
}
// lexSpace scans a run of space characters.
// One space has already been seen.
func lexSpace(l *lexer) stateFn {
for isSpace(l.peek()) {
l.next()
}
l.emit(itemSpace)
return lexInsideAction
}
// lexIdentifier scans an alphanumeric.
func lexIdentifier(l *lexer) stateFn {
Loop:
for {
switch r := l.next(); {
case isAlphaNumeric(r):
// absorb.
default:
l.backup()
word := l.input[l.start:l.pos]
if !l.atTerminator() {
return l.errorf("bad character %#U", r)
}
switch {
case key[word] > itemKeyword:
l.emit(key[word])
case word[0] == '.':
l.emit(itemField)
case word == "true", word == "false":
l.emit(itemBool)
default:
l.emit(itemIdentifier)
}
break Loop
}
}
return lexInsideAction
}
// lexField scans a field: .Alphanumeric.
// The . has been scanned.
func lexField(l *lexer) stateFn {
return lexFieldOrVariable(l, itemField)
}
// lexVariable scans a Variable: $Alphanumeric.
// The $ has been scanned.
func lexVariable(l *lexer) stateFn {
if l.atTerminator() { // Nothing interesting follows -> "$".
l.emit(itemVariable)
return lexInsideAction
}
return lexFieldOrVariable(l, itemVariable)
}
// lexVariable scans a field or variable: [.$]Alphanumeric.
// The . or $ has been scanned.
func lexFieldOrVariable(l *lexer, typ itemType) stateFn {
if l.atTerminator() { // Nothing interesting follows -> "." or "$".
if typ == itemVariable {
l.emit(itemVariable)
} else {
l.emit(itemDot)
}
return lexInsideAction
}
var r rune
for {
r = l.next()
if !isAlphaNumeric(r) {
l.backup()
break
}
}
if !l.atTerminator() {
return l.errorf("bad character %#U", r)
}
l.emit(typ)
return lexInsideAction
}
// atTerminator reports whether the input is at valid termination character to
// appear after an identifier. Breaks .X.Y into two pieces. Also catches cases
// like "$x+2" not being acceptable without a space, in case we decide one
// day to implement arithmetic.
func (l *lexer) atTerminator() bool {
r := l.peek()
if isSpace(r) || isEndOfLine(r) {
return true
}
switch r {
case eof, '.', ',', '|', ':', ')', '(':
return true
}
// Does r start the delimiter? This can be ambiguous (with delim=="//", $x/2 will
// succeed but should fail) but only in extremely rare cases caused by willfully
// bad choice of delimiter.
if rd, _ := utf8.DecodeRuneInString(l.rightDelim); rd == r {
return true
}
return false
}
// lexChar scans a character constant. The initial quote is already
// scanned. Syntax checking is done by the parser.
func lexChar(l *lexer) stateFn {
Loop:
for {
switch l.next() {
case '\\':
if r := l.next(); r != eof && r != '\n' {
break
}
fallthrough
case eof, '\n':
return l.errorf("unterminated character constant")
case '\'':
break Loop
}
}
l.emit(itemCharConstant)
return lexInsideAction
}
// lexNumber scans a number: decimal, octal, hex, float, or imaginary. This
// isn't a perfect number scanner - for instance it accepts "." and "0x0.2"
// and "089" - but when it's wrong the input is invalid and the parser (via
// strconv) will notice.
func lexNumber(l *lexer) stateFn {
if !l.scanNumber() {
return l.errorf("bad number syntax: %q", l.input[l.start:l.pos])
}
if sign := l.peek(); sign == '+' || sign == '-' {
// Complex: 1+2i. No spaces, must end in 'i'.
if !l.scanNumber() || l.input[l.pos-1] != 'i' {
return l.errorf("bad number syntax: %q", l.input[l.start:l.pos])
}
l.emit(itemComplex)
} else {
l.emit(itemNumber)
}
return lexInsideAction
}
func (l *lexer) scanNumber() bool {
// Optional leading sign.
l.accept("+-")
// Is it hex?
digits := "0123456789"
if l.accept("0") && l.accept("xX") {
digits = "0123456789abcdefABCDEF"
}
l.acceptRun(digits)
if l.accept(".") {
l.acceptRun(digits)
}
if l.accept("eE") {
l.accept("+-")
l.acceptRun("0123456789")
}
// Is it imaginary?
l.accept("i")
// Next thing mustn't be alphanumeric.
if isAlphaNumeric(l.peek()) {
l.next()
return false
}
return true
}
// lexQuote scans a quoted string.
func lexQuote(l *lexer) stateFn {
Loop:
for {
switch l.next() {
case '\\':
if r := l.next(); r != eof && r != '\n' {
break
}
fallthrough
case eof, '\n':
return l.errorf("unterminated quoted string")
case '"':
break Loop
}
}
l.emit(itemString)
return lexInsideAction
}
// lexRawQuote scans a raw quoted string.
func lexRawQuote(l *lexer) stateFn {
Loop:
for {
switch l.next() {
case eof, '\n':
return l.errorf("unterminated raw quoted string")
case '`':
break Loop
}
}
l.emit(itemRawString)
return lexInsideAction
}
// isSpace reports whether r is a space character.
func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}
// isEndOfLine reports whether r is an end-of-line character.
func isEndOfLine(r rune) bool {
return r == '\r' || r == '\n'
}
// isAlphaNumeric reports whether r is an alphabetic, digit, or underscore.
func isAlphaNumeric(r rune) bool {
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
}

View File

@@ -1,834 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Parse nodes.
package parse
import (
"bytes"
"fmt"
"strconv"
"strings"
)
var textFormat = "%s" // Changed to "%q" in tests for better error messages.
// A Node is an element in the parse tree. The interface is trivial.
// The interface contains an unexported method so that only
// types local to this package can satisfy it.
type Node interface {
Type() NodeType
String() string
// Copy does a deep copy of the Node and all its components.
// To avoid type assertions, some XxxNodes also have specialized
// CopyXxx methods that return *XxxNode.
Copy() Node
Position() Pos // byte position of start of node in full original input string
// tree returns the containing *Tree.
// It is unexported so all implementations of Node are in this package.
tree() *Tree
}
// NodeType identifies the type of a parse tree node.
type NodeType int
// Pos represents a byte position in the original input text from which
// this template was parsed.
type Pos int
func (p Pos) Position() Pos {
return p
}
// Type returns itself and provides an easy default implementation
// for embedding in a Node. Embedded in all non-trivial Nodes.
func (t NodeType) Type() NodeType {
return t
}
const (
NodeText NodeType = iota // Plain text.
NodeAction // A non-control action such as a field evaluation.
NodeBool // A boolean constant.
NodeChain // A sequence of field accesses.
NodeCommand // An element of a pipeline.
NodeDot // The cursor, dot.
nodeElse // An else action. Not added to tree.
nodeEnd // An end action. Not added to tree.
NodeField // A field or method name.
NodeIdentifier // An identifier; always a function name.
NodeIf // An if action.
NodeList // A list of Nodes.
NodeNil // An untyped nil constant.
NodeNumber // A numerical constant.
NodePipe // A pipeline of commands.
NodeRange // A range action.
NodeString // A string constant.
NodeTemplate // A template invocation action.
NodeVariable // A $ variable.
NodeWith // A with action.
)
// Nodes.
// ListNode holds a sequence of nodes.
type ListNode struct {
NodeType
Pos
tr *Tree
Nodes []Node // The element nodes in lexical order.
}
func (t *Tree) newList(pos Pos) *ListNode {
return &ListNode{tr: t, NodeType: NodeList, Pos: pos}
}
func (l *ListNode) append(n Node) {
l.Nodes = append(l.Nodes, n)
}
func (l *ListNode) tree() *Tree {
return l.tr
}
func (l *ListNode) String() string {
b := new(bytes.Buffer)
for _, n := range l.Nodes {
fmt.Fprint(b, n)
}
return b.String()
}
func (l *ListNode) CopyList() *ListNode {
if l == nil {
return l
}
n := l.tr.newList(l.Pos)
for _, elem := range l.Nodes {
n.append(elem.Copy())
}
return n
}
func (l *ListNode) Copy() Node {
return l.CopyList()
}
// TextNode holds plain text.
type TextNode struct {
NodeType
Pos
tr *Tree
Text []byte // The text; may span newlines.
}
func (t *Tree) newText(pos Pos, text string) *TextNode {
return &TextNode{tr: t, NodeType: NodeText, Pos: pos, Text: []byte(text)}
}
func (t *TextNode) String() string {
return fmt.Sprintf(textFormat, t.Text)
}
func (t *TextNode) tree() *Tree {
return t.tr
}
func (t *TextNode) Copy() Node {
return &TextNode{tr: t.tr, NodeType: NodeText, Pos: t.Pos, Text: append([]byte{}, t.Text...)}
}
// PipeNode holds a pipeline with optional declaration
type PipeNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input (deprecated; kept for compatibility)
Decl []*VariableNode // Variable declarations in lexical order.
Cmds []*CommandNode // The commands in lexical order.
}
func (t *Tree) newPipeline(pos Pos, line int, decl []*VariableNode) *PipeNode {
return &PipeNode{tr: t, NodeType: NodePipe, Pos: pos, Line: line, Decl: decl}
}
func (p *PipeNode) append(command *CommandNode) {
p.Cmds = append(p.Cmds, command)
}
func (p *PipeNode) String() string {
s := ""
if len(p.Decl) > 0 {
for i, v := range p.Decl {
if i > 0 {
s += ", "
}
s += v.String()
}
s += " := "
}
for i, c := range p.Cmds {
if i > 0 {
s += " | "
}
s += c.String()
}
return s
}
func (p *PipeNode) tree() *Tree {
return p.tr
}
func (p *PipeNode) CopyPipe() *PipeNode {
if p == nil {
return p
}
var decl []*VariableNode
for _, d := range p.Decl {
decl = append(decl, d.Copy().(*VariableNode))
}
n := p.tr.newPipeline(p.Pos, p.Line, decl)
for _, c := range p.Cmds {
n.append(c.Copy().(*CommandNode))
}
return n
}
func (p *PipeNode) Copy() Node {
return p.CopyPipe()
}
// ActionNode holds an action (something bounded by delimiters).
// Control actions have their own nodes; ActionNode represents simple
// ones such as field evaluations and parenthesized pipelines.
type ActionNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input (deprecated; kept for compatibility)
Pipe *PipeNode // The pipeline in the action.
}
func (t *Tree) newAction(pos Pos, line int, pipe *PipeNode) *ActionNode {
return &ActionNode{tr: t, NodeType: NodeAction, Pos: pos, Line: line, Pipe: pipe}
}
func (a *ActionNode) String() string {
return fmt.Sprintf("{{%s}}", a.Pipe)
}
func (a *ActionNode) tree() *Tree {
return a.tr
}
func (a *ActionNode) Copy() Node {
return a.tr.newAction(a.Pos, a.Line, a.Pipe.CopyPipe())
}
// CommandNode holds a command (a pipeline inside an evaluating action).
type CommandNode struct {
NodeType
Pos
tr *Tree
Args []Node // Arguments in lexical order: Identifier, field, or constant.
}
func (t *Tree) newCommand(pos Pos) *CommandNode {
return &CommandNode{tr: t, NodeType: NodeCommand, Pos: pos}
}
func (c *CommandNode) append(arg Node) {
c.Args = append(c.Args, arg)
}
func (c *CommandNode) String() string {
s := ""
for i, arg := range c.Args {
if i > 0 {
s += " "
}
if arg, ok := arg.(*PipeNode); ok {
s += "(" + arg.String() + ")"
continue
}
s += arg.String()
}
return s
}
func (c *CommandNode) tree() *Tree {
return c.tr
}
func (c *CommandNode) Copy() Node {
if c == nil {
return c
}
n := c.tr.newCommand(c.Pos)
for _, c := range c.Args {
n.append(c.Copy())
}
return n
}
// IdentifierNode holds an identifier.
type IdentifierNode struct {
NodeType
Pos
tr *Tree
Ident string // The identifier's name.
}
// NewIdentifier returns a new IdentifierNode with the given identifier name.
func NewIdentifier(ident string) *IdentifierNode {
return &IdentifierNode{NodeType: NodeIdentifier, Ident: ident}
}
// SetPos sets the position. NewIdentifier is a public method so we can't modify its signature.
// Chained for convenience.
// TODO: fix one day?
func (i *IdentifierNode) SetPos(pos Pos) *IdentifierNode {
i.Pos = pos
return i
}
// SetTree sets the parent tree for the node. NewIdentifier is a public method so we can't modify its signature.
// Chained for convenience.
// TODO: fix one day?
func (i *IdentifierNode) SetTree(t *Tree) *IdentifierNode {
i.tr = t
return i
}
func (i *IdentifierNode) String() string {
return i.Ident
}
func (i *IdentifierNode) tree() *Tree {
return i.tr
}
func (i *IdentifierNode) Copy() Node {
return NewIdentifier(i.Ident).SetTree(i.tr).SetPos(i.Pos)
}
// VariableNode holds a list of variable names, possibly with chained field
// accesses. The dollar sign is part of the (first) name.
type VariableNode struct {
NodeType
Pos
tr *Tree
Ident []string // Variable name and fields in lexical order.
}
func (t *Tree) newVariable(pos Pos, ident string) *VariableNode {
return &VariableNode{tr: t, NodeType: NodeVariable, Pos: pos, Ident: strings.Split(ident, ".")}
}
func (v *VariableNode) String() string {
s := ""
for i, id := range v.Ident {
if i > 0 {
s += "."
}
s += id
}
return s
}
func (v *VariableNode) tree() *Tree {
return v.tr
}
func (v *VariableNode) Copy() Node {
return &VariableNode{tr: v.tr, NodeType: NodeVariable, Pos: v.Pos, Ident: append([]string{}, v.Ident...)}
}
// DotNode holds the special identifier '.'.
type DotNode struct {
NodeType
Pos
tr *Tree
}
func (t *Tree) newDot(pos Pos) *DotNode {
return &DotNode{tr: t, NodeType: NodeDot, Pos: pos}
}
func (d *DotNode) Type() NodeType {
// Override method on embedded NodeType for API compatibility.
// TODO: Not really a problem; could change API without effect but
// api tool complains.
return NodeDot
}
func (d *DotNode) String() string {
return "."
}
func (d *DotNode) tree() *Tree {
return d.tr
}
func (d *DotNode) Copy() Node {
return d.tr.newDot(d.Pos)
}
// NilNode holds the special identifier 'nil' representing an untyped nil constant.
type NilNode struct {
NodeType
Pos
tr *Tree
}
func (t *Tree) newNil(pos Pos) *NilNode {
return &NilNode{tr: t, NodeType: NodeNil, Pos: pos}
}
func (n *NilNode) Type() NodeType {
// Override method on embedded NodeType for API compatibility.
// TODO: Not really a problem; could change API without effect but
// api tool complains.
return NodeNil
}
func (n *NilNode) String() string {
return "nil"
}
func (n *NilNode) tree() *Tree {
return n.tr
}
func (n *NilNode) Copy() Node {
return n.tr.newNil(n.Pos)
}
// FieldNode holds a field (identifier starting with '.').
// The names may be chained ('.x.y').
// The period is dropped from each ident.
type FieldNode struct {
NodeType
Pos
tr *Tree
Ident []string // The identifiers in lexical order.
}
func (t *Tree) newField(pos Pos, ident string) *FieldNode {
return &FieldNode{tr: t, NodeType: NodeField, Pos: pos, Ident: strings.Split(ident[1:], ".")} // [1:] to drop leading period
}
func (f *FieldNode) String() string {
s := ""
for _, id := range f.Ident {
s += "." + id
}
return s
}
func (f *FieldNode) tree() *Tree {
return f.tr
}
func (f *FieldNode) Copy() Node {
return &FieldNode{tr: f.tr, NodeType: NodeField, Pos: f.Pos, Ident: append([]string{}, f.Ident...)}
}
// ChainNode holds a term followed by a chain of field accesses (identifier starting with '.').
// The names may be chained ('.x.y').
// The periods are dropped from each ident.
type ChainNode struct {
NodeType
Pos
tr *Tree
Node Node
Field []string // The identifiers in lexical order.
}
func (t *Tree) newChain(pos Pos, node Node) *ChainNode {
return &ChainNode{tr: t, NodeType: NodeChain, Pos: pos, Node: node}
}
// Add adds the named field (which should start with a period) to the end of the chain.
func (c *ChainNode) Add(field string) {
if len(field) == 0 || field[0] != '.' {
panic("no dot in field")
}
field = field[1:] // Remove leading dot.
if field == "" {
panic("empty field")
}
c.Field = append(c.Field, field)
}
func (c *ChainNode) String() string {
s := c.Node.String()
if _, ok := c.Node.(*PipeNode); ok {
s = "(" + s + ")"
}
for _, field := range c.Field {
s += "." + field
}
return s
}
func (c *ChainNode) tree() *Tree {
return c.tr
}
func (c *ChainNode) Copy() Node {
return &ChainNode{tr: c.tr, NodeType: NodeChain, Pos: c.Pos, Node: c.Node, Field: append([]string{}, c.Field...)}
}
// BoolNode holds a boolean constant.
type BoolNode struct {
NodeType
Pos
tr *Tree
True bool // The value of the boolean constant.
}
func (t *Tree) newBool(pos Pos, true bool) *BoolNode {
return &BoolNode{tr: t, NodeType: NodeBool, Pos: pos, True: true}
}
func (b *BoolNode) String() string {
if b.True {
return "true"
}
return "false"
}
func (b *BoolNode) tree() *Tree {
return b.tr
}
func (b *BoolNode) Copy() Node {
return b.tr.newBool(b.Pos, b.True)
}
// NumberNode holds a number: signed or unsigned integer, float, or complex.
// The value is parsed and stored under all the types that can represent the value.
// This simulates in a small amount of code the behavior of Go's ideal constants.
type NumberNode struct {
NodeType
Pos
tr *Tree
IsInt bool // Number has an integral value.
IsUint bool // Number has an unsigned integral value.
IsFloat bool // Number has a floating-point value.
IsComplex bool // Number is complex.
Int64 int64 // The signed integer value.
Uint64 uint64 // The unsigned integer value.
Float64 float64 // The floating-point value.
Complex128 complex128 // The complex value.
Text string // The original textual representation from the input.
}
func (t *Tree) newNumber(pos Pos, text string, typ itemType) (*NumberNode, error) {
n := &NumberNode{tr: t, NodeType: NodeNumber, Pos: pos, Text: text}
switch typ {
case itemCharConstant:
rune, _, tail, err := strconv.UnquoteChar(text[1:], text[0])
if err != nil {
return nil, err
}
if tail != "'" {
return nil, fmt.Errorf("malformed character constant: %s", text)
}
n.Int64 = int64(rune)
n.IsInt = true
n.Uint64 = uint64(rune)
n.IsUint = true
n.Float64 = float64(rune) // odd but those are the rules.
n.IsFloat = true
return n, nil
case itemComplex:
// fmt.Sscan can parse the pair, so let it do the work.
if _, err := fmt.Sscan(text, &n.Complex128); err != nil {
return nil, err
}
n.IsComplex = true
n.simplifyComplex()
return n, nil
}
// Imaginary constants can only be complex unless they are zero.
if len(text) > 0 && text[len(text)-1] == 'i' {
f, err := strconv.ParseFloat(text[:len(text)-1], 64)
if err == nil {
n.IsComplex = true
n.Complex128 = complex(0, f)
n.simplifyComplex()
return n, nil
}
}
// Do integer test first so we get 0x123 etc.
u, err := strconv.ParseUint(text, 0, 64) // will fail for -0; fixed below.
if err == nil {
n.IsUint = true
n.Uint64 = u
}
i, err := strconv.ParseInt(text, 0, 64)
if err == nil {
n.IsInt = true
n.Int64 = i
if i == 0 {
n.IsUint = true // in case of -0.
n.Uint64 = u
}
}
// If an integer extraction succeeded, promote the float.
if n.IsInt {
n.IsFloat = true
n.Float64 = float64(n.Int64)
} else if n.IsUint {
n.IsFloat = true
n.Float64 = float64(n.Uint64)
} else {
f, err := strconv.ParseFloat(text, 64)
if err == nil {
n.IsFloat = true
n.Float64 = f
// If a floating-point extraction succeeded, extract the int if needed.
if !n.IsInt && float64(int64(f)) == f {
n.IsInt = true
n.Int64 = int64(f)
}
if !n.IsUint && float64(uint64(f)) == f {
n.IsUint = true
n.Uint64 = uint64(f)
}
}
}
if !n.IsInt && !n.IsUint && !n.IsFloat {
return nil, fmt.Errorf("illegal number syntax: %q", text)
}
return n, nil
}
// simplifyComplex pulls out any other types that are represented by the complex number.
// These all require that the imaginary part be zero.
func (n *NumberNode) simplifyComplex() {
n.IsFloat = imag(n.Complex128) == 0
if n.IsFloat {
n.Float64 = real(n.Complex128)
n.IsInt = float64(int64(n.Float64)) == n.Float64
if n.IsInt {
n.Int64 = int64(n.Float64)
}
n.IsUint = float64(uint64(n.Float64)) == n.Float64
if n.IsUint {
n.Uint64 = uint64(n.Float64)
}
}
}
func (n *NumberNode) String() string {
return n.Text
}
func (n *NumberNode) tree() *Tree {
return n.tr
}
func (n *NumberNode) Copy() Node {
nn := new(NumberNode)
*nn = *n // Easy, fast, correct.
return nn
}
// StringNode holds a string constant. The value has been "unquoted".
type StringNode struct {
NodeType
Pos
tr *Tree
Quoted string // The original text of the string, with quotes.
Text string // The string, after quote processing.
}
func (t *Tree) newString(pos Pos, orig, text string) *StringNode {
return &StringNode{tr: t, NodeType: NodeString, Pos: pos, Quoted: orig, Text: text}
}
func (s *StringNode) String() string {
return s.Quoted
}
func (s *StringNode) tree() *Tree {
return s.tr
}
func (s *StringNode) Copy() Node {
return s.tr.newString(s.Pos, s.Quoted, s.Text)
}
// endNode represents an {{end}} action.
// It does not appear in the final parse tree.
type endNode struct {
NodeType
Pos
tr *Tree
}
func (t *Tree) newEnd(pos Pos) *endNode {
return &endNode{tr: t, NodeType: nodeEnd, Pos: pos}
}
func (e *endNode) String() string {
return "{{end}}"
}
func (e *endNode) tree() *Tree {
return e.tr
}
func (e *endNode) Copy() Node {
return e.tr.newEnd(e.Pos)
}
// elseNode represents an {{else}} action. Does not appear in the final tree.
type elseNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input (deprecated; kept for compatibility)
}
func (t *Tree) newElse(pos Pos, line int) *elseNode {
return &elseNode{tr: t, NodeType: nodeElse, Pos: pos, Line: line}
}
func (e *elseNode) Type() NodeType {
return nodeElse
}
func (e *elseNode) String() string {
return "{{else}}"
}
func (e *elseNode) tree() *Tree {
return e.tr
}
func (e *elseNode) Copy() Node {
return e.tr.newElse(e.Pos, e.Line)
}
// BranchNode is the common representation of if, range, and with.
type BranchNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input (deprecated; kept for compatibility)
Pipe *PipeNode // The pipeline to be evaluated.
List *ListNode // What to execute if the value is non-empty.
ElseList *ListNode // What to execute if the value is empty (nil if absent).
}
func (b *BranchNode) String() string {
name := ""
switch b.NodeType {
case NodeIf:
name = "if"
case NodeRange:
name = "range"
case NodeWith:
name = "with"
default:
panic("unknown branch type")
}
if b.ElseList != nil {
return fmt.Sprintf("{{%s %s}}%s{{else}}%s{{end}}", name, b.Pipe, b.List, b.ElseList)
}
return fmt.Sprintf("{{%s %s}}%s{{end}}", name, b.Pipe, b.List)
}
func (b *BranchNode) tree() *Tree {
return b.tr
}
func (b *BranchNode) Copy() Node {
switch b.NodeType {
case NodeIf:
return b.tr.newIf(b.Pos, b.Line, b.Pipe, b.List, b.ElseList)
case NodeRange:
return b.tr.newRange(b.Pos, b.Line, b.Pipe, b.List, b.ElseList)
case NodeWith:
return b.tr.newWith(b.Pos, b.Line, b.Pipe, b.List, b.ElseList)
default:
panic("unknown branch type")
}
}
// IfNode represents an {{if}} action and its commands.
type IfNode struct {
BranchNode
}
func (t *Tree) newIf(pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) *IfNode {
return &IfNode{BranchNode{tr: t, NodeType: NodeIf, Pos: pos, Line: line, Pipe: pipe, List: list, ElseList: elseList}}
}
func (i *IfNode) Copy() Node {
return i.tr.newIf(i.Pos, i.Line, i.Pipe.CopyPipe(), i.List.CopyList(), i.ElseList.CopyList())
}
// RangeNode represents a {{range}} action and its commands.
type RangeNode struct {
BranchNode
}
func (t *Tree) newRange(pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) *RangeNode {
return &RangeNode{BranchNode{tr: t, NodeType: NodeRange, Pos: pos, Line: line, Pipe: pipe, List: list, ElseList: elseList}}
}
func (r *RangeNode) Copy() Node {
return r.tr.newRange(r.Pos, r.Line, r.Pipe.CopyPipe(), r.List.CopyList(), r.ElseList.CopyList())
}
// WithNode represents a {{with}} action and its commands.
type WithNode struct {
BranchNode
}
func (t *Tree) newWith(pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) *WithNode {
return &WithNode{BranchNode{tr: t, NodeType: NodeWith, Pos: pos, Line: line, Pipe: pipe, List: list, ElseList: elseList}}
}
func (w *WithNode) Copy() Node {
return w.tr.newWith(w.Pos, w.Line, w.Pipe.CopyPipe(), w.List.CopyList(), w.ElseList.CopyList())
}
// TemplateNode represents a {{template}} action.
type TemplateNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input (deprecated; kept for compatibility)
Name string // The name of the template (unquoted).
Pipe *PipeNode // The command to evaluate as dot for the template.
}
func (t *Tree) newTemplate(pos Pos, line int, name string, pipe *PipeNode) *TemplateNode {
return &TemplateNode{tr: t, NodeType: NodeTemplate, Pos: pos, Line: line, Name: name, Pipe: pipe}
}
func (t *TemplateNode) String() string {
if t.Pipe == nil {
return fmt.Sprintf("{{template %q}}", t.Name)
}
return fmt.Sprintf("{{template %q %s}}", t.Name, t.Pipe)
}
func (t *TemplateNode) tree() *Tree {
return t.tr
}
func (t *TemplateNode) Copy() Node {
return t.tr.newTemplate(t.Pos, t.Line, t.Name, t.Pipe.CopyPipe())
}

View File

@@ -1,700 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package parse builds parse trees for templates as defined by text/template
// and html/template. Clients should use those packages to construct templates
// rather than this one, which provides shared internal data structures not
// intended for general use.
package parse
import (
"bytes"
"fmt"
"runtime"
"strconv"
"strings"
)
// Tree is the representation of a single parsed template.
type Tree struct {
Name string // name of the template represented by the tree.
ParseName string // name of the top-level template during parsing, for error messages.
Root *ListNode // top-level root of the tree.
text string // text parsed to create the template (or its parent)
// Parsing only; cleared after parse.
funcs []map[string]interface{}
lex *lexer
token [3]item // three-token lookahead for parser.
peekCount int
vars []string // variables defined at the moment.
}
// Copy returns a copy of the Tree. Any parsing state is discarded.
func (t *Tree) Copy() *Tree {
if t == nil {
return nil
}
return &Tree{
Name: t.Name,
ParseName: t.ParseName,
Root: t.Root.CopyList(),
text: t.text,
}
}
// Parse returns a map from template name to parse.Tree, created by parsing the
// templates described in the argument string. The top-level template will be
// given the specified name. If an error is encountered, parsing stops and an
// empty map is returned with the error.
func Parse(name, text, leftDelim, rightDelim string, funcs ...map[string]interface{}) (treeSet map[string]*Tree, err error) {
treeSet = make(map[string]*Tree)
t := New(name)
t.text = text
_, err = t.Parse(text, leftDelim, rightDelim, treeSet, funcs...)
return
}
// next returns the next token.
func (t *Tree) next() item {
if t.peekCount > 0 {
t.peekCount--
} else {
t.token[0] = t.lex.nextItem()
}
return t.token[t.peekCount]
}
// backup backs the input stream up one token.
func (t *Tree) backup() {
t.peekCount++
}
// backup2 backs the input stream up two tokens.
// The zeroth token is already there.
func (t *Tree) backup2(t1 item) {
t.token[1] = t1
t.peekCount = 2
}
// backup3 backs the input stream up three tokens
// The zeroth token is already there.
func (t *Tree) backup3(t2, t1 item) { // Reverse order: we're pushing back.
t.token[1] = t1
t.token[2] = t2
t.peekCount = 3
}
// peek returns but does not consume the next token.
func (t *Tree) peek() item {
if t.peekCount > 0 {
return t.token[t.peekCount-1]
}
t.peekCount = 1
t.token[0] = t.lex.nextItem()
return t.token[0]
}
// nextNonSpace returns the next non-space token.
func (t *Tree) nextNonSpace() (token item) {
for {
token = t.next()
if token.typ != itemSpace {
break
}
}
return token
}
// peekNonSpace returns but does not consume the next non-space token.
func (t *Tree) peekNonSpace() (token item) {
for {
token = t.next()
if token.typ != itemSpace {
break
}
}
t.backup()
return token
}
// Parsing.
// New allocates a new parse tree with the given name.
func New(name string, funcs ...map[string]interface{}) *Tree {
return &Tree{
Name: name,
funcs: funcs,
}
}
// ErrorContext returns a textual representation of the location of the node in the input text.
// The receiver is only used when the node does not have a pointer to the tree inside,
// which can occur in old code.
func (t *Tree) ErrorContext(n Node) (location, context string) {
pos := int(n.Position())
tree := n.tree()
if tree == nil {
tree = t
}
text := tree.text[:pos]
byteNum := strings.LastIndex(text, "\n")
if byteNum == -1 {
byteNum = pos // On first line.
} else {
byteNum++ // After the newline.
byteNum = pos - byteNum
}
lineNum := 1 + strings.Count(text, "\n")
context = n.String()
if len(context) > 20 {
context = fmt.Sprintf("%.20s...", context)
}
return fmt.Sprintf("%s:%d:%d", tree.ParseName, lineNum, byteNum), context
}
// errorf formats the error and terminates processing.
func (t *Tree) errorf(format string, args ...interface{}) {
t.Root = nil
format = fmt.Sprintf("template: %s:%d: %s", t.ParseName, t.lex.lineNumber(), format)
panic(fmt.Errorf(format, args...))
}
// error terminates processing.
func (t *Tree) error(err error) {
t.errorf("%s", err)
}
// expect consumes the next token and guarantees it has the required type.
func (t *Tree) expect(expected itemType, context string) item {
token := t.nextNonSpace()
if token.typ != expected {
t.unexpected(token, context)
}
return token
}
// expectOneOf consumes the next token and guarantees it has one of the required types.
func (t *Tree) expectOneOf(expected1, expected2 itemType, context string) item {
token := t.nextNonSpace()
if token.typ != expected1 && token.typ != expected2 {
t.unexpected(token, context)
}
return token
}
// unexpected complains about the token and terminates processing.
func (t *Tree) unexpected(token item, context string) {
t.errorf("unexpected %s in %s", token, context)
}
// recover is the handler that turns panics into returns from the top level of Parse.
func (t *Tree) recover(errp *error) {
e := recover()
if e != nil {
if _, ok := e.(runtime.Error); ok {
panic(e)
}
if t != nil {
t.stopParse()
}
*errp = e.(error)
}
return
}
// startParse initializes the parser, using the lexer.
func (t *Tree) startParse(funcs []map[string]interface{}, lex *lexer) {
t.Root = nil
t.lex = lex
t.vars = []string{"$"}
t.funcs = funcs
}
// stopParse terminates parsing.
func (t *Tree) stopParse() {
t.lex = nil
t.vars = nil
t.funcs = nil
}
// Parse parses the template definition string to construct a representation of
// the template for execution. If either action delimiter string is empty, the
// default ("{{" or "}}") is used. Embedded template definitions are added to
// the treeSet map.
func (t *Tree) Parse(text, leftDelim, rightDelim string, treeSet map[string]*Tree, funcs ...map[string]interface{}) (tree *Tree, err error) {
defer t.recover(&err)
t.ParseName = t.Name
t.startParse(funcs, lex(t.Name, text, leftDelim, rightDelim))
t.text = text
t.parse(treeSet)
t.add(treeSet)
t.stopParse()
return t, nil
}
// add adds tree to the treeSet.
func (t *Tree) add(treeSet map[string]*Tree) {
tree := treeSet[t.Name]
if tree == nil || IsEmptyTree(tree.Root) {
treeSet[t.Name] = t
return
}
if !IsEmptyTree(t.Root) {
t.errorf("template: multiple definition of template %q", t.Name)
}
}
// IsEmptyTree reports whether this tree (node) is empty of everything but space.
func IsEmptyTree(n Node) bool {
switch n := n.(type) {
case nil:
return true
case *ActionNode:
case *IfNode:
case *ListNode:
for _, node := range n.Nodes {
if !IsEmptyTree(node) {
return false
}
}
return true
case *RangeNode:
case *TemplateNode:
case *TextNode:
return len(bytes.TrimSpace(n.Text)) == 0
case *WithNode:
default:
panic("unknown node: " + n.String())
}
return false
}
// parse is the top-level parser for a template, essentially the same
// as itemList except it also parses {{define}} actions.
// It runs to EOF.
func (t *Tree) parse(treeSet map[string]*Tree) (next Node) {
t.Root = t.newList(t.peek().pos)
for t.peek().typ != itemEOF {
if t.peek().typ == itemLeftDelim {
delim := t.next()
if t.nextNonSpace().typ == itemDefine {
newT := New("definition") // name will be updated once we know it.
newT.text = t.text
newT.ParseName = t.ParseName
newT.startParse(t.funcs, t.lex)
newT.parseDefinition(treeSet)
continue
}
t.backup2(delim)
}
n := t.textOrAction()
if n.Type() == nodeEnd {
t.errorf("unexpected %s", n)
}
t.Root.append(n)
}
return nil
}
// parseDefinition parses a {{define}} ... {{end}} template definition and
// installs the definition in the treeSet map. The "define" keyword has already
// been scanned.
func (t *Tree) parseDefinition(treeSet map[string]*Tree) {
const context = "define clause"
name := t.expectOneOf(itemString, itemRawString, context)
var err error
t.Name, err = strconv.Unquote(name.val)
if err != nil {
t.error(err)
}
t.expect(itemRightDelim, context)
var end Node
t.Root, end = t.itemList()
if end.Type() != nodeEnd {
t.errorf("unexpected %s in %s", end, context)
}
t.add(treeSet)
t.stopParse()
}
// itemList:
// textOrAction*
// Terminates at {{end}} or {{else}}, returned separately.
func (t *Tree) itemList() (list *ListNode, next Node) {
list = t.newList(t.peekNonSpace().pos)
for t.peekNonSpace().typ != itemEOF {
n := t.textOrAction()
switch n.Type() {
case nodeEnd, nodeElse:
return list, n
}
list.append(n)
}
t.errorf("unexpected EOF")
return
}
// textOrAction:
// text | action
func (t *Tree) textOrAction() Node {
switch token := t.nextNonSpace(); token.typ {
case itemElideNewline:
return t.elideNewline()
case itemText:
return t.newText(token.pos, token.val)
case itemLeftDelim:
return t.action()
default:
t.unexpected(token, "input")
}
return nil
}
// elideNewline:
// Remove newlines trailing rightDelim if \\ is present.
func (t *Tree) elideNewline() Node {
token := t.peek()
if token.typ != itemText {
t.unexpected(token, "input")
return nil
}
t.next()
stripped := strings.TrimLeft(token.val, "\n\r")
diff := len(token.val) - len(stripped)
if diff > 0 {
// This is a bit nasty. We mutate the token in-place to remove
// preceding newlines.
token.pos += Pos(diff)
token.val = stripped
}
return t.newText(token.pos, token.val)
}
// Action:
// control
// command ("|" command)*
// Left delim is past. Now get actions.
// First word could be a keyword such as range.
func (t *Tree) action() (n Node) {
switch token := t.nextNonSpace(); token.typ {
case itemElse:
return t.elseControl()
case itemEnd:
return t.endControl()
case itemIf:
return t.ifControl()
case itemRange:
return t.rangeControl()
case itemTemplate:
return t.templateControl()
case itemWith:
return t.withControl()
}
t.backup()
// Do not pop variables; they persist until "end".
return t.newAction(t.peek().pos, t.lex.lineNumber(), t.pipeline("command"))
}
// Pipeline:
// declarations? command ('|' command)*
func (t *Tree) pipeline(context string) (pipe *PipeNode) {
var decl []*VariableNode
pos := t.peekNonSpace().pos
// Are there declarations?
for {
if v := t.peekNonSpace(); v.typ == itemVariable {
t.next()
// Since space is a token, we need 3-token look-ahead here in the worst case:
// in "$x foo" we need to read "foo" (as opposed to ":=") to know that $x is an
// argument variable rather than a declaration. So remember the token
// adjacent to the variable so we can push it back if necessary.
tokenAfterVariable := t.peek()
if next := t.peekNonSpace(); next.typ == itemColonEquals || (next.typ == itemChar && next.val == ",") {
t.nextNonSpace()
variable := t.newVariable(v.pos, v.val)
decl = append(decl, variable)
t.vars = append(t.vars, v.val)
if next.typ == itemChar && next.val == "," {
if context == "range" && len(decl) < 2 {
continue
}
t.errorf("too many declarations in %s", context)
}
} else if tokenAfterVariable.typ == itemSpace {
t.backup3(v, tokenAfterVariable)
} else {
t.backup2(v)
}
}
break
}
pipe = t.newPipeline(pos, t.lex.lineNumber(), decl)
for {
switch token := t.nextNonSpace(); token.typ {
case itemRightDelim, itemRightParen:
if len(pipe.Cmds) == 0 {
t.errorf("missing value for %s", context)
}
if token.typ == itemRightParen {
t.backup()
}
return
case itemBool, itemCharConstant, itemComplex, itemDot, itemField, itemIdentifier,
itemNumber, itemNil, itemRawString, itemString, itemVariable, itemLeftParen:
t.backup()
pipe.append(t.command())
default:
t.unexpected(token, context)
}
}
}
func (t *Tree) parseControl(allowElseIf bool, context string) (pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) {
defer t.popVars(len(t.vars))
line = t.lex.lineNumber()
pipe = t.pipeline(context)
var next Node
list, next = t.itemList()
switch next.Type() {
case nodeEnd: //done
case nodeElse:
if allowElseIf {
// Special case for "else if". If the "else" is followed immediately by an "if",
// the elseControl will have left the "if" token pending. Treat
// {{if a}}_{{else if b}}_{{end}}
// as
// {{if a}}_{{else}}{{if b}}_{{end}}{{end}}.
// To do this, parse the if as usual and stop at it {{end}}; the subsequent{{end}}
// is assumed. This technique works even for long if-else-if chains.
// TODO: Should we allow else-if in with and range?
if t.peek().typ == itemIf {
t.next() // Consume the "if" token.
elseList = t.newList(next.Position())
elseList.append(t.ifControl())
// Do not consume the next item - only one {{end}} required.
break
}
}
elseList, next = t.itemList()
if next.Type() != nodeEnd {
t.errorf("expected end; found %s", next)
}
}
return pipe.Position(), line, pipe, list, elseList
}
// If:
// {{if pipeline}} itemList {{end}}
// {{if pipeline}} itemList {{else}} itemList {{end}}
// If keyword is past.
func (t *Tree) ifControl() Node {
return t.newIf(t.parseControl(true, "if"))
}
// Range:
// {{range pipeline}} itemList {{end}}
// {{range pipeline}} itemList {{else}} itemList {{end}}
// Range keyword is past.
func (t *Tree) rangeControl() Node {
return t.newRange(t.parseControl(false, "range"))
}
// With:
// {{with pipeline}} itemList {{end}}
// {{with pipeline}} itemList {{else}} itemList {{end}}
// If keyword is past.
func (t *Tree) withControl() Node {
return t.newWith(t.parseControl(false, "with"))
}
// End:
// {{end}}
// End keyword is past.
func (t *Tree) endControl() Node {
return t.newEnd(t.expect(itemRightDelim, "end").pos)
}
// Else:
// {{else}}
// Else keyword is past.
func (t *Tree) elseControl() Node {
// Special case for "else if".
peek := t.peekNonSpace()
if peek.typ == itemIf {
// We see "{{else if ... " but in effect rewrite it to {{else}}{{if ... ".
return t.newElse(peek.pos, t.lex.lineNumber())
}
return t.newElse(t.expect(itemRightDelim, "else").pos, t.lex.lineNumber())
}
// Template:
// {{template stringValue pipeline}}
// Template keyword is past. The name must be something that can evaluate
// to a string.
func (t *Tree) templateControl() Node {
var name string
token := t.nextNonSpace()
switch token.typ {
case itemString, itemRawString:
s, err := strconv.Unquote(token.val)
if err != nil {
t.error(err)
}
name = s
default:
t.unexpected(token, "template invocation")
}
var pipe *PipeNode
if t.nextNonSpace().typ != itemRightDelim {
t.backup()
// Do not pop variables; they persist until "end".
pipe = t.pipeline("template")
}
return t.newTemplate(token.pos, t.lex.lineNumber(), name, pipe)
}
// command:
// operand (space operand)*
// space-separated arguments up to a pipeline character or right delimiter.
// we consume the pipe character but leave the right delim to terminate the action.
func (t *Tree) command() *CommandNode {
cmd := t.newCommand(t.peekNonSpace().pos)
for {
t.peekNonSpace() // skip leading spaces.
operand := t.operand()
if operand != nil {
cmd.append(operand)
}
switch token := t.next(); token.typ {
case itemSpace:
continue
case itemError:
t.errorf("%s", token.val)
case itemRightDelim, itemRightParen:
t.backup()
case itemPipe:
default:
t.errorf("unexpected %s in operand; missing space?", token)
}
break
}
if len(cmd.Args) == 0 {
t.errorf("empty command")
}
return cmd
}
// operand:
// term .Field*
// An operand is a space-separated component of a command,
// a term possibly followed by field accesses.
// A nil return means the next item is not an operand.
func (t *Tree) operand() Node {
node := t.term()
if node == nil {
return nil
}
if t.peek().typ == itemField {
chain := t.newChain(t.peek().pos, node)
for t.peek().typ == itemField {
chain.Add(t.next().val)
}
// Compatibility with original API: If the term is of type NodeField
// or NodeVariable, just put more fields on the original.
// Otherwise, keep the Chain node.
// TODO: Switch to Chains always when we can.
switch node.Type() {
case NodeField:
node = t.newField(chain.Position(), chain.String())
case NodeVariable:
node = t.newVariable(chain.Position(), chain.String())
default:
node = chain
}
}
return node
}
// term:
// literal (number, string, nil, boolean)
// function (identifier)
// .
// .Field
// $
// '(' pipeline ')'
// A term is a simple "expression".
// A nil return means the next item is not a term.
func (t *Tree) term() Node {
switch token := t.nextNonSpace(); token.typ {
case itemError:
t.errorf("%s", token.val)
case itemIdentifier:
if !t.hasFunction(token.val) {
t.errorf("function %q not defined", token.val)
}
return NewIdentifier(token.val).SetTree(t).SetPos(token.pos)
case itemDot:
return t.newDot(token.pos)
case itemNil:
return t.newNil(token.pos)
case itemVariable:
return t.useVar(token.pos, token.val)
case itemField:
return t.newField(token.pos, token.val)
case itemBool:
return t.newBool(token.pos, token.val == "true")
case itemCharConstant, itemComplex, itemNumber:
number, err := t.newNumber(token.pos, token.val, token.typ)
if err != nil {
t.error(err)
}
return number
case itemLeftParen:
pipe := t.pipeline("parenthesized pipeline")
if token := t.next(); token.typ != itemRightParen {
t.errorf("unclosed right paren: unexpected %s", token)
}
return pipe
case itemString, itemRawString:
s, err := strconv.Unquote(token.val)
if err != nil {
t.error(err)
}
return t.newString(token.pos, token.val, s)
}
t.backup()
return nil
}
// hasFunction reports if a function name exists in the Tree's maps.
func (t *Tree) hasFunction(name string) bool {
for _, funcMap := range t.funcs {
if funcMap == nil {
continue
}
if funcMap[name] != nil {
return true
}
}
return false
}
// popVars trims the variable list to the specified length
func (t *Tree) popVars(n int) {
t.vars = t.vars[:n]
}
// useVar returns a node for a variable reference. It errors if the
// variable is not defined.
func (t *Tree) useVar(pos Pos, name string) Node {
v := t.newVariable(pos, name)
for _, varName := range t.vars {
if varName == v.Ident[0] {
return v
}
}
t.errorf("undefined variable %q", v.Ident[0])
return nil
}

View File

@@ -1,218 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"reflect"
"github.com/alecthomas/template/parse"
)
// common holds the information shared by related templates.
type common struct {
tmpl map[string]*Template
// We use two maps, one for parsing and one for execution.
// This separation makes the API cleaner since it doesn't
// expose reflection to the client.
parseFuncs FuncMap
execFuncs map[string]reflect.Value
}
// Template is the representation of a parsed template. The *parse.Tree
// field is exported only for use by html/template and should be treated
// as unexported by all other clients.
type Template struct {
name string
*parse.Tree
*common
leftDelim string
rightDelim string
}
// New allocates a new template with the given name.
func New(name string) *Template {
return &Template{
name: name,
}
}
// Name returns the name of the template.
func (t *Template) Name() string {
return t.name
}
// New allocates a new template associated with the given one and with the same
// delimiters. The association, which is transitive, allows one template to
// invoke another with a {{template}} action.
func (t *Template) New(name string) *Template {
t.init()
return &Template{
name: name,
common: t.common,
leftDelim: t.leftDelim,
rightDelim: t.rightDelim,
}
}
func (t *Template) init() {
if t.common == nil {
t.common = new(common)
t.tmpl = make(map[string]*Template)
t.parseFuncs = make(FuncMap)
t.execFuncs = make(map[string]reflect.Value)
}
}
// Clone returns a duplicate of the template, including all associated
// templates. The actual representation is not copied, but the name space of
// associated templates is, so further calls to Parse in the copy will add
// templates to the copy but not to the original. Clone can be used to prepare
// common templates and use them with variant definitions for other templates
// by adding the variants after the clone is made.
func (t *Template) Clone() (*Template, error) {
nt := t.copy(nil)
nt.init()
nt.tmpl[t.name] = nt
for k, v := range t.tmpl {
if k == t.name { // Already installed.
continue
}
// The associated templates share nt's common structure.
tmpl := v.copy(nt.common)
nt.tmpl[k] = tmpl
}
for k, v := range t.parseFuncs {
nt.parseFuncs[k] = v
}
for k, v := range t.execFuncs {
nt.execFuncs[k] = v
}
return nt, nil
}
// copy returns a shallow copy of t, with common set to the argument.
func (t *Template) copy(c *common) *Template {
nt := New(t.name)
nt.Tree = t.Tree
nt.common = c
nt.leftDelim = t.leftDelim
nt.rightDelim = t.rightDelim
return nt
}
// AddParseTree creates a new template with the name and parse tree
// and associates it with t.
func (t *Template) AddParseTree(name string, tree *parse.Tree) (*Template, error) {
if t.common != nil && t.tmpl[name] != nil {
return nil, fmt.Errorf("template: redefinition of template %q", name)
}
nt := t.New(name)
nt.Tree = tree
t.tmpl[name] = nt
return nt, nil
}
// Templates returns a slice of the templates associated with t, including t
// itself.
func (t *Template) Templates() []*Template {
if t.common == nil {
return nil
}
// Return a slice so we don't expose the map.
m := make([]*Template, 0, len(t.tmpl))
for _, v := range t.tmpl {
m = append(m, v)
}
return m
}
// Delims sets the action delimiters to the specified strings, to be used in
// subsequent calls to Parse, ParseFiles, or ParseGlob. Nested template
// definitions will inherit the settings. An empty delimiter stands for the
// corresponding default: {{ or }}.
// The return value is the template, so calls can be chained.
func (t *Template) Delims(left, right string) *Template {
t.leftDelim = left
t.rightDelim = right
return t
}
// Funcs adds the elements of the argument map to the template's function map.
// It panics if a value in the map is not a function with appropriate return
// type. However, it is legal to overwrite elements of the map. The return
// value is the template, so calls can be chained.
func (t *Template) Funcs(funcMap FuncMap) *Template {
t.init()
addValueFuncs(t.execFuncs, funcMap)
addFuncs(t.parseFuncs, funcMap)
return t
}
// Lookup returns the template with the given name that is associated with t,
// or nil if there is no such template.
func (t *Template) Lookup(name string) *Template {
if t.common == nil {
return nil
}
return t.tmpl[name]
}
// Parse parses a string into a template. Nested template definitions will be
// associated with the top-level template t. Parse may be called multiple times
// to parse definitions of templates to associate with t. It is an error if a
// resulting template is non-empty (contains content other than template
// definitions) and would replace a non-empty template with the same name.
// (In multiple calls to Parse with the same receiver template, only one call
// can contain text other than space, comments, and template definitions.)
func (t *Template) Parse(text string) (*Template, error) {
t.init()
trees, err := parse.Parse(t.name, text, t.leftDelim, t.rightDelim, t.parseFuncs, builtins)
if err != nil {
return nil, err
}
// Add the newly parsed trees, including the one for t, into our common structure.
for name, tree := range trees {
// If the name we parsed is the name of this template, overwrite this template.
// The associate method checks it's not a redefinition.
tmpl := t
if name != t.name {
tmpl = t.New(name)
}
// Even if t == tmpl, we need to install it in the common.tmpl map.
if replace, err := t.associate(tmpl, tree); err != nil {
return nil, err
} else if replace {
tmpl.Tree = tree
}
tmpl.leftDelim = t.leftDelim
tmpl.rightDelim = t.rightDelim
}
return t, nil
}
// associate installs the new template into the group of templates associated
// with t. It is an error to reuse a name except to overwrite an empty
// template. The two are already known to share the common structure.
// The boolean return value reports wither to store this tree as t.Tree.
func (t *Template) associate(new *Template, tree *parse.Tree) (bool, error) {
if new.common != t.common {
panic("internal error: associate not common")
}
name := new.name
if old := t.tmpl[name]; old != nil {
oldIsEmpty := parse.IsEmptyTree(old.Root)
newIsEmpty := parse.IsEmptyTree(tree.Root)
if newIsEmpty {
// Whether old is empty or not, new is empty; no reason to replace old.
return false, nil
}
if !oldIsEmpty {
return false, fmt.Errorf("template: redefinition of template %q", name)
}
}
t.tmpl[name] = new
return true, nil
}

View File

@@ -0,0 +1,3 @@
// Package keywrap is an implementation of the RFC 3394 AES key wrapping
// algorithm. This is used in OpenPGP with elliptic curve keys.
package keywrap

View File

@@ -0,0 +1,151 @@
// Copyright 2014 Matthew Endsley
// All rights reserved
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted providing that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
// OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
package keywrap
import (
"crypto/aes"
"encoding/binary"
"errors"
)
var (
// ErrWrapPlaintext is returned if the plaintext is not a multiple
// of 64 bits.
ErrWrapPlaintext = errors.New("keywrap: plainText must be a multiple of 64 bits")
// ErrUnwrapCiphertext is returned if the ciphertext is not a
// multiple of 64 bits.
ErrUnwrapCiphertext = errors.New("keywrap: cipherText must by a multiple of 64 bits")
// ErrUnwrapFailed is returned if unwrapping a key fails.
ErrUnwrapFailed = errors.New("keywrap: failed to unwrap key")
// NB: the AES NewCipher call only fails if the key is an invalid length.
// ErrInvalidKey is returned when the AES key is invalid.
ErrInvalidKey = errors.New("keywrap: invalid AES key")
)
// Wrap a key using the RFC 3394 AES Key Wrap Algorithm.
func Wrap(key, plainText []byte) ([]byte, error) {
if len(plainText)%8 != 0 {
return nil, ErrWrapPlaintext
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, ErrInvalidKey
}
nblocks := len(plainText) / 8
// 1) Initialize variables.
var block [aes.BlockSize]byte
// - Set A = IV, an initial value (see 2.2.3)
for ii := 0; ii < 8; ii++ {
block[ii] = 0xA6
}
// - For i = 1 to n
// - Set R[i] = P[i]
intermediate := make([]byte, len(plainText))
copy(intermediate, plainText)
// 2) Calculate intermediate values.
for ii := 0; ii < 6; ii++ {
for jj := 0; jj < nblocks; jj++ {
// - B = AES(K, A | R[i])
copy(block[8:], intermediate[jj*8:jj*8+8])
c.Encrypt(block[:], block[:])
// - A = MSB(64, B) ^ t where t = (n*j)+1
t := uint64(ii*nblocks + jj + 1)
val := binary.BigEndian.Uint64(block[:8]) ^ t
binary.BigEndian.PutUint64(block[:8], val)
// - R[i] = LSB(64, B)
copy(intermediate[jj*8:jj*8+8], block[8:])
}
}
// 3) Output results.
// - Set C[0] = A
// - For i = 1 to n
// - C[i] = R[i]
return append(block[:8], intermediate...), nil
}
// Unwrap a key using the RFC 3394 AES Key Wrap Algorithm.
func Unwrap(key, cipherText []byte) ([]byte, error) {
if len(cipherText)%8 != 0 {
return nil, ErrUnwrapCiphertext
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, ErrInvalidKey
}
nblocks := len(cipherText)/8 - 1
// 1) Initialize variables.
var block [aes.BlockSize]byte
// - Set A = C[0]
copy(block[:8], cipherText[:8])
// - For i = 1 to n
// - Set R[i] = C[i]
intermediate := make([]byte, len(cipherText)-8)
copy(intermediate, cipherText[8:])
// 2) Compute intermediate values.
for jj := 5; jj >= 0; jj-- {
for ii := nblocks - 1; ii >= 0; ii-- {
// - B = AES-1(K, (A ^ t) | R[i]) where t = n*j+1
// - A = MSB(64, B)
t := uint64(jj*nblocks + ii + 1)
val := binary.BigEndian.Uint64(block[:8]) ^ t
binary.BigEndian.PutUint64(block[:8], val)
copy(block[8:], intermediate[ii*8:ii*8+8])
c.Decrypt(block[:], block[:])
// - R[i] = LSB(B, 64)
copy(intermediate[ii*8:ii*8+8], block[8:])
}
}
// 3) Output results.
// - If A is an appropriate initial value (see 2.2.3),
for ii := 0; ii < 8; ii++ {
if block[ii] != 0xA6 {
return nil, ErrUnwrapFailed
}
}
// - For i = 1 to n
// - P[i] = R[i]
return intermediate, nil
}

2
vendor/github.com/datarhei/gosrt/.dockerignore generated vendored Normal file
View File

@@ -0,0 +1,2 @@
Dockerfile
/.git

20
vendor/github.com/datarhei/gosrt/.editorconfig generated vendored Normal file
View File

@@ -0,0 +1,20 @@
# For more information about the properties used in
# this file, please see the EditorConfig documentation:
# http://editorconfig.org/
root = true
[*]
charset = utf-8
end_of_line = lf
indent_size = 4
indent_style = tab
insert_final_newline = true
trim_trailing_whitespace = true
spaces_around_brackets = outside
[*.md]
trim_trailing_whitespace = false
indent_style = space
[*.patch]
indent_style = space

7
vendor/github.com/datarhei/gosrt/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,7 @@
/contrib/server/server*
/contrib/client/client*
*.prof
*.out
*.html
*.ts
*.mp4

15
vendor/github.com/datarhei/gosrt/Dockerfile generated vendored Normal file
View File

@@ -0,0 +1,15 @@
ARG BUILD_IMAGE=golang:1.18.3-alpine3.16
FROM $BUILD_IMAGE as builder
COPY . /build
RUN cd /build/contrib/client && CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -a -o client .
RUN cd /build/contrib/server && CGO_ENABLED=0 GOOS=linux go build -ldflags="-w -s" -a -o server .
FROM scratch
COPY --from=builder /build/contrib/client/client /bin/srt-client
COPY --from=builder /build/contrib/server/server /bin/srt-server
WORKDIR /srt

21
vendor/github.com/datarhei/gosrt/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020-2022 FOSS GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

69
vendor/github.com/datarhei/gosrt/Makefile generated vendored Normal file
View File

@@ -0,0 +1,69 @@
COMMIT := $(shell if [ -d .git ]; then git rev-parse HEAD; else echo "unknown"; fi)
SHORTCOMMIT := $(shell echo $(COMMIT) | head -c 7)
all: build
## test: Run all tests
test:
go test -race -coverprofile=/dev/null -timeout 15s -v ./...
## vet: Analyze code for potential errors
vet:
go vet ./...
## fmt: Format code
fmt:
go fmt ./...
## update: Update dependencies
update:
go get -u -t
@-$(MAKE) tidy
@-$(MAKE) vendor
## tidy: Tidy up go.mod
tidy:
go mod tidy
## vendor: Update vendored packages
vendor:
go mod vendor
## lint: Static analysis with staticcheck
lint:
staticcheck ./...
## client: Build import binary
client:
cd contrib/client && CGO_ENABLED=0 go build -o client -ldflags="-s -w" -a
## server: Build import binary
server:
cd contrib/server && CGO_ENABLED=0 go build -o server -ldflags="-s -w" -a
## coverage: Generate code coverage analysis
coverage:
go test -race -coverprofile=cover.out -timeout 15s -v ./...
go tool cover -html=cover.out -o cover.html
## commit: Prepare code for commit (vet, fmt, test)
commit: vet fmt lint test
@echo "No errors found. Ready for a commit."
## docker: Build standard Docker image
docker:
docker build -t gosrt:$(SHORTCOMMIT) .
## logtopics: Extract all logging topics
logtopics:
grep -ERho 'log\("([^"]+)' *.go | sed -E -e 's/log\("//' | sort -u
.PHONY: help test vet fmt vendor commit coverage lint client server update logtopics
## help: Show all commands
help: Makefile
@echo
@echo " Choose a command:"
@echo
@sed -n 's/^##//p' $< | column -t -s ':' | sed -e 's/^/ /'
@echo

386
vendor/github.com/datarhei/gosrt/README.md generated vendored Normal file
View File

@@ -0,0 +1,386 @@
Implementation of the SRT protocol in pure Go with minimal dependencies.
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
![Tests](https://github.com/datarhei/gosrt/actions/workflows/go-tests.yml/badge.svg)
[![codecov](https://codecov.io/gh/datarhei/gosrt/branch/main/graph/badge.svg?token=90YMPZRAFK)](https://codecov.io/gh/datarhei/gosrt)
[![Go Report Card](https://goreportcard.com/badge/github.com/datarhei/gosrt)](https://goreportcard.com/report/github.com/datarhei/gosrt)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/datarhei/gosrt)](https://pkg.go.dev/github.com/datarhei/gosrt)
- [SRT reference implementation](https://github.com/Haivision/srt)
- [SRT RFC](https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html)
- [SRT Technical Overview](https://github.com/Haivision/srt/files/2489142/SRT_Protocol_TechnicalOverview_DRAFT_2018-10-17.pdf)
This implementation of the SRT protocol has live streaming of video/audio in mind. Because of this, the buffer mode and File Transfer
Congestion Control (FileCC) are not implemented.
| | |
| --- | ----------------------------------------- |
| ✅ | Message mode |
| ✅ | Caller-Listener Handshake |
| ✅ | Timestamp-Based Packet Delivery (TSBPD) |
| ✅ | Too-Late Packet Drop (TLPKTDROP) |
| ✅ | Live Congestion Control (LiveCC) |
| ✅ | NAK and Peridoc NAK |
| ✅ | Encryption |
| ❌ | Buffer mode |
| ❌ | Rendezvous Handshake |
| ❌ | File Transfer Congestion Control (FileCC) |
| ❌ | Connection Bonding |
The parts that are implemented are based on what has been published in the SRT RFC.
# Requirements
A Go version of 1.16+ is required.
# Installation
```
go get github.com/datarhei/gosrt
```
# Caller example
```
import github.com/datarhei/gosrt
conn, err := srt.Dial("srt", "golang.org:6000", srt.Config{
StreamId: "...",
})
if err != nil {
// handle error
}
buffer := make([]byte, 2048)
for {
n, err := conn.Read(buffer)
if err != nil {
// handle error
}
// handle received data
}
conn.Close()
```
In the `contrib/client` directory you'll find a complete example of a SRT client.
# Listener example
```
import github.com/datarhei/gosrt
ln, err := srt.Listen("srt", ":6000", srt.Config{...})
if err != nil {
// handle error
}
for {
conn, mode, err := ln.Accept(func(req ConnRequest) ConnType {
// check connection request
return srt.REJECT
})
if err != nil {
// handle error
}
if mode == srt.REJECT {
// rejected connection, ignore
continue
}
if mode == srt.PUBLISH {
go handlePublish(conn)
} else { // srt.SUBSCRIBE
go handleSubscribe(conn)
}
}
```
In the `contrib/server` directory you'll find a complete example of a SRT server. For your convenience
this modules provides the `Server` type which is a light framework for creating your own SRT server. The
example server is based on this type.
## PUBLISH / SUBSCRIBE
The `Accept` function from the `Listener` expects a function that handles the connection requests. It can
return 3 different values: `srt.PUBLISH`, `srt.SUBSCRIBE`, and `srt.REJECT`. `srt.PUBLISH` means that the
server expects the caller to send data, whereas `srt.SUBSCRIBE` means that the server will send data to
the caller. This is opiniated towards a streaming server, however in your implementation of a listener
you are free to handle connections requests to your liking.
# Contributed client
In the `contrib/client` directory you'll find an example implementation of a SRT client.
Build the client application with
```
cd contrib/client && go build
```
The application requires only two options:
| Option | Description |
| ------- | -------------------- |
| `-from` | Address to read from |
| `-to` | Address to write to |
Both options accept an address. Valid addresses are: `-` for `stdin`, resp. `stdout`, a `srt://` address, or an `udp://` address.
## SRT URL
A SRT URL is of the form `srt://[host]:[port]/?[options]` where options are in the form of a `HTTP` query string. These are the
known options (similar to [srt-live-transmit](https://github.com/Haivision/srt/blob/master/docs/apps/srt-live-transmit.md)):
| Option | Values | Description |
| -------------------- | ---------------------- | ----------------------------------------------------------------------- |
| `mode` | `listener` or `caller` | Enforce listener or caller mode. |
| `congestion` | `live` | Congestion control. Currently only `live` is supported. |
| `conntimeo` | `ms` | Connection timeout. |
| `drifttracer` | `bool` | Enable drift tracer. Not implemented. |
| `enforcedencryption` | `bool` | Accept connection only if both parties have encryption enabled. |
| `fc` | `bytes` | Flow control window size. |
| `inputbw` | `bytes` | Input bandwidth. Ignored. |
| `iptos` | 0...255 | IP socket type of service. Broken. |
| `ipttl` | 1...255 | Defines IP socket "time to live" option. Broken. |
| `ipv6only` | `bool` | Use IPv6 only. Not implemented. |
| `kmpreannounce` | `packets` | Duration of Stream Encryption key switchover. |
| `kmrefreshrate` | `packets` | Stream encryption key refresh rate. |
| `latency` | `ms` | Maximum accepted transmission latency. |
| `lossmaxttl` | `ms` | Packet reorder tolerance. Not implemented. |
| `maxbw` | `bytes` | Bandwidth limit. Ignored. |
| `mininputbw` | `bytes` | Minimum allowed estimate of `inputbw`. |
| `messageapi` | `bool` | Enable SRT message mode. Must be `false`. |
| `mss` | 76... | MTU size. |
| `nakreport` | `bool` | Enable periodic NAK reports. |
| `oheadbw` | 10...100 | Limits bandwidth overhead. Percents. Ignored. |
| `packetfilter` | `string` | Set up the packet filter. Not implemented. |
| `passphrase` | `string` | Password for the encrypted transmission. |
| `payloadsize` | `bytes` | Maximum payload size. |
| `pbkeylen` | `16`, `24`, or `32` | Crypto key length in bytes. |
| `peeridletimeo` | `ms` | Peer idle timeout. |
| `peerlatency` | `ms` | Minimum receiver latency to be requested by sender. |
| `rcvbuf` | `bytes` | Receiver buffer size. |
| `rcvlatency` | `ms` | Receiver-side latency. |
| `sndbuf` | `bytes` | Sender buffer size. |
| `snddropdelay` | `ms` | Sender's delay before dropping packets. |
| `streamid` | `string` | Stream ID (settable in caller mode only, visible on the listener peer). |
| `tlpktdrop` | `bool` | Drop too late packets. |
| `transtype` | `live` | Transmission type. Must be `live`. |
| `tsbpdmode` | `bool` | Enable timestamp-based packet delivery mode. |
## Usage
Reading from a SRT sender and play with `ffplay`:
```
./client -from "srt://127.0.0.1:6001/?mode=listener&streamid=..." -to - | ffplay -f mpegts -i -
```
Reading from UDP and sending to a SRT server:
```
./client -from udp://:6000 -to "srt://127.0.0.1:6001/?mode=caller&streamid=..."
```
Simulate point-to-point transfer on localhost. Open one console and start `ffmpeg` (you need at least version 4.3.2, built with SRT enabled) to send to an UDP address:
```
ffmpeg -f lavfi -re -i testsrc2=rate=25:size=640x360 -codec:v libx264 -b:v 1024k -maxrate:v 1024k -bufsize:v 1024k -preset ultrafast -r 25 -g 50 -pix_fmt yuv420p -vsync 1 -flags2 local_header -f mpegts "udp://127.0.0.1:6000?pkt_size=1316"
```
In another console read from the UDP and start a SRT listenr:
```
./client -from udp://:6000 -to "srt://127.0.0.1:6001/?mode=listener&streamid=foobar"
```
In the third console connect to that stream and play the video with `ffplay`:
```
./client -from "srt://127.0.0.1:6001/?mode=caller&streamid=foobar" -to - | ffplay -f mpegts -i -
```
# Contributed server
In the `contrib/server` directory you'll find an example implementation of a SRT server. This server allows you to publish
a stream that can be read by many clients.
Build the client application with
```
cd contrib/server && go build
```
The application has these options:
| Option | Default | Description |
| ------------- | --------- | ------------------------------------------ |
| `-addr` | required | Address to listen on |
| `-app` | `/` | Path prefix for streamid |
| `-token` | (not set) | Token query param for streamid |
| `-passphrase` | (not set) | Passphrase for de- and enrcypting the data |
| `-logtopics` | (not set) | Topics for the log output |
| `-profile` | `false` | Enable profiling |
This example server expects the streamID (without any prefix) to be an URL path with optional query parameter, e.g. `/live/stream`. If the `-app`
option is used, then the path must start with that path, e.g. the value is `/live` then the streamID must start with that value. The `-token`
option can be used to define a token for that stream as some kind of access control, e.g. with `-token foobar` the streamID might look like
`/live/stream?token=foobar`.
Use `-passphrase` in order to enable and enforce encryption.
Use `-logtopics` in order to write debug output. The value are a comma separated list of topics you want to be written to `stderr`, e.g. `connection,listen`. Check the [Logging](#logging) section in order to find out more about the different topics.
Use `-profile` in order to write a CPU profile.
## StreamID
In SRT the StreamID is used to transport somewhat arbitrary information from the caller to the listener. The provided example server uses this
machanism to decide who is the sender and who is the receiver. The server must know if the connecting client wants to publish a stream or
if it wants to subscribe to a stream.
The example server looks for the `publish:` prefix in the StreamID. If this prefix is present, the server assumes that it is the receiver
and the client will send the data. The subcribing clients must use the same StreamID (withouth the `publish:` prefix) in order to be able to
receive data.
If you implement your own server you are free to interpret the streamID as you wish.
## Usage
Running a server listening on port 6001 with defaults:
```
./server -addr ":6001"
```
Now you can use the contributed client to publish a stream:
```
./client -from ... -to "srt://127.0.0.1:6001/?mode=caller&streamid=publish:/live/stream"
```
or directly from `ffmpeg`:
```
ffmpeg -f lavfi -re -i testsrc2=rate=25:size=640x360 -codec:v libx264 -b:v 1024k -maxrate:v 1024k -bufsize:v 1024k -preset ultrafast -r 25 -g 50 -pix_fmt yuv420p -vsync 1 -flags2 local_header -f mpegts -transtype live "srt://127.0.0.1:6001?streamid=publish:/live/stream"
```
If the server is not on localhost, you might adjust the `peerlatency` in order to avoid packet loss: `-peerlatency 1000000`.
Now you can play the stream:
```
ffplay -f mpegts -transtype live -i "srt://127.0.0.1:6001?streamid=/live/stream"
```
You will most likely first see some error messages from `ffplay` because it tries to make sense of the received data until a keyframe arrives. If you
get more errors during playback, you might increase the receive buffer by adding e.g. `-rcvlatency 1000000` to the command line.
## Encryption
The stream can be encrypted with a passphrase. First start the server with a passphrase. If you are using `srt-live-transmit`, the passphrase has to be at least 10 characters long otherwise it will not be accepted.
```
./server -addr :6001 -passphrase foobarfoobar
```
Send an encrpyted stream to the server:
```
ffmpeg -f lavfi -re -i testsrc2=rate=25:size=640x360 -codec:v libx264 -b:v 1024k -maxrate:v 1024k -bufsize:v 1024k -preset ultrafast -r 25 -g 50 -pix_fmt yuv420p -vsync 1 -flags2 local_header -f mpegts -transtype live "srt://127.0.0.1:6001?streamid=publish:/live/stream&passphrase=foobarfoobar"
```
Receive an encrypted stream from the server:
```
ffplay -f mpegts -transtype live -i "srt://127.0.0.1:6001?streamid=/live/stream&passphrase=foobarfoobar"
```
You will most likely first see some error messages from `ffplay` because it tries to make sense of the received data until a keyframe arrives. If you
get more errors during playback, you might increase the receive buffer by adding e.g. `-rcvlatency 1000000` to the command line.
# Logging
This SRT module has a built-in logging facility for debugging purposes. Check the `Logger` interface and the `NewLogger(topics []string)` function. Because logging everything would be too much output if you wonly want to debug something specific, you have the possibility to limit the logging to specific areas like everything regarding a connection or only the handshake. That's why there are various topics.
In the contributed server you see an example of how logging is used. Here's the essence:
```
logger := srt.NewLogger([]string{"connection", "handshake"})
config := srt.DefaultConfig
config.Logger = logger
ln, err := srt.Listen("udp", ":6000", config)
if err != nil {
// handle error
}
go func() {
for m := range logger.Listen() {
fmt.Fprintf(os.Stderr, "%#08x %s (in %s:%d)\n%s \n", m.SocketId, m.Topic, m.File, m.Line, m.Message)
}
}()
for {
conn, mode, err := ln.Accept(acceptFn)
...
}
```
Currently known topics are:
```
connection:close
connection:error
connection:filter
connection:new
connection:rtt
connection:tsbpd
control:recv:ACK:cif
control:recv:ACK:dump
control:recv:ACK:error
control:recv:ACKACK:dump
control:recv:ACKACK:error
control:recv:KM:cif
control:recv:KM:dump
control:recv:KM:error
control:recv:NAK:cif
control:recv:NAK:dump
control:recv:NAK:error
control:recv:keepalive:dump
control:recv:shutdown:dump
control:send:ACK:cif
control:send:ACK:dump
control:send:ACKACK:dump
control:send:KM:cif
control:send:KM:dump
control:send:KM:error
control:send:NAK:cif
control:send:NAK:dump
control:send:keepalive:dump
control:send:shutdown:cif
control:send:shutdown:dump
data:recv:dump
data:send:dump
dial
handshake:recv:cif
handshake:recv:dump
handshake:recv:error
handshake:send:cif
handshake:send:dump
listen
packet:recv:dump
packet:send:dump
```
You can run `make logtopics` in order to extract the list of topics.
# Docker
The docker image you can build with `docker build -t srt .` provides the example SRT client and server as mentioned in the paragraph above.
E.g. run the server with `docker run -it --rm -p 6001:6001/udp srt srt-server -addr :6001`.

742
vendor/github.com/datarhei/gosrt/config.go generated vendored Normal file
View File

@@ -0,0 +1,742 @@
package srt
import (
"fmt"
"net/url"
"strconv"
"time"
)
const (
UDP_HEADER_SIZE = 28
SRT_HEADER_SIZE = 16
MIN_MSS_SIZE = 76
MAX_MSS_SIZE = 1500
MIN_PAYLOAD_SIZE = MIN_MSS_SIZE - UDP_HEADER_SIZE - SRT_HEADER_SIZE
MAX_PAYLOAD_SIZE = MAX_MSS_SIZE - UDP_HEADER_SIZE - SRT_HEADER_SIZE
MIN_PASSPHRASE_SIZE = 10
MAX_PASSPHRASE_SIZE = 79
MAX_STREAMID_SIZE = 512
SRT_VERSION = 0x010401
)
// Config is the configuration for a SRT connection
type Config struct {
// Type of congestion control. 'live' or 'file'
// SRTO_CONGESTION
Congestion string
// Connection timeout.
// SRTO_CONNTIMEO
ConnectionTimeout time.Duration
// Enable drift tracer.
// SRTO_DRIFTTRACER
DriftTracer bool
// Reject connection if parties set different passphrase.
// SRTO_ENFORCEDENCRYPTION
EnforcedEncryption bool
// Flow control window size. Packets.
// SRTO_FC
FC uint32
// Accept group connections.
// SRTO_GROUPCONNECT
GroupConnect bool
// Group stability timeout.
// SRTO_GROUPSTABTIMEO
GroupStabilityTimeout time.Duration
// Input bandwidth. Bytes.
// SRTO_INPUTBW
InputBW int64
// IP socket type of service
// SRTO_IPTOS
IPTOS int
// Defines IP socket "time to live" option.
// SRTO_IPTTL
IPTTL int
// Allow only IPv6.
// SRTO_IPV6ONLY
IPv6Only int
// Duration of Stream Encryption key switchover. Packets.
// SRTO_KMPREANNOUNCE
KMPreAnnounce uint64
// Stream encryption key refresh rate. Packets.
// SRTO_KMREFRESHRATE
KMRefreshRate uint64
// Defines the maximum accepted transmission latency.
// SRTO_LATENCY
Latency time.Duration
// Packet reorder tolerance.
// SRTO_LOSSMAXTTL
LossMaxTTL uint32
// Bandwidth limit in bytes/s.
// SRTO_MAXBW
MaxBW int64
// Enable SRT message mode.
// SRTO_MESSAGEAPI
MessageAPI bool
// Minimum input bandwidth
// This option is effective only if both SRTO_MAXBW and SRTO_INPUTBW are set to 0. It controls the minimum allowed value of the input bitrate estimate.
// SRTO_MININPUTBW
MinInputBW int64
// Minimum SRT library version of a peer.
// SRTO_MINVERSION
MinVersion uint32
// MTU size
// SRTO_MSS
MSS uint32
// Enable periodic NAK reports
// SRTO_NAKREPORT
NAKReport bool
// Limit bandwidth overhead, percents
// SRTO_OHEADBW
OverheadBW int64
// Set up the packet filter.
// SRTO_PACKETFILTER
PacketFilter string
// Password for the encrypted transmission.
// SRTO_PASSPHRASE
Passphrase string
// Maximum payload size. Bytes.
// SRTO_PAYLOADSIZE
PayloadSize uint32
// Crypto key length in bytes.
// SRTO_PBKEYLEN
PBKeylen int
// Peer idle timeout.
// SRTO_PEERIDLETIMEO
PeerIdleTimeout time.Duration
// Minimum receiver latency to be requested by sender.
// SRTO_PEERLATENCY
PeerLatency time.Duration
// Receiver buffer size. Bytes.
// SRTO_RCVBUF
ReceiverBufferSize uint32
// Receiver-side latency.
// SRTO_RCVLATENCY
ReceiverLatency time.Duration
// Sender buffer size. Bytes.
// SRTO_SNDBUF
SendBufferSize uint32
// Sender's delay before dropping packets.
// SRTO_SNDDROPDELAY
SendDropDelay time.Duration
// Stream ID (settable in caller mode only, visible on the listener peer)
// SRTO_STREAMID
StreamId string
// Drop too late packets.
// SRTO_TLPKTDROP
TooLatePacketDrop bool
// Transmission type. 'live' or 'file'.
// SRTO_TRANSTYPE
TransmissionType string
// Timestamp-based packet delivery mode.
// SRTO_TSBPDMODE
TSBPDMode bool
// An implementation of the Logger interface
Logger Logger
}
// DefaultConfig is the default configuration for a SRT connection
// if no individual configuration has been provided.
var defaultConfig Config = Config{
Congestion: "live",
ConnectionTimeout: 3 * time.Second,
DriftTracer: true,
EnforcedEncryption: true,
FC: 25600,
GroupConnect: false,
GroupStabilityTimeout: 0,
InputBW: 0,
IPTOS: 0,
IPTTL: 0,
IPv6Only: -1,
KMPreAnnounce: 1 << 12,
KMRefreshRate: 1 << 24,
Latency: -1,
LossMaxTTL: 0,
MaxBW: -1,
MessageAPI: false,
MinVersion: SRT_VERSION,
MSS: MAX_MSS_SIZE,
NAKReport: true,
OverheadBW: 25,
PacketFilter: "",
Passphrase: "",
PayloadSize: MAX_PAYLOAD_SIZE,
PBKeylen: 16,
PeerIdleTimeout: 2 * time.Second,
PeerLatency: 120 * time.Millisecond,
ReceiverBufferSize: 0,
ReceiverLatency: 120 * time.Millisecond,
SendBufferSize: 0,
SendDropDelay: 1 * time.Second,
StreamId: "",
TooLatePacketDrop: true,
TransmissionType: "live",
TSBPDMode: true,
}
// DefaultConfig returns the default configuration for Dial and Listen.
func DefaultConfig() Config {
return defaultConfig
}
// UnmarshalURL takes a SRT URL and parses out the configuration. A SRT URL is
// srt://[host]:[port]?[key1]=[value1]&[key2]=[value2]...
func (c *Config) UnmarshalURL(addr string) error {
u, err := url.Parse(addr)
if err != nil {
return err
}
if u.Scheme != "srt" {
return fmt.Errorf("the URL doesn't seem to be an srt:// URL")
}
return c.UnmarshalQuery(u.RawQuery)
}
// UnmarshalQuery parses a query string and interprets it as a configuration
// for a SRT connection. The key in each key/value pair corresponds to the
// respective field in the Config type, but with only lower case letters. Bool
// values can be represented as "true"/"false", "on"/"off", "yes"/"no", or "0"/"1".
func (c *Config) UnmarshalQuery(query string) error {
v, err := url.ParseQuery(query)
if err != nil {
return err
}
// https://github.com/Haivision/srt/blob/master/docs/apps/srt-live-transmit.md
if s := v.Get("congestion"); len(s) != 0 {
c.Congestion = s
}
if s := v.Get("conntimeo"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.ConnectionTimeout = time.Duration(d) * time.Millisecond
}
}
if s := v.Get("drifttracer"); len(s) != 0 {
switch s {
case "yes", "on", "true", "1":
c.DriftTracer = true
case "no", "off", "false", "0":
c.DriftTracer = false
}
}
if s := v.Get("enforcedencryption"); len(s) != 0 {
switch s {
case "yes", "on", "true", "1":
c.EnforcedEncryption = true
case "no", "off", "false", "0":
c.EnforcedEncryption = false
}
}
if s := v.Get("fc"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 32); err == nil {
c.FC = uint32(d)
}
}
if s := v.Get("groupconnect"); len(s) != 0 {
switch s {
case "yes", "on", "true", "1":
c.GroupConnect = true
case "no", "off", "false", "0":
c.GroupConnect = false
}
}
if s := v.Get("groupstabtimeo"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.GroupStabilityTimeout = time.Duration(d) * time.Millisecond
}
}
if s := v.Get("inputbw"); len(s) != 0 {
if d, err := strconv.ParseInt(s, 10, 64); err == nil {
c.InputBW = d
}
}
if s := v.Get("iptos"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.IPTOS = d
}
}
if s := v.Get("ipttl"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.IPTTL = d
}
}
if s := v.Get("ipv6only"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.IPv6Only = d
}
}
if s := v.Get("kmpreannounce"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 64); err == nil {
c.KMPreAnnounce = d
}
}
if s := v.Get("kmrefreshrate"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 64); err == nil {
c.KMRefreshRate = d
}
}
if s := v.Get("latency"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.Latency = time.Duration(d) * time.Millisecond
}
}
if s := v.Get("lossmaxttl"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 32); err == nil {
c.LossMaxTTL = uint32(d)
}
}
if s := v.Get("maxbw"); len(s) != 0 {
if d, err := strconv.ParseInt(s, 10, 64); err == nil {
c.MaxBW = d
}
}
if s := v.Get("mininputbw"); len(s) != 0 {
if d, err := strconv.ParseInt(s, 10, 64); err == nil {
c.MinInputBW = d
}
}
if s := v.Get("messageapi"); len(s) != 0 {
switch s {
case "yes", "on", "true", "1":
c.MessageAPI = true
case "no", "off", "false", "0":
c.MessageAPI = false
}
}
// minversion is ignored
if s := v.Get("mss"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 32); err == nil {
c.MSS = uint32(d)
}
}
if s := v.Get("nakreport"); len(s) != 0 {
switch s {
case "yes", "on", "true", "1":
c.NAKReport = true
case "no", "off", "false", "0":
c.NAKReport = false
}
}
if s := v.Get("oheadbw"); len(s) != 0 {
if d, err := strconv.ParseInt(s, 10, 64); err == nil {
c.OverheadBW = d
}
}
if s := v.Get("packetfilter"); len(s) != 0 {
c.PacketFilter = s
}
if s := v.Get("passphrase"); len(s) != 0 {
c.Passphrase = s
}
if s := v.Get("payloadsize"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 32); err == nil {
c.PayloadSize = uint32(d)
}
}
if s := v.Get("pbkeylen"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.PBKeylen = d
}
}
if s := v.Get("peeridletimeo"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.PeerIdleTimeout = time.Duration(d) * time.Millisecond
}
}
if s := v.Get("peerlatency"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.PeerLatency = time.Duration(d) * time.Millisecond
}
}
if s := v.Get("rcvbuf"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 32); err == nil {
c.ReceiverBufferSize = uint32(d)
}
}
if s := v.Get("rcvlatency"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.ReceiverLatency = time.Duration(d) * time.Millisecond
}
}
// retransmitalgo not implemented (there's only one)
if s := v.Get("sndbuf"); len(s) != 0 {
if d, err := strconv.ParseUint(s, 10, 32); err == nil {
c.SendBufferSize = uint32(d)
}
}
if s := v.Get("snddropdelay"); len(s) != 0 {
if d, err := strconv.Atoi(s); err == nil {
c.SendDropDelay = time.Duration(d) * time.Millisecond
}
}
if s := v.Get("streamid"); len(s) != 0 {
c.StreamId = s
}
if s := v.Get("tlpktdrop"); len(s) != 0 {
switch s {
case "yes", "on", "true", "1":
c.TooLatePacketDrop = true
case "no", "off", "false", "0":
c.TooLatePacketDrop = false
}
}
if s := v.Get("transtype"); len(s) != 0 {
c.TransmissionType = s
}
if s := v.Get("tsbpdmode"); len(s) != 0 {
switch s {
case "yes", "on", "true", "1":
c.TSBPDMode = true
case "no", "off", "false", "0":
c.TSBPDMode = false
}
}
return nil
}
// MarshalURL returns the SRT URL for this config and the given host and port.
func (c *Config) MarshalURL(host string, port uint) string {
return "srt://" + host + ":" + strconv.FormatUint(uint64(port), 10) + "?" + c.MarshalQuery()
}
// MarshalQuery returns the corresponding query string for a configuration.
func (c *Config) MarshalQuery() string {
q := url.Values{}
if c.Congestion != defaultConfig.Congestion {
q.Set("congestion", c.Congestion)
}
if c.ConnectionTimeout != defaultConfig.ConnectionTimeout {
q.Set("conntimeo", strconv.FormatInt(c.ConnectionTimeout.Milliseconds(), 10))
}
if c.DriftTracer != defaultConfig.DriftTracer {
q.Set("drifttracer", strconv.FormatBool(c.DriftTracer))
}
if c.EnforcedEncryption != defaultConfig.EnforcedEncryption {
q.Set("enforcedencryption", strconv.FormatBool(c.EnforcedEncryption))
}
if c.FC != defaultConfig.FC {
q.Set("fc", strconv.FormatUint(uint64(c.FC), 10))
}
if c.GroupConnect != defaultConfig.GroupConnect {
q.Set("groupconnect", strconv.FormatBool(c.GroupConnect))
}
if c.GroupStabilityTimeout != defaultConfig.GroupStabilityTimeout {
q.Set("groupstabtimeo", strconv.FormatInt(c.GroupStabilityTimeout.Milliseconds(), 10))
}
if c.InputBW != defaultConfig.InputBW {
q.Set("inputbw", strconv.FormatInt(c.InputBW, 10))
}
if c.IPTOS != defaultConfig.IPTOS {
q.Set("iptos", strconv.FormatInt(int64(c.IPTOS), 10))
}
if c.IPTTL != defaultConfig.IPTTL {
q.Set("ipttl", strconv.FormatInt(int64(c.IPTTL), 10))
}
if c.IPv6Only != defaultConfig.IPv6Only {
q.Set("ipv6only", strconv.FormatInt(int64(c.IPv6Only), 10))
}
if len(c.Passphrase) != 0 {
if c.KMPreAnnounce != defaultConfig.KMPreAnnounce {
q.Set("kmpreannounce", strconv.FormatUint(c.KMPreAnnounce, 10))
}
if c.KMRefreshRate != defaultConfig.KMRefreshRate {
q.Set("kmrefreshrate", strconv.FormatUint(c.KMRefreshRate, 10))
}
}
if c.Latency != defaultConfig.Latency {
q.Set("latency", strconv.FormatInt(c.Latency.Milliseconds(), 10))
}
if c.LossMaxTTL != defaultConfig.LossMaxTTL {
q.Set("lossmaxttl", strconv.FormatInt(int64(c.LossMaxTTL), 10))
}
if c.MaxBW != defaultConfig.MaxBW {
q.Set("maxbw", strconv.FormatInt(c.MaxBW, 10))
}
if c.MinInputBW != defaultConfig.InputBW {
q.Set("mininputbw", strconv.FormatInt(c.MinInputBW, 10))
}
if c.MessageAPI != defaultConfig.MessageAPI {
q.Set("messageapi", strconv.FormatBool(c.MessageAPI))
}
if c.MSS != defaultConfig.MSS {
q.Set("mss", strconv.FormatUint(uint64(c.MSS), 10))
}
if c.NAKReport != defaultConfig.NAKReport {
q.Set("nakreport", strconv.FormatBool(c.NAKReport))
}
if c.OverheadBW != defaultConfig.OverheadBW {
q.Set("oheadbw", strconv.FormatInt(c.OverheadBW, 10))
}
if c.PacketFilter != defaultConfig.PacketFilter {
q.Set("packetfilter", c.PacketFilter)
}
if len(c.Passphrase) != 0 {
q.Set("passphrase", c.Passphrase)
}
if c.PayloadSize != defaultConfig.PayloadSize {
q.Set("payloadsize", strconv.FormatUint(uint64(c.PayloadSize), 10))
}
if c.PBKeylen != defaultConfig.PBKeylen {
q.Set("pbkeylen", strconv.FormatInt(int64(c.PBKeylen), 10))
}
if c.PeerIdleTimeout != defaultConfig.PeerIdleTimeout {
q.Set("peeridletimeo", strconv.FormatInt(c.PeerIdleTimeout.Milliseconds(), 10))
}
if c.PeerLatency != defaultConfig.PeerLatency {
q.Set("peerlatency", strconv.FormatInt(c.PeerLatency.Milliseconds(), 10))
}
if c.ReceiverBufferSize != defaultConfig.ReceiverBufferSize {
q.Set("rcvbuf", strconv.FormatInt(int64(c.ReceiverBufferSize), 10))
}
if c.ReceiverLatency != defaultConfig.ReceiverLatency {
q.Set("rcvlatency", strconv.FormatInt(c.ReceiverLatency.Milliseconds(), 10))
}
if c.SendBufferSize != defaultConfig.SendBufferSize {
q.Set("sndbuf", strconv.FormatInt(int64(c.SendBufferSize), 10))
}
if c.SendDropDelay != defaultConfig.SendDropDelay {
q.Set("snddropdelay", strconv.FormatInt(c.SendDropDelay.Milliseconds(), 10))
}
if len(c.StreamId) != 0 {
q.Set("streamid", c.StreamId)
}
if c.TooLatePacketDrop != defaultConfig.TooLatePacketDrop {
q.Set("tlpktdrop", strconv.FormatBool(c.TooLatePacketDrop))
}
if c.TransmissionType != defaultConfig.TransmissionType {
q.Set("transtype", c.TransmissionType)
}
if c.TSBPDMode != defaultConfig.TSBPDMode {
q.Set("tsbpdmode", strconv.FormatBool(c.TSBPDMode))
}
return q.Encode()
}
// Validate validates a configuration or returns an error if a field
// has an invalid value.
func (c Config) Validate() error {
if c.TransmissionType != "live" {
return fmt.Errorf("config: TransmissionType must be 'live'")
}
c.Congestion = "live"
c.NAKReport = true
c.TooLatePacketDrop = true
c.TSBPDMode = true
if c.Congestion != "live" {
return fmt.Errorf("config: Congestion mode must be 'live'")
}
if c.ConnectionTimeout <= 0 {
return fmt.Errorf("config: ConnectionTimeout must be greater than 0")
}
if c.GroupConnect {
return fmt.Errorf("config: GroupConnect is not supported")
}
if c.IPTOS > 0 && c.IPTOS > 255 {
return fmt.Errorf("config: IPTOS must be lower than 255")
}
if c.IPTTL > 0 && c.IPTTL > 255 {
return fmt.Errorf("config: IPTTL must be between 1 and 255")
}
if c.IPv6Only > 0 {
return fmt.Errorf("config: IPv6Only is not supported")
}
if c.KMRefreshRate != 0 {
if c.KMPreAnnounce < 1 || c.KMPreAnnounce > c.KMRefreshRate/2 {
return fmt.Errorf("config: KMPreAnnounce must be greater than 1 and smaller than KMRefreshRate/2")
}
}
if c.Latency >= 0 {
c.PeerLatency = c.Latency
c.ReceiverLatency = c.Latency
}
if c.MinVersion != SRT_VERSION {
return fmt.Errorf("config: MinVersion must be %#06x", SRT_VERSION)
}
if c.MSS < MIN_MSS_SIZE || c.MSS > MAX_MSS_SIZE {
return fmt.Errorf("config: MSS must be between %d and %d (both inclusive)", MIN_MSS_SIZE, MAX_MSS_SIZE)
}
if !c.NAKReport {
return fmt.Errorf("config: NAKReport must be enabled")
}
if c.OverheadBW < 10 || c.OverheadBW > 100 {
return fmt.Errorf("config: OverheadBW must be between 10 and 100")
}
if len(c.PacketFilter) != 0 {
return fmt.Errorf("config: PacketFilter are not supported")
}
if len(c.Passphrase) != 0 {
if len(c.Passphrase) < MIN_PASSPHRASE_SIZE || len(c.Passphrase) > MAX_PASSPHRASE_SIZE {
return fmt.Errorf("config: Passphrase must be between %d and %d bytes long", MIN_PASSPHRASE_SIZE, MAX_PASSPHRASE_SIZE)
}
}
if c.PayloadSize < MIN_PAYLOAD_SIZE || c.PayloadSize > MAX_PAYLOAD_SIZE {
return fmt.Errorf("config: PayloadSize must be between %d and %d (both inclusive)", MIN_PAYLOAD_SIZE, MAX_PAYLOAD_SIZE)
}
if c.PayloadSize > c.MSS-uint32(SRT_HEADER_SIZE+UDP_HEADER_SIZE) {
return fmt.Errorf("config: PayloadSize must not be larger than %d (MSS - %d)", c.MSS-uint32(SRT_HEADER_SIZE+UDP_HEADER_SIZE), SRT_HEADER_SIZE-UDP_HEADER_SIZE)
}
if c.PBKeylen != 16 && c.PBKeylen != 24 && c.PBKeylen != 32 {
return fmt.Errorf("config: PBKeylen must be 16, 24, or 32 bytes")
}
if c.PeerLatency < 0 {
return fmt.Errorf("config: PeerLatency must be greater than 0")
}
if c.ReceiverLatency < 0 {
return fmt.Errorf("config: ReceiverLatency must be greater than 0")
}
if c.SendDropDelay < 0 {
return fmt.Errorf("config: SendDropDelay must be greater than 0")
}
if len(c.StreamId) > MAX_STREAMID_SIZE {
return fmt.Errorf("config: StreamId must be shorter than or equal to %d bytes", MAX_STREAMID_SIZE)
}
if !c.TooLatePacketDrop {
return fmt.Errorf("config: TooLatePacketDrop must be enabled")
}
if c.TransmissionType != "live" {
return fmt.Errorf("config: TransmissionType must be 'live'")
}
if !c.TSBPDMode {
return fmt.Errorf("config: TSBPDMode must be enabled")
}
return nil
}

1091
vendor/github.com/datarhei/gosrt/connection.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

675
vendor/github.com/datarhei/gosrt/dial.go generated vendored Normal file
View File

@@ -0,0 +1,675 @@
package srt
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"math/rand"
"net"
"os"
"sync"
"syscall"
"time"
"github.com/datarhei/gosrt/internal/circular"
"github.com/datarhei/gosrt/internal/crypto"
"github.com/datarhei/gosrt/internal/packet"
)
// ErrClientClosed is returned when the client connection has
// been voluntarily closed.
var ErrClientClosed = errors.New("srt: client closed")
// dialer implements the Conn interface
type dialer struct {
pc *net.UDPConn
localAddr net.Addr
remoteAddr net.Addr
config Config
socketId uint32
initialPacketSequenceNumber circular.Number
crypto crypto.Crypto
conn *srtConn
connLock sync.RWMutex
connChan chan connResponse
start time.Time
rcvQueue chan packet.Packet // for packets that come from the wire
sndQueue chan packet.Packet // for packets that go to the wire
shutdown bool
shutdownLock sync.RWMutex
shutdownOnce sync.Once
stopReader context.CancelFunc
stopWriter context.CancelFunc
doneChan chan error
}
type connResponse struct {
conn *srtConn
err error
}
// Dial connects to the address using the SRT protocol with the given config
// and returns a Conn interface.
//
// The address is of the form "host:port".
//
// Example:
// Dial("srt", "127.0.0.1:3000", DefaultConfig())
//
// In case of an error the returned Conn is nil and the error is non-nil.
func Dial(network, address string, config Config) (Conn, error) {
if network != "srt" {
return nil, fmt.Errorf("the network must be 'srt'")
}
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
if config.Logger == nil {
config.Logger = NewLogger(nil)
}
dl := &dialer{
config: config,
}
raddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, fmt.Errorf("unable to resolve address: %w", err)
}
pc, err := net.DialUDP("udp", nil, raddr)
if err != nil {
return nil, fmt.Errorf("failed dialing: %w", err)
}
file, err := pc.File()
if err != nil {
return nil, err
}
// Set TOS
if config.IPTOS > 0 {
err = syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_TOS, config.IPTOS)
if err != nil {
return nil, fmt.Errorf("failed setting socket option TOS: %w", err)
}
}
// Set TTL
if config.IPTTL > 0 {
err = syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_TTL, config.IPTTL)
if err != nil {
return nil, fmt.Errorf("failed setting socket option TTL: %w", err)
}
}
dl.pc = pc
dl.localAddr = pc.LocalAddr()
dl.remoteAddr = pc.RemoteAddr()
dl.conn = nil
dl.connChan = make(chan connResponse)
dl.rcvQueue = make(chan packet.Packet, 2048)
dl.sndQueue = make(chan packet.Packet, 2048)
dl.doneChan = make(chan error)
dl.start = time.Now()
// create a new socket ID
r := rand.New(rand.NewSource(time.Now().UnixNano()))
dl.socketId = r.Uint32()
dl.initialPacketSequenceNumber = circular.New(r.Uint32()&packet.MAX_SEQUENCENUMBER, packet.MAX_SEQUENCENUMBER)
go func() {
buffer := make([]byte, MAX_MSS_SIZE) // MTU size
for {
if dl.isShutdown() {
dl.doneChan <- ErrClientClosed
return
}
pc.SetReadDeadline(time.Now().Add(3 * time.Second))
n, _, err := pc.ReadFrom(buffer)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
if dl.isShutdown() {
dl.doneChan <- ErrClientClosed
return
}
dl.doneChan <- err
return
}
p := packet.NewPacket(dl.remoteAddr, buffer[:n])
if p == nil {
continue
}
dl.rcvQueue <- p
}
}()
var readerCtx context.Context
readerCtx, dl.stopReader = context.WithCancel(context.Background())
go dl.reader(readerCtx)
var writerCtx context.Context
writerCtx, dl.stopWriter = context.WithCancel(context.Background())
go dl.writer(writerCtx)
// Send the initial handshake request
dl.sendInduction()
dl.log("dial", func() string { return "waiting for response" })
timer := time.AfterFunc(dl.config.ConnectionTimeout, func() {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("connection timeout. server didn't respond"),
}
})
// Wait for handshake to conclude
response := <-dl.connChan
if response.err != nil {
dl.Close()
return nil, response.err
}
timer.Stop()
dl.connLock.Lock()
dl.conn = response.conn
dl.connLock.Unlock()
return dl, nil
}
func (dl *dialer) checkConnection() error {
select {
case err := <-dl.doneChan:
dl.Close()
return err
default:
}
return nil
}
// reader reads packets from the receive queue and pushes them into the connection
func (dl *dialer) reader(ctx context.Context) {
defer func() {
dl.log("dial", func() string { return "left reader loop" })
}()
dl.log("dial", func() string { return "reader loop started" })
for {
select {
case <-ctx.Done():
return
case p := <-dl.rcvQueue:
if dl.isShutdown() {
break
}
dl.log("packet:recv:dump", func() string { return p.Dump() })
if p.Header().DestinationSocketId != dl.socketId {
break
}
if p.Header().IsControlPacket && p.Header().ControlType == packet.CTRLTYPE_HANDSHAKE {
dl.handleHandshake(p)
break
}
dl.connLock.RLock()
if dl.conn == nil {
dl.connLock.RUnlock()
break
}
dl.conn.push(p)
dl.connLock.RUnlock()
}
}
}
// send adds a packet to the send queue
func (dl *dialer) send(p packet.Packet) {
// non-blocking
select {
case dl.sndQueue <- p:
default:
dl.log("dial", func() string { return "send queue is full" })
}
}
// writer reads packets from the send queue and writes them to the wire
func (dl *dialer) writer(ctx context.Context) {
defer func() {
dl.log("dial", func() string { return "left writer loop" })
}()
dl.log("dial", func() string { return "writer loop started" })
var data bytes.Buffer
for {
select {
case <-ctx.Done():
return
case p := <-dl.sndQueue:
data.Reset()
p.Marshal(&data)
buffer := data.Bytes()
dl.log("packet:send:dump", func() string { return p.Dump() })
// Write the packet's contents to the wire.
dl.pc.Write(buffer)
if p.Header().IsControlPacket {
// Control packets can be decommissioned because they will not be sent again
p.Decommission()
}
}
}
}
func (dl *dialer) handleHandshake(p packet.Packet) {
cif := &packet.CIFHandshake{}
err := p.UnmarshalCIF(cif)
dl.log("handshake:recv:dump", func() string { return p.Dump() })
dl.log("handshake:recv:cif", func() string { return cif.String() })
if err != nil {
dl.log("handshake:recv:error", func() string { return err.Error() })
return
}
// assemble the response (4.3.1. Caller-Listener Handshake)
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(dl.start).Microseconds())
p.Header().DestinationSocketId = cif.SRTSocketId
if cif.HandshakeType == packet.HSTYPE_INDUCTION {
// Verify version
if cif.Version != 5 {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support handshake v5"),
}
return
}
// Verify magic number
if cif.ExtensionField != 0x4A17 {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer sent the wrong magic number"),
}
return
}
// Setup crypto context
if len(dl.config.Passphrase) != 0 {
keylen := dl.config.PBKeylen
// If the server advertises a specific block cipher family and key size,
// use this one, otherwise, use the configured one
if cif.EncryptionField != 0 {
switch cif.EncryptionField {
case 2:
keylen = 16
case 3:
keylen = 24
case 4:
keylen = 32
}
}
cr, err := crypto.New(keylen)
if err != nil {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("failed creating crypto context: %w", err),
}
}
dl.crypto = cr
}
cif.IsRequest = true
cif.HandshakeType = packet.HSTYPE_CONCLUSION
cif.InitialPacketSequenceNumber = dl.initialPacketSequenceNumber
cif.MaxTransmissionUnitSize = dl.config.MSS // MTU size
cif.MaxFlowWindowSize = dl.config.FC
cif.SRTSocketId = dl.socketId
cif.PeerIP.FromNetAddr(dl.localAddr)
cif.HasHS = true
cif.SRTVersion = SRT_VERSION
cif.SRTFlags.TSBPDSND = true
cif.SRTFlags.TSBPDRCV = true
cif.SRTFlags.CRYPT = true // must always set to true
cif.SRTFlags.TLPKTDROP = true
cif.SRTFlags.PERIODICNAK = true
cif.SRTFlags.REXMITFLG = true
cif.SRTFlags.STREAM = false
cif.SRTFlags.PACKET_FILTER = false
cif.RecvTSBPDDelay = uint16(dl.config.ReceiverLatency.Milliseconds())
cif.SendTSBPDDelay = uint16(dl.config.PeerLatency.Milliseconds())
cif.HasSID = true
cif.StreamId = dl.config.StreamId
if dl.crypto != nil {
cif.HasKM = true
cif.SRTKM = &packet.CIFKM{}
if err := dl.crypto.MarshalKM(cif.SRTKM, dl.config.Passphrase, packet.EvenKeyEncrypted); err != nil {
dl.connChan <- connResponse{
conn: nil,
err: err,
}
return
}
}
p.MarshalCIF(cif)
dl.log("handshake:send:dump", func() string { return p.Dump() })
dl.log("handshake:send:cif", func() string { return cif.String() })
dl.send(p)
} else if cif.HandshakeType == packet.HSTYPE_CONCLUSION {
// We only support HSv5
if cif.Version != 5 {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support handshake v5"),
}
return
}
// Check if the peer version is sufficient
if cif.SRTVersion < dl.config.MinVersion {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer SRT version is not sufficient"),
}
return
}
// Check the required SRT flags
if !cif.SRTFlags.TSBPDSND || !cif.SRTFlags.TSBPDRCV || !cif.SRTFlags.TLPKTDROP || !cif.SRTFlags.PERIODICNAK || !cif.SRTFlags.REXMITFLG {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't agree on SRT flags"),
}
return
}
// We only support live streaming
if cif.SRTFlags.STREAM {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support live streaming"),
}
return
}
// Use the largest TSBPD delay as advertised by the listener, but
// at least 120ms
tsbpdDelay := uint16(120)
if cif.RecvTSBPDDelay > tsbpdDelay {
tsbpdDelay = cif.RecvTSBPDDelay
}
if cif.SendTSBPDDelay > tsbpdDelay {
tsbpdDelay = cif.SendTSBPDDelay
}
// If the peer has a smaller MTU size, adjust to it
if cif.MaxTransmissionUnitSize < dl.config.MSS {
dl.config.MSS = cif.MaxTransmissionUnitSize
dl.config.PayloadSize = dl.config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE
if dl.config.PayloadSize < MIN_PAYLOAD_SIZE {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("effective MSS too small (%d bytes) to fit the minimal payload size (%d bytes)", dl.config.MSS, MIN_PAYLOAD_SIZE),
}
return
}
}
// Create a new connection
conn := newSRTConn(srtConnConfig{
localAddr: dl.localAddr,
remoteAddr: dl.remoteAddr,
config: dl.config,
start: dl.start,
socketId: dl.socketId,
peerSocketId: cif.SRTSocketId,
tsbpdTimeBase: uint64(time.Since(dl.start).Microseconds()),
tsbpdDelay: uint64(tsbpdDelay) * 1000,
initialPacketSequenceNumber: cif.InitialPacketSequenceNumber,
crypto: dl.crypto,
keyBaseEncryption: packet.EvenKeyEncrypted,
onSend: dl.send,
onShutdown: func(socketId uint32) { dl.Close() },
logger: dl.config.Logger,
})
dl.log("connection:new", func() string { return fmt.Sprintf("%#08x (%s)", conn.SocketId(), conn.StreamId()) })
dl.connChan <- connResponse{
conn: conn,
err: nil,
}
} else {
var err error
if cif.HandshakeType.IsRejection() {
err = fmt.Errorf("connection rejected: %s", cif.HandshakeType.String())
} else {
err = fmt.Errorf("unsupported handshake: %s", cif.HandshakeType.String())
}
dl.connChan <- connResponse{
conn: nil,
err: err,
}
}
}
func (dl *dialer) sendInduction() {
p := packet.NewPacket(dl.remoteAddr, nil)
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(dl.start).Microseconds())
p.Header().DestinationSocketId = 0
cif := &packet.CIFHandshake{
IsRequest: true,
Version: 4,
EncryptionField: 0,
ExtensionField: 2,
InitialPacketSequenceNumber: circular.New(0, packet.MAX_SEQUENCENUMBER),
MaxTransmissionUnitSize: dl.config.MSS, // MTU size
MaxFlowWindowSize: dl.config.FC,
HandshakeType: packet.HSTYPE_INDUCTION,
SRTSocketId: dl.socketId,
SynCookie: 0,
}
cif.PeerIP.FromNetAddr(dl.localAddr)
p.MarshalCIF(cif)
dl.log("handshake:send:dump", func() string { return p.Dump() })
dl.log("handshake:send:cif", func() string { return cif.String() })
dl.send(p)
}
func (dl *dialer) sendShutdown(peerSocketId uint32) {
p := packet.NewPacket(dl.remoteAddr, nil)
data := [4]byte{}
binary.BigEndian.PutUint32(data[0:], 0)
p.SetData(data[0:4])
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_SHUTDOWN
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(dl.start).Microseconds())
p.Header().DestinationSocketId = peerSocketId
dl.log("control:send:shutdown:dump", func() string { return p.Dump() })
dl.send(p)
}
func (dl *dialer) LocalAddr() net.Addr {
return dl.conn.LocalAddr()
}
func (dl *dialer) RemoteAddr() net.Addr {
return dl.conn.RemoteAddr()
}
func (dl *dialer) SocketId() uint32 {
return dl.conn.SocketId()
}
func (dl *dialer) PeerSocketId() uint32 {
return dl.conn.PeerSocketId()
}
func (dl *dialer) StreamId() string {
return dl.conn.StreamId()
}
func (dl *dialer) isShutdown() bool {
dl.shutdownLock.RLock()
defer dl.shutdownLock.RUnlock()
return dl.shutdown
}
func (dl *dialer) Close() error {
dl.shutdownOnce.Do(func() {
dl.shutdownLock.Lock()
dl.shutdown = true
dl.shutdownLock.Unlock()
dl.connLock.RLock()
if dl.conn != nil {
dl.conn.Close()
}
dl.connLock.RUnlock()
dl.stopReader()
dl.stopWriter()
dl.log("dial", func() string { return "closing socket" })
dl.pc.Close()
select {
case <-dl.doneChan:
default:
}
})
return nil
}
func (dl *dialer) Read(p []byte) (n int, err error) {
if err := dl.checkConnection(); err != nil {
return 0, err
}
dl.connLock.RLock()
defer dl.connLock.RUnlock()
return dl.conn.Read(p)
}
func (dl *dialer) Write(p []byte) (n int, err error) {
if err := dl.checkConnection(); err != nil {
return 0, err
}
dl.connLock.RLock()
defer dl.connLock.RUnlock()
return dl.conn.Write(p)
}
func (dl *dialer) SetDeadline(t time.Time) error { return dl.conn.SetDeadline(t) }
func (dl *dialer) SetReadDeadline(t time.Time) error { return dl.conn.SetReadDeadline(t) }
func (dl *dialer) SetWriteDeadline(t time.Time) error { return dl.conn.SetWriteDeadline(t) }
func (dl *dialer) Stats() Statistics { return dl.conn.Stats() }
func (dl *dialer) log(topic string, message func() string) {
dl.config.Logger.Print(topic, dl.socketId, 2, message)
}

61
vendor/github.com/datarhei/gosrt/doc.go generated vendored Normal file
View File

@@ -0,0 +1,61 @@
/*
Package srt provides an interface for network I/O using the SRT protocol (https://github.com/Haivision/srt).
The package gives access to the basic interface provided by the Dial, Listen, and Accept functions and the associated
Conn and Listener interfaces.
The Dial function connects to a server:
conn, err := srt.Dial("srt", "golang.org:6000", srt.Config{
StreamId: "...",
})
if err != nil {
// handle error
}
buffer := make([]byte, 2048)
for {
n, err := conn.Read(buffer)
if err != nil {
// handle error
}
// handle received data
}
conn.Close()
The Listen function creates servers:
ln, err := srt.Listen("srt", ":6000", srt.Config{...})
if err != nil {
// handle error
}
for {
conn, mode, err := ln.Accept(handleConnect)
if err != nil {
// handle error
}
if mode == srt.REJECT {
// rejected connection, ignore
continue
}
if mode == srt.PUBLISH {
go handlePublish(conn)
} else {
go handleSubscribe(conn)
}
}
The ln.Accept function expects a function that takes a srt.ConnRequest
and returns a srt.ConnType. The srt.ConnRequest lets you retrieve the
streamid with on which you can decide what mode (srt.ConnType) to return.
Check out the Server type that wraps the Listen and Accept into a
convenient framework for your own SRT server.
*/
package srt

11
vendor/github.com/datarhei/gosrt/go.mod generated vendored Normal file
View File

@@ -0,0 +1,11 @@
module github.com/datarhei/gosrt
go 1.16
require (
github.com/benburkert/openpgp v0.0.0-20160410205803-c2471f86866c
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pkg/profile v1.6.0
github.com/stretchr/testify v1.7.2
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
)

25
vendor/github.com/datarhei/gosrt/go.sum generated vendored Normal file
View File

@@ -0,0 +1,25 @@
github.com/benburkert/openpgp v0.0.0-20160410205803-c2471f86866c h1:8XZeJrs4+ZYhJeJ2aZxADI2tGADS15AzIF8MQ8XAhT4=
github.com/benburkert/openpgp v0.0.0-20160410205803-c2471f86866c/go.mod h1:x1vxHcL/9AVzuk5HOloOEPrtJY0MaalYr78afXZ+pWI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pkg/profile v1.6.0 h1:hUDfIISABYI59DyeB3OTay/HxSRwTQ8rB/H83k6r5dM=
github.com/pkg/profile v1.6.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s=
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,184 @@
// Package circular implements "circular numbers". This is a number that can be
// increased (or decreased) indefinitely while only using up a limited amount of
// memory. This feature comes with the limitiation in how distant two such
// numbers can be. Circular numbers have a maximum. The maximum distance is
// half the maximum value. If a number that has the maximum value is
// increased by 1, it becomes 0. If a number that has the value of 0 is
// decreased by 1, it becomes the maximum value. By comparing two circular
// numbers it is not possible to tell how often they wrapped. Therefore these
// two numbers must come from the same domain in order to make sense of the
// camparison.
package circular
// Number represents a "circular number". A Number is immutable. All modification
// to a Number will result in a new instance of a Number.
type Number struct {
max uint32
threshold uint32
value uint32
}
// New returns a new circular number with the value of x and the maximum of max.
func New(x, max uint32) Number {
c := Number{
value: 0,
max: max,
threshold: max / 2,
}
if x > max {
return c.Add(x)
}
c.value = x
return c
}
// Val returns the current value of the number.
func (a Number) Val() uint32 {
return a.value
}
// Equals returns whether two circular numbers have the same value.
func (a Number) Equals(b Number) bool {
return a.value == b.value
}
// Distance returns the distance of two circular numbers.
func (a Number) Distance(b Number) uint32 {
if a.Equals(b) {
return 0
}
d := uint32(0)
if a.value > b.value {
d = a.value - b.value
} else {
d = b.value - a.value
}
if d >= a.threshold {
d = a.max - d + 1
}
return d
}
// Lt returns whether the circular number is lower than the circular number b.
func (a Number) Lt(b Number) bool {
if a.Equals(b) {
return false
}
d := uint32(0)
altb := false
if a.value > b.value {
d = a.value - b.value
} else {
d = b.value - a.value
altb = true
}
if d < a.threshold {
return altb
}
return !altb
}
// Lte returns whether the circular number is lower than or equal to the circular number b.
func (a Number) Lte(b Number) bool {
if a.Equals(b) {
return true
}
return a.Lt(b)
}
// Gt returns whether the circular number is greather than the circular number b.
func (a Number) Gt(b Number) bool {
if a.Equals(b) {
return false
}
d := uint32(0)
agtb := false
if a.value > b.value {
d = a.value - b.value
agtb = true
} else {
d = b.value - a.value
}
if d < a.threshold {
return agtb
}
return !agtb
}
// Gte returns whether the circular number is greather than or equal to the circular number b.
func (a Number) Gte(b Number) bool {
if a.Equals(b) {
return true
}
return a.Gt(b)
}
// Inc returns a new circular number with a value that is increased by 1.
func (a Number) Inc() Number {
b := a
if b.value == b.max {
b.value = 0
} else {
b.value++
}
return b
}
// Add returns a new circular number with a value that is increased by b.
func (a Number) Add(b uint32) Number {
c := a
x := c.max - c.value
if b <= x {
c.value += b
} else {
c.value = b - x - 1
}
return c
}
// Dec returns a new circular number with a value that is decreased by 1.
func (a Number) Dec() Number {
b := a
if b.value == 0 {
b.value = b.max
} else {
b.value--
}
return b
}
// Sub returns a new circular number with a value that is decreased by b.
func (a Number) Sub(b uint32) Number {
c := a
if b <= c.value {
c.value -= b
} else {
c.value = c.max - (b - c.value) + 1
}
return c
}

View File

@@ -0,0 +1,103 @@
// Package congestions provides congestion control implementations for SRT
package congestion
import (
"github.com/datarhei/gosrt/internal/circular"
"github.com/datarhei/gosrt/internal/packet"
)
// SendConfig is the configuration for the liveSend congestion control
type SendConfig struct {
InitialSequenceNumber circular.Number
DropInterval uint64
MaxBW int64
InputBW int64
MinInputBW int64
OverheadBW int64
OnDeliver func(p packet.Packet)
}
// Sender is the sending part of the congestion control
type Sender interface {
Stats() SendStats
Flush()
Push(p packet.Packet)
Tick(now uint64)
ACK(sequenceNumber circular.Number)
NAK(sequenceNumbers []circular.Number)
}
// ReceiveConfig is the configuration for the liveResv congestion control
type ReceiveConfig struct {
InitialSequenceNumber circular.Number
PeriodicACKInterval uint64 // microseconds
PeriodicNAKInterval uint64 // microseconds
OnSendACK func(seq circular.Number, light bool)
OnSendNAK func(from, to circular.Number)
OnDeliver func(p packet.Packet)
}
// Receiver is the receiving part of the congestion control
type Receiver interface {
Stats() ReceiveStats
PacketRate() (pps, bps uint32)
Flush()
Push(pkt packet.Packet)
Tick(now uint64)
SetNAKInterval(nakInterval uint64)
}
// SendStats are collected statistics from liveSend
type SendStats struct {
PktSent uint64
ByteSent uint64
PktSentUnique uint64
ByteSentUnique uint64
PktSndLoss uint64
ByteSndLoss uint64
PktRetrans uint64
ByteRetrans uint64
UsSndDuration uint64 // microseconds
PktSndDrop uint64
ByteSndDrop uint64
// instantaneous
PktSndBuf uint64
ByteSndBuf uint64
MsSndBuf uint64
PktFlightSize uint64
UsPktSndPeriod float64 // microseconds
BytePayload uint64
}
// ReceiveStats are collected statistics from liveRecv
type ReceiveStats struct {
PktRecv uint64
ByteRecv uint64
PktRecvUnique uint64
ByteRecvUnique uint64
PktRcvLoss uint64
ByteRcvLoss uint64
PktRcvRetrans uint64
ByteRcvRetrans uint64
PktRcvDrop uint64
ByteRcvDrop uint64
// instantaneous
PktRcvBuf uint64
ByteRcvBuf uint64
MsRcvBuf uint64
BytePayload uint64
}

View File

@@ -0,0 +1,620 @@
package congestion
import (
"container/list"
"fmt"
"strings"
"sync"
"time"
"github.com/datarhei/gosrt/internal/circular"
"github.com/datarhei/gosrt/internal/packet"
)
// liveSend implements the Sender interface
type liveSend struct {
nextSequenceNumber circular.Number
packetList *list.List
lossList *list.List
lock sync.RWMutex
dropInterval uint64 // microseconds
avgPayloadSize float64 // bytes
pktSndPeriod float64 // microseconds
maxBW float64 // bytes/s
inputBW float64 // bytes/s
overheadBW float64 // percent
statistics SendStats
rate struct {
period time.Duration
last time.Time
bytes uint64
prevBytes uint64
estimatedInputBW float64 // bytes/s
}
deliver func(p packet.Packet)
}
// NewLiveSend takes a SendConfig and returns a new Sender
func NewLiveSend(config SendConfig) Sender {
s := &liveSend{
nextSequenceNumber: config.InitialSequenceNumber,
packetList: list.New(),
lossList: list.New(),
dropInterval: config.DropInterval, // microseconds
avgPayloadSize: 1456, // 5.1.2. SRT's Default LiveCC Algorithm
maxBW: float64(config.MaxBW),
inputBW: float64(config.InputBW),
overheadBW: float64(config.OverheadBW),
deliver: config.OnDeliver,
}
if s.deliver == nil {
s.deliver = func(p packet.Packet) {}
}
s.maxBW = 128 * 1024 * 1024 // 1 Gbit/s
s.pktSndPeriod = (s.avgPayloadSize + 16) * 1_000_000 / s.maxBW
s.rate.period = time.Second
s.rate.last = time.Now()
return s
}
func (s *liveSend) Stats() SendStats {
s.lock.RLock()
defer s.lock.RUnlock()
s.statistics.UsPktSndPeriod = s.pktSndPeriod
s.statistics.BytePayload = uint64(s.avgPayloadSize)
s.statistics.MsSndBuf = 0
max := s.lossList.Back()
min := s.lossList.Front()
if max != nil && min != nil {
s.statistics.MsSndBuf = (max.Value.(packet.Packet).Header().PktTsbpdTime - min.Value.(packet.Packet).Header().PktTsbpdTime) / 1_000
}
return s.statistics
}
func (s *liveSend) Flush() {
s.lock.Lock()
defer s.lock.Unlock()
s.packetList = s.packetList.Init()
s.lossList = s.lossList.Init()
}
func (s *liveSend) Push(p packet.Packet) {
s.lock.Lock()
defer s.lock.Unlock()
if p == nil {
return
}
// give to the packet a sequence number
p.Header().PacketSequenceNumber = s.nextSequenceNumber
s.nextSequenceNumber = s.nextSequenceNumber.Inc()
pktLen := p.Len()
s.statistics.PktSndBuf++
s.statistics.ByteSndBuf += pktLen
// bandwidth calculation
s.rate.bytes += pktLen
now := time.Now()
tdiff := now.Sub(s.rate.last)
if tdiff > s.rate.period {
s.rate.estimatedInputBW = float64(s.rate.bytes-s.rate.prevBytes) / tdiff.Seconds()
s.rate.prevBytes = s.rate.bytes
s.rate.last = now
}
p.Header().Timestamp = uint32(p.Header().PktTsbpdTime & uint64(packet.MAX_TIMESTAMP))
s.packetList.PushBack(p)
s.statistics.PktFlightSize = uint64(s.packetList.Len())
}
func (s *liveSend) Tick(now uint64) {
// deliver packets whose PktTsbpdTime is ripe
s.lock.Lock()
removeList := make([]*list.Element, 0, s.packetList.Len())
for e := s.packetList.Front(); e != nil; e = e.Next() {
p := e.Value.(packet.Packet)
if p.Header().PktTsbpdTime <= now {
s.statistics.PktSent++
s.statistics.PktSentUnique++
s.statistics.ByteSent += p.Len()
s.statistics.ByteSentUnique += p.Len()
s.statistics.UsSndDuration += uint64(s.pktSndPeriod)
// 5.1.2. SRT's Default LiveCC Algorithm
s.avgPayloadSize = 0.875*s.avgPayloadSize + 0.125*float64(p.Len())
s.deliver(p)
removeList = append(removeList, e)
} else {
break
}
}
for _, e := range removeList {
s.lossList.PushBack(e.Value)
s.packetList.Remove(e)
}
s.lock.Unlock()
s.lock.Lock()
removeList = make([]*list.Element, 0, s.lossList.Len())
for e := s.lossList.Front(); e != nil; e = e.Next() {
p := e.Value.(packet.Packet)
if p.Header().PktTsbpdTime+s.dropInterval <= now {
// dropped packet because too old
s.statistics.PktSndDrop++
s.statistics.PktSndLoss++
s.statistics.ByteSndDrop += p.Len()
s.statistics.ByteSndLoss += p.Len()
removeList = append(removeList, e)
}
}
// These packets are not needed anymore (too late)
for _, e := range removeList {
p := e.Value.(packet.Packet)
s.statistics.PktSndBuf--
s.statistics.ByteSndBuf -= p.Len()
s.lossList.Remove(e)
// This packet has been ACK'd and we don't need it anymore
p.Decommission()
}
s.lock.Unlock()
}
func (s *liveSend) ACK(sequenceNumber circular.Number) {
s.lock.Lock()
defer s.lock.Unlock()
removeList := make([]*list.Element, 0, s.lossList.Len())
for e := s.lossList.Front(); e != nil; e = e.Next() {
p := e.Value.(packet.Packet)
if p.Header().PacketSequenceNumber.Lt(sequenceNumber) {
// remove packet from buffer because it has been successfully transmitted
removeList = append(removeList, e)
} else {
break
}
}
// These packets are not needed anymore (ACK'd)
for _, e := range removeList {
p := e.Value.(packet.Packet)
s.statistics.PktSndBuf--
s.statistics.ByteSndBuf -= p.Len()
s.lossList.Remove(e)
// This packet has been ACK'd and we don't need it anymore
p.Decommission()
}
s.pktSndPeriod = (s.avgPayloadSize + 16) * 1000000 / s.maxBW
}
func (s *liveSend) NAK(sequenceNumbers []circular.Number) {
if len(sequenceNumbers) == 0 {
return
}
s.lock.Lock()
defer s.lock.Unlock()
for e := s.lossList.Back(); e != nil; e = e.Prev() {
p := e.Value.(packet.Packet)
for i := 0; i < len(sequenceNumbers); i += 2 {
if p.Header().PacketSequenceNumber.Gte(sequenceNumbers[i]) && p.Header().PacketSequenceNumber.Lte(sequenceNumbers[i+1]) {
s.statistics.PktRetrans++
s.statistics.PktSent++
s.statistics.PktSndLoss++
s.statistics.ByteRetrans += p.Len()
s.statistics.ByteSent += p.Len()
s.statistics.ByteSndLoss += p.Len()
// 5.1.2. SRT's Default LiveCC Algorithm
s.avgPayloadSize = 0.875*s.avgPayloadSize + 0.125*float64(p.Len())
p.Header().RetransmittedPacketFlag = true
s.deliver(p)
}
}
}
}
// liveReceive implements the Receiver interface
type liveReceive struct {
maxSeenSequenceNumber circular.Number
lastACKSequenceNumber circular.Number
lastDeliveredSequenceNumber circular.Number
packetList *list.List
lock sync.RWMutex
nPackets uint
periodicACKInterval uint64 // config
periodicNAKInterval uint64 // config
lastPeriodicACK uint64
lastPeriodicNAK uint64
avgPayloadSize float64 // bytes
statistics ReceiveStats
rate struct {
last time.Time
period time.Duration
packets uint64
prevPackets uint64
bytes uint64
prevBytes uint64
pps uint32
bps uint32
}
sendACK func(seq circular.Number, light bool)
sendNAK func(from, to circular.Number)
deliver func(p packet.Packet)
}
// NewLiveReceive takes a ReceiveConfig and returns a new Receiver
func NewLiveReceive(config ReceiveConfig) Receiver {
r := &liveReceive{
maxSeenSequenceNumber: config.InitialSequenceNumber.Dec(),
lastACKSequenceNumber: config.InitialSequenceNumber.Dec(),
lastDeliveredSequenceNumber: config.InitialSequenceNumber.Dec(),
packetList: list.New(),
periodicACKInterval: config.PeriodicACKInterval,
periodicNAKInterval: config.PeriodicNAKInterval,
avgPayloadSize: 1456, // 5.1.2. SRT's Default LiveCC Algorithm
sendACK: config.OnSendACK,
sendNAK: config.OnSendNAK,
deliver: config.OnDeliver,
}
if r.sendACK == nil {
r.sendACK = func(seq circular.Number, light bool) {}
}
if r.sendNAK == nil {
r.sendNAK = func(from, to circular.Number) {}
}
if r.deliver == nil {
r.deliver = func(p packet.Packet) {}
}
r.rate.last = time.Now()
r.rate.period = time.Second
return r
}
func (r *liveReceive) Stats() ReceiveStats {
r.lock.RLock()
defer r.lock.RUnlock()
r.statistics.BytePayload = uint64(r.avgPayloadSize)
return r.statistics
}
func (r *liveReceive) PacketRate() (pps, bps uint32) {
r.lock.Lock()
defer r.lock.Unlock()
tdiff := time.Since(r.rate.last)
if tdiff < r.rate.period {
pps = r.rate.pps
bps = r.rate.bps
return
}
pdiff := r.rate.packets - r.rate.prevPackets
bdiff := r.rate.bytes - r.rate.prevBytes
r.rate.pps = uint32(float64(pdiff) / tdiff.Seconds())
r.rate.bps = uint32(float64(bdiff) / tdiff.Seconds())
r.rate.prevPackets, r.rate.prevBytes = r.rate.packets, r.rate.bytes
r.rate.last = time.Now()
pps = r.rate.pps
bps = r.rate.bps
return
}
func (r *liveReceive) Flush() {
r.lock.Lock()
defer r.lock.Unlock()
r.packetList = r.packetList.Init()
}
func (r *liveReceive) Push(pkt packet.Packet) {
r.lock.Lock()
defer r.lock.Unlock()
if pkt == nil {
return
}
r.nPackets++
pktLen := pkt.Len()
r.rate.packets++
r.rate.bytes += pktLen
r.statistics.PktRecv++
r.statistics.ByteRecv += pktLen
//pkt.PktTsbpdTime = pkt.Timestamp + r.delay
if pkt.Header().RetransmittedPacketFlag {
r.statistics.PktRcvRetrans++
r.statistics.ByteRcvRetrans += pktLen
}
// 5.1.2. SRT's Default LiveCC Algorithm
r.avgPayloadSize = 0.875*r.avgPayloadSize + 0.125*float64(pktLen)
if pkt.Header().PacketSequenceNumber.Lte(r.lastDeliveredSequenceNumber) {
// too old, because up until r.lastDeliveredSequenceNumber, we already delivered
r.statistics.PktRcvDrop++
r.statistics.ByteRcvDrop += pktLen
return
}
if pkt.Header().PacketSequenceNumber.Lt(r.lastACKSequenceNumber) {
// already acknowledged, ignoring
r.statistics.PktRcvDrop++
r.statistics.ByteRcvDrop += pktLen
return
}
if pkt.Header().PacketSequenceNumber.Equals(r.maxSeenSequenceNumber.Inc()) {
// in order, the packet we expected
r.maxSeenSequenceNumber = pkt.Header().PacketSequenceNumber
} else if pkt.Header().PacketSequenceNumber.Lte(r.maxSeenSequenceNumber) {
// out of order, is it a missing piece? put it in the correct position
for e := r.packetList.Front(); e != nil; e = e.Next() {
p := e.Value.(packet.Packet)
if p.Header().PacketSequenceNumber == pkt.Header().PacketSequenceNumber {
// already received (has been sent more than once), ignoring
r.statistics.PktRcvDrop++
r.statistics.ByteRcvDrop += pktLen
break
} else if p.Header().PacketSequenceNumber.Gt(pkt.Header().PacketSequenceNumber) {
// late arrival, this fills a gap
r.statistics.PktRcvBuf++
r.statistics.PktRecvUnique++
r.statistics.ByteRcvBuf += pktLen
r.statistics.ByteRecvUnique += pktLen
r.packetList.InsertBefore(pkt, e)
break
}
}
return
} else {
// too far ahead, there are some missing sequence numbers, immediate NAK report
// here we can prevent a possibly unnecessary NAK with SRTO_LOXXMAXTTL
r.sendNAK(r.maxSeenSequenceNumber.Inc(), pkt.Header().PacketSequenceNumber.Dec())
len := uint64(pkt.Header().PacketSequenceNumber.Distance(r.maxSeenSequenceNumber))
r.statistics.PktRcvLoss += len
r.statistics.ByteRcvLoss += len * uint64(r.avgPayloadSize)
r.maxSeenSequenceNumber = pkt.Header().PacketSequenceNumber
}
r.statistics.PktRcvBuf++
r.statistics.PktRecvUnique++
r.statistics.ByteRcvBuf += pktLen
r.statistics.ByteRecvUnique += pktLen
r.packetList.PushBack(pkt)
}
func (r *liveReceive) periodicACK(now uint64) (ok bool, sequenceNumber circular.Number, lite bool) {
r.lock.RLock()
defer r.lock.RUnlock()
// 4.8.1. Packet Acknowledgement (ACKs, ACKACKs)
if now-r.lastPeriodicACK < r.periodicACKInterval {
if r.nPackets >= 64 {
lite = true // send light ACK
} else {
return
}
}
minPktTsbpdTime, maxPktTsbpdTime := uint64(0), uint64(0)
ackSequenceNumber := r.lastDeliveredSequenceNumber
// find the sequence number up until we have all in a row.
// where the first gap is (or at the end of the list) is where we can ACK to.
e := r.packetList.Front()
if e != nil {
p := e.Value.(packet.Packet)
minPktTsbpdTime = p.Header().PktTsbpdTime
maxPktTsbpdTime = p.Header().PktTsbpdTime
if p.Header().PacketSequenceNumber.Equals(ackSequenceNumber.Inc()) {
ackSequenceNumber = p.Header().PacketSequenceNumber
for e = e.Next(); e != nil; e = e.Next() {
p = e.Value.(packet.Packet)
if !p.Header().PacketSequenceNumber.Equals(ackSequenceNumber.Inc()) {
break
}
ackSequenceNumber = p.Header().PacketSequenceNumber
maxPktTsbpdTime = p.Header().PktTsbpdTime
}
}
ok = true
sequenceNumber = ackSequenceNumber.Inc()
// keep track of the last ACK's sequence. with this we can faster ignore
// packets that come in that have a lower sequence number.
r.lastACKSequenceNumber = ackSequenceNumber
}
r.lastPeriodicACK = now
r.nPackets = 0
r.statistics.MsRcvBuf = (maxPktTsbpdTime - minPktTsbpdTime) / 1_000
return
}
func (r *liveReceive) periodicNAK(now uint64) (ok bool, from, to circular.Number) {
r.lock.RLock()
defer r.lock.RUnlock()
if now-r.lastPeriodicNAK < r.periodicNAKInterval {
return
}
// send a periodic NAK
ackSequenceNumber := r.lastDeliveredSequenceNumber
// send a NAK only for the first gap.
// alternatively send a NAK for max. X gaps because the size of the NAK packet is limited
for e := r.packetList.Front(); e != nil; e = e.Next() {
p := e.Value.(packet.Packet)
if !p.Header().PacketSequenceNumber.Equals(ackSequenceNumber.Inc()) {
nackSequenceNumber := ackSequenceNumber.Inc()
ok = true
from = nackSequenceNumber
to = p.Header().PacketSequenceNumber.Dec()
break
}
ackSequenceNumber = p.Header().PacketSequenceNumber
}
r.lastPeriodicNAK = now
return
}
func (r *liveReceive) Tick(now uint64) {
if ok, sequenceNumber, lite := r.periodicACK(now); ok {
r.sendACK(sequenceNumber, lite)
}
if ok, from, to := r.periodicNAK(now); ok {
r.sendNAK(from, to)
}
// deliver packets whose PktTsbpdTime is ripe
r.lock.Lock()
defer r.lock.Unlock()
removeList := make([]*list.Element, 0, r.packetList.Len())
for e := r.packetList.Front(); e != nil; e = e.Next() {
p := e.Value.(packet.Packet)
if p.Header().PacketSequenceNumber.Lte(r.lastACKSequenceNumber) && p.Header().PktTsbpdTime <= now {
r.statistics.PktRcvBuf--
r.statistics.ByteRcvBuf -= p.Len()
r.lastDeliveredSequenceNumber = p.Header().PacketSequenceNumber
r.deliver(p)
removeList = append(removeList, e)
} else {
break
}
}
for _, e := range removeList {
r.packetList.Remove(e)
}
}
func (r *liveReceive) SetNAKInterval(nakInterval uint64) {
r.lock.Lock()
defer r.lock.Unlock()
r.periodicNAKInterval = nakInterval
}
func (r *liveReceive) String(t uint64) string {
var b strings.Builder
b.WriteString(fmt.Sprintf("maxSeen=%d lastACK=%d lastDelivered=%d\n", r.maxSeenSequenceNumber.Val(), r.lastACKSequenceNumber.Val(), r.lastDeliveredSequenceNumber.Val()))
r.lock.RLock()
for e := r.packetList.Front(); e != nil; e = e.Next() {
p := e.Value.(packet.Packet)
b.WriteString(fmt.Sprintf(" %d @ %d (in %d)\n", p.Header().PacketSequenceNumber.Val(), p.Header().PktTsbpdTime, int64(p.Header().PktTsbpdTime)-int64(t)))
}
r.lock.RUnlock()
return b.String()
}

View File

@@ -0,0 +1,260 @@
// Package crypto provides SRT cryptography
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"errors"
"fmt"
"github.com/datarhei/gosrt/internal/packet"
"github.com/benburkert/openpgp/aes/keywrap"
"golang.org/x/crypto/pbkdf2"
)
// Crypto implements the SRT data encryption and decryption.
type Crypto interface {
// Generate generates an even or odd SEK.
GenerateSEK(key packet.PacketEncryption) error
// UnmarshalMK unwraps the key with the passphrase in a Key Material Extension Message. If the passphrase
// is wrong an error is returned.
UnmarshalKM(km *packet.CIFKM, passphrase string) error
// MarshalKM wraps the key with the passphrase and the odd/even SEK for a Key Material Extension Message.
MarshalKM(km *packet.CIFKM, passphrase string, key packet.PacketEncryption) error
// EncryptOrDecryptPayload encrypts or decrypts the data of a packet with an even or odd SEK and
// the sequence number.
EncryptOrDecryptPayload(data []byte, key packet.PacketEncryption, packetSequenceNumber uint32) error
}
// crypto implements the Crypto interface
type crypto struct {
salt []byte
keyLength int
evenSEK []byte
oddSEK []byte
}
// New returns a new SRT data encryption and decryption for the keyLength. On failure
// error is non-nil.
func New(keyLength int) (Crypto, error) {
// 3.2.2. Key Material
switch keyLength {
case 16:
case 24:
case 32:
default:
return nil, fmt.Errorf("crypto: invalid key size, must be either 16, 24, or 32")
}
c := &crypto{
keyLength: keyLength,
}
// 3.2.2. Key Material: "The only valid length of salt defined is 128 bits."
c.salt = make([]byte, 16)
if err := c.prng(c.salt); err != nil {
return nil, fmt.Errorf("crypto: can't generate salt: %w", err)
}
c.evenSEK = make([]byte, c.keyLength)
if err := c.GenerateSEK(packet.EvenKeyEncrypted); err != nil {
return nil, err
}
c.oddSEK = make([]byte, c.keyLength)
if err := c.GenerateSEK(packet.OddKeyEncrypted); err != nil {
return nil, err
}
return c, nil
}
func (c *crypto) GenerateSEK(key packet.PacketEncryption) error {
if !key.IsValid() {
return fmt.Errorf("crypto: unknown key type")
}
if key == packet.EvenKeyEncrypted {
if err := c.prng(c.evenSEK); err != nil {
return fmt.Errorf("crypto: can't generate even key: %w", err)
}
} else if key == packet.OddKeyEncrypted {
if err := c.prng(c.oddSEK); err != nil {
return fmt.Errorf("crypto: can't generate odd key: %w", err)
}
}
return nil
}
// ErrInvalidKey is returned when the packet encryption is invalid
var ErrInvalidKey = errors.New("crypto: invalid key for encryption. Must be even, odd, or both")
// ErrInvalidWrap is returned when the packet encryption indicates a different length of the wrapped key
var ErrInvalidWrap = errors.New("crypto: the unwrapped key has the wrong length")
func (c *crypto) UnmarshalKM(km *packet.CIFKM, passphrase string) error {
if km.KeyBasedEncryption == packet.UnencryptedPacket || !km.KeyBasedEncryption.IsValid() {
return ErrInvalidKey
}
if len(km.Salt) != 0 {
copy(c.salt, km.Salt)
}
kek := c.calculateKEK(passphrase)
unwrap, err := keywrap.Unwrap(kek, km.Wrap)
if err != nil {
return err
}
n := 1
if km.KeyBasedEncryption == packet.EvenAndOddKey {
n = 2
}
if len(unwrap) != n*c.keyLength {
return ErrInvalidWrap
}
if km.KeyBasedEncryption == packet.EvenKeyEncrypted {
copy(c.evenSEK, unwrap)
} else if km.KeyBasedEncryption == packet.OddKeyEncrypted {
copy(c.oddSEK, unwrap)
} else {
copy(c.evenSEK, unwrap[:c.keyLength])
copy(c.oddSEK, unwrap[c.keyLength:])
}
return nil
}
func (c *crypto) MarshalKM(km *packet.CIFKM, passphrase string, key packet.PacketEncryption) error {
if key == packet.UnencryptedPacket || !key.IsValid() {
return ErrInvalidKey
}
km.S = 0
km.Version = 1
km.PacketType = 2
km.Sign = 0x2029
km.KeyBasedEncryption = key // even or odd key
km.KeyEncryptionKeyIndex = 0
km.Cipher = 2
km.Authentication = 0
km.StreamEncapsulation = 2
km.SLen = 16
km.KLen = uint16(c.keyLength)
if len(km.Salt) != 16 {
km.Salt = make([]byte, 16)
}
copy(km.Salt, c.salt)
n := 1
if key == packet.EvenAndOddKey {
n = 2
}
w := make([]byte, n*c.keyLength)
if key == packet.EvenKeyEncrypted {
copy(w, c.evenSEK)
} else if key == packet.OddKeyEncrypted {
copy(w, c.oddSEK)
} else {
copy(w[:c.keyLength], c.evenSEK)
copy(w[c.keyLength:], c.oddSEK)
}
kek := c.calculateKEK(passphrase)
wrap, err := keywrap.Wrap(kek, w)
if err != nil {
return err
}
if len(km.Wrap) != len(wrap) {
km.Wrap = make([]byte, len(wrap))
}
copy(km.Wrap, wrap)
return nil
}
func (c *crypto) EncryptOrDecryptPayload(data []byte, key packet.PacketEncryption, packetSequenceNumber uint32) error {
// 6.1.2. AES Counter
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
// | 0s | psn | 0 0|
// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
// XOR
// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+
// | MSB(112, Salt) |
// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+
//
// psn (32 bit): packet sequence number
// ctr (16 bit): block counter, all zeros
// nonce (112 bit): 14 most significant bytes of the salt
//
// CTR = (MSB(112, Salt) XOR psn) << 16
ctr := make([]byte, 16)
binary.BigEndian.PutUint32(ctr[10:], packetSequenceNumber)
for i := range ctr[:14] {
ctr[i] ^= c.salt[i]
}
var sek []byte
if key == packet.EvenKeyEncrypted {
sek = c.evenSEK
} else if key == packet.OddKeyEncrypted {
sek = c.oddSEK
} else {
return fmt.Errorf("crypto: invalid SEK selected. Must be either even or odd")
}
// 6.2.2. Encrypting the Payload
// 6.3.2. Decrypting the Payload
block, err := aes.NewCipher(sek)
if err != nil {
return err
}
stream := cipher.NewCTR(block, ctr)
stream.XORKeyStream(data, data)
return nil
}
// calculateKEK calculates a KEK based on the passphrase.
func (c *crypto) calculateKEK(passphrase string) []byte {
// 6.1.4. Key Encrypting Key (KEK)
return pbkdf2.Key([]byte(passphrase), c.salt[8:], 2048, c.keyLength, sha1.New)
}
// prng generates a random sequence of byte into the given slice p.
func (c *crypto) prng(p []byte) error {
n, err := rand.Read(p)
if err != nil {
return err
}
if n != len(p) {
return fmt.Errorf("crypto: random byte sequence is too short")
}
return nil
}

141
vendor/github.com/datarhei/gosrt/internal/net/ip.go generated vendored Normal file
View File

@@ -0,0 +1,141 @@
package net
import (
"encoding/binary"
"fmt"
"net"
"strings"
)
type IP struct {
ip net.IP
}
func (i *IP) setDefault() {
i.ip = net.ParseIP("127.0.0.1")
}
func (i *IP) isValid() bool {
if i.ip.String() == "<nil>" || i.ip.IsUnspecified() {
return false
}
return true
}
func (i IP) String() string {
return i.ip.String()
}
func (i *IP) Parse(ip string) {
i.ip = net.ParseIP(ip)
if !i.isValid() {
i.setDefault()
}
}
func (i *IP) FromNetIP(ip net.IP) {
i.ip = net.ParseIP(ip.String())
if !i.isValid() {
i.setDefault()
}
}
func (i *IP) FromNetAddr(addr net.Addr) {
if addr.Network() != "udp" {
i.setDefault()
}
if a, err := net.ResolveUDPAddr("udp", addr.String()); err == nil {
i.ip = a.IP
} else {
i.setDefault()
}
}
// Unmarshal converts 16 bytes in host byte order to IP
func (i *IP) Unmarshal(data []byte) error {
if len(data) != 4 && len(data) != 16 {
return fmt.Errorf("invalid number of bytes")
}
if len(data) == 4 {
ip0 := binary.LittleEndian.Uint32(data[0:])
i.ip = net.IPv4(byte((ip0&0xff000000)>>24), byte((ip0&0x00ff0000)>>16), byte((ip0&0x0000ff00)>>8), byte(ip0&0x0000ff))
} else {
ip3 := binary.LittleEndian.Uint32(data[0:])
ip2 := binary.LittleEndian.Uint32(data[4:])
ip1 := binary.LittleEndian.Uint32(data[8:])
ip0 := binary.LittleEndian.Uint32(data[12:])
if ip0 == 0 && ip1 == 0 && ip2 == 0 {
i.ip = net.IPv4(byte((ip3&0xff000000)>>24), byte((ip3&0x00ff0000)>>16), byte((ip3&0x0000ff00)>>8), byte(ip3&0x0000ff))
} else {
var b strings.Builder
fmt.Fprintf(&b, "%04x:", (ip0&0xffff0000)>>16)
fmt.Fprintf(&b, "%04x:", ip0&0x0000ffff)
fmt.Fprintf(&b, "%04x:", (ip1&0xffff0000)>>16)
fmt.Fprintf(&b, "%04x:", ip1&0x0000ffff)
fmt.Fprintf(&b, "%04x:", (ip2&0xffff0000)>>16)
fmt.Fprintf(&b, "%04x:", ip2&0x0000ffff)
fmt.Fprintf(&b, "%04x:", (ip3&0xffff0000)>>16)
fmt.Fprintf(&b, "%04x", ip3&0x0000ffff)
i.ip = net.ParseIP(b.String())
}
}
if !i.isValid() {
i.setDefault()
}
return nil
}
// Marshal converts an IP to 16 byte host byte order
func (i *IP) Marshal(data []byte) {
if len(data) < 16 {
return
}
data[0] = i.ip[15]
data[1] = i.ip[14]
data[2] = i.ip[13]
data[3] = i.ip[12]
if i.ip.To4() != nil {
data[4] = 0
data[5] = 0
data[6] = 0
data[7] = 0
data[8] = 0
data[9] = 0
data[10] = 0
data[11] = 0
data[12] = 0
data[13] = 0
data[14] = 0
data[15] = 0
} else {
data[4] = i.ip[11]
data[5] = i.ip[10]
data[6] = i.ip[9]
data[7] = i.ip[8]
data[8] = i.ip[7]
data[9] = i.ip[6]
data[10] = i.ip[5]
data[11] = i.ip[4]
data[12] = i.ip[3]
data[13] = i.ip[2]
data[14] = i.ip[1]
data[15] = i.ip[0]
}
}

View File

@@ -0,0 +1,74 @@
package net
import (
"crypto/md5"
"encoding/binary"
"math/rand"
"strconv"
"time"
)
type SYNCookie struct {
secret1 string
secret2 string
daddr string
counter func() int64
}
func defaultCounter() int64 {
return time.Now().Unix() >> 6
}
func NewSYNCookie(daddr string, seed int64, counter func() int64) SYNCookie {
s := SYNCookie{
daddr: daddr,
counter: counter,
}
if s.counter == nil {
s.counter = defaultCounter
}
// https://www.calhoun.io/creating-random-strings-in-go/
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
seededRand := rand.New(rand.NewSource(seed))
stringWithCharset := func(length int, charset string) string {
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
s.secret1 = stringWithCharset(32, charset)
s.secret2 = stringWithCharset(32, charset)
return s
}
func (s *SYNCookie) Get(saddr string) uint32 {
return s.calculate(s.counter(), saddr)
}
func (s *SYNCookie) Verify(cookie uint32, saddr string) bool {
counter := s.counter()
if s.calculate(counter, saddr) == cookie {
return true
}
if s.calculate(counter-1, saddr) == cookie {
return true
}
return false
}
func (s *SYNCookie) calculate(counter int64, saddr string) uint32 {
data := s.secret1 + s.daddr + saddr + s.secret2 + strconv.FormatInt(counter, 10)
md5sum := md5.Sum([]byte(data))
return binary.BigEndian.Uint32(md5sum[0:])
}

File diff suppressed because it is too large Load Diff

702
vendor/github.com/datarhei/gosrt/listen.go generated vendored Normal file
View File

@@ -0,0 +1,702 @@
package srt
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"os"
"sync"
"syscall"
"time"
"github.com/datarhei/gosrt/internal/crypto"
srtnet "github.com/datarhei/gosrt/internal/net"
"github.com/datarhei/gosrt/internal/packet"
)
// ConnType represents the kind of connection as returned
// from the AcceptFunc. It is one of REJECT, PUBLISH, or SUBSCRIBE.
type ConnType int
// String returns a string representation of the ConnType.
func (c ConnType) String() string {
switch c {
case REJECT:
return "REJECT"
case PUBLISH:
return "PUBLISH"
case SUBSCRIBE:
return "SUBSCRIBE"
default:
return ""
}
}
const (
REJECT ConnType = ConnType(1 << iota) // Reject a connection
PUBLISH // This connection is meant to write data to the server
SUBSCRIBE // This connection is meant to read data from a PUBLISHed stream
)
// ConnRequest is an incoming connection request
type ConnRequest interface {
// RemoteAddr returns the address of the peer. The returned net.Addr
// is a copy and can be used at will.
RemoteAddr() net.Addr
// StreamId returns the streamid of the requesting connection. Use this
// to decide what to do with the connection.
StreamId() string
// IsEncrypted returns whether the connection is encrypted. If it is
// encrypted, use SetPassphrase to set the passphrase for decrypting.
IsEncrypted() bool
// SetPassphrase sets the passphrase in order to decrypt the incoming
// data. Returns an error if the passphrase did not work or the connection
// is not encrypted.
SetPassphrase(p string) error
}
// connRequest implements the ConnRequest interface
type connRequest struct {
addr net.Addr
start time.Time
socketId uint32
timestamp uint32
handshake *packet.CIFHandshake
crypto crypto.Crypto
passphrase string
}
func (req *connRequest) RemoteAddr() net.Addr {
addr, _ := net.ResolveUDPAddr("udp", req.addr.String())
return addr
}
func (req *connRequest) StreamId() string {
return req.handshake.StreamId
}
func (req *connRequest) IsEncrypted() bool {
return req.crypto != nil
}
func (req *connRequest) SetPassphrase(passphrase string) error {
if req.crypto == nil {
return fmt.Errorf("listen: request without encryption")
}
if err := req.crypto.UnmarshalKM(req.handshake.SRTKM, passphrase); err != nil {
return err
}
req.passphrase = passphrase
return nil
}
// ErrListenerClosed is returned when the listener is about to shutdown.
var ErrListenerClosed = errors.New("srt: listener closed")
// AcceptFunc receives a connection request and returns the type of connection
// and is required by the Listener for each Accept of a new connection.
type AcceptFunc func(req ConnRequest) ConnType
// Listener waits for new connections
type Listener interface {
// Accept waits for new connections. For each new connection the AcceptFunc
// gets called. Conn is a new connection if AcceptFunc is PUBLISH or SUBSCRIBE.
// If AcceptFunc returns REJECT, Conn is nil. In case of failure error is not
// nil, Conn is nil and ConnType is REJECT. On closing the listener err will
// be ErrListenerClosed and ConnType is REJECT.
Accept(AcceptFunc) (Conn, ConnType, error)
// Close closes the listener. It will stop accepting new connections and
// close all currently established connections.
Close()
// Addr returns the address of the listener.
Addr() net.Addr
}
// listener implements the Listener interface.
type listener struct {
pc *net.UDPConn
addr net.Addr
config Config
backlog chan connRequest
conns map[uint32]*srtConn
lock sync.RWMutex
start time.Time
rcvQueue chan packet.Packet
sndQueue chan packet.Packet
syncookie srtnet.SYNCookie
shutdown bool
shutdownLock sync.RWMutex
shutdownOnce sync.Once
stopReader context.CancelFunc
stopWriter context.CancelFunc
doneChan chan error
}
// Listen returns a new listener on the SRT protocol on the address with
// the provided config. The network parameter needs to be "srt".
//
// The address has the form "host:port".
//
// Examples:
// Listen("srt", "127.0.0.1:3000", DefaultConfig())
//
// In case of an error, the returned Listener is nil and the error is non-nil.
func Listen(network, address string, config Config) (Listener, error) {
if network != "srt" {
return nil, fmt.Errorf("listen: the network must be 'srt'")
}
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("listen: invalid config: %w", err)
}
if config.Logger == nil {
config.Logger = NewLogger(nil)
}
ln := &listener{
config: config,
}
raddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, fmt.Errorf("listen: unable to resolve address: %w", err)
}
pc, err := net.ListenUDP("udp", raddr)
if err != nil {
return nil, fmt.Errorf("listen: failed listening: %w", err)
}
file, err := pc.File()
if err != nil {
return nil, err
}
// Set TOS
if config.IPTOS > 0 {
err = syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_TOS, config.IPTOS)
if err != nil {
return nil, fmt.Errorf("listen: failed setting socket option TOS: %w", err)
}
}
// Set TTL
if config.IPTTL > 0 {
err = syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_TTL, config.IPTTL)
if err != nil {
return nil, fmt.Errorf("listen: failed setting socket option TTL: %w", err)
}
}
ln.pc = pc
ln.addr = pc.LocalAddr()
ln.conns = make(map[uint32]*srtConn)
ln.backlog = make(chan connRequest, 128)
ln.rcvQueue = make(chan packet.Packet, 2048)
ln.sndQueue = make(chan packet.Packet, 2048)
ln.syncookie = srtnet.NewSYNCookie(ln.addr.String(), time.Now().UnixNano(), nil)
ln.doneChan = make(chan error)
ln.start = time.Now()
var readerCtx context.Context
readerCtx, ln.stopReader = context.WithCancel(context.Background())
go ln.reader(readerCtx)
var writerCtx context.Context
writerCtx, ln.stopWriter = context.WithCancel(context.Background())
go ln.writer(writerCtx)
go func() {
buffer := make([]byte, config.MSS) // MTU size
for {
if ln.isShutdown() {
ln.doneChan <- ErrListenerClosed
return
}
ln.pc.SetReadDeadline(time.Now().Add(3 * time.Second))
n, addr, err := ln.pc.ReadFrom(buffer)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
if ln.isShutdown() {
ln.doneChan <- ErrListenerClosed
return
}
ln.doneChan <- err
return
}
p := packet.NewPacket(addr, buffer[:n])
if p == nil {
continue
}
ln.rcvQueue <- p
}
}()
return ln, nil
}
func (ln *listener) Accept(acceptFn AcceptFunc) (Conn, ConnType, error) {
if ln.isShutdown() {
return nil, REJECT, ErrListenerClosed
}
select {
case err := <-ln.doneChan:
return nil, REJECT, err
case request := <-ln.backlog:
if acceptFn == nil {
ln.reject(request, packet.REJ_PEER)
break
}
mode := acceptFn(&request)
if mode != PUBLISH && mode != SUBSCRIBE {
ln.reject(request, packet.REJ_PEER)
break
}
if request.crypto != nil && len(request.passphrase) == 0 {
ln.reject(request, packet.REJ_BADSECRET)
break
}
// Create a new socket ID
socketId := uint32(time.Since(ln.start).Microseconds())
// Select the largest TSBPD delay advertised by the caller, but at
// least 120ms
tsbpdDelay := uint16(120)
if request.handshake.RecvTSBPDDelay > tsbpdDelay {
tsbpdDelay = request.handshake.RecvTSBPDDelay
}
if request.handshake.SendTSBPDDelay > tsbpdDelay {
tsbpdDelay = request.handshake.SendTSBPDDelay
}
ln.config.StreamId = request.handshake.StreamId
ln.config.Passphrase = request.passphrase
// Create a new connection
conn := newSRTConn(srtConnConfig{
localAddr: ln.addr,
remoteAddr: request.addr,
config: ln.config,
start: request.start,
socketId: socketId,
peerSocketId: request.handshake.SRTSocketId,
tsbpdTimeBase: uint64(request.timestamp),
tsbpdDelay: uint64(tsbpdDelay) * 1000,
initialPacketSequenceNumber: request.handshake.InitialPacketSequenceNumber,
crypto: request.crypto,
keyBaseEncryption: packet.EvenKeyEncrypted,
onSend: ln.send,
onShutdown: ln.handleShutdown,
logger: ln.config.Logger,
})
ln.log("connection:new", func() string { return fmt.Sprintf("%#08x (%s) %s", conn.SocketId(), conn.StreamId(), mode) })
request.handshake.SRTSocketId = socketId
request.handshake.SynCookie = 0
// 3.2.1.1.1. Handshake Extension Message Flags
request.handshake.SRTVersion = 0x00010402
request.handshake.SRTFlags.TSBPDSND = true
request.handshake.SRTFlags.TSBPDRCV = true
request.handshake.SRTFlags.CRYPT = true
request.handshake.SRTFlags.TLPKTDROP = true
request.handshake.SRTFlags.PERIODICNAK = true
request.handshake.SRTFlags.REXMITFLG = true
request.handshake.SRTFlags.STREAM = false
request.handshake.SRTFlags.PACKET_FILTER = false
ln.accept(request)
// Add the connection to the list of known connections
ln.lock.Lock()
ln.conns[conn.socketId] = conn
ln.lock.Unlock()
return conn, mode, nil
}
return nil, REJECT, nil
}
func (ln *listener) handleShutdown(socketId uint32) {
ln.lock.Lock()
delete(ln.conns, socketId)
ln.lock.Unlock()
}
func (ln *listener) reject(request connRequest, reason packet.HandshakeType) {
p := packet.NewPacket(request.addr, nil)
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(ln.start).Microseconds())
p.Header().DestinationSocketId = request.socketId
request.handshake.HandshakeType = reason
p.MarshalCIF(request.handshake)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return request.handshake.String() })
ln.send(p)
}
func (ln *listener) accept(request connRequest) {
p := packet.NewPacket(request.addr, nil)
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(request.start).Microseconds())
p.Header().DestinationSocketId = request.socketId
p.MarshalCIF(request.handshake)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return request.handshake.String() })
ln.send(p)
}
func (ln *listener) isShutdown() bool {
ln.shutdownLock.RLock()
defer ln.shutdownLock.RUnlock()
return ln.shutdown
}
func (ln *listener) Close() {
ln.shutdownOnce.Do(func() {
ln.shutdownLock.Lock()
ln.shutdown = true
ln.shutdownLock.Unlock()
ln.lock.RLock()
for _, conn := range ln.conns {
conn.close()
}
ln.lock.RUnlock()
ln.stopReader()
ln.stopWriter()
ln.log("listen", func() string { return "closing socket" })
ln.pc.Close()
})
}
func (ln *listener) Addr() net.Addr {
addr, _ := net.ResolveUDPAddr("udp", ln.addr.String())
return addr
}
func (ln *listener) reader(ctx context.Context) {
defer func() {
ln.log("listen", func() string { return "left reader loop" })
}()
ln.log("listen", func() string { return "reader loop started" })
for {
select {
case <-ctx.Done():
return
case p := <-ln.rcvQueue:
if ln.isShutdown() {
break
}
ln.log("packet:recv:dump", func() string { return p.Dump() })
if p.Header().DestinationSocketId == 0 {
if p.Header().IsControlPacket && p.Header().ControlType == packet.CTRLTYPE_HANDSHAKE {
ln.handleHandshake(p)
}
break
}
ln.lock.RLock()
conn, ok := ln.conns[p.Header().DestinationSocketId]
ln.lock.RUnlock()
if !ok {
// ignore the packet, we don't know the destination
break
}
conn.push(p)
}
}
}
func (ln *listener) send(p packet.Packet) {
// non-blocking
select {
case ln.sndQueue <- p:
default:
ln.log("listen", func() string { return "send queue is full" })
}
}
func (ln *listener) writer(ctx context.Context) {
defer func() {
ln.log("listen", func() string { return "left writer loop" })
}()
ln.log("listen", func() string { return "writer loop started" })
var data bytes.Buffer
for {
select {
case <-ctx.Done():
return
case p := <-ln.sndQueue:
data.Reset()
p.Marshal(&data)
buffer := data.Bytes()
ln.log("packet:send:dump", func() string { return p.Dump() })
// Write the packet's contents to the wire
ln.pc.WriteTo(buffer, p.Header().Addr)
if p.Header().IsControlPacket {
// Control packets can be decommissioned because they will not be sent again (data packets might be retransferred)
p.Decommission()
}
}
}
}
func (ln *listener) handleHandshake(p packet.Packet) {
cif := &packet.CIFHandshake{}
err := p.UnmarshalCIF(cif)
ln.log("handshake:recv:dump", func() string { return p.Dump() })
ln.log("handshake:recv:cif", func() string { return cif.String() })
if err != nil {
ln.log("handshake:recv:error", func() string { return err.Error() })
return
}
// Assemble the response (4.3.1. Caller-Listener Handshake)
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(ln.start).Microseconds())
p.Header().DestinationSocketId = cif.SRTSocketId
cif.PeerIP.FromNetAddr(ln.addr)
if cif.HandshakeType == packet.HSTYPE_INDUCTION {
// cif
cif.Version = 5
cif.EncryptionField = 0 // Don't advertise any specific encryption method
cif.ExtensionField = 0x4A17
//cif.initialPacketSequenceNumber = newCircular(0, MAX_SEQUENCENUMBER)
//cif.maxTransmissionUnitSize = 0
//cif.maxFlowWindowSize = 0
cif.SRTSocketId = 0
cif.SynCookie = ln.syncookie.Get(p.Header().Addr.String())
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
} else if cif.HandshakeType == packet.HSTYPE_CONCLUSION {
// Verify the SYN cookie
if !ln.syncookie.Verify(cif.SynCookie, p.Header().Addr.String()) {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return "invalid SYN cookie" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// We only support HSv5
if cif.Version != 5 {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return "only HSv5 is supported" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Check if the peer version is sufficient
if cif.SRTVersion < ln.config.MinVersion {
cif.HandshakeType = packet.REJ_VERSION
ln.log("handshake:recv:error", func() string {
return fmt.Sprintf("peer version insufficient (%#06x), expecting at least %#06x", cif.SRTVersion, ln.config.MinVersion)
})
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Check the required SRT flags
if !cif.SRTFlags.TSBPDSND || !cif.SRTFlags.TSBPDRCV || !cif.SRTFlags.TLPKTDROP || !cif.SRTFlags.PERIODICNAK || !cif.SRTFlags.REXMITFLG {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return "not all required flags are set" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// We only support live streaming
if cif.SRTFlags.STREAM {
cif.HandshakeType = packet.REJ_MESSAGEAPI
ln.log("handshake:recv:error", func() string { return "only live streaming is supported" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// Peer is advertising a too big MSS
if cif.MaxTransmissionUnitSize > MAX_MSS_SIZE {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("MTU is too big (%d bytes)", cif.MaxTransmissionUnitSize) })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
// If the peer has a smaller MTU size, adjust to it
if cif.MaxTransmissionUnitSize < ln.config.MSS {
ln.config.MSS = cif.MaxTransmissionUnitSize
ln.config.PayloadSize = ln.config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE
if ln.config.PayloadSize < MIN_PAYLOAD_SIZE {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("payload size is too small (%d bytes)", ln.config.PayloadSize) })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
}
}
// Fill up a connection request with all relevant data and put it into the backlog
c := connRequest{
addr: p.Header().Addr,
start: time.Now(),
socketId: cif.SRTSocketId,
timestamp: p.Header().Timestamp,
handshake: cif,
}
if cif.SRTKM != nil {
cr, err := crypto.New(int(cif.SRTKM.KLen))
if err != nil {
cif.HandshakeType = packet.REJ_ROGUE
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("crypto: %s", err) })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}
c.crypto = cr
}
// If the backlog is full, reject the connection
select {
case ln.backlog <- c:
default:
cif.HandshakeType = packet.REJ_BACKLOG
ln.log("handshake:recv:error", func() string { return "backlog is full" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
}
} else {
if cif.HandshakeType.IsRejection() {
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("connection rejected: %s", cif.HandshakeType.String()) })
} else {
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("unsupported handshake: %s", cif.HandshakeType.String()) })
}
}
}
func (ln *listener) log(topic string, message func() string) {
ln.config.Logger.Print(topic, 0, 2, message)
}

114
vendor/github.com/datarhei/gosrt/log.go generated vendored Normal file
View File

@@ -0,0 +1,114 @@
package srt
import (
"runtime"
"strings"
"time"
)
// Logger is for logging debug messages.
type Logger interface {
// HasTopic returns whether this Logger is logging messages of that topic.
HasTopic(topic string) bool
// Print adds a new message to the message queue. The message itself is
// a function that returns the string to be logges. It will only be
// executed if HasTopic returns true on the given topic.
Print(topic string, socketId uint32, skip int, message func() string)
// Listen returns a read channel for Log messages.
Listen() <-chan Log
// Close closes the logger. No more messages will be logged.
Close()
}
// logger implements a Logger
type logger struct {
logQueue chan Log
topics map[string]bool
}
// NewLogger returns a Logger that only listens on the given list of topics.
func NewLogger(topics []string) Logger {
l := &logger{
logQueue: make(chan Log, 1024),
topics: make(map[string]bool),
}
for _, topic := range topics {
l.topics[topic] = true
}
return l
}
func (l *logger) HasTopic(topic string) bool {
if len(l.topics) == 0 {
return false
}
if ok := l.topics[topic]; ok {
return true
}
len := len(topic)
for {
i := strings.LastIndexByte(topic[:len], ':')
if i == -1 {
break
}
len = i
if ok := l.topics[topic[:len]]; !ok {
continue
}
return true
}
return false
}
func (l *logger) Print(topic string, socketId uint32, skip int, message func() string) {
if !l.HasTopic(topic) {
return
}
_, file, line, _ := runtime.Caller(skip)
msg := Log{
Time: time.Now(),
SocketId: socketId,
Topic: topic,
Message: message(),
File: file,
Line: line,
}
// Write to log queue, but don't block if it's full
select {
case l.logQueue <- msg:
default:
}
}
func (l *logger) Listen() <-chan Log {
return l.logQueue
}
func (l *logger) Close() {
close(l.logQueue)
}
// Log represents a log message
type Log struct {
Time time.Time // Time of when this message has been logged
SocketId uint32 // The socketid if connection related, 0 otherwise
Topic string // The topic of this message
Message string // The message itself
File string // The file in which this message has been dispatched
Line int // The line number in the file in which this message has been dispatched
}

170
vendor/github.com/datarhei/gosrt/pubsub.go generated vendored Normal file
View File

@@ -0,0 +1,170 @@
package srt
import (
"context"
"fmt"
"io"
"sync"
"github.com/datarhei/gosrt/internal/packet"
)
// PubSub is a publish/subscriber service for SRT connections.
type PubSub interface {
// Publish accepts a SRT connection where it reads from. It blocks
// until the connection closes. The returned error indicated why it
// stopped. There can be only one publisher.
Publish(c Conn) error
// Subscribe accepts a SRT connection where it writes the data from
// the publisher to. It blocks until an error happens. If the publisher
// disconnects, io.EOF is returned. There can be an arbitrary number
// of subscribers.
Subscribe(c Conn) error
}
// pubSub is an implementation of the PubSub interface
type pubSub struct {
incoming chan packet.Packet
ctx context.Context
cancel context.CancelFunc
publish bool
publishLock sync.Mutex
listeners map[uint32]chan packet.Packet
listenersLock sync.Mutex
logger Logger
}
// PubSubConfig is for configuring a new PubSub
type PubSubConfig struct {
Logger Logger // Optional logger
}
// NewPubSub returns a PubSub. After the publishing connection closed
// this PubSub can't be used anymore.
func NewPubSub(config PubSubConfig) PubSub {
pb := &pubSub{
incoming: make(chan packet.Packet, 1024),
listeners: make(map[uint32]chan packet.Packet),
logger: config.Logger,
}
pb.ctx, pb.cancel = context.WithCancel(context.Background())
if pb.logger == nil {
pb.logger = NewLogger(nil)
}
go pb.broadcast()
return pb
}
func (pb *pubSub) broadcast() {
defer func() {
pb.logger.Print("pubsub:close", 0, 1, func() string { return "exiting broadcast loop" })
}()
pb.logger.Print("pubsub:new", 0, 1, func() string { return "starting broadcast loop" })
for {
select {
case <-pb.ctx.Done():
return
case p := <-pb.incoming:
pb.listenersLock.Lock()
for socketId, c := range pb.listeners {
pp := p.Clone()
select {
case c <- pp:
default:
pb.logger.Print("pubsub:error", socketId, 1, func() string { return "broadcast target queue is full" })
}
}
pb.listenersLock.Unlock()
// We don't need this packet anymore
p.Decommission()
}
}
}
func (pb *pubSub) Publish(c Conn) error {
pb.publishLock.Lock()
defer pb.publishLock.Unlock()
if pb.publish {
err := fmt.Errorf("only one publisher is allowed")
pb.logger.Print("pubsub:error", 0, 1, func() string { return err.Error() })
return err
}
var p packet.Packet
var err error
conn, ok := c.(*srtConn)
if !ok {
err := fmt.Errorf("the provided connection is not a SRT connection")
pb.logger.Print("pubsub:error", 0, 1, func() string { return err.Error() })
return err
}
pb.logger.Print("pubsub:publish", conn.SocketId(), 1, func() string { return "new publisher" })
pb.publish = true
for {
p, err = conn.readPacket()
if err != nil {
pb.logger.Print("pubsub:error", conn.SocketId(), 1, func() string { return err.Error() })
break
}
select {
case pb.incoming <- p:
default:
pb.logger.Print("pubsub:error", conn.SocketId(), 1, func() string { return "incoming queue is full" })
}
}
pb.cancel()
return err
}
func (pb *pubSub) Subscribe(c Conn) error {
l := make(chan packet.Packet, 1024)
socketId := c.SocketId()
conn, ok := c.(*srtConn)
if !ok {
err := fmt.Errorf("the provided connection is not a SRT connection")
pb.logger.Print("pubsub:error", 0, 1, func() string { return err.Error() })
return err
}
pb.logger.Print("pubsub:subscribe", socketId, 1, func() string { return "new subscriber" })
pb.listenersLock.Lock()
pb.listeners[socketId] = l
pb.listenersLock.Unlock()
defer func() {
pb.listenersLock.Lock()
delete(pb.listeners, socketId)
pb.listenersLock.Unlock()
}()
for {
select {
case <-pb.ctx.Done():
return io.EOF
case p := <-l:
err := conn.writePacket(p)
p.Decommission()
if err != nil {
pb.logger.Print("pubsub:error", socketId, 1, func() string { return err.Error() })
return err
}
}
}
}

98
vendor/github.com/datarhei/gosrt/server.go generated vendored Normal file
View File

@@ -0,0 +1,98 @@
package srt
import (
"errors"
)
// Server is a framework for a SRT server
type Server struct {
// The address the SRT server should listen on, e.g. ":6001".
Addr string
// Config is the configuration for a SRT listener.
Config *Config
// HandleConnect will be called for each incoming connection. This
// allows you to implement your own interpretation of the streamid
// and authorization. If this is nil, all connections will be
// rejected.
HandleConnect AcceptFunc
// HandlePublish will be called for a publishing connection.
HandlePublish func(conn Conn)
// HandlePublish will be called for a subscribing connection.
HandleSubscribe func(conn Conn)
ln Listener
}
// ErrServerClosed is returned when the server is about to shutdown.
var ErrServerClosed = errors.New("srt: server closed")
// ListenAndServe starts the SRT server. It blocks until an error happens.
// If the error is ErrServerClosed the server has shutdown normally.
func (s *Server) ListenAndServe() error {
// Set some defaults if required.
if s.HandlePublish == nil {
s.HandlePublish = s.defaultHandler
}
if s.HandleSubscribe == nil {
s.HandleSubscribe = s.defaultHandler
}
if s.Config == nil {
config := DefaultConfig()
s.Config = &config
}
// Start listening for incoming connections.
ln, err := Listen("srt", s.Addr, *s.Config)
if err != nil {
return err
}
defer ln.Close()
s.ln = ln
for {
// Wait for connections.
conn, mode, err := ln.Accept(s.HandleConnect)
if err != nil {
if err == ErrListenerClosed {
return ErrServerClosed
}
return err
}
if conn == nil {
// rejected connection, ignore
continue
}
if mode == PUBLISH {
go s.HandlePublish(conn)
} else {
go s.HandleSubscribe(conn)
}
}
}
// Shutdown will shutdown the server. ListenAndServe will return a ErrServerClosed
func (s *Server) Shutdown() {
if s.ln == nil {
return
}
// Close the listener
s.ln.Close()
s.ln = nil
}
func (s *Server) defaultHandler(conn Conn) {
// Close the incoming connection
conn.Close()
}

61
vendor/github.com/datarhei/gosrt/statistics.go generated vendored Normal file
View File

@@ -0,0 +1,61 @@
// https://github.com/Haivision/srt/blob/master/docs/API/statistics.md
package srt
// Statistics represents the statistics for a connection
type Statistics struct {
MsTimeStamp uint64 // The time elapsed, in milliseconds, since the SRT socket has been created
// Accumulated
PktSent uint64 // The total number of sent DATA packets, including retransmitted packets
PktRecv uint64 // The total number of received DATA packets, including retransmitted packets
PktSentUnique uint64 // The total number of unique DATA packets sent by the SRT sender
PktRecvUnique uint64 // The total number of unique original, retransmitted or recovered by the packet filter DATA packets received in time, decrypted without errors and, as a result, scheduled for delivery to the upstream application by the SRT receiver.
PktSndLoss uint64 // The total number of data packets considered or reported as lost at the sender side. Does not correspond to the packets detected as lost at the receiver side.
PktRcvLoss uint64 // The total number of SRT DATA packets detected as presently missing (either reordered or lost) at the receiver side
PktRetrans uint64 // The total number of retransmitted packets sent by the SRT sender
PktRcvRetrans uint64 // The total number of retransmitted packets registered at the receiver side
PktSentACK uint64 // The total number of sent ACK (Acknowledgement) control packets
PktRecvACK uint64 // The total number of received ACK (Acknowledgement) control packets
PktSentNAK uint64 // The total number of sent NAK (Negative Acknowledgement) control packets
PktRecvNAK uint64 // The total number of received NAK (Negative Acknowledgement) control packets
PktSentKM uint64 // The total number of sent KM (Key Material) control packets
PktRecvKM uint64 // The total number of received KM (Key Material) control packets
UsSndDuration uint64 // The total accumulated time in microseconds, during which the SRT sender has some data to transmit, including packets that have been sent, but not yet acknowledged
PktSndDrop uint64 // The total number of dropped by the SRT sender DATA packets that have no chance to be delivered in time
PktRcvDrop uint64 // The total number of dropped by the SRT receiver and, as a result, not delivered to the upstream application DATA packets
PktRcvUndecrypt uint64 // The total number of packets that failed to be decrypted at the receiver side
ByteSent uint64 // Same as pktSent, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
ByteRecv uint64 // Same as pktRecv, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
ByteSentUnique uint64 // Same as pktSentUnique, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
ByteRecvUnique uint64 // Same as pktRecvUnique, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
ByteRcvLoss uint64 // Same as pktRcvLoss, but expressed in bytes, including payload and all the headers (IP, TCP, SRT), bytes for the presently missing (either reordered or lost) packets' payloads are estimated based on the average packet size
ByteRetrans uint64 // Same as pktRetrans, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
ByteSndDrop uint64 // Same as pktSndDrop, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
ByteRcvDrop uint64 // Same as pktRcvDrop, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
ByteRcvUndecrypt uint64 // Same as pktRcvUndecrypt, but expressed in bytes, including payload and all the headers (IP, TCP, SRT)
// Instantaneous
UsPktSndPeriod float64 // Current minimum time interval between which consecutive packets are sent, in microseconds
PktFlowWindow uint64 // The maximum number of packets that can be "in flight"
PktFlightSize uint64 // The number of packets in flight
MsRTT float64 // Smoothed round-trip time (SRTT), an exponentially-weighted moving average (EWMA) of an endpoint's RTT samples, in milliseconds
MbpsBandwidth float64 // Estimated bandwidth of the network link, in Mbps
ByteAvailSndBuf uint64 // The available space in the sender's buffer, in bytes
ByteAvailRcvBuf uint64 // The available space in the receiver's buffer, in bytes
MbpsMaxBW float64 // Transmission bandwidth limit, in Mbps
ByteMSS uint64 // Maximum Segment Size (MSS), in bytes
PktSndBuf uint64 // The number of packets in the sender's buffer that are already scheduled for sending or even possibly sent, but not yet acknowledged
ByteSndBuf uint64 // Instantaneous (current) value of pktSndBuf, but expressed in bytes, including payload and all headers (IP, TCP, SRT)
MsSndBuf uint64 // The timespan (msec) of packets in the sender's buffer (unacknowledged packets)
MsSndTsbPdDelay uint64 // Timestamp-based Packet Delivery Delay value of the peer
PktRcvBuf uint64 // The number of acknowledged packets in receiver's buffer
ByteRcvBuf uint64 // Instantaneous (current) value of pktRcvBuf, expressed in bytes, including payload and all headers (IP, TCP, SRT)
MsRcvBuf uint64 // The timespan (msec) of acknowledged packets in the receiver's buffer
MsRcvTsbPdDelay uint64 // Timestamp-based Packet Delivery Delay value set on the socket via SRTO_RCVLATENCY or SRTO_LATENCY
PktReorderTolerance uint64 // Instant value of the packet reorder tolerance
PktRcvAvgBelatedTime uint64 // Accumulated difference between the current time and the time-to-play of a packet that is received late
}

412
vendor/github.com/invopop/jsonschema/README.md generated vendored Normal file
View File

@@ -0,0 +1,412 @@
# Go JSON Schema Reflection
[![Lint](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml)
[![Test Go](https://github.com/invopop/jsonschema/actions/workflows/test.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/test.yaml)
[![Go Report Card](https://goreportcard.com/badge/github.com/invopop/jsonschema)](https://goreportcard.com/report/github.com/invopop/jsonschema)
[![GoDoc](https://godoc.org/github.com/invopop/jsonschema?status.svg)](https://godoc.org/github.com/invopop/jsonschema)
![Latest Tag](https://img.shields.io/github/v/tag/invopop/jsonschema)
This package can be used to generate [JSON Schemas](http://json-schema.org/latest/json-schema-validation.html) from Go types through reflection.
- Supports arbitrarily complex types, including `interface{}`, maps, slices, etc.
- Supports json-schema features such as minLength, maxLength, pattern, format, etc.
- Supports simple string and numeric enums.
- Supports custom property fields via the `jsonschema_extras` struct tag.
This repository is a fork of the original [jsonschema](https://github.com/alecthomas/jsonschema) by [@alecthomas](https://github.com/alecthomas). At [Invopop](https://invopop.com) we use jsonschema as a cornerstone in our [GOBL library](https://github.com/invopop/gobl), and wanted to be able to continue building and adding features without taking up Alec's time. There have been a few significant changes that probably mean this version is a not compatible with with Alec's:
- The original was stuck on the draft-04 version of JSON Schema, we've now moved to the latest JSON Schema Draft 2020-12.
- Schema IDs are added automatically from the current Go package's URL in order to be unique, and can be disabled with the `Anonymous` option.
- Support for the `FullyQualifyTypeName` option has been removed. If you have conflicts, you should use multiple schema files with different IDs, set the `DoNotReference` option to true to hide definitions completely, or add your own naming strategy using the `Namer` property.
## Example
The following Go type:
```go
type TestUser struct {
ID int `json:"id"`
Name string `json:"name" jsonschema:"title=the name,description=The name of a friend,example=joe,example=lucy,default=alex"`
Friends []int `json:"friends,omitempty" jsonschema_description:"The list of IDs, omitted when empty"`
Tags map[string]interface{} `json:"tags,omitempty" jsonschema_extras:"a=b,foo=bar,foo=bar1"`
BirthDate time.Time `json:"birth_date,omitempty" jsonschema:"oneof_required=date"`
YearOfBirth string `json:"year_of_birth,omitempty" jsonschema:"oneof_required=year"`
Metadata interface{} `json:"metadata,omitempty" jsonschema:"oneof_type=string;array"`
FavColor string `json:"fav_color,omitempty" jsonschema:"enum=red,enum=green,enum=blue"`
}
```
Results in following JSON Schema:
```go
jsonschema.Reflect(&TestUser{})
```
```json
{
"$schema": "http://json-schema.org/draft/2020-12/schema",
"$ref": "#/$defs/SampleUser",
"$defs": {
"SampleUser": {
"oneOf": [
{
"required": ["birth_date"],
"title": "date"
},
{
"required": ["year_of_birth"],
"title": "year"
}
],
"properties": {
"id": {
"type": "integer"
},
"name": {
"type": "string",
"title": "the name",
"description": "The name of a friend",
"default": "alex",
"examples": ["joe", "lucy"]
},
"friends": {
"items": {
"type": "integer"
},
"type": "array",
"description": "The list of IDs, omitted when empty"
},
"tags": {
"type": "object",
"a": "b",
"foo": ["bar", "bar1"]
},
"birth_date": {
"type": "string",
"format": "date-time"
},
"year_of_birth": {
"type": "string"
},
"metadata": {
"oneOf": [
{
"type": "string"
},
{
"type": "array"
}
]
},
"fav_color": {
"type": "string",
"enum": ["red", "green", "blue"]
}
},
"additionalProperties": false,
"type": "object",
"required": ["id", "name"]
}
}
}
```
## Configurable behaviour
The behaviour of the schema generator can be altered with parameters when a `jsonschema.Reflector`
instance is created.
### ExpandedStruct
If set to `true`, makes the top level struct not to reference itself in the definitions. But type passed should be a struct type.
eg.
```go
type GrandfatherType struct {
FamilyName string `json:"family_name" jsonschema:"required"`
}
type SomeBaseType struct {
SomeBaseProperty int `json:"some_base_property"`
// The jsonschema required tag is nonsensical for private and ignored properties.
// Their presence here tests that the fields *will not* be required in the output
// schema, even if they are tagged required.
somePrivateBaseProperty string `json:"i_am_private" jsonschema:"required"`
SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"`
SomeSchemaIgnoredProperty string `jsonschema:"-,required"`
SomeUntaggedBaseProperty bool `jsonschema:"required"`
someUnexportedUntaggedBaseProperty bool
Grandfather GrandfatherType `json:"grand"`
}
```
will output:
```json
{
"$schema": "http://json-schema.org/draft/2020-12/schema",
"required": ["some_base_property", "grand", "SomeUntaggedBaseProperty"],
"properties": {
"SomeUntaggedBaseProperty": {
"type": "boolean"
},
"grand": {
"$schema": "http://json-schema.org/draft/2020-12/schema",
"$ref": "#/definitions/GrandfatherType"
},
"some_base_property": {
"type": "integer"
}
},
"type": "object",
"$defs": {
"GrandfatherType": {
"required": ["family_name"],
"properties": {
"family_name": {
"type": "string"
}
},
"additionalProperties": false,
"type": "object"
}
}
}
```
### PreferYAMLSchema
JSON schemas can also be used to validate YAML, however YAML frequently uses
different identifiers to JSON indicated by the `yaml:` tag. The `Reflector` will
by default prefer `json:` tags over `yaml:` tags (and only use the latter if the
former are not present). This behavior can be changed via the `PreferYAMLSchema`
flag, that will switch this behavior: `yaml:` tags will be preferred over
`json:` tags.
With `PreferYAMLSchema: true`, the following struct:
```go
type Person struct {
FirstName string `json:"FirstName" yaml:"first_name"`
}
```
would result in this schema:
```json
{
"$schema": "http://json-schema.org/draft/2020-12/schema",
"$ref": "#/$defs/TestYamlAndJson",
"$defs": {
"Person": {
"required": ["first_name"],
"properties": {
"first_name": {
"type": "string"
}
},
"additionalProperties": false,
"type": "object"
}
}
}
```
whereas without the flag one obtains:
```json
{
"$schema": "http://json-schema.org/draft/2020-12/schema",
"$ref": "#/$defs/TestYamlAndJson",
"$defs": {
"Person": {
"required": ["FirstName"],
"properties": {
"first_name": {
"type": "string"
}
},
"additionalProperties": false,
"type": "object"
}
}
}
```
### Using Go Comments
Writing a good schema with descriptions inside tags can become cumbersome and tedious, especially if you already have some Go comments around your types and field definitions. If you'd like to take advantage of these existing comments, you can use the `AddGoComments(base, path string)` method that forms part of the reflector to parse your go files and automatically generate a dictionary of Go import paths, types, and fields, to individual comments. These will then be used automatically as description fields, and can be overridden with a manual definition if needed.
Take a simplified example of a User struct which for the sake of simplicity we assume is defined inside this package:
```go
package main
// User is used as a base to provide tests for comments.
type User struct {
// Unique sequential identifier.
ID int `json:"id" jsonschema:"required"`
// Name of the user
Name string `json:"name"`
}
```
To get the comments provided into your JSON schema, use a regular `Reflector` and add the go code using an import module URL and path. Fully qualified go module paths cannot be determined reliably by the `go/parser` library, so we need to introduce this manually:
```go
r := new(Reflector)
if err := r.AddGoComments("github.com/invopop/jsonschema", "./"); err != nil {
// deal with error
}
s := r.Reflect(&User{})
// output
```
Expect the results to be similar to:
```json
{
"$schema": "http://json-schema.org/draft/2020-12/schema",
"$ref": "#/$defs/User",
"$defs": {
"User": {
"required": ["id"],
"properties": {
"id": {
"type": "integer",
"description": "Unique sequential identifier."
},
"name": {
"type": "string",
"description": "Name of the user"
}
},
"additionalProperties": false,
"type": "object",
"description": "User is used as a base to provide tests for comments."
}
}
}
```
### Custom Key Naming
In some situations, the keys actually used to write files are different from Go structs'.
This is often the case when writing a configuration file to YAML or JSON from a Go struct, or when returning a JSON response for a Web API: APIs typically use snake_case, while Go uses PascalCase.
You can pass a `func(string) string` function to `Reflector`'s `KeyNamer` option to map Go field names to JSON key names and reflect the aforementionned transformations, without having to specify `json:"..."` on every struct field.
For example, consider the following struct
```go
type User struct {
GivenName string
PasswordSalted []byte `json:"salted_password"`
}
```
We can transform field names to snake_case in the generated JSON schema:
```go
r := new(jsonschema.Reflector)
r.KeyNamer = strcase.SnakeCase // from package github.com/stoewer/go-strcase
r.Reflect(&User{})
```
Will yield
```diff
{
"$schema": "http://json-schema.org/draft/2020-12/schema",
"$ref": "#/$defs/User",
"$defs": {
"User": {
"properties": {
- "GivenName": {
+ "given_name": {
"type": "string"
},
"salted_password": {
"type": "string",
"contentEncoding": "base64"
}
},
"additionalProperties": false,
"type": "object",
- "required": ["GivenName", "salted_password"]
+ "required": ["given_name", "salted_password"]
}
}
}
```
As you can see, if a field name has a `json:""` or `yaml:""` tag set, the `key` argument to `KeyNamer` will have the value of that tag (if a field name has both, the value of `key` will respect [`PreferYAMLSchema`](#preferyamlschema)).
### Custom Type Definitions
Sometimes it can be useful to have custom JSON Marshal and Unmarshal methods in your structs that automatically convert for example a string into an object.
To override auto-generating an object type for your type, implement the `JSONSchema() *Schema` method and whatever is defined will be provided in the schema definitions.
Take the following simplified example of a `CompactDate` that only includes the Year and Month:
```go
type CompactDate struct {
Year int
Month int
}
func (d *CompactDate) UnmarshalJSON(data []byte) error {
if len(data) != 9 {
return errors.New("invalid compact date length")
}
var err error
d.Year, err = strconv.Atoi(string(data[1:5]))
if err != nil {
return err
}
d.Month, err = strconv.Atoi(string(data[7:8]))
if err != nil {
return err
}
return nil
}
func (d *CompactDate) MarshalJSON() ([]byte, error) {
buf := new(bytes.Buffer)
buf.WriteByte('"')
buf.WriteString(fmt.Sprintf("%d-%02d", d.Year, d.Month))
buf.WriteByte('"')
return buf.Bytes(), nil
}
func (CompactDate) JSONSchema() *Schema {
return &Schema{
Type: "string",
Title: "Compact Date",
Description: "Short date that only includes year and month",
Pattern: "^[0-9]{4}-[0-1][0-9]$",
}
}
```
The resulting schema generated for this struct would look like:
```json
{
"$schema": "http://json-schema.org/draft/2020-12/schema",
"$ref": "#/$defs/CompactDate",
"$defs": {
"CompactDate": {
"pattern": "^[0-9]{4}-[0-1][0-9]$",
"type": "string",
"title": "Compact Date",
"description": "Short date that only includes year and month"
}
}
}
```

View File

@@ -0,0 +1,90 @@
package jsonschema
import (
"fmt"
"io/fs"
gopath "path"
"path/filepath"
"strings"
"go/ast"
"go/doc"
"go/parser"
"go/token"
)
// ExtractGoComments will read all the go files contained in the provided path,
// including sub-directories, in order to generate a dictionary of comments
// associated with Types and Fields. The results will be added to the `commentsMap`
// provided in the parameters and expected to be used for Schema "description" fields.
//
// The `go/parser` library is used to extract all the comments and unfortunately doesn't
// have a built-in way to determine the fully qualified name of a package. The `base` paremeter,
// the URL used to import that package, is thus required to be able to match reflected types.
//
// When parsing type comments, we use the `go/doc`'s Synopsis method to extract the first phrase
// only. Field comments, which tend to be much shorter, will include everything.
func ExtractGoComments(base, path string, commentMap map[string]string) error {
fset := token.NewFileSet()
dict := make(map[string][]*ast.Package)
err := filepath.Walk(path, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
d, err := parser.ParseDir(fset, path, nil, parser.ParseComments)
if err != nil {
return err
}
for _, v := range d {
// paths may have multiple packages, like for tests
k := gopath.Join(base, path)
dict[k] = append(dict[k], v)
}
}
return nil
})
if err != nil {
return err
}
for pkg, p := range dict {
for _, f := range p {
gtxt := ""
typ := ""
ast.Inspect(f, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.TypeSpec:
typ = x.Name.String()
if !ast.IsExported(typ) {
typ = ""
} else {
txt := x.Doc.Text()
if txt == "" && gtxt != "" {
txt = gtxt
gtxt = ""
}
txt = doc.Synopsis(txt)
commentMap[fmt.Sprintf("%s.%s", pkg, typ)] = strings.TrimSpace(txt)
}
case *ast.Field:
txt := x.Doc.Text()
if typ != "" && txt != "" {
for _, n := range x.Names {
if ast.IsExported(n.String()) {
k := fmt.Sprintf("%s.%s.%s", pkg, typ, n)
commentMap[k] = strings.TrimSpace(txt)
}
}
}
case *ast.GenDecl:
// remember for the next type
gtxt = x.Doc.Text()
}
return true
})
}
}
return nil
}

View File

@@ -1,6 +1,6 @@
module github.com/alecthomas/jsonschema
module github.com/invopop/jsonschema
go 1.12
go 1.16
require (
github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0

76
vendor/github.com/invopop/jsonschema/id.go generated vendored Normal file
View File

@@ -0,0 +1,76 @@
package jsonschema
import (
"errors"
"fmt"
"net/url"
"strings"
)
// ID represents a Schema ID type which should always be a URI.
// See draft-bhutton-json-schema-00 section 8.2.1
type ID string
// EmptyID is used to explicitly define an ID with no value.
const EmptyID ID = ""
// Validate is used to check if the ID looks like a proper schema.
// This is done by parsing the ID as a URL and checking it has all the
// relevant parts.
func (id ID) Validate() error {
u, err := url.Parse(id.String())
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
if u.Hostname() == "" {
return errors.New("missing hostname")
}
if !strings.Contains(u.Hostname(), ".") {
return errors.New("hostname does not look valid")
}
if u.Path == "" {
return errors.New("path is expected")
}
if u.Scheme != "https" && u.Scheme != "http" {
return errors.New("unexpected schema")
}
return nil
}
// Anchor sets the anchor part of the schema URI.
func (id ID) Anchor(name string) ID {
b := id.Base()
return ID(b.String() + "#" + name)
}
// Def adds or replaces a definition identifier.
func (id ID) Def(name string) ID {
b := id.Base()
return ID(b.String() + "#/$defs/" + name)
}
// Add appends the provided path to the id, and removes any
// anchor data that might be there.
func (id ID) Add(path string) ID {
b := id.Base()
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return ID(b.String() + path)
}
// Base removes any anchor information from the schema
func (id ID) Base() ID {
s := id.String()
i := strings.LastIndex(s, "#")
if i != -1 {
s = s[0:i]
}
s = strings.TrimRight(s, "/")
return ID(s)
}
// String provides string version of ID
func (id ID) String() string {
return string(id)
}

1069
vendor/github.com/invopop/jsonschema/reflect.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

18
vendor/github.com/invopop/jsonschema/utils.go generated vendored Normal file
View File

@@ -0,0 +1,18 @@
package jsonschema
import (
"regexp"
"strings"
)
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
// ToSnakeCase converts the provided string into snake case using dashes.
// This is useful for Schema IDs and definitions to be coherent with
// common JSON Schema examples.
func ToSnakeCase(str string) string {
snake := matchFirstCap.ReplaceAllString(str, "${1}-${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}-${2}")
return strings.ToLower(snake)
}

View File

@@ -351,10 +351,9 @@ func PartitionsWithContext(ctx context.Context, all bool) ([]PartitionStat, erro
// so we get the real device name from its major/minor number
if d.Device == "/dev/root" {
devpath, err := os.Readlink(common.HostSys("/dev/block/" + blockDeviceID))
if err != nil {
return nil, err
if err == nil {
d.Device = strings.Replace(d.Device, "root", filepath.Base(devpath), 1)
}
d.Device = strings.Replace(d.Device, "root", filepath.Base(devpath), 1)
}
}
ret = append(ret, d)

View File

@@ -12,6 +12,7 @@ import (
"github.com/shirou/gopsutil/v3/internal/common"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
var (
@@ -44,6 +45,14 @@ type diskPerformance struct {
alignmentPadding uint32 // necessary for 32bit support, see https://github.com/elastic/beats/pull/16553
}
func init() {
// enable disk performance counters on Windows Server editions (needs to run as admin)
key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Services\PartMgr`, registry.SET_VALUE)
if err == nil {
key.SetDWordValue("EnableCounterForIoctl", 1)
}
}
func UsageWithContext(ctx context.Context, path string) (*UsageStat, error) {
lpFreeBytesAvailable := int64(0)
lpTotalNumberOfBytes := int64(0)

View File

@@ -285,8 +285,7 @@ func PidExistsWithContext(ctx context.Context, pid int32) (bool, error) {
}
return false, err
}
const STILL_ACTIVE = 259 // https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-getexitcodeprocess
h, err := windows.OpenProcess(processQueryInformation, false, uint32(pid))
h, err := windows.OpenProcess(windows.SYNCHRONIZE, false, uint32(pid))
if err == windows.ERROR_ACCESS_DENIED {
return true, nil
}
@@ -296,10 +295,9 @@ func PidExistsWithContext(ctx context.Context, pid int32) (bool, error) {
if err != nil {
return false, err
}
defer syscall.CloseHandle(syscall.Handle(h))
var exitCode uint32
err = windows.GetExitCodeProcess(h, &exitCode)
return exitCode == STILL_ACTIVE, err
defer windows.CloseHandle(h)
event, err := windows.WaitForSingleObject(h, 0)
return event == uint32(windows.WAIT_TIMEOUT), err
}
func (p *Process) PpidWithContext(ctx context.Context) (int32, error) {

View File

@@ -9,7 +9,7 @@ package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatability
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)

View File

@@ -6,7 +6,7 @@ require (
github.com/labstack/echo/v4 v4.7.2
github.com/mailru/easyjson v0.7.7 // indirect
github.com/stretchr/testify v1.7.0
github.com/swaggo/files v0.0.0-20210815190702-a29dd2bc99b2
github.com/swaggo/files v0.0.0-20220610200504-28940afbdbfe
github.com/swaggo/swag v1.8.1
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4

View File

@@ -69,6 +69,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/swaggo/files v0.0.0-20210815190702-a29dd2bc99b2 h1:+iNTcqQJy0OZ5jk6a5NLib47eqXK8uYcPX+O4+cBpEM=
github.com/swaggo/files v0.0.0-20210815190702-a29dd2bc99b2/go.mod h1:lKJPbtWzJ9JhsTN1k1gZgleJWY/cqq0psdoMmaThG3w=
github.com/swaggo/files v0.0.0-20220610200504-28940afbdbfe h1:K8pHPVoTgxFJt1lXuIzzOX7zZhZFldJQK/CgKx9BFIc=
github.com/swaggo/files v0.0.0-20220610200504-28940afbdbfe/go.mod h1:lKJPbtWzJ9JhsTN1k1gZgleJWY/cqq0psdoMmaThG3w=
github.com/swaggo/swag v1.8.1 h1:JuARzFX1Z1njbCGz+ZytBR15TFJwF2Q7fu8puJHhQYI=
github.com/swaggo/swag v1.8.1/go.mod h1:ugemnJsPZm/kRwFUnzBlbHRd0JY9zE1M4F+uy2pAaPQ=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=

View File

@@ -22,6 +22,22 @@ type Config struct {
InstanceName string
DeepLinking bool
PersistAuthorization bool
// The information for OAuth2 integration, if any.
OAuth *OAuthConfig
}
// OAuthConfig stores configuration for Swagger UI OAuth2 integration. See
// https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/ for further details.
type OAuthConfig struct {
// The ID of the client sent to the OAuth2 IAM provider.
ClientId string
// The OAuth2 realm that the client should operate in. If not applicable, use empty string.
Realm string
// The name to display for the application in the authentication popup.
AppName string
}
// URL presents the url pointing to API definition (normally swagger.json or swagger.yaml).
@@ -67,6 +83,12 @@ func PersistAuthorization(persistAuthorization bool) func(*Config) {
}
}
func OAuth(config *OAuthConfig) func(*Config) {
return func(c *Config) {
c.OAuth = config
}
}
func newConfig(configFns ...func(*Config)) *Config {
config := Config{
URL: "doc.json",
@@ -250,6 +272,15 @@ window.onload = function() {
],
layout: "StandaloneLayout"
})
{{if .OAuth}}
ui.initOAuth({
clientId: "{{.OAuth.ClientId}}",
realm: "{{.OAuth.Realm}}",
appName: "{{.OAuth.AppName}}"
})
{{end}}
window.ui = ui
}
</script>

View File

@@ -1,5 +1,13 @@
all: build
.PHONY: init
init:
git submodule update --init --recursive
.PHONY: deps
deps:
go install github.com/UnnoTed/fileb0x@v1.1.4
.PHONY: build
build:
fileb0x fileb0x/b0x.yaml

View File

@@ -1,5 +1,5 @@
// Code generated by fileb0x at "2021-08-11 21:14:46.428511689 +0300 EEST m=+0.096329763" from config file "b0x.yaml" DO NOT EDIT.
// modification hash(37610a5b0ca328f5072d5ee653766db2.84893f7d7f6af7d7916db9fe20160151)
// Code generated by fileb0x at "2022-06-10 22:55:21.84892167 +0300 EEST m=+0.038018428" from config file "b0x.yaml" DO NOT EDIT.
// modification hash(c06e09df5514d5bdf3c145144658221e.84893f7d7f6af7d7916db9fe20160151)
package swaggerFiles

View File

@@ -1,5 +1,5 @@
// Code generaTed by fileb0x at "2021-08-11 21:11:48.973787292 +0300 EEST m=+0.273940433" from config file "b0x.yaml" DO NOT EDIT.
// modified(2021-08-11 21:10:36.055919109 +0300 EEST)
// Code generaTed by fileb0x at "2022-06-10 22:55:22.230336439 +0300 EEST m=+0.419433197" from config file "b0x.yaml" DO NOT EDIT.
// modified(2022-06-10 22:48:15.340539512 +0300 EEST)
// original path: swagger-ui/dist/favicon-16x16.png
package swaggerFiles

View File

@@ -1,5 +1,5 @@
// Code generaTed by fileb0x at "2021-08-11 21:11:49.069266556 +0300 EEST m=+0.369419692" from config file "b0x.yaml" DO NOT EDIT.
// modified(2021-08-11 21:10:36.055919109 +0300 EEST)
// Code generaTed by fileb0x at "2022-06-10 22:55:22.211623651 +0300 EEST m=+0.400720419" from config file "b0x.yaml" DO NOT EDIT.
// modified(2022-06-10 22:48:15.340539512 +0300 EEST)
// original path: swagger-ui/dist/favicon-32x32.png
package swaggerFiles

30
vendor/github.com/swaggo/files/b0xfile__index.css.go generated vendored Normal file
View File

@@ -0,0 +1,30 @@
// Code generaTed by fileb0x at "2022-06-10 22:55:21.999275787 +0300 EEST m=+0.188372544" from config file "b0x.yaml" DO NOT EDIT.
// modified(2022-06-10 22:50:50.595877768 +0300 EEST)
// original path: swagger-ui/dist/index.css
package swaggerFiles
import (
"os"
)
// FileIndexCSS is "/index.css"
var FileIndexCSS = []byte("\x68\x74\x6d\x6c\x20\x7b\x0a\x20\x20\x20\x20\x62\x6f\x78\x2d\x73\x69\x7a\x69\x6e\x67\x3a\x20\x62\x6f\x72\x64\x65\x72\x2d\x62\x6f\x78\x3b\x0a\x20\x20\x20\x20\x6f\x76\x65\x72\x66\x6c\x6f\x77\x3a\x20\x2d\x6d\x6f\x7a\x2d\x73\x63\x72\x6f\x6c\x6c\x62\x61\x72\x73\x2d\x76\x65\x72\x74\x69\x63\x61\x6c\x3b\x0a\x20\x20\x20\x20\x6f\x76\x65\x72\x66\x6c\x6f\x77\x2d\x79\x3a\x20\x73\x63\x72\x6f\x6c\x6c\x3b\x0a\x7d\x0a\x0a\x2a\x2c\x0a\x2a\x3a\x62\x65\x66\x6f\x72\x65\x2c\x0a\x2a\x3a\x61\x66\x74\x65\x72\x20\x7b\x0a\x20\x20\x20\x20\x62\x6f\x78\x2d\x73\x69\x7a\x69\x6e\x67\x3a\x20\x69\x6e\x68\x65\x72\x69\x74\x3b\x0a\x7d\x0a\x0a\x62\x6f\x64\x79\x20\x7b\x0a\x20\x20\x20\x20\x6d\x61\x72\x67\x69\x6e\x3a\x20\x30\x3b\x0a\x20\x20\x20\x20\x62\x61\x63\x6b\x67\x72\x6f\x75\x6e\x64\x3a\x20\x23\x66\x61\x66\x61\x66\x61\x3b\x0a\x7d\x0a")
func init() {
f, err := FS.OpenFile(CTX, "/index.css", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
if err != nil {
panic(err)
}
_, err = f.Write(FileIndexCSS)
if err != nil {
panic(err)
}
err = f.Close()
if err != nil {
panic(err)
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,30 @@
// Code generaTed by fileb0x at "2022-06-10 22:55:22.243360665 +0300 EEST m=+0.432457425" from config file "b0x.yaml" DO NOT EDIT.
// modified(2022-06-10 22:50:50.595877768 +0300 EEST)
// original path: swagger-ui/dist/swagger-initializer.js
package swaggerFiles
import (
"os"
)
// FileSwaggerInitializerJs is "/swagger-initializer.js"
var FileSwaggerInitializerJs = []byte("\x77\x69\x6e\x64\x6f\x77\x2e\x6f\x6e\x6c\x6f\x61\x64\x20\x3d\x20\x66\x75\x6e\x63\x74\x69\x6f\x6e\x28\x29\x20\x7b\x0a\x20\x20\x2f\x2f\x3c\x65\x64\x69\x74\x6f\x72\x2d\x66\x6f\x6c\x64\x20\x64\x65\x73\x63\x3d\x22\x43\x68\x61\x6e\x67\x65\x61\x62\x6c\x65\x20\x43\x6f\x6e\x66\x69\x67\x75\x72\x61\x74\x69\x6f\x6e\x20\x42\x6c\x6f\x63\x6b\x22\x3e\x0a\x0a\x20\x20\x2f\x2f\x20\x74\x68\x65\x20\x66\x6f\x6c\x6c\x6f\x77\x69\x6e\x67\x20\x6c\x69\x6e\x65\x73\x20\x77\x69\x6c\x6c\x20\x62\x65\x20\x72\x65\x70\x6c\x61\x63\x65\x64\x20\x62\x79\x20\x64\x6f\x63\x6b\x65\x72\x2f\x63\x6f\x6e\x66\x69\x67\x75\x72\x61\x74\x6f\x72\x2c\x20\x77\x68\x65\x6e\x20\x69\x74\x20\x72\x75\x6e\x73\x20\x69\x6e\x20\x61\x20\x64\x6f\x63\x6b\x65\x72\x2d\x63\x6f\x6e\x74\x61\x69\x6e\x65\x72\x0a\x20\x20\x77\x69\x6e\x64\x6f\x77\x2e\x75\x69\x20\x3d\x20\x53\x77\x61\x67\x67\x65\x72\x55\x49\x42\x75\x6e\x64\x6c\x65\x28\x7b\x0a\x20\x20\x20\x20\x75\x72\x6c\x3a\x20\x22\x68\x74\x74\x70\x73\x3a\x2f\x2f\x70\x65\x74\x73\x74\x6f\x72\x65\x2e\x73\x77\x61\x67\x67\x65\x72\x2e\x69\x6f\x2f\x76\x32\x2f\x73\x77\x61\x67\x67\x65\x72\x2e\x6a\x73\x6f\x6e\x22\x2c\x0a\x20\x20\x20\x20\x64\x6f\x6d\x5f\x69\x64\x3a\x20\x27\x23\x73\x77\x61\x67\x67\x65\x72\x2d\x75\x69\x27\x2c\x0a\x20\x20\x20\x20\x64\x65\x65\x70\x4c\x69\x6e\x6b\x69\x6e\x67\x3a\x20\x74\x72\x75\x65\x2c\x0a\x20\x20\x20\x20\x70\x72\x65\x73\x65\x74\x73\x3a\x20\x5b\x0a\x20\x20\x20\x20\x20\x20\x53\x77\x61\x67\x67\x65\x72\x55\x49\x42\x75\x6e\x64\x6c\x65\x2e\x70\x72\x65\x73\x65\x74\x73\x2e\x61\x70\x69\x73\x2c\x0a\x20\x20\x20\x20\x20\x20\x53\x77\x61\x67\x67\x65\x72\x55\x49\x53\x74\x61\x6e\x64\x61\x6c\x6f\x6e\x65\x50\x72\x65\x73\x65\x74\x0a\x20\x20\x20\x20\x5d\x2c\x0a\x20\x20\x20\x20\x70\x6c\x75\x67\x69\x6e\x73\x3a\x20\x5b\x0a\x20\x20\x20\x20\x20\x20\x53\x77\x61\x67\x67\x65\x72\x55\x49\x42\x75\x6e\x64\x6c\x65\x2e\x70\x6c\x75\x67\x69\x6e\x73\x2e\x44\x6f\x77\x6e\x6c\x6f\x61\x64\x55\x72\x6c\x0a\x20\x20\x20\x20\x5d\x2c\x0a\x20\x20\x20\x20\x6c\x61\x79\x6f\x75\x74\x3a\x20\x22\x53\x74\x61\x6e\x64\x61\x6c\x6f\x6e\x65\x4c\x61\x79\x6f\x75\x74\x22\x0a\x20\x20\x7d\x29\x3b\x0a\x0a\x20\x20\x2f\x2f\x3c\x2f\x65\x64\x69\x74\x6f\x72\x2d\x66\x6f\x6c\x64\x3e\x0a\x7d\x3b\x0a")
func init() {
f, err := FS.OpenFile(CTX, "/swagger-initializer.js", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
if err != nil {
panic(err)
}
_, err = f.Write(FileSwaggerInitializerJs)
if err != nil {
panic(err)
}
err = f.Close()
if err != nil {
panic(err)
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,7 @@
# Dockerfile References: https://docs.docker.com/engine/reference/builder/
# Start from the latest golang base image
FROM golang:1.17-alpine as builder
FROM golang:1.18.3-alpine as builder
# Set the Current Working Directory inside the container
WORKDIR /app

View File

@@ -63,7 +63,7 @@ $ go install github.com/swaggo/swag/cmd/swag@latest
swag init
```
确保导入了生成的`docs/docs.go`文件这样特定的配置文件才会被初始化。如果通用API指数没有写在`main.go`中,可以使用`-g`标识符来告知swag。
确保导入了生成的`docs/docs.go`文件这样特定的配置文件才会被初始化。如果通用API注释没有写在`main.go`中,可以使用`-g`标识符来告知swag。
```bash
swag init -g http/api.go

View File

@@ -388,10 +388,6 @@ func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error {
varNamesTag := ps.tag.Get("x-enum-varnames")
if varNamesTag != "" {
if schema.Extensions == nil {
schema.Extensions = map[string]interface{}{}
}
varNames := strings.Split(varNamesTag, ",")
if len(varNames) != len(field.enums) {
return fmt.Errorf("invalid count of x-enum-varnames. expected %d, got %d", len(field.enums), len(varNames))
@@ -403,7 +399,19 @@ func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error {
field.enumVarNames = append(field.enumVarNames, v)
}
schema.Extensions["x-enum-varnames"] = field.enumVarNames
if field.schemaType == ARRAY {
// Add the var names in the items schema
if schema.Items.Schema.Extensions == nil {
schema.Items.Schema.Extensions = map[string]interface{}{}
}
schema.Items.Schema.Extensions["x-enum-varnames"] = field.enumVarNames
} else {
// Add to top level schema
if schema.Extensions == nil {
schema.Extensions = map[string]interface{}{}
}
schema.Extensions["x-enum-varnames"] = field.enumVarNames
}
}
eleSchema := schema

109
vendor/github.com/swaggo/swag/generics.go generated vendored Normal file
View File

@@ -0,0 +1,109 @@
//go:build go1.18
// +build go1.18
package swag
import (
"go/ast"
"strings"
)
func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
fullName := typeSpecDef.FullName()
if typeSpecDef.TypeSpec.TypeParams != nil {
fullName = fullName + "["
for i, typeParam := range typeSpecDef.TypeSpec.TypeParams.List {
if i > 0 {
fullName = fullName + "-"
}
fullName = fullName + typeParam.Names[0].Name
}
fullName = fullName + "]"
}
return fullName
}
func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
genericParams := strings.Split(strings.TrimRight(fullGenericForm, "]"), "[")
if len(genericParams) == 1 {
return nil
}
genericParams = strings.Split(genericParams[1], ",")
for i, p := range genericParams {
genericParams[i] = strings.TrimSpace(p)
}
genericParamTypeDefs := map[string]*TypeSpecDef{}
if len(genericParams) != len(original.TypeSpec.TypeParams.List) {
return nil
}
for i, genericParam := range genericParams {
tdef, ok := pkgDefs.uniqueDefinitions[genericParam]
if !ok {
return nil
}
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = tdef
}
parametrizedTypeSpec := &TypeSpecDef{
File: original.File,
PkgPath: original.PkgPath,
TypeSpec: &ast.TypeSpec{
Doc: original.TypeSpec.Doc,
Comment: original.TypeSpec.Comment,
Assign: original.TypeSpec.Assign,
},
}
ident := &ast.Ident{
NamePos: original.TypeSpec.Name.NamePos,
Obj: original.TypeSpec.Name.Obj,
}
genNameParts := strings.Split(fullGenericForm, "[")
if strings.Contains(genNameParts[0], ".") {
genNameParts[0] = strings.Split(genNameParts[0], ".")[1]
}
ident.Name = genNameParts[0] + "-" + strings.Replace(strings.Join(genericParams, "-"), ".", "_", -1)
ident.Name = strings.Replace(strings.Replace(ident.Name, "\t", "", -1), " ", "", -1)
parametrizedTypeSpec.TypeSpec.Name = ident
origStructType := original.TypeSpec.Type.(*ast.StructType)
newStructTypeDef := &ast.StructType{
Struct: origStructType.Struct,
Incomplete: origStructType.Incomplete,
Fields: &ast.FieldList{
Opening: origStructType.Fields.Opening,
Closing: origStructType.Fields.Closing,
},
}
for _, field := range origStructType.Fields.List {
newField := &ast.Field{
Doc: field.Doc,
Names: field.Names,
Tag: field.Tag,
Comment: field.Comment,
}
if genTypeSpec, ok := genericParamTypeDefs[field.Type.(*ast.Ident).Name]; ok {
newField.Type = genTypeSpec.TypeSpec.Type
} else {
newField.Type = field.Type
}
newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField)
}
parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef
return parametrizedTypeSpec
}

12
vendor/github.com/swaggo/swag/generics_other.go generated vendored Normal file
View File

@@ -0,0 +1,12 @@
//go:build !go1.18
// +build !go1.18
package swag
func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}
func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
return original
}

74
vendor/github.com/swaggo/swag/golist.go generated vendored Normal file
View File

@@ -0,0 +1,74 @@
package swag
import (
"bytes"
"context"
"encoding/json"
"fmt"
"go/build"
"os/exec"
"path/filepath"
)
func listPackages(ctx context.Context, dir string, env []string, args ...string) (pkgs []*build.Package, finalErr error) {
cmd := exec.CommandContext(ctx, "go", append([]string{"list", "-json", "-e"}, args...)...)
cmd.Env = env
cmd.Dir = dir
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
var stderrBuf bytes.Buffer
cmd.Stderr = &stderrBuf
defer func() {
if stderrBuf.Len() > 0 {
finalErr = fmt.Errorf("%v\n%s", finalErr, stderrBuf.Bytes())
}
}()
err = cmd.Start()
if err != nil {
return nil, err
}
dec := json.NewDecoder(stdout)
for dec.More() {
var pkg build.Package
err = dec.Decode(&pkg)
if err != nil {
return nil, err
}
pkgs = append(pkgs, &pkg)
}
err = cmd.Wait()
if err != nil {
return nil, err
}
return pkgs, nil
}
func (parser *Parser) getAllGoFileInfoFromDepsByList(pkg *build.Package) error {
ignoreInternal := pkg.Goroot && !parser.ParseInternal
if ignoreInternal { // ignored internal
return nil
}
srcDir := pkg.Dir
var err error
for i := range pkg.GoFiles {
err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.GoFiles[i]), nil)
if err != nil {
return err
}
}
// parse .go source files that import "C"
for i := range pkg.CgoFiles {
err = parser.parseFile(pkg.ImportPath, filepath.Join(srcDir, pkg.CgoFiles[i]), nil)
if err != nil {
return err
}
}
return nil
}

View File

@@ -385,7 +385,7 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
if objectType == PRIMITIVE {
param.Schema = PrimitiveSchema(refType)
} else {
schema, err := operation.parseAPIObjectSchema(objectType, refType, astFile)
schema, err := operation.parseAPIObjectSchema(commentLine, objectType, refType, astFile)
if err != nil {
return err
}
@@ -933,7 +933,16 @@ func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *a
}), nil
}
func (operation *Operation) parseAPIObjectSchema(schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
func (operation *Operation) parseAPIObjectSchema(commentLine, schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
if strings.HasSuffix(refType, ",") && strings.Contains(refType, "[") {
// regexp may have broken generics syntax. find closing bracket and add it back
allMatchesLenOffset := strings.Index(commentLine, refType) + len(refType)
lostPartEndIdx := strings.Index(commentLine[allMatchesLenOffset:], "]")
if lostPartEndIdx >= 0 {
refType += commentLine[allMatchesLenOffset : allMatchesLenOffset+lostPartEndIdx+1]
}
}
switch schemaType {
case OBJECT:
if !strings.HasPrefix(refType, "[]") {
@@ -969,7 +978,7 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as
description := strings.Trim(matches[4], "\"")
schema, err := operation.parseAPIObjectSchema(strings.Trim(matches[2], "{}"), matches[3], astFile)
schema, err := operation.parseAPIObjectSchema(commentLine, strings.Trim(matches[2], "{}"), matches[3], astFile)
if err != nil {
return err
}

View File

@@ -6,6 +6,7 @@ import (
"go/token"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
@@ -78,6 +79,11 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF
func rangeFiles(files map[*ast.File]*AstFileInfo, handle func(filename string, file *ast.File) error) error {
sortedFiles := make([]*AstFileInfo, 0, len(files))
for _, info := range files {
// ignore package path prefix with 'vendor' or $GOROOT,
// because the router info of api will not be included these files.
if strings.HasPrefix(info.PackagePath, "vendor") || strings.HasPrefix(info.Path, runtime.GOROOT()) {
continue
}
sortedFiles = append(sortedFiles, info)
}
@@ -128,7 +134,8 @@ func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packag
pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef)
}
fullName := typeSpecDef.FullName()
fullName := typeSpecFullName(typeSpecDef)
anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName]
if ok {
if typeSpecDef.PkgPath == anotherTypeDef.PkgPath {
@@ -286,7 +293,7 @@ func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File
return pkgDefs.uniqueDefinitions[typeName]
}
parts := strings.Split(typeName, ".")
parts := strings.Split(strings.Split(typeName, "[")[0], ".")
if len(parts) > 1 {
isAliasPkgName := func(file *ast.File, pkgName string) bool {
if file != nil && file.Imports != nil {
@@ -322,6 +329,22 @@ func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File
}
}
if strings.Contains(typeName, "[") {
// joinedParts differs from typeName in that it does not contain any type parameters
joinedParts := strings.Join(parts, ".")
for tName, tSpec := range pkgDefs.uniqueDefinitions {
if !strings.Contains(tName, "[") {
continue
}
if strings.Contains(tName, joinedParts) {
if parametrized := pkgDefs.parametrizeStruct(tSpec, typeName); parametrized != nil {
return parametrized
}
}
}
}
return pkgDefs.findTypeSpec(pkgPath, parts[1])
}

View File

@@ -1,6 +1,7 @@
package swag
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -48,9 +49,21 @@ const (
deprecatedAttr = "@deprecated"
securityAttr = "@security"
titleAttr = "@title"
conNameAttr = "@contact.name"
conURLAttr = "@contact.url"
conEmailAttr = "@contact.email"
licNameAttr = "@license.name"
licURLAttr = "@license.url"
versionAttr = "@version"
descriptionAttr = "@description"
descriptionMarkdownAttr = "@description.markdown"
secBasicAttr = "@securitydefinitions.basic"
secAPIKeyAttr = "@securitydefinitions.apikey"
secApplicationAttr = "@securitydefinitions.oauth2.application"
secImplicitAttr = "@securitydefinitions.oauth2.implicit"
secPasswordAttr = "@securitydefinitions.oauth2.password"
secAccessCodeAttr = "@securitydefinitions.oauth2.accesscode"
tosAttr = "@termsofservice"
xCodeSamplesAttr = "@x-codesamples"
scopeAttrPrefix = "@scope."
)
@@ -140,6 +153,9 @@ type Parser struct {
// Overrides allows global replacements of types. A blank replacement will be skipped.
Overrides map[string]string
// parseGoList whether swag use go list to parse dependency
parseGoList bool
}
// FieldParserFactory create FieldParser.
@@ -241,7 +257,10 @@ func SetStrict(strict bool) func(*Parser) {
// SetDebugger allows the use of user-defined implementations.
func SetDebugger(logger Debugger) func(parser *Parser) {
return func(p *Parser) {
p.debug = logger
if logger != nil {
p.debug = logger
}
}
}
@@ -261,6 +280,13 @@ func SetOverrides(overrides map[string]string) func(parser *Parser) {
}
}
// ParseUsingGoList sets whether swag use go list to parse dependency
func ParseUsingGoList(enabled bool) func(parser *Parser) {
return func(p *Parser) {
p.parseGoList = enabled
}
}
// ParseAPI parses general api info for given searchDir and mainAPIFile.
func (parser *Parser) ParseAPI(searchDir string, mainAPIFile string, parseDepth int) error {
return parser.ParseAPIMultiSearchDir([]string{searchDir}, mainAPIFile, parseDepth)
@@ -287,26 +313,41 @@ func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile st
return err
}
// Use 'go list' command instead of depth.Resolve()
if parser.ParseDependency {
var tree depth.Tree
tree.ResolveInternal = true
tree.MaxDepth = parseDepth
if parser.parseGoList {
pkgs, err := listPackages(context.Background(), filepath.Dir(absMainAPIFilePath), nil, "-deps")
if err != nil {
return fmt.Errorf("pkg %s cannot find all dependencies, %s", filepath.Dir(absMainAPIFilePath), err)
}
pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath))
if err != nil {
return err
}
length := len(pkgs)
for i := 0; i < length; i++ {
err := parser.getAllGoFileInfoFromDepsByList(pkgs[i])
if err != nil {
return err
}
}
} else {
var t depth.Tree
t.ResolveInternal = true
t.MaxDepth = parseDepth
err = tree.Resolve(pkgName)
if err != nil {
return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err)
}
for i := 0; i < len(tree.Root.Deps); i++ {
err := parser.getAllGoFileInfoFromDeps(&tree.Root.Deps[i])
pkgName, err := getPkgName(filepath.Dir(absMainAPIFilePath))
if err != nil {
return err
}
err = t.Resolve(pkgName)
if err != nil {
return fmt.Errorf("pkg %s cannot find all dependencies, %s", pkgName, err)
}
for i := 0; i < len(t.Root.Deps); i++ {
err := parser.getAllGoFileInfoFromDeps(&t.Root.Deps[i])
if err != nil {
return err
}
}
}
}
@@ -356,14 +397,6 @@ func getPkgName(searchDir string) (string, error) {
return outStr, nil
}
func initIfEmpty(license *spec.License) *spec.License {
if license == nil {
return new(spec.License)
}
return license
}
// ParseGeneralAPIInfo parses general api info for given mainAPIFile path.
func (parser *Parser) ParseGeneralAPIInfo(mainAPIFile string) error {
fileTree, err := goparser.ParseFile(token.NewFileSet(), mainAPIFile, nil, goparser.ParseComments)
@@ -396,16 +429,15 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error {
commentLine := comments[line]
attribute := strings.Split(commentLine, " ")[0]
value := strings.TrimSpace(commentLine[len(attribute):])
multilineBlock := false
if previousAttribute == attribute {
multilineBlock = true
}
switch strings.ToLower(attribute) {
case versionAttr:
parser.swagger.Info.Version = value
case titleAttr:
parser.swagger.Info.Title = value
switch attr := strings.ToLower(attribute); attr {
case versionAttr, titleAttr, tosAttr, licNameAttr, licURLAttr, conNameAttr, conURLAttr, conEmailAttr:
setSwaggerInfo(parser.swagger, attr, value)
case descriptionAttr:
if multilineBlock {
parser.swagger.Info.Description += "\n" + value
@@ -413,32 +445,20 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error {
continue
}
parser.swagger.Info.Description = value
case "@description.markdown":
setSwaggerInfo(parser.swagger, attr, value)
case descriptionMarkdownAttr:
commentInfo, err := getMarkdownForTag("api", parser.markdownFileDir)
if err != nil {
return err
}
parser.swagger.Info.Description = string(commentInfo)
case "@termsofservice":
parser.swagger.Info.TermsOfService = value
case "@contact.name":
parser.swagger.Info.Contact.Name = value
case "@contact.email":
parser.swagger.Info.Contact.Email = value
case "@contact.url":
parser.swagger.Info.Contact.URL = value
case "@license.name":
parser.swagger.Info.License = initIfEmpty(parser.swagger.Info.License)
parser.swagger.Info.License.Name = value
case "@license.url":
parser.swagger.Info.License = initIfEmpty(parser.swagger.Info.License)
parser.swagger.Info.License.URL = value
setSwaggerInfo(parser.swagger, descriptionAttr, string(commentInfo))
case "@host":
parser.swagger.Host = value
case "@basepath":
parser.swagger.BasePath = value
case acceptAttr:
err := parser.ParseAcceptComment(value)
if err != nil {
@@ -450,7 +470,7 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error {
return err
}
case "@schemes":
parser.swagger.Schemes = getSchemes(commentLine)
parser.swagger.Schemes = strings.Split(value, " ")
case "@tag.name":
parser.swagger.Tags = append(parser.swagger.Tags, spec.Tag{
TagProps: spec.TagProps{
@@ -487,43 +507,15 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error {
tag.TagProps.ExternalDocs.Description = value
replaceLastTag(parser.swagger.Tags, tag)
case "@securitydefinitions.basic":
parser.swagger.SecurityDefinitions[value] = spec.BasicAuth()
case "@securitydefinitions.apikey":
attrMap, _, extensions, err := parseSecAttr(attribute, []string{"@in", "@name"}, comments, &line)
case secBasicAttr, secAPIKeyAttr, secApplicationAttr, secImplicitAttr, secPasswordAttr, secAccessCodeAttr:
scheme, err := parseSecAttributes(attribute, comments, &line)
if err != nil {
return err
}
parser.swagger.SecurityDefinitions[value] = tryAddDescription(spec.APIKeyAuth(attrMap["@name"], attrMap["@in"]), extensions)
case "@securitydefinitions.oauth2.application":
attrMap, scopes, extensions, err := parseSecAttr(attribute, []string{"@tokenurl"}, comments, &line)
if err != nil {
return err
}
parser.swagger.SecurityDefinitions[value] = scheme
parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2Application(attrMap["@tokenurl"], scopes, extensions), extensions)
case "@securitydefinitions.oauth2.implicit":
attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@authorizationurl"}, comments, &line)
if err != nil {
return err
}
parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2Implicit(attrs["@authorizationurl"], scopes, ext), ext)
case "@securitydefinitions.oauth2.password":
attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@tokenurl"}, comments, &line)
if err != nil {
return err
}
parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2Password(attrs["@tokenurl"], scopes, ext), ext)
case "@securitydefinitions.oauth2.accesscode":
attrs, scopes, ext, err := parseSecAttr(attribute, []string{"@tokenurl", "@authorizationurl"}, comments, &line)
if err != nil {
return err
}
parser.swagger.SecurityDefinitions[value] = tryAddDescription(secOAuth2AccessToken(attrs["@authorizationurl"], attrs["@tokenurl"], scopes, ext), ext)
case "@query.collection.format":
parser.collectionFormatInQuery = value
default:
@@ -578,14 +570,140 @@ func parseGeneralAPIInfo(parser *Parser, comments []string) error {
return nil
}
func tryAddDescription(spec *spec.SecurityScheme, extensions map[string]interface{}) *spec.SecurityScheme {
if val, ok := extensions["@description"]; ok {
if str, ok := val.(string); ok {
spec.Description = str
func setSwaggerInfo(swagger *spec.Swagger, attribute, value string) {
switch attribute {
case versionAttr:
swagger.Info.Version = value
case titleAttr:
swagger.Info.Title = value
case tosAttr:
swagger.Info.TermsOfService = value
case descriptionAttr:
swagger.Info.Description = value
case conNameAttr:
swagger.Info.Contact.Name = value
case conEmailAttr:
swagger.Info.Contact.Email = value
case conURLAttr:
swagger.Info.Contact.URL = value
case licNameAttr:
swagger.Info.License = initIfEmpty(swagger.Info.License)
swagger.Info.License.Name = value
case licURLAttr:
swagger.Info.License = initIfEmpty(swagger.Info.License)
swagger.Info.License.URL = value
}
}
func parseSecAttributes(context string, lines []string, index *int) (*spec.SecurityScheme, error) {
const (
in = "@in"
name = "@name"
descriptionAttr = "@description"
tokenURL = "@tokenurl"
authorizationURL = "@authorizationurl"
)
var search []string
attribute := strings.ToLower(strings.Split(lines[*index], " ")[0])
switch attribute {
case secBasicAttr:
return spec.BasicAuth(), nil
case secAPIKeyAttr:
search = []string{in, name}
case secApplicationAttr, secPasswordAttr:
search = []string{tokenURL}
case secImplicitAttr:
search = []string{authorizationURL}
case secAccessCodeAttr:
search = []string{tokenURL, authorizationURL}
}
// For the first line we get the attributes in the context parameter, so we skip to the next one
*index++
attrMap, scopes := make(map[string]string), make(map[string]string)
extensions, description := make(map[string]interface{}), ""
for ; *index < len(lines); *index++ {
v := lines[*index]
securityAttr := strings.ToLower(strings.Split(v, " ")[0])
for _, findterm := range search {
if securityAttr == findterm {
attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):])
continue
}
}
isExists, err := isExistsScope(securityAttr)
if err != nil {
return nil, err
}
if isExists {
scopes[securityAttr[len(scopeAttrPrefix):]] = v[len(securityAttr):]
}
if strings.HasPrefix(securityAttr, "@x-") {
// Add the custom attribute without the @
extensions[securityAttr[1:]] = strings.TrimSpace(v[len(securityAttr):])
}
// Not mandatory field
if securityAttr == descriptionAttr {
description = strings.TrimSpace(v[len(securityAttr):])
}
// next securityDefinitions
if strings.Index(securityAttr, "@securitydefinitions.") == 0 {
// Go back to the previous line and break
*index--
break
}
}
return spec
if len(attrMap) != len(search) {
return nil, fmt.Errorf("%s is %v required", context, search)
}
var scheme *spec.SecurityScheme
switch attribute {
case secAPIKeyAttr:
scheme = spec.APIKeyAuth(attrMap[name], attrMap[in])
case secApplicationAttr:
scheme = spec.OAuth2Application(attrMap[tokenURL])
case secImplicitAttr:
scheme = spec.OAuth2Implicit(attrMap[authorizationURL])
case secPasswordAttr:
scheme = spec.OAuth2Password(attrMap[tokenURL])
case secAccessCodeAttr:
scheme = spec.OAuth2AccessToken(attrMap[authorizationURL], attrMap[tokenURL])
}
scheme.Description = description
for extKey, extValue := range extensions {
scheme.AddExtension(extKey, extValue)
}
for scope, scopeDescription := range scopes {
scheme.AddScope(scope, scopeDescription)
}
return scheme, nil
}
func initIfEmpty(license *spec.License) *spec.License {
if license == nil {
return new(spec.License)
}
return license
}
// ParseAcceptComment parses comment for given `accept` comment string.
@@ -611,121 +729,6 @@ func isGeneralAPIComment(comments []string) bool {
return true
}
func parseSecAttr(context string, search []string, lines []string, index *int) (map[string]string, map[string]string, map[string]interface{}, error) {
attrMap := map[string]string{}
scopes := map[string]string{}
extensions := map[string]interface{}{}
// For the first line we get the attributes in the context parameter, so we skip to the next one
*index++
for ; *index < len(lines); *index++ {
v := lines[*index]
securityAttr := strings.ToLower(strings.Split(v, " ")[0])
for _, findterm := range search {
if securityAttr == findterm {
attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):])
continue
}
}
isExists, err := isExistsScope(securityAttr)
if err != nil {
return nil, nil, nil, err
}
if isExists {
scopes[securityAttr[len(scopeAttrPrefix):]] = v[len(securityAttr):]
}
if strings.HasPrefix(securityAttr, "@x-") {
// Add the custom attribute without the @
extensions[securityAttr[1:]] = strings.TrimSpace(v[len(securityAttr):])
}
// Not mandatory field
if securityAttr == "@description" {
extensions[securityAttr] = strings.TrimSpace(v[len(securityAttr):])
}
// next securityDefinitions
if strings.Index(securityAttr, "@securitydefinitions.") == 0 {
// Go back to the previous line and break
*index--
break
}
}
if len(attrMap) != len(search) {
return nil, nil, nil, fmt.Errorf("%s is %v required", context, search)
}
return attrMap, scopes, extensions, nil
}
type (
authExtensions map[string]interface{}
authScopes map[string]string
)
func secOAuth2Application(tokenURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme {
securityScheme := spec.OAuth2Application(tokenURL)
securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions)
for scope, description := range scopes {
securityScheme.AddScope(scope, description)
}
return securityScheme
}
func secOAuth2Implicit(authorizationURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme {
securityScheme := spec.OAuth2Implicit(authorizationURL)
securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions)
for scope, description := range scopes {
securityScheme.AddScope(scope, description)
}
return securityScheme
}
func secOAuth2Password(tokenURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme {
securityScheme := spec.OAuth2Password(tokenURL)
securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions)
for scope, description := range scopes {
securityScheme.AddScope(scope, description)
}
return securityScheme
}
func secOAuth2AccessToken(authorizationURL, tokenURL string, scopes authScopes, extensions authExtensions) *spec.SecurityScheme {
securityScheme := spec.OAuth2AccessToken(authorizationURL, tokenURL)
securityScheme.VendorExtensible.Extensions = handleSecuritySchemaExtensions(extensions)
for scope, description := range scopes {
securityScheme.AddScope(scope, description)
}
return securityScheme
}
func handleSecuritySchemaExtensions(providedExtensions authExtensions) spec.Extensions {
var extensions spec.Extensions
if len(providedExtensions) > 0 {
extensions = make(map[string]interface{}, len(providedExtensions))
for key, value := range providedExtensions {
extensions[key] = value
}
}
return extensions
}
func getMarkdownForTag(tagName string, dirPath string) ([]byte, error) {
filesInfos, err := ioutil.ReadDir(dirPath)
if err != nil {
@@ -771,13 +774,6 @@ func isExistsScope(scope string) (bool, error) {
return strings.Contains(scope, scopeAttrPrefix), nil
}
// getSchemes parses swagger schemes for given commentLine.
func getSchemes(commentLine string) []string {
attribute := strings.ToLower(strings.Split(commentLine, " ")[0])
return strings.Split(strings.TrimSpace(commentLine[len(attribute):]), " ")
}
// ParseRouterAPIInfo parses router api info for given astFile.
func (parser *Parser) ParseRouterAPIInfo(fileName string, astFile *ast.File) error {
for _, astDescription := range astFile.Decls {
@@ -1365,7 +1361,7 @@ func replaceLastTag(slice []spec.Tag, element spec.Tag) {
slice = append(slice[:len(slice)-1], element)
}
// defineTypeOfExample example value define the type (object and array unsupported)
// defineTypeOfExample example value define the type (object and array unsupported).
func defineTypeOfExample(schemaType, arrayType, exampleValue string) (interface{}, error) {
switch schemaType {
case STRING:

View File

@@ -1,4 +1,4 @@
package swag
// Version of swag.
const Version = "v1.8.1"
const Version = "v1.8.3"

View File

@@ -0,0 +1,35 @@
package validator
import (
"fmt"
"github.com/vektah/gqlparser/v2/ast"
. "github.com/vektah/gqlparser/v2/validator"
)
func init() {
AddRule("KnownRootType", func(observers *Events, addError AddErrFunc) {
// A query's root must be a valid type. Surprisingly, this isn't
// checked anywhere else!
observers.OnOperation(func(walker *Walker, operation *ast.OperationDefinition) {
var def *ast.Definition
switch operation.Operation {
case ast.Query, "":
def = walker.Schema.Query
case ast.Mutation:
def = walker.Schema.Mutation
case ast.Subscription:
def = walker.Schema.Subscription
default:
// This shouldn't even parse; if it did we probably need to
// update this switch block to add the new operation type.
panic(fmt.Sprintf(`got unknown operation type "%s"`, operation.Operation))
}
if def == nil {
addError(
Message(`Schema does not support operation type "%s"`, operation.Operation),
At(operation.Position))
}
})
})
}

View File

@@ -81,7 +81,22 @@ func ValidateSchemaDocument(ast *SchemaDocument) (*Schema, *gqlerror.Error) {
for i, dir := range ast.Directives {
if schema.Directives[dir.Name] != nil {
return nil, gqlerror.ErrorPosf(dir.Position, "Cannot redeclare directive %s.", dir.Name)
// While the spec says SDL must not (§3.5) explicitly define builtin
// scalars, it may (§3.13) define builtin directives. Here we check for
// that, and reject doubly-defined directives otherwise.
switch dir.Name {
case "include", "skip", "deprecated", "specifiedBy": // the builtins
// In principle here we might want to validate that the
// directives are the same. But they might not be, if the
// server has an older spec than we do. (Plus, validating this
// is a lot of work.) So we just keep the first one we saw.
// That's an arbitrary choice, but in theory the only way it
// fails is if the server is using features newer than this
// version of gqlparser, in which case they're in trouble
// anyway.
default:
return nil, gqlerror.ErrorPosf(dir.Position, "Cannot redeclare directive %s.", dir.Name)
}
}
schema.Directives[dir.Name] = ast.Directives[i]
}

View File

@@ -523,6 +523,11 @@ directives:
message: "Cannot redeclare directive A."
locations: [{line: 2, column: 12}]
- name: can redeclare builtin directives
input: |
directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT
directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT
- name: must be declared
input: |
type User {