mirror of
https://github.com/pion/stun.git
synced 2025-09-26 20:01:18 +08:00
Upgrade golangci-lint, more linters
Introduces new linters, upgrade golangci-lint to version (v1.63.4)
This commit is contained in:
@@ -25,17 +25,32 @@ linters-settings:
|
|||||||
- ^os.Exit$
|
- ^os.Exit$
|
||||||
- ^panic$
|
- ^panic$
|
||||||
- ^print(ln)?$
|
- ^print(ln)?$
|
||||||
|
varnamelen:
|
||||||
|
max-distance: 12
|
||||||
|
min-name-length: 2
|
||||||
|
ignore-type-assert-ok: true
|
||||||
|
ignore-map-index-ok: true
|
||||||
|
ignore-chan-recv-ok: true
|
||||||
|
ignore-decls:
|
||||||
|
- i int
|
||||||
|
- n int
|
||||||
|
- w io.Writer
|
||||||
|
- r io.Reader
|
||||||
|
- b []byte
|
||||||
|
|
||||||
linters:
|
linters:
|
||||||
enable:
|
enable:
|
||||||
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
||||||
- bidichk # Checks for dangerous unicode character sequences
|
- bidichk # Checks for dangerous unicode character sequences
|
||||||
- bodyclose # checks whether HTTP response body is closed successfully
|
- bodyclose # checks whether HTTP response body is closed successfully
|
||||||
|
- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||||
- contextcheck # check the function whether use a non-inherited context
|
- contextcheck # check the function whether use a non-inherited context
|
||||||
|
- cyclop # checks function and package cyclomatic complexity
|
||||||
- decorder # check declaration order and count of types, constants, variables and functions
|
- decorder # check declaration order and count of types, constants, variables and functions
|
||||||
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||||
- dupl # Tool for code clone detection
|
- dupl # Tool for code clone detection
|
||||||
- durationcheck # check for two durations multiplied together
|
- durationcheck # check for two durations multiplied together
|
||||||
|
- err113 # Golang linter to check the errors handling expressions
|
||||||
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
||||||
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
|
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
|
||||||
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
|
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
|
||||||
@@ -46,18 +61,17 @@ linters:
|
|||||||
- forcetypeassert # finds forced type assertions
|
- forcetypeassert # finds forced type assertions
|
||||||
- gci # Gci control golang package import order and make it always deterministic.
|
- gci # Gci control golang package import order and make it always deterministic.
|
||||||
- gochecknoglobals # Checks that no globals are present in Go code
|
- gochecknoglobals # Checks that no globals are present in Go code
|
||||||
- gochecknoinits # Checks that no init functions are present in Go code
|
|
||||||
- gocognit # Computes and checks the cognitive complexity of functions
|
- gocognit # Computes and checks the cognitive complexity of functions
|
||||||
- goconst # Finds repeated strings that could be replaced by a constant
|
- goconst # Finds repeated strings that could be replaced by a constant
|
||||||
- gocritic # The most opinionated Go source code linter
|
- gocritic # The most opinionated Go source code linter
|
||||||
|
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||||
|
- godot # Check if comments end in a period
|
||||||
- godox # Tool for detection of FIXME, TODO and other comment keywords
|
- godox # Tool for detection of FIXME, TODO and other comment keywords
|
||||||
- err113 # Golang linter to check the errors handling expressions
|
|
||||||
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
||||||
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
|
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
|
||||||
- goheader # Checks is file header matches to pattern
|
- goheader # Checks is file header matches to pattern
|
||||||
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
|
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
|
||||||
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
||||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
|
||||||
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
|
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
|
||||||
- gosec # Inspects source code for security problems
|
- gosec # Inspects source code for security problems
|
||||||
- gosimple # Linter for Go source code that specializes in simplifying a code
|
- gosimple # Linter for Go source code that specializes in simplifying a code
|
||||||
@@ -65,9 +79,15 @@ linters:
|
|||||||
- grouper # An analyzer to analyze expression groups.
|
- grouper # An analyzer to analyze expression groups.
|
||||||
- importas # Enforces consistent import aliases
|
- importas # Enforces consistent import aliases
|
||||||
- ineffassign # Detects when assignments to existing variables are not used
|
- ineffassign # Detects when assignments to existing variables are not used
|
||||||
|
- lll # Reports long lines
|
||||||
|
- maintidx # maintidx measures the maintainability index of each function.
|
||||||
|
- makezero # Finds slice declarations with non-zero initial length
|
||||||
- misspell # Finds commonly misspelled English words in comments
|
- misspell # Finds commonly misspelled English words in comments
|
||||||
|
- nakedret # Finds naked returns in functions greater than a specified function length
|
||||||
|
- nestif # Reports deeply nested if statements
|
||||||
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
||||||
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
|
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
|
||||||
|
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||||
- noctx # noctx finds sending http request without context.Context
|
- noctx # noctx finds sending http request without context.Context
|
||||||
- predeclared # find code that shadows one of Go's predeclared identifiers
|
- predeclared # find code that shadows one of Go's predeclared identifiers
|
||||||
- revive # golint replacement, finds style mistakes
|
- revive # golint replacement, finds style mistakes
|
||||||
@@ -75,28 +95,22 @@ linters:
|
|||||||
- stylecheck # Stylecheck is a replacement for golint
|
- stylecheck # Stylecheck is a replacement for golint
|
||||||
- tagliatelle # Checks the struct tags.
|
- tagliatelle # Checks the struct tags.
|
||||||
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
||||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||||
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
||||||
- unconvert # Remove unnecessary type conversions
|
- unconvert # Remove unnecessary type conversions
|
||||||
- unparam # Reports unused function parameters
|
- unparam # Reports unused function parameters
|
||||||
- unused # Checks Go code for unused constants, variables, functions and types
|
- unused # Checks Go code for unused constants, variables, functions and types
|
||||||
|
- varnamelen # checks that the length of a variable's name matches its scope
|
||||||
- wastedassign # wastedassign finds wasted assignment statements
|
- wastedassign # wastedassign finds wasted assignment statements
|
||||||
- whitespace # Tool for detection of leading and trailing whitespace
|
- whitespace # Tool for detection of leading and trailing whitespace
|
||||||
disable:
|
disable:
|
||||||
- depguard # Go linter that checks if package imports are in a list of acceptable packages
|
- depguard # Go linter that checks if package imports are in a list of acceptable packages
|
||||||
- containedctx # containedctx is a linter that detects struct contained context.Context field
|
|
||||||
- cyclop # checks function and package cyclomatic complexity
|
|
||||||
- funlen # Tool for detection of long functions
|
- funlen # Tool for detection of long functions
|
||||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
- gochecknoinits # Checks that no init functions are present in Go code
|
||||||
- godot # Check if comments end in a period
|
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||||
- gomnd # An analyzer to detect magic numbers.
|
- interfacebloat # A linter that checks length of interface.
|
||||||
- ireturn # Accept Interfaces, Return Concrete Types
|
- ireturn # Accept Interfaces, Return Concrete Types
|
||||||
- lll # Reports long lines
|
- mnd # An analyzer to detect magic numbers
|
||||||
- maintidx # maintidx measures the maintainability index of each function.
|
|
||||||
- makezero # Finds slice declarations with non-zero initial length
|
|
||||||
- nakedret # Finds naked returns in functions greater than a specified function length
|
|
||||||
- nestif # Reports deeply nested if statements
|
|
||||||
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
|
||||||
- nolintlint # Reports ill-formed or insufficient nolint directives
|
- nolintlint # Reports ill-formed or insufficient nolint directives
|
||||||
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
|
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
|
||||||
- prealloc # Finds slice declarations that could potentially be preallocated
|
- prealloc # Finds slice declarations that could potentially be preallocated
|
||||||
@@ -104,8 +118,7 @@ linters:
|
|||||||
- rowserrcheck # checks whether Err of rows is checked successfully
|
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||||
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
||||||
- testpackage # linter that makes you use a separate _test package
|
- testpackage # linter that makes you use a separate _test package
|
||||||
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||||
- varnamelen # checks that the length of a variable's name matches its scope
|
|
||||||
- wrapcheck # Checks that errors returned from external packages are wrapped
|
- wrapcheck # Checks that errors returned from external packages are wrapped
|
||||||
- wsl # Whitespace Linter - Forces you to use empty lines!
|
- wsl # Whitespace Linter - Forces you to use empty lines!
|
||||||
|
|
||||||
@@ -123,3 +136,4 @@ issues:
|
|||||||
- path: cmd
|
- path: cmd
|
||||||
linters:
|
linters:
|
||||||
- forbidigo
|
- forbidigo
|
||||||
|
|
||||||
|
32
addr.go
32
addr.go
@@ -15,7 +15,7 @@ import (
|
|||||||
// This attribute is used only by servers for achieving backwards
|
// This attribute is used only by servers for achieving backwards
|
||||||
// compatibility with RFC 3489 clients.
|
// compatibility with RFC 3489 clients.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.1
|
// RFC 5389 Section 15.1.
|
||||||
type MappedAddress struct {
|
type MappedAddress struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Port int
|
Port int
|
||||||
@@ -23,7 +23,7 @@ type MappedAddress struct {
|
|||||||
|
|
||||||
// AlternateServer represents ALTERNATE-SERVER attribute.
|
// AlternateServer represents ALTERNATE-SERVER attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.11
|
// RFC 5389 Section 15.11.
|
||||||
type AlternateServer struct {
|
type AlternateServer struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Port int
|
Port int
|
||||||
@@ -31,7 +31,7 @@ type AlternateServer struct {
|
|||||||
|
|
||||||
// ResponseOrigin represents RESPONSE-ORIGIN attribute.
|
// ResponseOrigin represents RESPONSE-ORIGIN attribute.
|
||||||
//
|
//
|
||||||
// RFC 5780 Section 7.3
|
// RFC 5780 Section 7.3.
|
||||||
type ResponseOrigin struct {
|
type ResponseOrigin struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Port int
|
Port int
|
||||||
@@ -39,7 +39,7 @@ type ResponseOrigin struct {
|
|||||||
|
|
||||||
// OtherAddress represents OTHER-ADDRESS attribute.
|
// OtherAddress represents OTHER-ADDRESS attribute.
|
||||||
//
|
//
|
||||||
// RFC 5780 Section 7.4
|
// RFC 5780 Section 7.4.
|
||||||
type OtherAddress struct {
|
type OtherAddress struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Port int
|
Port int
|
||||||
@@ -48,12 +48,14 @@ type OtherAddress struct {
|
|||||||
// AddTo adds ALTERNATE-SERVER attribute to message.
|
// AddTo adds ALTERNATE-SERVER attribute to message.
|
||||||
func (s *AlternateServer) AddTo(m *Message) error {
|
func (s *AlternateServer) AddTo(m *Message) error {
|
||||||
a := (*MappedAddress)(s)
|
a := (*MappedAddress)(s)
|
||||||
|
|
||||||
return a.AddToAs(m, AttrAlternateServer)
|
return a.AddToAs(m, AttrAlternateServer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFrom decodes ALTERNATE-SERVER from message.
|
// GetFrom decodes ALTERNATE-SERVER from message.
|
||||||
func (s *AlternateServer) GetFrom(m *Message) error {
|
func (s *AlternateServer) GetFrom(m *Message) error {
|
||||||
a := (*MappedAddress)(s)
|
a := (*MappedAddress)(s)
|
||||||
|
|
||||||
return a.GetFromAs(m, AttrAlternateServer)
|
return a.GetFromAs(m, AttrAlternateServer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,14 +65,14 @@ func (a MappedAddress) String() string {
|
|||||||
|
|
||||||
// GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t.
|
// GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t.
|
||||||
func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
|
func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
|
||||||
v, err := m.Get(t)
|
value, err := m.Get(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(v) <= 4 {
|
if len(value) <= 4 {
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
family := bin.Uint16(v[0:2])
|
family := bin.Uint16(value[0:2])
|
||||||
if family != familyIPv6 && family != familyIPv4 {
|
if family != familyIPv6 && family != familyIPv4 {
|
||||||
return newDecodeErr("xor-mapped address", "family",
|
return newDecodeErr("xor-mapped address", "family",
|
||||||
fmt.Sprintf("bad value %d", family),
|
fmt.Sprintf("bad value %d", family),
|
||||||
@@ -91,13 +93,14 @@ func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
|
|||||||
for i := range a.IP {
|
for i := range a.IP {
|
||||||
a.IP[i] = 0
|
a.IP[i] = 0
|
||||||
}
|
}
|
||||||
a.Port = int(bin.Uint16(v[2:4]))
|
a.Port = int(bin.Uint16(value[2:4]))
|
||||||
copy(a.IP, v[4:])
|
copy(a.IP, value[4:])
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddToAs adds MAPPED-ADDRESS value to m as t attribute.
|
// AddToAs adds MAPPED-ADDRESS value to m as t attribute.
|
||||||
func (a *MappedAddress) AddToAs(m *Message, t AttrType) error {
|
func (a *MappedAddress) AddToAs(msg *Message, attrType AttrType) error {
|
||||||
var (
|
var (
|
||||||
family = familyIPv4
|
family = familyIPv4
|
||||||
ip = a.IP
|
ip = a.IP
|
||||||
@@ -114,9 +117,10 @@ func (a *MappedAddress) AddToAs(m *Message, t AttrType) error {
|
|||||||
value := make([]byte, 128)
|
value := make([]byte, 128)
|
||||||
value[0] = 0 // first 8 bits are zeroes
|
value[0] = 0 // first 8 bits are zeroes
|
||||||
bin.PutUint16(value[0:2], family)
|
bin.PutUint16(value[0:2], family)
|
||||||
bin.PutUint16(value[2:4], uint16(a.Port))
|
bin.PutUint16(value[2:4], uint16(a.Port)) //nolint:gosec //G115
|
||||||
copy(value[4:], ip)
|
copy(value[4:], ip)
|
||||||
m.Add(t, value[:4+len(ip)])
|
msg.Add(attrType, value[:4+len(ip)])
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,12 +137,14 @@ func (a *MappedAddress) GetFrom(m *Message) error {
|
|||||||
// AddTo adds OTHER-ADDRESS attribute to message.
|
// AddTo adds OTHER-ADDRESS attribute to message.
|
||||||
func (o *OtherAddress) AddTo(m *Message) error {
|
func (o *OtherAddress) AddTo(m *Message) error {
|
||||||
a := (*MappedAddress)(o)
|
a := (*MappedAddress)(o)
|
||||||
|
|
||||||
return a.AddToAs(m, AttrOtherAddress)
|
return a.AddToAs(m, AttrOtherAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFrom decodes OTHER-ADDRESS from message.
|
// GetFrom decodes OTHER-ADDRESS from message.
|
||||||
func (o *OtherAddress) GetFrom(m *Message) error {
|
func (o *OtherAddress) GetFrom(m *Message) error {
|
||||||
a := (*MappedAddress)(o)
|
a := (*MappedAddress)(o)
|
||||||
|
|
||||||
return a.GetFromAs(m, AttrOtherAddress)
|
return a.GetFromAs(m, AttrOtherAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,12 +155,14 @@ func (o OtherAddress) String() string {
|
|||||||
// AddTo adds RESPONSE-ORIGIN attribute to message.
|
// AddTo adds RESPONSE-ORIGIN attribute to message.
|
||||||
func (o *ResponseOrigin) AddTo(m *Message) error {
|
func (o *ResponseOrigin) AddTo(m *Message) error {
|
||||||
a := (*MappedAddress)(o)
|
a := (*MappedAddress)(o)
|
||||||
|
|
||||||
return a.AddToAs(m, AttrResponseOrigin)
|
return a.AddToAs(m, AttrResponseOrigin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFrom decodes RESPONSE-ORIGIN from message.
|
// GetFrom decodes RESPONSE-ORIGIN from message.
|
||||||
func (o *ResponseOrigin) GetFrom(m *Message) error {
|
func (o *ResponseOrigin) GetFrom(m *Message) error {
|
||||||
a := (*MappedAddress)(o)
|
a := (*MappedAddress)(o)
|
||||||
|
|
||||||
return a.GetFromAs(m, AttrResponseOrigin)
|
return a.GetFromAs(m, AttrResponseOrigin)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
12
addr_test.go
12
addr_test.go
@@ -11,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMappedAddress(t *testing.T) {
|
func TestMappedAddress(t *testing.T) {
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
addr := &MappedAddress{
|
addr := &MappedAddress{
|
||||||
IP: net.ParseIP("122.12.34.5"),
|
IP: net.ParseIP("122.12.34.5"),
|
||||||
Port: 5412,
|
Port: 5412,
|
||||||
@@ -23,17 +23,17 @@ func TestMappedAddress(t *testing.T) {
|
|||||||
badAddr := &MappedAddress{
|
badAddr := &MappedAddress{
|
||||||
IP: net.IP{1, 2, 3},
|
IP: net.IP{1, 2, 3},
|
||||||
}
|
}
|
||||||
if err := badAddr.AddTo(m); err == nil {
|
if err := badAddr.AddTo(msg); err == nil {
|
||||||
t.Error("should error")
|
t.Error("should error")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("AddTo", func(t *testing.T) {
|
t.Run("AddTo", func(t *testing.T) {
|
||||||
if err := addr.AddTo(m); err != nil {
|
if err := addr.AddTo(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
t.Run("GetFrom", func(t *testing.T) {
|
t.Run("GetFrom", func(t *testing.T) {
|
||||||
got := new(MappedAddress)
|
got := new(MappedAddress)
|
||||||
if err := got.GetFrom(m); err != nil {
|
if err := got.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !got.IP.Equal(addr.IP) {
|
if !got.IP.Equal(addr.IP) {
|
||||||
@@ -46,9 +46,9 @@ func TestMappedAddress(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("Bad family", func(t *testing.T) {
|
t.Run("Bad family", func(t *testing.T) {
|
||||||
v, _ := m.Attributes.Get(AttrMappedAddress)
|
v, _ := msg.Attributes.Get(AttrMappedAddress)
|
||||||
v.Value[0] = 32
|
v.Value[0] = 32
|
||||||
if err := got.GetFrom(m); err == nil {
|
if err := got.GetFrom(msg); err == nil {
|
||||||
t.Error("should error")
|
t.Error("should error")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
16
agent.go
16
agent.go
@@ -24,6 +24,7 @@ func NewAgent(h Handler) *Agent {
|
|||||||
transactions: make(map[transactionID]agentTransaction),
|
transactions: make(map[transactionID]agentTransaction),
|
||||||
handler: h,
|
handler: h,
|
||||||
}
|
}
|
||||||
|
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +81,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
|
|||||||
a.mux.Lock()
|
a.mux.Lock()
|
||||||
if a.closed {
|
if a.closed {
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
|
|
||||||
return ErrAgentClosed
|
return ErrAgentClosed
|
||||||
}
|
}
|
||||||
t, exists := a.transactions[id]
|
t, exists := a.transactions[id]
|
||||||
@@ -93,6 +95,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
|
|||||||
TransactionID: t.id,
|
TransactionID: t.id,
|
||||||
Error: err,
|
Error: err,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,6 +127,7 @@ func (a *Agent) Start(id [TransactionIDSize]byte, deadline time.Time) error {
|
|||||||
id: id,
|
id: id,
|
||||||
deadline: deadline,
|
deadline: deadline,
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,6 +151,7 @@ func (a *Agent) Collect(gcTime time.Time) error {
|
|||||||
// All transactions should be already closed
|
// All transactions should be already closed
|
||||||
// during Close() call.
|
// during Close() call.
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
|
|
||||||
return ErrAgentClosed
|
return ErrAgentClosed
|
||||||
}
|
}
|
||||||
// Adding all transactions with deadline before gcTime
|
// Adding all transactions with deadline before gcTime
|
||||||
@@ -175,24 +180,27 @@ func (a *Agent) Collect(gcTime time.Time) error {
|
|||||||
event.TransactionID = id
|
event.TransactionID = id
|
||||||
h(event)
|
h(event)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process incoming message, synchronously passing it to handler.
|
// Process incoming message, synchronously passing it to handler.
|
||||||
func (a *Agent) Process(m *Message) error {
|
func (a *Agent) Process(m *Message) error {
|
||||||
e := Event{
|
event := Event{
|
||||||
TransactionID: m.TransactionID,
|
TransactionID: m.TransactionID,
|
||||||
Message: m,
|
Message: m,
|
||||||
}
|
}
|
||||||
a.mux.Lock()
|
a.mux.Lock()
|
||||||
if a.closed {
|
if a.closed {
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
|
|
||||||
return ErrAgentClosed
|
return ErrAgentClosed
|
||||||
}
|
}
|
||||||
h := a.handler
|
h := a.handler
|
||||||
delete(a.transactions, m.TransactionID)
|
delete(a.transactions, m.TransactionID)
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
h(e)
|
h(event)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,10 +209,12 @@ func (a *Agent) SetHandler(h Handler) error {
|
|||||||
a.mux.Lock()
|
a.mux.Lock()
|
||||||
if a.closed {
|
if a.closed {
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
|
|
||||||
return ErrAgentClosed
|
return ErrAgentClosed
|
||||||
}
|
}
|
||||||
a.handler = h
|
a.handler = h
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,6 +227,7 @@ func (a *Agent) Close() error {
|
|||||||
a.mux.Lock()
|
a.mux.Lock()
|
||||||
if a.closed {
|
if a.closed {
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
|
|
||||||
return ErrAgentClosed
|
return ErrAgentClosed
|
||||||
}
|
}
|
||||||
for _, t := range a.transactions {
|
for _, t := range a.transactions {
|
||||||
@@ -227,6 +238,7 @@ func (a *Agent) Close() error {
|
|||||||
a.closed = true
|
a.closed = true
|
||||||
a.handler = nil
|
a.handler = nil
|
||||||
a.mux.Unlock()
|
a.mux.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -10,49 +10,49 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAgent_ProcessInTransaction(t *testing.T) {
|
func TestAgent_ProcessInTransaction(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
a := NewAgent(func(e Event) {
|
agent := NewAgent(func(e Event) {
|
||||||
if e.Error != nil {
|
if e.Error != nil {
|
||||||
t.Errorf("got error: %s", e.Error)
|
t.Errorf("got error: %s", e.Error)
|
||||||
}
|
}
|
||||||
if !e.Message.Equal(m) {
|
if !e.Message.Equal(msg) {
|
||||||
t.Errorf("%s (got) != %s (expected)", e.Message, m)
|
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if err := m.NewTransactionID(); err != nil {
|
if err := msg.NewTransactionID(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := a.Start(m.TransactionID, time.Time{}); err != nil {
|
if err := agent.Start(msg.TransactionID, time.Time{}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := a.Process(m); err != nil {
|
if err := agent.Process(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := a.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgent_Process(t *testing.T) {
|
func TestAgent_Process(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
a := NewAgent(func(e Event) {
|
agent := NewAgent(func(e Event) {
|
||||||
if e.Error != nil {
|
if e.Error != nil {
|
||||||
t.Errorf("got error: %s", e.Error)
|
t.Errorf("got error: %s", e.Error)
|
||||||
}
|
}
|
||||||
if !e.Message.Equal(m) {
|
if !e.Message.Equal(msg) {
|
||||||
t.Errorf("%s (got) != %s (expected)", e.Message, m)
|
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if err := m.NewTransactionID(); err != nil {
|
if err := msg.NewTransactionID(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := a.Process(m); err != nil {
|
if err := agent.Process(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := a.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := a.Process(m); !errors.Is(err, ErrAgentClosed) {
|
if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) {
|
||||||
t.Errorf("closed agent should return <%s>, but got <%s>",
|
t.Errorf("closed agent should return <%s>, but got <%s>",
|
||||||
ErrAgentClosed, err,
|
ErrAgentClosed, err,
|
||||||
)
|
)
|
||||||
@@ -60,27 +60,27 @@ func TestAgent_Process(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAgent_Start(t *testing.T) {
|
func TestAgent_Start(t *testing.T) {
|
||||||
a := NewAgent(nil)
|
agent := NewAgent(nil)
|
||||||
id := NewTransactionID()
|
id := NewTransactionID()
|
||||||
deadline := time.Now().AddDate(0, 0, 1)
|
deadline := time.Now().AddDate(0, 0, 1)
|
||||||
if err := a.Start(id, deadline); err != nil {
|
if err := agent.Start(id, deadline); err != nil {
|
||||||
t.Errorf("failed to statt transaction: %s", err)
|
t.Errorf("failed to statt transaction: %s", err)
|
||||||
}
|
}
|
||||||
if err := a.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
|
if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
|
||||||
t.Errorf("duplicate start should return <%s>, got <%s>",
|
t.Errorf("duplicate start should return <%s>, got <%s>",
|
||||||
ErrTransactionExists, err,
|
ErrTransactionExists, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if err := a.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
id = NewTransactionID()
|
id = NewTransactionID()
|
||||||
if err := a.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
|
if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
|
||||||
t.Errorf("start on closed agent should return <%s>, got <%s>",
|
t.Errorf("start on closed agent should return <%s>, got <%s>",
|
||||||
ErrAgentClosed, err,
|
ErrAgentClosed, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if err := a.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
|
if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
|
||||||
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
|
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
|
||||||
ErrAgentClosed, err,
|
ErrAgentClosed, err,
|
||||||
)
|
)
|
||||||
@@ -89,18 +89,18 @@ func TestAgent_Start(t *testing.T) {
|
|||||||
|
|
||||||
func TestAgent_Stop(t *testing.T) {
|
func TestAgent_Stop(t *testing.T) {
|
||||||
called := make(chan Event, 1)
|
called := make(chan Event, 1)
|
||||||
a := NewAgent(func(e Event) {
|
agent := NewAgent(func(e Event) {
|
||||||
called <- e
|
called <- e
|
||||||
})
|
})
|
||||||
if err := a.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
|
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
|
||||||
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
|
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
|
||||||
}
|
}
|
||||||
id := NewTransactionID()
|
id := NewTransactionID()
|
||||||
timeout := time.Millisecond * 200
|
timeout := time.Millisecond * 200
|
||||||
if err := a.Start(id, time.Now().Add(timeout)); err != nil {
|
if err := agent.Start(id, time.Now().Add(timeout)); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := a.Stop(id); err != nil {
|
if err := agent.Stop(id); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
@@ -113,19 +113,19 @@ func TestAgent_Stop(t *testing.T) {
|
|||||||
case <-time.After(timeout * 2):
|
case <-time.After(timeout * 2):
|
||||||
t.Fatal("timed out")
|
t.Fatal("timed out")
|
||||||
}
|
}
|
||||||
if err := a.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := a.Close(); !errors.Is(err, ErrAgentClosed) {
|
if err := agent.Close(); !errors.Is(err, ErrAgentClosed) {
|
||||||
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
|
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
|
||||||
}
|
}
|
||||||
if err := a.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
|
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
|
||||||
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
|
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgent_GC(t *testing.T) {
|
func TestAgent_GC(t *testing.T) { //nolint:cyclop
|
||||||
a := NewAgent(nil)
|
agent := NewAgent(nil)
|
||||||
shouldTimeOutID := make(map[transactionID]bool)
|
shouldTimeOutID := make(map[transactionID]bool)
|
||||||
deadline := time.Date(2027, time.November, 21,
|
deadline := time.Date(2027, time.November, 21,
|
||||||
23, 0, 0, 0,
|
23, 0, 0, 0,
|
||||||
@@ -133,7 +133,7 @@ func TestAgent_GC(t *testing.T) {
|
|||||||
)
|
)
|
||||||
gcDeadline := deadline.Add(-time.Second)
|
gcDeadline := deadline.Add(-time.Second)
|
||||||
deadlineNotGC := gcDeadline.AddDate(0, 0, -1)
|
deadlineNotGC := gcDeadline.AddDate(0, 0, -1)
|
||||||
a.SetHandler(func(e Event) { //nolint:errcheck,gosec
|
agent.SetHandler(func(e Event) { //nolint:errcheck,gosec
|
||||||
id := e.TransactionID
|
id := e.TransactionID
|
||||||
shouldTimeOut, found := shouldTimeOutID[id]
|
shouldTimeOut, found := shouldTimeOutID[id]
|
||||||
if !found {
|
if !found {
|
||||||
@@ -149,67 +149,67 @@ func TestAgent_GC(t *testing.T) {
|
|||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
id := NewTransactionID()
|
id := NewTransactionID()
|
||||||
shouldTimeOutID[id] = false
|
shouldTimeOutID[id] = false
|
||||||
if err := a.Start(id, deadline); err != nil {
|
if err := agent.Start(id, deadline); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
id := NewTransactionID()
|
id := NewTransactionID()
|
||||||
shouldTimeOutID[id] = true
|
shouldTimeOutID[id] = true
|
||||||
if err := a.Start(id, deadlineNotGC); err != nil {
|
if err := agent.Start(id, deadlineNotGC); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := a.Collect(gcDeadline); err != nil {
|
if err := agent.Collect(gcDeadline); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := a.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := a.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
|
if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
|
||||||
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
|
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAgent_GC(b *testing.B) {
|
func BenchmarkAgent_GC(b *testing.B) {
|
||||||
a := NewAgent(nil)
|
agent := NewAgent(nil)
|
||||||
deadline := time.Now().AddDate(0, 0, 1)
|
deadline := time.Now().AddDate(0, 0, 1)
|
||||||
for i := 0; i < agentCollectCap; i++ {
|
for i := 0; i < agentCollectCap; i++ {
|
||||||
if err := a.Start(NewTransactionID(), deadline); err != nil {
|
if err := agent.Start(NewTransactionID(), deadline); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := a.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
b.Error(err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
gcDeadline := deadline.Add(-time.Second)
|
gcDeadline := deadline.Add(-time.Second)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := a.Collect(gcDeadline); err != nil {
|
if err := agent.Collect(gcDeadline); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAgent_Process(b *testing.B) {
|
func BenchmarkAgent_Process(b *testing.B) {
|
||||||
a := NewAgent(nil)
|
agent := NewAgent(nil)
|
||||||
deadline := time.Now().AddDate(0, 0, 1)
|
deadline := time.Now().AddDate(0, 0, 1)
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
if err := a.Start(NewTransactionID(), deadline); err != nil {
|
if err := agent.Start(NewTransactionID(), deadline); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := a.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
b.Error(err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
m := MustBuild(TransactionID)
|
m := MustBuild(TransactionID)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := a.Process(m); err != nil {
|
if err := agent.Process(m); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -21,6 +21,7 @@ func (a Attributes) Get(t AttrType) (RawAttribute, bool) {
|
|||||||
return candidate, true
|
return candidate, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return RawAttribute{}, false
|
return RawAttribute{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +78,7 @@ const (
|
|||||||
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
|
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
|
||||||
)
|
)
|
||||||
|
|
||||||
// Attributes from RFC 5780 NAT Behavior Discovery
|
// Attributes from RFC 5780 NAT Behavior Discovery.
|
||||||
const (
|
const (
|
||||||
AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST
|
AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST
|
||||||
AttrPadding AttrType = 0x0026 // PADDING
|
AttrPadding AttrType = 0x0026 // PADDING
|
||||||
@@ -166,6 +167,7 @@ func (t AttrType) String() string {
|
|||||||
// Just return hex representation of unknown attribute type.
|
// Just return hex representation of unknown attribute type.
|
||||||
return fmt.Sprintf("0x%x", uint16(t))
|
return fmt.Sprintf("0x%x", uint16(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,6 +188,7 @@ type RawAttribute struct {
|
|||||||
// the Length field.
|
// the Length field.
|
||||||
func (a RawAttribute) AddTo(m *Message) error {
|
func (a RawAttribute) AddTo(m *Message) error {
|
||||||
m.Add(a.Type, a.Value)
|
m.Add(a.Type, a.Value)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,6 +208,7 @@ func (a RawAttribute) Equal(b RawAttribute) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,6 +228,7 @@ func (m *Message) Get(t AttrType) ([]byte, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrAttributeNotFound
|
return nil, ErrAttributeNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return v.Value, nil
|
return v.Value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,6 +245,7 @@ func nearestPaddedValueLength(l int) int {
|
|||||||
if n < l {
|
if n < l {
|
||||||
n += padding
|
n += padding
|
||||||
}
|
}
|
||||||
|
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,5 +256,6 @@ func compatAttrType(val uint16) AttrType {
|
|||||||
if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN
|
if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN
|
||||||
return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on)
|
return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on)
|
||||||
}
|
}
|
||||||
|
|
||||||
return AttrType(val)
|
return AttrType(val)
|
||||||
}
|
}
|
||||||
|
@@ -44,13 +44,13 @@ func TestRawAttribute_AddTo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_GetNoAllocs(t *testing.T) {
|
func TestMessage_GetNoAllocs(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
NewSoftware("c").AddTo(m) //nolint:errcheck,gosec
|
NewSoftware("c").AddTo(msg) //nolint:errcheck,gosec
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
|
|
||||||
t.Run("Default", func(t *testing.T) {
|
t.Run("Default", func(t *testing.T) {
|
||||||
allocs := testing.AllocsPerRun(10, func() {
|
allocs := testing.AllocsPerRun(10, func() {
|
||||||
m.Get(AttrSoftware) //nolint:errcheck,gosec
|
msg.Get(AttrSoftware) //nolint:errcheck,gosec
|
||||||
})
|
})
|
||||||
if allocs > 0 {
|
if allocs > 0 {
|
||||||
t.Error("allocated memory, but should not")
|
t.Error("allocated memory, but should not")
|
||||||
@@ -58,7 +58,7 @@ func TestMessage_GetNoAllocs(t *testing.T) {
|
|||||||
})
|
})
|
||||||
t.Run("Not found", func(t *testing.T) {
|
t.Run("Not found", func(t *testing.T) {
|
||||||
allocs := testing.AllocsPerRun(10, func() {
|
allocs := testing.AllocsPerRun(10, func() {
|
||||||
m.Get(AttrOrigin) //nolint:errcheck,gosec
|
msg.Get(AttrOrigin) //nolint:errcheck,gosec
|
||||||
})
|
})
|
||||||
if allocs > 0 {
|
if allocs > 0 {
|
||||||
t.Error("allocated memory, but should not")
|
t.Error("allocated memory, but should not")
|
||||||
|
@@ -17,6 +17,7 @@ func CheckSize(_ AttrType, got, expected int) error {
|
|||||||
if got == expected {
|
if got == expected {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrAttributeSizeInvalid
|
return ErrAttributeSizeInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ func checkHMAC(got, expected []byte) error {
|
|||||||
if hmac.Equal(got, expected) {
|
if hmac.Equal(got, expected) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrIntegrityMismatch
|
return ErrIntegrityMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,6 +33,7 @@ func checkFingerprint(got, expected uint32) error {
|
|||||||
if got == expected {
|
if got == expected {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrFingerprintMismatch
|
return ErrFingerprintMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,6 +47,7 @@ func CheckOverflow(_ AttrType, got, maxVal int) error {
|
|||||||
if got <= maxVal {
|
if got <= maxVal {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrAttributeSizeOverflow
|
return ErrAttributeSizeOverflow
|
||||||
}
|
}
|
||||||
|
|
||||||
|
126
client.go
126
client.go
@@ -21,7 +21,7 @@ import (
|
|||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI
|
// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI.
|
||||||
var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")
|
var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")
|
||||||
|
|
||||||
// Dial connects to the address on the named network and then
|
// Dial connects to the address on the named network and then
|
||||||
@@ -31,10 +31,11 @@ func Dial(network, address string) (*Client, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewClient(conn)
|
return NewClient(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialConfig is used to pass configuration to DialURI()
|
// DialConfig is used to pass configuration to DialURI().
|
||||||
type DialConfig struct {
|
type DialConfig struct {
|
||||||
DTLSConfig dtls.Config
|
DTLSConfig dtls.Config
|
||||||
TLSConfig tls.Config
|
TLSConfig tls.Config
|
||||||
@@ -44,7 +45,7 @@ type DialConfig struct {
|
|||||||
|
|
||||||
// DialURI connect to the STUN/TURN URI and then
|
// DialURI connect to the STUN/TURN URI and then
|
||||||
// initializes Client on that connection, returning error if any.
|
// initializes Client on that connection, returning error if any.
|
||||||
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) {
|
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop
|
||||||
var conn Connection
|
var conn Connection
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -203,7 +204,7 @@ const (
|
|||||||
// provide any API for it, so if you need to read application data, wrap the
|
// provide any API for it, so if you need to read application data, wrap the
|
||||||
// connection with your (de-)multiplexer and pass the wrapper as conn.
|
// connection with your (de-)multiplexer and pass the wrapper as conn.
|
||||||
func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
|
func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
|
||||||
c := &Client{
|
client := &Client{
|
||||||
close: make(chan struct{}),
|
close: make(chan struct{}),
|
||||||
c: conn,
|
c: conn,
|
||||||
clock: systemClock(),
|
clock: systemClock(),
|
||||||
@@ -214,32 +215,33 @@ func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
|
|||||||
closeConn: true,
|
closeConn: true,
|
||||||
}
|
}
|
||||||
for _, o := range options {
|
for _, o := range options {
|
||||||
o(c)
|
o(client)
|
||||||
}
|
}
|
||||||
if c.c == nil {
|
if client.c == nil {
|
||||||
return nil, ErrNoConnection
|
return nil, ErrNoConnection
|
||||||
}
|
}
|
||||||
if c.a == nil {
|
if client.a == nil {
|
||||||
c.a = NewAgent(nil)
|
client.a = NewAgent(nil)
|
||||||
}
|
}
|
||||||
if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
|
if err := client.a.SetHandler(client.handleAgentCallback); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if c.collector == nil {
|
if client.collector == nil {
|
||||||
c.collector = &tickerCollector{
|
client.collector = &tickerCollector{
|
||||||
close: make(chan struct{}),
|
close: make(chan struct{}),
|
||||||
clock: c.clock,
|
clock: client.clock,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := c.collector.Start(c.rtoRate, func(t time.Time) {
|
if err := client.collector.Start(client.rtoRate, func(t time.Time) {
|
||||||
closedOrPanic(c.a.Collect(t))
|
closedOrPanic(client.a.Collect(t))
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.wg.Add(1)
|
client.wg.Add(1)
|
||||||
go c.readUntilClosed()
|
go client.readUntilClosed()
|
||||||
runtime.SetFinalizer(c, clientFinalizer)
|
runtime.SetFinalizer(client, clientFinalizer)
|
||||||
return c, nil
|
|
||||||
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func clientFinalizer(c *Client) {
|
func clientFinalizer(c *Client) {
|
||||||
@@ -252,6 +254,7 @@ func clientFinalizer(c *Client) {
|
|||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
log.Println("client: called finalizer on non-closed client") // nolint
|
log.Println("client: called finalizer on non-closed client") // nolint
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Println("client: called finalizer on non-closed client:", err) // nolint
|
log.Println("client: called finalizer on non-closed client:", err) // nolint
|
||||||
@@ -353,6 +356,7 @@ func (c *Client) start(t *clientTransaction) error {
|
|||||||
return ErrTransactionExists
|
return ErrTransactionExists
|
||||||
}
|
}
|
||||||
c.t[t.id] = t
|
c.t[t.id] = t
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,6 +403,7 @@ func sprintErr(err error) string {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return "<nil>" //nolint:goconst
|
return "<nil>" //nolint:goconst
|
||||||
}
|
}
|
||||||
|
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,18 +460,21 @@ func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error
|
|||||||
select {
|
select {
|
||||||
case <-a.close:
|
case <-a.close:
|
||||||
t.Stop()
|
t.Stop()
|
||||||
|
|
||||||
return
|
return
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
f(a.clock.Now())
|
f(a.clock.Now())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *tickerCollector) Close() error {
|
func (a *tickerCollector) Close() error {
|
||||||
close(a.close)
|
close(a.close)
|
||||||
a.wg.Wait()
|
a.wg.Wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -481,6 +489,7 @@ func (c *Client) Close() error {
|
|||||||
c.mux.Lock()
|
c.mux.Lock()
|
||||||
if c.closed {
|
if c.closed {
|
||||||
c.mux.Unlock()
|
c.mux.Unlock()
|
||||||
|
|
||||||
return ErrClientClosed
|
return ErrClientClosed
|
||||||
}
|
}
|
||||||
c.closed = true
|
c.closed = true
|
||||||
@@ -498,6 +507,7 @@ func (c *Client) Close() error {
|
|||||||
if agentErr == nil && connErr == nil {
|
if agentErr == nil && connErr == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return CloseErr{
|
return CloseErr{
|
||||||
AgentErr: agentErr,
|
AgentErr: agentErr,
|
||||||
ConnectionErr: connErr,
|
ConnectionErr: connErr,
|
||||||
@@ -566,6 +576,7 @@ func (c *Client) checkInit() error {
|
|||||||
if c == nil || c.c == nil || c.a == nil || c.close == nil {
|
if c == nil || c.c == nil || c.a == nil || c.close == nil {
|
||||||
return ErrClientNotInitialized
|
return ErrClientNotInitialized
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -590,6 +601,7 @@ func (c *Client) Do(m *Message, f func(Event)) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.wait()
|
h.wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -611,80 +623,85 @@ var bufferPool = &sync.Pool{ //nolint:gochecknoglobals
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handleAgentCallback(e Event) {
|
func (c *Client) handleAgentCallback(event Event) { //nolint:cyclop
|
||||||
c.mux.Lock()
|
c.mux.Lock()
|
||||||
if c.closed {
|
if c.closed {
|
||||||
c.mux.Unlock()
|
c.mux.Unlock()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t, found := c.t[e.TransactionID]
|
transaction, found := c.t[event.TransactionID]
|
||||||
if found {
|
if found {
|
||||||
delete(c.t, t.id)
|
delete(c.t, transaction.id)
|
||||||
}
|
}
|
||||||
c.mux.Unlock()
|
c.mux.Unlock()
|
||||||
if !found {
|
if !found {
|
||||||
if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
|
if c.handler != nil && !errors.Is(event.Error, ErrTransactionStopped) {
|
||||||
c.handler(e)
|
c.handler(event)
|
||||||
}
|
}
|
||||||
// Ignoring.
|
// Ignoring.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
|
if atomic.LoadInt32(&c.maxAttempts) <= transaction.attempt || event.Error == nil {
|
||||||
// Transaction completed.
|
// Transaction completed.
|
||||||
t.handle(e)
|
transaction.handle(event)
|
||||||
putClientTransaction(t)
|
putClientTransaction(transaction)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Doing re-transmission.
|
// Doing re-transmission.
|
||||||
t.attempt++
|
transaction.attempt++
|
||||||
b := bufferPool.Get().(*buffer) //nolint:forcetypeassert
|
buff := bufferPool.Get().(*buffer) //nolint:forcetypeassert
|
||||||
b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
|
buff.buf = buff.buf[:copy(buff.buf[:cap(buff.buf)], transaction.raw)]
|
||||||
defer bufferPool.Put(b)
|
defer bufferPool.Put(buff)
|
||||||
var (
|
var (
|
||||||
now = c.clock.Now()
|
now = c.clock.Now()
|
||||||
timeOut = t.nextTimeout(now)
|
timeOut = transaction.nextTimeout(now)
|
||||||
id = t.id
|
id = transaction.id
|
||||||
)
|
)
|
||||||
// Starting client transaction.
|
// Starting client transaction.
|
||||||
if startErr := c.start(t); startErr != nil {
|
if startErr := c.start(transaction); startErr != nil {
|
||||||
c.delete(id)
|
c.delete(id)
|
||||||
e.Error = startErr
|
event.Error = startErr
|
||||||
t.handle(e)
|
transaction.handle(event)
|
||||||
putClientTransaction(t)
|
putClientTransaction(transaction)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Starting agent transaction.
|
// Starting agent transaction.
|
||||||
if startErr := c.a.Start(id, timeOut); startErr != nil {
|
if startErr := c.a.Start(id, timeOut); startErr != nil {
|
||||||
c.delete(id)
|
c.delete(id)
|
||||||
e.Error = startErr
|
event.Error = startErr
|
||||||
t.handle(e)
|
transaction.handle(event)
|
||||||
putClientTransaction(t)
|
putClientTransaction(transaction)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Writing message to connection again.
|
// Writing message to connection again.
|
||||||
_, writeErr := c.c.Write(b.buf)
|
_, writeErr := c.c.Write(buff.buf)
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
c.delete(id)
|
c.delete(id)
|
||||||
e.Error = writeErr
|
event.Error = writeErr
|
||||||
// Stopping agent transaction instead of waiting until it's deadline.
|
// Stopping agent transaction instead of waiting until it's deadline.
|
||||||
// This will call handleAgentCallback with "ErrTransactionStopped" error
|
// This will call handleAgentCallback with "ErrTransactionStopped" error
|
||||||
// which will be ignored.
|
// which will be ignored.
|
||||||
if stopErr := c.a.Stop(id); stopErr != nil {
|
if stopErr := c.a.Stop(id); stopErr != nil {
|
||||||
// Failed to stop agent transaction. Wrapping the error in StopError.
|
// Failed to stop agent transaction. Wrapping the error in StopError.
|
||||||
e.Error = StopErr{
|
event.Error = StopErr{
|
||||||
Err: stopErr,
|
Err: stopErr,
|
||||||
Cause: writeErr,
|
Cause: writeErr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.handle(e)
|
transaction.handle(event)
|
||||||
putClientTransaction(t)
|
putClientTransaction(transaction)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts transaction (if h set) and writes message to server, handler
|
// Start starts transaction (if h set) and writes message to server, handler
|
||||||
// is called asynchronously.
|
// is called asynchronously.
|
||||||
func (c *Client) Start(m *Message, h Handler) error {
|
func (c *Client) Start(msg *Message, handler Handler) error {
|
||||||
if err := c.checkInit(); err != nil {
|
if err := c.checkInit(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -694,34 +711,35 @@ func (c *Client) Start(m *Message, h Handler) error {
|
|||||||
if closed {
|
if closed {
|
||||||
return ErrClientClosed
|
return ErrClientClosed
|
||||||
}
|
}
|
||||||
if h != nil {
|
if handler != nil {
|
||||||
// Starting transaction only if h is set. Useful for indications.
|
// Starting transaction only if h is set. Useful for indications.
|
||||||
t := acquireClientTransaction()
|
t := acquireClientTransaction()
|
||||||
t.id = m.TransactionID
|
t.id = msg.TransactionID
|
||||||
t.start = c.clock.Now()
|
t.start = c.clock.Now()
|
||||||
t.h = h
|
t.h = handler
|
||||||
t.rto = time.Duration(atomic.LoadInt64(&c.rto))
|
t.rto = time.Duration(atomic.LoadInt64(&c.rto))
|
||||||
t.attempt = 0
|
t.attempt = 0
|
||||||
t.raw = append(t.raw[:0], m.Raw...)
|
t.raw = append(t.raw[:0], msg.Raw...)
|
||||||
t.calls = 0
|
t.calls = 0
|
||||||
d := t.nextTimeout(t.start)
|
d := t.nextTimeout(t.start)
|
||||||
if err := c.start(t); err != nil {
|
if err := c.start(t); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := c.a.Start(m.TransactionID, d); err != nil {
|
if err := c.a.Start(msg.TransactionID, d); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := m.WriteTo(c.c)
|
_, err := msg.WriteTo(c.c)
|
||||||
if err != nil && h != nil {
|
if err != nil && handler != nil {
|
||||||
c.delete(m.TransactionID)
|
c.delete(msg.TransactionID)
|
||||||
// Stopping transaction instead of waiting until deadline.
|
// Stopping transaction instead of waiting until deadline.
|
||||||
if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
|
if stopErr := c.a.Stop(msg.TransactionID); stopErr != nil {
|
||||||
return StopErr{
|
return StopErr{
|
||||||
Err: stopErr,
|
Err: stopErr,
|
||||||
Cause: err,
|
Cause: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
163
client_test.go
163
client_test.go
@@ -38,11 +38,13 @@ type TestAgent struct {
|
|||||||
|
|
||||||
func (n *TestAgent) SetHandler(h Handler) error {
|
func (n *TestAgent) SetHandler(h Handler) error {
|
||||||
n.h = h
|
n.h = h
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *TestAgent) Close() error {
|
func (n *TestAgent) Close() error {
|
||||||
close(n.e)
|
close(n.e)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +56,7 @@ func (n *TestAgent) Start(id [TransactionIDSize]byte, _ time.Time) error {
|
|||||||
n.e <- Event{
|
n.e <- Event{
|
||||||
TransactionID: id,
|
TransactionID: id,
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,6 +72,7 @@ func (noopConnection) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
func (noopConnection) Read([]byte) (int, error) {
|
func (noopConnection) Read([]byte) (int, error) {
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +140,7 @@ func (t *testConnection) Close() error {
|
|||||||
return errClientAlreadyStopped
|
return errClientAlreadyStopped
|
||||||
}
|
}
|
||||||
t.stopped = true
|
t.stopped = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,6 +153,7 @@ func (t *testConnection) Read(b []byte) (int, error) {
|
|||||||
if t.read != nil {
|
if t.read != nil {
|
||||||
return t.read(b)
|
return t.read(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
return copy(b, t.b), nil
|
return copy(b, t.b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +171,7 @@ func TestClosedOrPanic(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_Start(t *testing.T) {
|
func TestClient_Start(t *testing.T) { //nolint:cyclop
|
||||||
response := MustBuild(TransactionID, BindingSuccess)
|
response := MustBuild(TransactionID, BindingSuccess)
|
||||||
response.Encode()
|
response.Encode()
|
||||||
write := make(chan struct{}, 1)
|
write := make(chan struct{}, 1)
|
||||||
@@ -178,6 +184,7 @@ func TestClient_Start(t *testing.T) {
|
|||||||
case <-read:
|
case <-read:
|
||||||
t.Log("reading")
|
t.Log("reading")
|
||||||
copy(i, response.Raw)
|
copy(i, response.Raw)
|
||||||
|
|
||||||
return len(response.Raw), nil
|
return len(response.Raw), nil
|
||||||
case <-time.After(time.Millisecond * 10):
|
case <-time.After(time.Millisecond * 10):
|
||||||
return 0, errClientReadTimedOut
|
return 0, errClientReadTimedOut
|
||||||
@@ -188,33 +195,34 @@ func TestClient_Start(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case <-write:
|
case <-write:
|
||||||
t.Log("writing")
|
t.Log("writing")
|
||||||
|
|
||||||
return len(bytes), nil
|
return len(bytes), nil
|
||||||
case <-time.After(time.Millisecond * 10):
|
case <-time.After(time.Millisecond * 10):
|
||||||
return 0, errClientWriteTimedOut
|
return 0, errClientWriteTimedOut
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c, err := NewClient(conn)
|
client, err := NewClient(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := c.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := c.Close(); err == nil {
|
if err := client.Close(); err == nil {
|
||||||
t.Error("second close should fail")
|
t.Error("second close should fail")
|
||||||
}
|
}
|
||||||
if err := c.Do(MustBuild(TransactionID), nil); err == nil {
|
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
|
||||||
t.Error("Do after Close should fail")
|
t.Error("Do after Close should fail")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
m := MustBuild(response, BindingRequest)
|
msg := MustBuild(response, BindingRequest)
|
||||||
t.Log("init")
|
t.Log("init")
|
||||||
got := make(chan struct{})
|
got := make(chan struct{})
|
||||||
write <- struct{}{}
|
write <- struct{}{}
|
||||||
t.Log("starting the first transaction")
|
t.Log("starting the first transaction")
|
||||||
if err := c.Start(m, func(event Event) {
|
if err := client.Start(msg, func(event Event) {
|
||||||
t.Log("got first transaction callback")
|
t.Log("got first transaction callback")
|
||||||
if event.Error != nil {
|
if event.Error != nil {
|
||||||
t.Error(event.Error)
|
t.Error(event.Error)
|
||||||
@@ -224,7 +232,7 @@ func TestClient_Start(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
t.Log("starting the second transaction")
|
t.Log("starting the second transaction")
|
||||||
if err := c.Start(m, func(Event) {
|
if err := client.Start(msg, func(Event) {
|
||||||
t.Error("should not be called")
|
t.Error("should not be called")
|
||||||
}); !errors.Is(err, ErrTransactionExists) {
|
}); !errors.Is(err, ErrTransactionExists) {
|
||||||
t.Errorf("unexpected error %v", err)
|
t.Errorf("unexpected error %v", err)
|
||||||
@@ -247,25 +255,25 @@ func TestClient_Do(t *testing.T) {
|
|||||||
return len(bytes), nil
|
return len(bytes), nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c, err := NewClient(conn)
|
client, err := NewClient(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := c.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := c.Close(); err == nil {
|
if err := client.Close(); err == nil {
|
||||||
t.Error("second close should fail")
|
t.Error("second close should fail")
|
||||||
}
|
}
|
||||||
if err := c.Do(MustBuild(TransactionID), nil); err == nil {
|
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
|
||||||
t.Error("Do after Close should fail")
|
t.Error("Do after Close should fail")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
m := MustBuild(
|
m := MustBuild(
|
||||||
NewTransactionIDSetter(response.TransactionID),
|
NewTransactionIDSetter(response.TransactionID),
|
||||||
)
|
)
|
||||||
if err := c.Do(m, func(event Event) {
|
if err := client.Do(m, func(event Event) {
|
||||||
if event.Error != nil {
|
if event.Error != nil {
|
||||||
t.Error(event.Error)
|
t.Error(event.Error)
|
||||||
}
|
}
|
||||||
@@ -273,13 +281,13 @@ func TestClient_Do(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
m = MustBuild(TransactionID)
|
m = MustBuild(TransactionID)
|
||||||
if err := c.Do(m, nil); err != nil {
|
if err := client.Do(m, nil); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCloseErr_Error(t *testing.T) {
|
func TestCloseErr_Error(t *testing.T) {
|
||||||
for id, c := range []struct {
|
for id, testCase := range []struct {
|
||||||
Err CloseErr
|
Err CloseErr
|
||||||
Out string
|
Out string
|
||||||
}{
|
}{
|
||||||
@@ -291,16 +299,16 @@ func TestCloseErr_Error(t *testing.T) {
|
|||||||
ConnectionErr: io.ErrUnexpectedEOF,
|
ConnectionErr: io.ErrUnexpectedEOF,
|
||||||
}, "failed to close: unexpected EOF (connection), <nil> (agent)"},
|
}, "failed to close: unexpected EOF (connection), <nil> (agent)"},
|
||||||
} {
|
} {
|
||||||
if out := c.Err.Error(); out != c.Out {
|
if out := testCase.Err.Error(); out != testCase.Out {
|
||||||
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
||||||
id, c.Err, out, c.Out,
|
id, testCase.Err, out, testCase.Out,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStopErr_Error(t *testing.T) {
|
func TestStopErr_Error(t *testing.T) {
|
||||||
for id, c := range []struct {
|
for id, testcase := range []struct {
|
||||||
Err StopErr
|
Err StopErr
|
||||||
Out string
|
Out string
|
||||||
}{
|
}{
|
||||||
@@ -312,9 +320,9 @@ func TestStopErr_Error(t *testing.T) {
|
|||||||
Cause: io.ErrUnexpectedEOF,
|
Cause: io.ErrUnexpectedEOF,
|
||||||
}, "error while stopping due to unexpected EOF: <nil>"},
|
}, "error while stopping due to unexpected EOF: <nil>"},
|
||||||
} {
|
} {
|
||||||
if out := c.Err.Error(); out != c.Out {
|
if out := testcase.Err.Error(); out != testcase.Out {
|
||||||
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
||||||
id, c.Err, out, c.Out,
|
id, testcase.Err, out, testcase.Out,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -352,7 +360,7 @@ func TestClientAgentError(t *testing.T) {
|
|||||||
return len(bytes), nil
|
return len(bytes), nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c, err := NewClient(conn,
|
client, err := NewClient(conn,
|
||||||
WithAgent(errorAgent{
|
WithAgent(errorAgent{
|
||||||
startErr: io.ErrUnexpectedEOF,
|
startErr: io.ErrUnexpectedEOF,
|
||||||
}),
|
}),
|
||||||
@@ -361,15 +369,15 @@ func TestClientAgentError(t *testing.T) {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := c.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
m := MustBuild(NewTransactionIDSetter(response.TransactionID))
|
m := MustBuild(NewTransactionIDSetter(response.TransactionID))
|
||||||
if err := c.Do(m, nil); err != nil {
|
if err := client.Do(m, nil); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := c.Do(m, func(event Event) {
|
if err := client.Do(m, func(event Event) {
|
||||||
if event.Error == nil {
|
if event.Error == nil {
|
||||||
t.Error("error expected")
|
t.Error("error expected")
|
||||||
}
|
}
|
||||||
@@ -384,20 +392,20 @@ func TestClientConnErr(t *testing.T) {
|
|||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c, err := NewClient(conn)
|
client, err := NewClient(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := c.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
m := MustBuild(TransactionID)
|
m := MustBuild(TransactionID)
|
||||||
if err := c.Do(m, nil); err == nil {
|
if err := client.Do(m, nil); err == nil {
|
||||||
t.Error("error expected")
|
t.Error("error expected")
|
||||||
}
|
}
|
||||||
if err := c.Do(m, NoopHandler()); err == nil {
|
if err := client.Do(m, NoopHandler()); err == nil {
|
||||||
t.Error("error expected")
|
t.Error("error expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -408,7 +416,7 @@ func TestClientConnErrStopErr(t *testing.T) {
|
|||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c, err := NewClient(conn,
|
client, err := NewClient(conn,
|
||||||
WithAgent(errorAgent{
|
WithAgent(errorAgent{
|
||||||
stopErr: io.ErrUnexpectedEOF,
|
stopErr: io.ErrUnexpectedEOF,
|
||||||
}),
|
}),
|
||||||
@@ -417,12 +425,12 @@ func TestClientConnErrStopErr(t *testing.T) {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := c.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
m := MustBuild(TransactionID)
|
m := MustBuild(TransactionID)
|
||||||
if err := c.Do(m, NoopHandler()); err == nil {
|
if err := client.Do(m, NoopHandler()); err == nil {
|
||||||
t.Error("error expected")
|
t.Error("error expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -556,11 +564,13 @@ func (a *gcWaitAgent) Stop([TransactionIDSize]byte) error {
|
|||||||
|
|
||||||
func (a *gcWaitAgent) Close() error {
|
func (a *gcWaitAgent) Close() error {
|
||||||
close(a.gc)
|
close(a.gc)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *gcWaitAgent) Collect(time.Time) error {
|
func (a *gcWaitAgent) Collect(time.Time) error {
|
||||||
a.gc <- struct{}{}
|
a.gc <- struct{}{}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -617,6 +627,7 @@ func captureLog() (*bytes.Buffer, func()) {
|
|||||||
log.SetOutput(&buf)
|
log.SetOutput(&buf)
|
||||||
f := log.Flags()
|
f := log.Flags()
|
||||||
log.SetFlags(0)
|
log.SetFlags(0)
|
||||||
|
|
||||||
return &buf, func() {
|
return &buf, func() {
|
||||||
log.SetFlags(f)
|
log.SetFlags(f)
|
||||||
log.SetOutput(os.Stderr)
|
log.SetOutput(os.Stderr)
|
||||||
@@ -633,12 +644,12 @@ func TestClientFinalizer(t *testing.T) {
|
|||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c, err := NewClient(conn)
|
client, err := NewClient(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
log.Panic(err)
|
||||||
}
|
}
|
||||||
clientFinalizer(c)
|
clientFinalizer(client)
|
||||||
clientFinalizer(c)
|
clientFinalizer(client)
|
||||||
response := MustBuild(TransactionID, BindingSuccess)
|
response := MustBuild(TransactionID, BindingSuccess)
|
||||||
response.Encode()
|
response.Encode()
|
||||||
conn = &testConnection{
|
conn = &testConnection{
|
||||||
@@ -647,7 +658,7 @@ func TestClientFinalizer(t *testing.T) {
|
|||||||
return len(bytes), nil
|
return len(bytes), nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c, err = NewClient(conn,
|
client, err = NewClient(conn,
|
||||||
WithAgent(errorAgent{
|
WithAgent(errorAgent{
|
||||||
closeErr: io.ErrUnexpectedEOF,
|
closeErr: io.ErrUnexpectedEOF,
|
||||||
}),
|
}),
|
||||||
@@ -655,7 +666,7 @@ func TestClientFinalizer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
log.Panic(err)
|
||||||
}
|
}
|
||||||
clientFinalizer(c)
|
clientFinalizer(client)
|
||||||
reader := bufio.NewScanner(buf)
|
reader := bufio.NewScanner(buf)
|
||||||
var lines int
|
var lines int
|
||||||
expectedLines := []string{
|
expectedLines := []string{
|
||||||
@@ -700,6 +711,7 @@ func (m *manualCollector) Collect(t time.Time) {
|
|||||||
|
|
||||||
func (m *manualCollector) Start(_ time.Duration, f func(t time.Time)) error {
|
func (m *manualCollector) Start(_ time.Duration, f func(t time.Time)) error {
|
||||||
m.f = f
|
m.f = f
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -717,12 +729,14 @@ func (m *manualClock) Add(d time.Duration) time.Time {
|
|||||||
v := m.current.Add(d)
|
v := m.current.Add(d)
|
||||||
m.current = v
|
m.current = v
|
||||||
m.mux.Unlock()
|
m.mux.Unlock()
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manualClock) Now() time.Time {
|
func (m *manualClock) Now() time.Time {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
return m.current
|
return m.current
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -735,6 +749,7 @@ type manualAgent struct {
|
|||||||
|
|
||||||
func (n *manualAgent) SetHandler(h Handler) error {
|
func (n *manualAgent) SetHandler(h Handler) error {
|
||||||
n.h = h
|
n.h = h
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -748,6 +763,7 @@ func (n *manualAgent) Process(m *Message) error {
|
|||||||
if n.process != nil {
|
if n.process != nil {
|
||||||
return n.process(m)
|
return n.process(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -759,6 +775,7 @@ func (n *manualAgent) Stop(id [TransactionIDSize]byte) error {
|
|||||||
if n.stop != nil {
|
if n.stop != nil {
|
||||||
return n.stop(id)
|
return n.stop(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -788,9 +805,10 @@ func TestClientRetransmission(t *testing.T) {
|
|||||||
Message: response,
|
Message: response,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c, err := NewClient(connR,
|
client, err := NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -799,7 +817,7 @@ func TestClientRetransmission(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
c.SetRTO(time.Second)
|
client.SetRTO(time.Second)
|
||||||
gotReads := make(chan struct{})
|
gotReads := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
@@ -819,7 +837,7 @@ func TestClientRetransmission(t *testing.T) {
|
|||||||
}
|
}
|
||||||
gotReads <- struct{}{}
|
gotReads <- struct{}{}
|
||||||
}()
|
}()
|
||||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||||
if event.Error != nil {
|
if event.Error != nil {
|
||||||
t.Error("failed")
|
t.Error("failed")
|
||||||
}
|
}
|
||||||
@@ -829,7 +847,9 @@ func TestClientRetransmission(t *testing.T) {
|
|||||||
<-gotReads
|
<-gotReads
|
||||||
}
|
}
|
||||||
|
|
||||||
func testClientDoConcurrent(t *testing.T, concurrency int) {
|
func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
response := MustBuild(TransactionID, BindingSuccess)
|
response := MustBuild(TransactionID, BindingSuccess)
|
||||||
response.Encode()
|
response.Encode()
|
||||||
connL, connR := net.Pipe()
|
connL, connR := net.Pipe()
|
||||||
@@ -846,9 +866,10 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
|
|||||||
TransactionID: id,
|
TransactionID: id,
|
||||||
Message: response,
|
Message: response,
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c, err := NewClient(connR,
|
client, err := NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -856,7 +877,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
c.SetRTO(time.Second)
|
client.SetRTO(time.Second)
|
||||||
conns := new(sync.WaitGroup)
|
conns := new(sync.WaitGroup)
|
||||||
wg := new(sync.WaitGroup)
|
wg := new(sync.WaitGroup)
|
||||||
for i := 0; i < concurrency; i++ {
|
for i := 0; i < concurrency; i++ {
|
||||||
@@ -880,7 +901,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if doErr := c.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
|
if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
|
||||||
if event.Error != nil {
|
if event.Error != nil {
|
||||||
t.Error("failed")
|
t.Error("failed")
|
||||||
}
|
}
|
||||||
@@ -962,14 +983,14 @@ func TestClient_Close(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClientDefaultHandler(t *testing.T) {
|
func TestClientDefaultHandler(t *testing.T) {
|
||||||
a := &TestAgent{
|
agent := &TestAgent{
|
||||||
e: make(chan Event),
|
e: make(chan Event),
|
||||||
}
|
}
|
||||||
id := NewTransactionID()
|
id := NewTransactionID()
|
||||||
handlerCalled := make(chan struct{})
|
handlerCalled := make(chan struct{})
|
||||||
called := false
|
called := false
|
||||||
c, createErr := NewClient(noopConnection{},
|
client, createErr := NewClient(noopConnection{},
|
||||||
WithAgent(a),
|
WithAgent(agent),
|
||||||
WithHandler(func(e Event) {
|
WithHandler(func(e Event) {
|
||||||
if called {
|
if called {
|
||||||
t.Error("should not be called twice")
|
t.Error("should not be called twice")
|
||||||
@@ -985,7 +1006,7 @@ func TestClientDefaultHandler(t *testing.T) {
|
|||||||
t.Fatal(createErr)
|
t.Fatal(createErr)
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
a.h(Event{
|
agent.h(Event{
|
||||||
TransactionID: id,
|
TransactionID: id,
|
||||||
})
|
})
|
||||||
}()
|
}()
|
||||||
@@ -995,11 +1016,11 @@ func TestClientDefaultHandler(t *testing.T) {
|
|||||||
case <-time.After(time.Millisecond * 100):
|
case <-time.After(time.Millisecond * 100):
|
||||||
t.Fatal("timed out")
|
t.Fatal("timed out")
|
||||||
}
|
}
|
||||||
if closeErr := c.Close(); closeErr != nil {
|
if closeErr := client.Close(); closeErr != nil {
|
||||||
t.Error(closeErr)
|
t.Error(closeErr)
|
||||||
}
|
}
|
||||||
// Handler call should be ignored.
|
// Handler call should be ignored.
|
||||||
a.h(Event{})
|
agent.h(Event{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientClosedStart(t *testing.T) {
|
func TestClientClosedStart(t *testing.T) {
|
||||||
@@ -1047,9 +1068,10 @@ func TestWithNoRetransmit(t *testing.T) {
|
|||||||
Error: ErrTransactionTimeOut,
|
Error: ErrTransactionTimeOut,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c, err := NewClient(connR,
|
client, err := NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -1071,7 +1093,7 @@ func TestWithNoRetransmit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
gotReads <- struct{}{}
|
gotReads <- struct{}{}
|
||||||
}()
|
}()
|
||||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||||
if !errors.Is(event.Error, ErrTransactionTimeOut) {
|
if !errors.Is(event.Error, ErrTransactionTimeOut) {
|
||||||
t.Error("unexpected error")
|
t.Error("unexpected error")
|
||||||
}
|
}
|
||||||
@@ -1087,7 +1109,7 @@ func (c callbackClock) Now() time.Time {
|
|||||||
return c()
|
return c()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientRTOStartErr(t *testing.T) {
|
func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
|
||||||
response := MustBuild(TransactionID, BindingSuccess)
|
response := MustBuild(TransactionID, BindingSuccess)
|
||||||
response.Encode()
|
response.Encode()
|
||||||
connL, connR := net.Pipe()
|
connL, connR := net.Pipe()
|
||||||
@@ -1116,13 +1138,14 @@ func TestClientRTOStartErr(t *testing.T) {
|
|||||||
} else {
|
} else {
|
||||||
t.Log("clock returned")
|
t.Log("clock returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Now()
|
return time.Now()
|
||||||
})
|
})
|
||||||
agent := &manualAgent{}
|
agent := &manualAgent{}
|
||||||
attempt := 0
|
attempt := 0
|
||||||
gotReads := make(chan struct{})
|
gotReads := make(chan struct{})
|
||||||
var (
|
var (
|
||||||
c *Client
|
client *Client
|
||||||
startClientErr error
|
startClientErr error
|
||||||
)
|
)
|
||||||
agent.start = func(id [TransactionIDSize]byte, _ time.Time) error {
|
agent.start = func(id [TransactionIDSize]byte, _ time.Time) error {
|
||||||
@@ -1146,7 +1169,7 @@ func TestClientRTOStartErr(t *testing.T) {
|
|||||||
t.Log("clock locked")
|
t.Log("clock locked")
|
||||||
<-clockLocked
|
<-clockLocked
|
||||||
t.Log("closing client")
|
t.Log("closing client")
|
||||||
if closeErr := c.Close(); closeErr != nil {
|
if closeErr := client.Close(); closeErr != nil {
|
||||||
t.Error(closeErr)
|
t.Error(closeErr)
|
||||||
}
|
}
|
||||||
t.Log("client closed, unlocking clock")
|
t.Log("client closed, unlocking clock")
|
||||||
@@ -1154,9 +1177,10 @@ func TestClientRTOStartErr(t *testing.T) {
|
|||||||
t.Log("clock unlocked")
|
t.Log("clock unlocked")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c, startClientErr = NewClient(connR,
|
client, startClientErr = NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -1186,7 +1210,7 @@ func TestClientRTOStartErr(t *testing.T) {
|
|||||||
t.Log("starting")
|
t.Log("starting")
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||||
if !errors.Is(event.Error, ErrClientClosed) {
|
if !errors.Is(event.Error, ErrClientClosed) {
|
||||||
t.Error(event.Error)
|
t.Error(event.Error)
|
||||||
}
|
}
|
||||||
@@ -1203,7 +1227,7 @@ func TestClientRTOStartErr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientRTOWriteErr(t *testing.T) {
|
func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
|
||||||
response := MustBuild(TransactionID, BindingSuccess)
|
response := MustBuild(TransactionID, BindingSuccess)
|
||||||
response.Encode()
|
response.Encode()
|
||||||
connL, connR := net.Pipe()
|
connL, connR := net.Pipe()
|
||||||
@@ -1232,13 +1256,14 @@ func TestClientRTOWriteErr(t *testing.T) {
|
|||||||
} else {
|
} else {
|
||||||
t.Log("clock returned")
|
t.Log("clock returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Now()
|
return time.Now()
|
||||||
})
|
})
|
||||||
agent := &manualAgent{}
|
agent := &manualAgent{}
|
||||||
attempt := 0
|
attempt := 0
|
||||||
gotReads := make(chan struct{})
|
gotReads := make(chan struct{})
|
||||||
var (
|
var (
|
||||||
c *Client
|
client *Client
|
||||||
startClientErr error
|
startClientErr error
|
||||||
)
|
)
|
||||||
agentStopErr := errClientAgentCantStop
|
agentStopErr := errClientAgentCantStop
|
||||||
@@ -1274,9 +1299,10 @@ func TestClientRTOWriteErr(t *testing.T) {
|
|||||||
t.Log("clock unlocked")
|
t.Log("clock unlocked")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c, startClientErr = NewClient(connR,
|
client, startClientErr = NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -1306,7 +1332,7 @@ func TestClientRTOWriteErr(t *testing.T) {
|
|||||||
t.Log("starting")
|
t.Log("starting")
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||||
var e StopErr
|
var e StopErr
|
||||||
if !errors.As(event.Error, &e) {
|
if !errors.As(event.Error, &e) {
|
||||||
t.Error(event.Error)
|
t.Error(event.Error)
|
||||||
@@ -1346,7 +1372,7 @@ func TestClientRTOAgentErr(t *testing.T) {
|
|||||||
attempt := 0
|
attempt := 0
|
||||||
gotReads := make(chan struct{})
|
gotReads := make(chan struct{})
|
||||||
var (
|
var (
|
||||||
c *Client
|
client *Client
|
||||||
startClientErr error
|
startClientErr error
|
||||||
)
|
)
|
||||||
agentStartErr := errClientStartRefused
|
agentStartErr := errClientStartRefused
|
||||||
@@ -1361,9 +1387,10 @@ func TestClientRTOAgentErr(t *testing.T) {
|
|||||||
} else {
|
} else {
|
||||||
return agentStartErr
|
return agentStartErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c, startClientErr = NewClient(connR,
|
client, startClientErr = NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -1384,7 +1411,7 @@ func TestClientRTOAgentErr(t *testing.T) {
|
|||||||
gotReads <- struct{}{}
|
gotReads <- struct{}{}
|
||||||
}()
|
}()
|
||||||
t.Log("starting")
|
t.Log("starting")
|
||||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||||
if !errors.Is(event.Error, agentStartErr) {
|
if !errors.Is(event.Error, agentStartErr) {
|
||||||
t.Error(event.Error)
|
t.Error(event.Error)
|
||||||
}
|
}
|
||||||
@@ -1415,9 +1442,10 @@ func TestClient_HandleProcessError(t *testing.T) {
|
|||||||
processCalled := make(chan struct{}, 1)
|
processCalled := make(chan struct{}, 1)
|
||||||
agent.process = func(*Message) error {
|
agent.process = func(*Message) error {
|
||||||
processCalled <- struct{}{}
|
processCalled <- struct{}{}
|
||||||
|
|
||||||
return ErrAgentClosed
|
return ErrAgentClosed
|
||||||
}
|
}
|
||||||
c, startClientErr := NewClient(connR,
|
client, startClientErr := NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -1440,7 +1468,7 @@ func TestClient_HandleProcessError(t *testing.T) {
|
|||||||
case <-time.After(time.Second * 5):
|
case <-time.After(time.Second * 5):
|
||||||
t.Error("reads timeout")
|
t.Error("reads timeout")
|
||||||
}
|
}
|
||||||
if closeErr := c.Close(); closeErr != nil {
|
if closeErr := client.Close(); closeErr != nil {
|
||||||
t.Error(closeErr)
|
t.Error(closeErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1475,9 +1503,10 @@ func TestClientImmediateTimeout(t *testing.T) {
|
|||||||
Error: ErrTransactionTimeOut,
|
Error: ErrTransactionTimeOut,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c, err := NewClient(connR,
|
client, err := NewClient(connR,
|
||||||
WithAgent(agent),
|
WithAgent(agent),
|
||||||
WithClock(clock),
|
WithClock(clock),
|
||||||
WithCollector(collector),
|
WithCollector(collector),
|
||||||
@@ -1498,7 +1527,7 @@ func TestClientImmediateTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
gotReads <- struct{}{}
|
gotReads <- struct{}{}
|
||||||
}()
|
}()
|
||||||
c.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
|
client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
|
||||||
if errors.Is(e.Error, ErrTransactionTimeOut) {
|
if errors.Is(e.Error, ErrTransactionTimeOut) {
|
||||||
t.Error("unexpected error")
|
t.Error("unexpected error")
|
||||||
}
|
}
|
||||||
|
@@ -30,7 +30,7 @@ var (
|
|||||||
realRand = flag.Bool("crypt", false, "use crypto/rand as random source") //nolint:gochecknoglobals
|
realRand = flag.Bool("crypt", false, "use crypto/rand as random source") //nolint:gochecknoglobals
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() { //nolint:gocognit
|
func main() { //nolint:gocognit,cyclop
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
uri, err := stun.ParseURI(*uriStr)
|
uri, err := stun.ParseURI(*uriStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -88,7 +88,7 @@ func main() { //nolint:gocognit
|
|||||||
log.Print("Using crypto/rand as random source for transaction id")
|
log.Print("Using crypto/rand as random source for transaction id")
|
||||||
}
|
}
|
||||||
for i := 0; i < *workers; i++ {
|
for i := 0; i < *workers; i++ {
|
||||||
c, clientErr := stun.DialURI(uri, &stun.DialConfig{})
|
client, clientErr := stun.DialURI(uri, &stun.DialConfig{})
|
||||||
if clientErr != nil {
|
if clientErr != nil {
|
||||||
log.Panicf("Failed to create client: %s", clientErr)
|
log.Panicf("Failed to create client: %s", clientErr)
|
||||||
}
|
}
|
||||||
@@ -105,12 +105,13 @@ func main() { //nolint:gocognit
|
|||||||
req.Type = stun.BindingRequest
|
req.Type = stun.BindingRequest
|
||||||
req.WriteHeader()
|
req.WriteHeader()
|
||||||
atomic.AddInt64(&request, 1)
|
atomic.AddInt64(&request, 1)
|
||||||
if doErr := c.Do(req, func(event stun.Event) {
|
if doErr := client.Do(req, func(event stun.Event) {
|
||||||
if event.Error != nil {
|
if event.Error != nil {
|
||||||
if !errors.Is(event.Error, stun.ErrTransactionTimeOut) {
|
if !errors.Is(event.Error, stun.ErrTransactionTimeOut) {
|
||||||
log.Printf("Failed STUN transaction: %s", event.Error)
|
log.Printf("Failed STUN transaction: %s", event.Error)
|
||||||
}
|
}
|
||||||
atomic.AddInt64(&requestErr, 1)
|
atomic.AddInt64(&requestErr, 1)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.AddInt64(&requestOK, 1)
|
atomic.AddInt64(&requestOK, 1)
|
||||||
|
@@ -31,11 +31,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// we only try the first address, so restrict ourselves to IPv4
|
// we only try the first address, so restrict ourselves to IPv4
|
||||||
c, err := stun.DialURI(uri, &stun.DialConfig{})
|
client, err := stun.DialURI(uri, &stun.DialConfig{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to dial: %s", err)
|
log.Fatalf("Failed to dial: %s", err)
|
||||||
}
|
}
|
||||||
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
log.Fatalf("Failed STUN transaction: %s", res.Error)
|
log.Fatalf("Failed STUN transaction: %s", res.Error)
|
||||||
}
|
}
|
||||||
@@ -49,7 +49,7 @@ func main() {
|
|||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Fatal("Do:", err)
|
log.Fatal("Do:", err)
|
||||||
}
|
}
|
||||||
if err := c.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
log.Fatalf("Failed to close connection: %s", err)
|
log.Fatalf("Failed to close connection: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -84,7 +84,7 @@ func multiplex(conn *net.UDPConn, stunAddr net.Addr, stunConn io.Reader) {
|
|||||||
|
|
||||||
var stunServer = flag.String("stun", "stun.l.google.com:19302", "STUN Server to use") //nolint:gochecknoglobals
|
var stunServer = flag.String("stun", "stun.l.google.com:19302", "STUN Server to use") //nolint:gochecknoglobals
|
||||||
|
|
||||||
func main() {
|
func main() { //nolint:cyclop
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
isServer := flag.Arg(0) == ""
|
isServer := flag.Arg(0) == ""
|
||||||
@@ -112,7 +112,7 @@ func main() {
|
|||||||
|
|
||||||
stunL, stunR := net.Pipe()
|
stunL, stunR := net.Pipe()
|
||||||
|
|
||||||
c, err := stun.NewClient(stunR)
|
client, err := stun.NewClient(stunR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("Failed to create client: %s", err)
|
log.Panicf("Failed to create client: %s", err)
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ func main() {
|
|||||||
// This can fail if your NAT Server is strict and will use separate ports
|
// This can fail if your NAT Server is strict and will use separate ports
|
||||||
// for application data and STUN
|
// for application data and STUN
|
||||||
var gotAddr stun.XORMappedAddress
|
var gotAddr stun.XORMappedAddress
|
||||||
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
log.Panicf("Failed STUN transaction: %s", res.Error)
|
log.Panicf("Failed STUN transaction: %s", res.Error)
|
||||||
}
|
}
|
||||||
@@ -153,7 +153,7 @@ func main() {
|
|||||||
// Any ping-pong will work, but we are just making binding requests.
|
// Any ping-pong will work, but we are just making binding requests.
|
||||||
// Note that STUN Server is not mandatory for keep alive, application
|
// Note that STUN Server is not mandatory for keep alive, application
|
||||||
// data will keep alive that binding too.
|
// data will keep alive that binding too.
|
||||||
go keepAlive(c)
|
go keepAlive(client)
|
||||||
|
|
||||||
notify := make(chan os.Signal, 1)
|
notify := make(chan os.Signal, 1)
|
||||||
signal.Notify(notify, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(notify, os.Interrupt, syscall.SIGTERM)
|
||||||
@@ -168,6 +168,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
case <-notify:
|
case <-notify:
|
||||||
log.Println("Stopping")
|
log.Println("Stopping")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -203,10 +204,12 @@ func main() {
|
|||||||
|
|
||||||
case m := <-messages:
|
case m := <-messages:
|
||||||
log.Printf("Got response from %s: %s", m.addr, m.text)
|
log.Printf("Got response from %s: %s", m.addr, m.text)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-notify:
|
case <-notify:
|
||||||
log.Print("Stopping")
|
log.Print("Stopping")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -30,10 +30,14 @@ func (c *stunServerConn) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address") //nolint:gochecknoglobals
|
//nolint:gochecknoglobals
|
||||||
timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response") //nolint:gochecknoglobals
|
addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address")
|
||||||
verbose = flag.Int("verbose", 1, "the verbosity level") //nolint:gochecknoglobals
|
//nolint:gochecknoglobals
|
||||||
log logging.LeveledLogger //nolint:gochecknoglobals
|
timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response")
|
||||||
|
//nolint:gochecknoglobals
|
||||||
|
verbose = flag.Int("verbose", 1, "the verbosity level")
|
||||||
|
//nolint:gochecknoglobals
|
||||||
|
log logging.LeveledLogger
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -70,11 +74,12 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RFC5780: 4.3. Determining NAT Mapping Behavior
|
// RFC5780: 4.3. Determining NAT Mapping Behavior.
|
||||||
func mappingTests(addrStr string) error {
|
func mappingTests(addrStr string) error { //nolint:cyclop
|
||||||
mapTestConn, err := connect(addrStr)
|
mapTestConn, err := connect(addrStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Error creating STUN connection: %s", err)
|
log.Warnf("Error creating STUN connection: %s", err)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,11 +96,13 @@ func mappingTests(addrStr string) error {
|
|||||||
resps1 := parse(resp)
|
resps1 := parse(resp)
|
||||||
if resps1.xorAddr == nil || resps1.otherAddr == nil {
|
if resps1.xorAddr == nil || resps1.otherAddr == nil {
|
||||||
log.Info("Error: NAT discovery feature not supported by this server")
|
log.Info("Error: NAT discovery feature not supported by this server")
|
||||||
|
|
||||||
return errNoOtherAddress
|
return errNoOtherAddress
|
||||||
}
|
}
|
||||||
addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String())
|
addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps1.otherAddr)
|
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps1.otherAddr)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mapTestConn.OtherAddr = addr
|
mapTestConn.OtherAddr = addr
|
||||||
@@ -104,6 +111,7 @@ func mappingTests(addrStr string) error {
|
|||||||
// Assert mapping behavior
|
// Assert mapping behavior
|
||||||
if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() {
|
if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() {
|
||||||
log.Warn("=> NAT mapping behavior: endpoint independent (no NAT)")
|
log.Warn("=> NAT mapping behavior: endpoint independent (no NAT)")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +129,7 @@ func mappingTests(addrStr string) error {
|
|||||||
log.Infof("Received XOR-MAPPED-ADDRESS: %v", resps2.xorAddr)
|
log.Infof("Received XOR-MAPPED-ADDRESS: %v", resps2.xorAddr)
|
||||||
if resps2.xorAddr.String() == resps1.xorAddr.String() {
|
if resps2.xorAddr.String() == resps1.xorAddr.String() {
|
||||||
log.Warn("=> NAT mapping behavior: endpoint independent")
|
log.Warn("=> NAT mapping behavior: endpoint independent")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,11 +152,12 @@ func mappingTests(addrStr string) error {
|
|||||||
return mapTestConn.Close()
|
return mapTestConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RFC5780: 4.4. Determining NAT Filtering Behavior
|
// RFC5780: 4.4. Determining NAT Filtering Behavior.
|
||||||
func filteringTests(addrStr string) error {
|
func filteringTests(addrStr string) error { //nolint:cyclop
|
||||||
mapTestConn, err := connect(addrStr)
|
mapTestConn, err := connect(addrStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Error creating STUN connection: %s", err)
|
log.Warnf("Error creating STUN connection: %s", err)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,11 +172,13 @@ func filteringTests(addrStr string) error {
|
|||||||
resps := parse(resp)
|
resps := parse(resp)
|
||||||
if resps.xorAddr == nil || resps.otherAddr == nil {
|
if resps.xorAddr == nil || resps.otherAddr == nil {
|
||||||
log.Warn("Error: NAT discovery feature not supported by this server")
|
log.Warn("Error: NAT discovery feature not supported by this server")
|
||||||
|
|
||||||
return errNoOtherAddress
|
return errNoOtherAddress
|
||||||
}
|
}
|
||||||
addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String())
|
addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps.otherAddr)
|
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps.otherAddr)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mapTestConn.OtherAddr = addr
|
mapTestConn.OtherAddr = addr
|
||||||
@@ -180,6 +192,7 @@ func filteringTests(addrStr string) error {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
parse(resp) // just to print out the resp
|
parse(resp) // just to print out the resp
|
||||||
log.Warn("=> NAT filtering behavior: endpoint independent")
|
log.Warn("=> NAT filtering behavior: endpoint independent")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
} else if !errors.Is(err, errTimedOut) {
|
} else if !errors.Is(err, errTimedOut) {
|
||||||
return err // something else went wrong
|
return err // something else went wrong
|
||||||
@@ -201,7 +214,7 @@ func filteringTests(addrStr string) error {
|
|||||||
return mapTestConn.Close()
|
return mapTestConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse a STUN message
|
// Parse a STUN message.
|
||||||
func parse(msg *stun.Message) (ret struct {
|
func parse(msg *stun.Message) (ret struct {
|
||||||
xorAddr *stun.XORMappedAddress
|
xorAddr *stun.XORMappedAddress
|
||||||
otherAddr *stun.OtherAddress
|
otherAddr *stun.OtherAddress
|
||||||
@@ -249,15 +262,17 @@ func parse(msg *stun.Message) (ret struct {
|
|||||||
log.Debugf("\t%v (l=%v)", attr, attr.Length)
|
log.Debugf("\t%v (l=%v)", attr, attr.Length)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given an address string, returns a StunServerConn
|
// Given an address string, returns a StunServerConn.
|
||||||
func connect(addrStr string) (*stunServerConn, error) {
|
func connect(addrStr string) (*stunServerConn, error) {
|
||||||
log.Infof("Connecting to STUN server: %s", addrStr)
|
log.Infof("Connecting to STUN server: %s", addrStr)
|
||||||
addr, err := net.ResolveUDPAddr("udp4", addrStr)
|
addr, err := net.ResolveUDPAddr("udp4", addrStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Error resolving address: %s", err)
|
log.Warnf("Error resolving address: %s", err)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,7 +293,7 @@ func connect(addrStr string) (*stunServerConn, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send request and wait for response or timeout
|
// Send request and wait for response or timeout.
|
||||||
func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Message, error) {
|
func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Message, error) {
|
||||||
_ = msg.NewTransactionID()
|
_ = msg.NewTransactionID()
|
||||||
log.Infof("Sending to %v: (%v bytes)", addr, msg.Length+messageHeaderSize)
|
log.Infof("Sending to %v: (%v bytes)", addr, msg.Length+messageHeaderSize)
|
||||||
@@ -289,6 +304,7 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
|
|||||||
_, err := c.conn.WriteTo(msg.Raw, addr)
|
_, err := c.conn.WriteTo(msg.Raw, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Error sending request to %v", addr)
|
log.Warnf("Error sending request to %v", addr)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,9 +314,11 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, errResponseMessage
|
return nil, errResponseMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
case <-time.After(time.Duration(*timeoutPtr) * time.Second):
|
case <-time.After(time.Duration(*timeoutPtr) * time.Second):
|
||||||
log.Infof("Timed out waiting for response from server %v", addr)
|
log.Infof("Timed out waiting for response from server %v", addr)
|
||||||
|
|
||||||
return nil, errTimedOut
|
return nil, errTimedOut
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -315,6 +333,7 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
|
|||||||
n, addr, err := conn.ReadFromUDP(buf)
|
n, addr, err := conn.ReadFromUDP(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
close(messages)
|
close(messages)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Infof("Response from %v: (%v bytes)", addr, n)
|
log.Infof("Response from %v: (%v bytes)", addr, n)
|
||||||
@@ -326,11 +345,13 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("Error decoding message: %v", err)
|
log.Infof("Error decoding message: %v", err)
|
||||||
close(messages)
|
close(messages)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
messages <- m
|
messages <- m
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -26,7 +26,7 @@ const (
|
|||||||
timeoutMillis = 500
|
timeoutMillis = 500
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() { //nolint:gocognit
|
func main() { //nolint:gocognit,cyclop
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
srvAddr, err := net.ResolveUDPAddr(udp, *server)
|
srvAddr, err := net.ResolveUDPAddr(udp, *server)
|
||||||
@@ -87,11 +87,13 @@ func main() { //nolint:gocognit
|
|||||||
decErr := m.Decode()
|
decErr := m.Decode()
|
||||||
if decErr != nil {
|
if decErr != nil {
|
||||||
log.Println("decode:", decErr)
|
log.Println("decode:", decErr)
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
var xorAddr stun.XORMappedAddress
|
var xorAddr stun.XORMappedAddress
|
||||||
if getErr := xorAddr.GetFrom(m); getErr != nil {
|
if getErr := xorAddr.GetFrom(m); getErr != nil {
|
||||||
log.Println("getFrom:", getErr)
|
log.Println("getFrom:", getErr)
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,6 +162,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
|
|||||||
n, _, err := conn.ReadFromUDP(buf)
|
n, _, err := conn.ReadFromUDP(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
close(messages)
|
close(messages)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
buf = buf[:n]
|
buf = buf[:n]
|
||||||
@@ -167,6 +170,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
|
|||||||
messages <- buf
|
messages <- buf
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func test(network string) {
|
func test(network string) { //nolint:cyclop
|
||||||
addr := resolve(network)
|
addr := resolve(network)
|
||||||
fmt.Println("START", strings.ToUpper(addr.Network())) //nolint
|
fmt.Println("START", strings.ToUpper(addr.Network())) //nolint
|
||||||
var (
|
var (
|
||||||
|
25
errorcode.go
25
errorcode.go
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
// ErrorCodeAttribute represents ERROR-CODE attribute.
|
// ErrorCodeAttribute represents ERROR-CODE attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.6
|
// RFC 5389 Section 15.6.
|
||||||
type ErrorCodeAttribute struct {
|
type ErrorCodeAttribute struct {
|
||||||
Code ErrorCode
|
Code ErrorCode
|
||||||
Reason []byte
|
Reason []byte
|
||||||
@@ -31,7 +31,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AddTo adds ERROR-CODE to m.
|
// AddTo adds ERROR-CODE to m.
|
||||||
func (c ErrorCodeAttribute) AddTo(m *Message) error {
|
func (c ErrorCodeAttribute) AddTo(msg *Message) error {
|
||||||
value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB)
|
value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB)
|
||||||
if err := CheckOverflow(AttrErrorCode,
|
if err := CheckOverflow(AttrErrorCode,
|
||||||
len(c.Reason)+errorCodeReasonStart,
|
len(c.Reason)+errorCodeReasonStart,
|
||||||
@@ -45,26 +45,28 @@ func (c ErrorCodeAttribute) AddTo(m *Message) error {
|
|||||||
value[errorCodeClassByte] = class
|
value[errorCodeClassByte] = class
|
||||||
value[errorCodeNumberByte] = number
|
value[errorCodeNumberByte] = number
|
||||||
copy(value[errorCodeReasonStart:], c.Reason)
|
copy(value[errorCodeReasonStart:], c.Reason)
|
||||||
m.Add(AttrErrorCode, value)
|
msg.Add(AttrErrorCode, value)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid.
|
// GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid.
|
||||||
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
|
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
|
||||||
v, err := m.Get(AttrErrorCode)
|
value, err := m.Get(AttrErrorCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(v) < errorCodeReasonStart {
|
if len(value) < errorCodeReasonStart {
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
var (
|
var (
|
||||||
class = uint16(v[errorCodeClassByte])
|
class = uint16(value[errorCodeClassByte])
|
||||||
number = uint16(v[errorCodeNumberByte])
|
number = uint16(value[errorCodeNumberByte])
|
||||||
code = int(class*errorCodeModulo + number)
|
code = int(class*errorCodeModulo + number)
|
||||||
)
|
)
|
||||||
c.Code = ErrorCode(code)
|
c.Code = ErrorCode(code)
|
||||||
c.Reason = v[errorCodeReasonStart:]
|
c.Reason = value[errorCodeReasonStart:]
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,6 +88,7 @@ func (c ErrorCode) AddTo(m *Message) error {
|
|||||||
Code: c,
|
Code: c,
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
}
|
}
|
||||||
|
|
||||||
return a.AddTo(m)
|
return a.AddTo(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +111,7 @@ const (
|
|||||||
|
|
||||||
// Error codes from RFC 5766.
|
// Error codes from RFC 5766.
|
||||||
//
|
//
|
||||||
// RFC 5766 Section 15
|
// RFC 5766 Section 15.
|
||||||
const (
|
const (
|
||||||
CodeForbidden ErrorCode = 403 // Forbidden
|
CodeForbidden ErrorCode = 403 // Forbidden
|
||||||
CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch
|
CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch
|
||||||
@@ -120,7 +123,7 @@ const (
|
|||||||
|
|
||||||
// Error codes from RFC 6062.
|
// Error codes from RFC 6062.
|
||||||
//
|
//
|
||||||
// RFC 6062 Section 6.3
|
// RFC 6062 Section 6.3.
|
||||||
const (
|
const (
|
||||||
CodeConnAlreadyExists ErrorCode = 446
|
CodeConnAlreadyExists ErrorCode = 446
|
||||||
CodeConnTimeoutOrFailure ErrorCode = 447
|
CodeConnTimeoutOrFailure ErrorCode = 447
|
||||||
@@ -128,7 +131,7 @@ const (
|
|||||||
|
|
||||||
// Error codes from RFC 6156.
|
// Error codes from RFC 6156.
|
||||||
//
|
//
|
||||||
// RFC 6156 Section 10.2
|
// RFC 6156 Section 10.2.
|
||||||
const (
|
const (
|
||||||
CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported
|
CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported
|
||||||
CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch
|
CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch
|
||||||
|
@@ -92,23 +92,23 @@ func TestMessage_AddErrorCode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestErrorCode(t *testing.T) {
|
func TestErrorCode(t *testing.T) {
|
||||||
a := &ErrorCodeAttribute{
|
attr := &ErrorCodeAttribute{
|
||||||
Code: 404,
|
Code: 404,
|
||||||
Reason: []byte("not found!"),
|
Reason: []byte("not found!"),
|
||||||
}
|
}
|
||||||
if a.String() != "404: not found!" {
|
if attr.String() != "404: not found!" {
|
||||||
t.Error("bad string", a)
|
t.Error("bad string", attr)
|
||||||
}
|
}
|
||||||
m := New()
|
m := New()
|
||||||
cod := ErrorCode(666)
|
cod := ErrorCode(666)
|
||||||
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) {
|
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) {
|
||||||
t.Error("should be ErrNoDefaultReason", err)
|
t.Error("should be ErrNoDefaultReason", err)
|
||||||
}
|
}
|
||||||
if err := a.GetFrom(m); err == nil {
|
if err := attr.GetFrom(m); err == nil {
|
||||||
t.Error("attr should not be in message")
|
t.Error("attr should not be in message")
|
||||||
}
|
}
|
||||||
a.Reason = make([]byte, 2048)
|
attr.Reason = make([]byte, 2048)
|
||||||
if err := a.AddTo(m); err == nil {
|
if err := attr.AddTo(m); err == nil {
|
||||||
t.Error("should error")
|
t.Error("should error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
// FingerprintAttr represents FINGERPRINT attribute.
|
// FingerprintAttr represents FINGERPRINT attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.5
|
// RFC 5389 Section 15.5.
|
||||||
type FingerprintAttr struct{}
|
type FingerprintAttr struct{}
|
||||||
|
|
||||||
// ErrFingerprintMismatch means that computed fingerprint differs from expected.
|
// ErrFingerprintMismatch means that computed fingerprint differs from expected.
|
||||||
@@ -50,6 +50,7 @@ func (FingerprintAttr) AddTo(m *Message) error {
|
|||||||
bin.PutUint32(b, val)
|
bin.PutUint32(b, val)
|
||||||
m.Length = l
|
m.Length = l
|
||||||
m.Add(AttrFingerprint, b)
|
m.Add(AttrFingerprint, b)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,5 +67,6 @@ func (FingerprintAttr) Check(m *Message) error {
|
|||||||
val := bin.Uint32(b)
|
val := bin.Uint32(b)
|
||||||
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
|
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
|
||||||
expected := FingerprintValue(m.Raw[:attrStart])
|
expected := FingerprintValue(m.Raw[:attrStart])
|
||||||
|
|
||||||
return checkFingerprint(val, expected)
|
return checkFingerprint(val, expected)
|
||||||
}
|
}
|
||||||
|
@@ -13,20 +13,20 @@ import (
|
|||||||
|
|
||||||
func BenchmarkFingerprint_AddTo(b *testing.B) {
|
func BenchmarkFingerprint_AddTo(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
s := NewSoftware("software")
|
s := NewSoftware("software")
|
||||||
addr := &XORMappedAddress{
|
addr := &XORMappedAddress{
|
||||||
IP: net.IPv4(213, 1, 223, 5),
|
IP: net.IPv4(213, 1, 223, 5),
|
||||||
}
|
}
|
||||||
addAttr(b, m, addr)
|
addAttr(b, msg, addr)
|
||||||
addAttr(b, m, s)
|
addAttr(b, msg, s)
|
||||||
b.SetBytes(int64(len(m.Raw)))
|
b.SetBytes(int64(len(msg.Raw)))
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
Fingerprint.AddTo(m) //nolint:errcheck,gosec
|
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
|
||||||
m.WriteLength()
|
msg.WriteLength()
|
||||||
m.Length -= attributeHeaderSize + fingerprintSize
|
msg.Length -= attributeHeaderSize + fingerprintSize
|
||||||
m.Raw = m.Raw[:m.Length+messageHeaderSize]
|
msg.Raw = msg.Raw[:msg.Length+messageHeaderSize]
|
||||||
m.Attributes = m.Attributes[:len(m.Attributes)-1]
|
msg.Attributes = msg.Attributes[:len(msg.Attributes)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -109,6 +109,7 @@ func FuzzSetters(f *testing.F) {
|
|||||||
if !IsAttrSizeOverflow(err) {
|
if !IsAttrSizeOverflow(err) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,5 +151,6 @@ func (a attributes) pick(v byte) struct {
|
|||||||
t AttrType
|
t AttrType
|
||||||
} {
|
} {
|
||||||
idx := int(v) % len(a)
|
idx := int(v) % len(a)
|
||||||
|
|
||||||
return a[idx]
|
return a[idx]
|
||||||
}
|
}
|
||||||
|
@@ -44,6 +44,7 @@ func (m *Message) Build(setters ...Setter) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ func (m *Message) Check(checkers ...Checker) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,6 +66,7 @@ func (m *Message) Parse(getters ...Getter) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,6 +76,7 @@ func MustBuild(setters ...Setter) *Message {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err) //nolint
|
panic(err) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +86,7 @@ func Build(setters ...Setter) (*Message, error) {
|
|||||||
if err := m.Build(setters...); err != nil {
|
if err := m.Build(setters...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,5 +110,6 @@ func (m *Message) ForEach(t AttrType, f func(m *Message) error) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
func BenchmarkBuildOverhead(b *testing.B) {
|
func BenchmarkBuildOverhead(b *testing.B) {
|
||||||
var (
|
var (
|
||||||
t = BindingRequest
|
msgType = BindingRequest
|
||||||
username = NewUsername("username")
|
username = NewUsername("username")
|
||||||
nonce = NewNonce("nonce")
|
nonce = NewNonce("nonce")
|
||||||
realm = NewRealm("example.org")
|
realm = NewRealm("example.org")
|
||||||
@@ -22,14 +22,14 @@ func BenchmarkBuildOverhead(b *testing.B) {
|
|||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
m := new(Message)
|
m := new(Message)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
m.Build(&t, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec
|
m.Build(&msgType, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("BuildNonPointer", func(b *testing.B) {
|
b.Run("BuildNonPointer", func(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
m := new(Message)
|
m := new(Message)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
m.Build(t, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec
|
m.Build(msgType, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("Raw", func(b *testing.B) {
|
b.Run("Raw", func(b *testing.B) {
|
||||||
@@ -38,7 +38,7 @@ func BenchmarkBuildOverhead(b *testing.B) {
|
|||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
m.Reset()
|
m.Reset()
|
||||||
m.WriteHeader()
|
m.WriteHeader()
|
||||||
m.SetType(t)
|
m.SetType(msgType)
|
||||||
username.AddTo(m) //nolint:errcheck,gosec
|
username.AddTo(m) //nolint:errcheck,gosec
|
||||||
nonce.AddTo(m) //nolint:errcheck,gosec
|
nonce.AddTo(m) //nolint:errcheck,gosec
|
||||||
realm.AddTo(m) //nolint:errcheck,gosec
|
realm.AddTo(m) //nolint:errcheck,gosec
|
||||||
@@ -52,7 +52,7 @@ func TestMessage_Apply(t *testing.T) {
|
|||||||
integrity = NewShortTermIntegrity("password")
|
integrity = NewShortTermIntegrity("password")
|
||||||
decoded = new(Message)
|
decoded = new(Message)
|
||||||
)
|
)
|
||||||
m, err := Build(BindingRequest, TransactionID,
|
msg, err := Build(BindingRequest, TransactionID,
|
||||||
NewUsername("username"),
|
NewUsername("username"),
|
||||||
NewNonce("nonce"),
|
NewNonce("nonce"),
|
||||||
NewRealm("example.org"),
|
NewRealm("example.org"),
|
||||||
@@ -62,13 +62,13 @@ func TestMessage_Apply(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to build:", err)
|
t.Fatal("failed to build:", err)
|
||||||
}
|
}
|
||||||
if err = m.Check(Fingerprint, integrity); err != nil {
|
if err = msg.Check(Fingerprint, integrity); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if _, err := decoded.Write(m.Raw); err != nil {
|
if _, err := decoded.Write(msg.Raw); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !decoded.Equal(m) {
|
if !decoded.Equal(msg) {
|
||||||
t.Error("not equal")
|
t.Error("not equal")
|
||||||
}
|
}
|
||||||
if err := integrity.Check(decoded); err != nil {
|
if err := integrity.Check(decoded); err != nil {
|
||||||
@@ -96,32 +96,32 @@ func (e errReturner) GetFrom(*Message) error {
|
|||||||
|
|
||||||
func TestHelpersErrorHandling(t *testing.T) {
|
func TestHelpersErrorHandling(t *testing.T) {
|
||||||
m := New()
|
m := New()
|
||||||
e := errReturner{Err: errTError}
|
errReturn := errReturner{Err: errTError}
|
||||||
if err := m.Build(e); !errors.Is(err, e.Err) {
|
if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) {
|
||||||
t.Error(err, "!=", e.Err)
|
t.Error(err, "!=", errReturn.Err)
|
||||||
}
|
}
|
||||||
if err := m.Check(e); !errors.Is(err, e.Err) {
|
if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) {
|
||||||
t.Error(err, "!=", e.Err)
|
t.Error(err, "!=", errReturn.Err)
|
||||||
}
|
}
|
||||||
if err := m.Parse(e); !errors.Is(err, e.Err) {
|
if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) {
|
||||||
t.Error(err, "!=", e.Err)
|
t.Error(err, "!=", errReturn.Err)
|
||||||
}
|
}
|
||||||
t.Run("MustBuild", func(t *testing.T) {
|
t.Run("MustBuild", func(t *testing.T) {
|
||||||
t.Run("Positive", func(*testing.T) {
|
t.Run("Positive", func(*testing.T) {
|
||||||
MustBuild(NewTransactionIDSetter(transactionID{}))
|
MustBuild(NewTransactionIDSetter(transactionID{}))
|
||||||
})
|
})
|
||||||
defer func() {
|
defer func() {
|
||||||
if p, ok := recover().(error); !ok || !errors.Is(p, e.Err) {
|
if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) {
|
||||||
t.Errorf("%s != %s",
|
t.Errorf("%s != %s",
|
||||||
p, e.Err,
|
p, errReturn.Err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
MustBuild(e)
|
MustBuild(errReturn)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_ForEach(t *testing.T) {
|
func TestMessage_ForEach(t *testing.T) { //nolint:cyclop
|
||||||
initial := New()
|
initial := New()
|
||||||
if err := initial.Build(
|
if err := initial.Build(
|
||||||
NewRealm("realm1"), NewRealm("realm2"),
|
NewRealm("realm1"), NewRealm("realm2"),
|
||||||
@@ -135,6 +135,7 @@ func TestMessage_ForEach(t *testing.T) {
|
|||||||
); err != nil {
|
); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
t.Run("NoResults", func(t *testing.T) {
|
t.Run("NoResults", func(t *testing.T) {
|
||||||
@@ -144,6 +145,7 @@ func TestMessage_ForEach(t *testing.T) {
|
|||||||
}
|
}
|
||||||
if err := m.ForEach(AttrUsername, func(*Message) error {
|
if err := m.ForEach(AttrUsername, func(*Message) error {
|
||||||
t.Error("should not be called")
|
t.Error("should not be called")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -160,6 +162,7 @@ func TestMessage_ForEach(t *testing.T) {
|
|||||||
t.Error("called multiple times")
|
t.Error("called multiple times")
|
||||||
}
|
}
|
||||||
calls++
|
calls++
|
||||||
|
|
||||||
return ErrAttributeNotFound
|
return ErrAttributeNotFound
|
||||||
}); !errors.Is(err, ErrAttributeNotFound) {
|
}); !errors.Is(err, ErrAttributeNotFound) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -169,14 +172,15 @@ func TestMessage_ForEach(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("Positive", func(t *testing.T) {
|
t.Run("Positive", func(t *testing.T) {
|
||||||
m := newMessage()
|
msg := newMessage()
|
||||||
var realms []string
|
var realms []string
|
||||||
if err := m.ForEach(AttrRealm, func(m *Message) error {
|
if err := msg.ForEach(AttrRealm, func(m *Message) error {
|
||||||
var realm Realm
|
var realm Realm
|
||||||
if err := realm.GetFrom(m); err != nil {
|
if err := realm.GetFrom(m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
realms = append(realms, realm.String())
|
realms = append(realms, realm.String())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -190,18 +194,18 @@ func TestMessage_ForEach(t *testing.T) {
|
|||||||
if realms[1] != "realm2" {
|
if realms[1] != "realm2" {
|
||||||
t.Error("bad value for 2 realm")
|
t.Error("bad value for 2 realm")
|
||||||
}
|
}
|
||||||
if !m.Equal(initial) {
|
if !msg.Equal(initial) {
|
||||||
t.Error("m should be equal to initial")
|
t.Error("m should be equal to initial")
|
||||||
}
|
}
|
||||||
t.Run("ZeroAlloc", func(t *testing.T) {
|
t.Run("ZeroAlloc", func(t *testing.T) {
|
||||||
m = newMessage()
|
msg = newMessage()
|
||||||
var realm Realm
|
var realm Realm
|
||||||
testutil.ShouldNotAllocate(t, func() {
|
testutil.ShouldNotAllocate(t, func() {
|
||||||
if err := m.ForEach(AttrRealm, realm.GetFrom); err != nil {
|
if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if !m.Equal(initial) {
|
if !msg.Equal(initial) {
|
||||||
t.Error("m should be equal to initial")
|
t.Error("m should be equal to initial")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -216,6 +220,7 @@ func ExampleMessage_ForEach() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Println(r)
|
fmt.Println(r)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
fmt.Println("error:", err)
|
fmt.Println("error:", err)
|
||||||
|
32
iana_test.go
32
iana_test.go
@@ -14,22 +14,24 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadCSV(t testing.TB, name string) [][]string {
|
func loadCSV(tb testing.TB, name string) [][]string {
|
||||||
t.Helper()
|
tb.Helper()
|
||||||
data := loadData(t, name)
|
|
||||||
|
data := loadData(tb, name)
|
||||||
r := csv.NewReader(bytes.NewReader(data))
|
r := csv.NewReader(bytes.NewReader(data))
|
||||||
r.Comment = '#'
|
r.Comment = '#'
|
||||||
records, err := r.ReadAll()
|
records, err := r.ReadAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
tb.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return records
|
return records
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIANA(t *testing.T) {
|
func TestIANA(t *testing.T) { //nolint:cyclop
|
||||||
t.Run("Methods", func(t *testing.T) {
|
t.Run("Methods", func(t *testing.T) {
|
||||||
records := loadCSV(t, "stun-parameters-2.csv")
|
records := loadCSV(t, "stun-parameters-2.csv")
|
||||||
m := make(map[string]Method)
|
methods := make(map[string]Method)
|
||||||
for _, r := range records[1:] {
|
for _, r := range records[1:] {
|
||||||
var (
|
var (
|
||||||
v = r[0]
|
v = r[0]
|
||||||
@@ -43,10 +45,10 @@ func TestIANA(t *testing.T) {
|
|||||||
t.Fatal(parseErr)
|
t.Fatal(parseErr)
|
||||||
}
|
}
|
||||||
t.Logf("value: 0x%x, name: %s", val, name)
|
t.Logf("value: 0x%x, name: %s", val, name)
|
||||||
m[name] = Method(val)
|
methods[name] = Method(val) //nolint:gosec // G115
|
||||||
}
|
}
|
||||||
for val, name := range methodName() {
|
for val, name := range methodName() {
|
||||||
mapped, ok := m[name]
|
mapped, ok := methods[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("failed to find method %s in IANA", name)
|
t.Errorf("failed to find method %s in IANA", name)
|
||||||
}
|
}
|
||||||
@@ -57,7 +59,7 @@ func TestIANA(t *testing.T) {
|
|||||||
})
|
})
|
||||||
t.Run("Attributes", func(t *testing.T) {
|
t.Run("Attributes", func(t *testing.T) {
|
||||||
records := loadCSV(t, "stun-parameters-4.csv")
|
records := loadCSV(t, "stun-parameters-4.csv")
|
||||||
m := map[string]AttrType{}
|
attrTypes := map[string]AttrType{}
|
||||||
for _, r := range records[1:] {
|
for _, r := range records[1:] {
|
||||||
var (
|
var (
|
||||||
v = r[0]
|
v = r[0]
|
||||||
@@ -71,16 +73,16 @@ func TestIANA(t *testing.T) {
|
|||||||
t.Fatal(parseErr)
|
t.Fatal(parseErr)
|
||||||
}
|
}
|
||||||
t.Logf("value: 0x%x, name: %s", val, name)
|
t.Logf("value: 0x%x, name: %s", val, name)
|
||||||
m[name] = AttrType(val)
|
attrTypes[name] = AttrType(val) //nolint:gosec // G115
|
||||||
}
|
}
|
||||||
// Not registered in IANA.
|
// Not registered in IANA.
|
||||||
for k, v := range map[string]AttrType{
|
for k, v := range map[string]AttrType{
|
||||||
"ORIGIN": 0x802F,
|
"ORIGIN": 0x802F,
|
||||||
} {
|
} {
|
||||||
m[k] = v
|
attrTypes[k] = v
|
||||||
}
|
}
|
||||||
for val, name := range attrNames() {
|
for val, name := range attrNames() {
|
||||||
mapped, ok := m[name]
|
mapped, ok := attrTypes[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("failed to find attribute %s in IANA", name)
|
t.Errorf("failed to find attribute %s in IANA", name)
|
||||||
}
|
}
|
||||||
@@ -91,7 +93,7 @@ func TestIANA(t *testing.T) {
|
|||||||
})
|
})
|
||||||
t.Run("ErrorCodes", func(t *testing.T) {
|
t.Run("ErrorCodes", func(t *testing.T) {
|
||||||
records := loadCSV(t, "stun-parameters-6.csv")
|
records := loadCSV(t, "stun-parameters-6.csv")
|
||||||
m := map[string]ErrorCode{}
|
errorCodes := map[string]ErrorCode{}
|
||||||
for _, r := range records[1:] {
|
for _, r := range records[1:] {
|
||||||
var (
|
var (
|
||||||
v = r[0]
|
v = r[0]
|
||||||
@@ -105,11 +107,11 @@ func TestIANA(t *testing.T) {
|
|||||||
t.Fatal(parseErr)
|
t.Fatal(parseErr)
|
||||||
}
|
}
|
||||||
t.Logf("value: 0x%x, name: %s", val, name)
|
t.Logf("value: 0x%x, name: %s", val, name)
|
||||||
m[name] = ErrorCode(val)
|
errorCodes[name] = ErrorCode(val)
|
||||||
}
|
}
|
||||||
for val, nameB := range errorReasons {
|
for val, nameB := range errorReasons {
|
||||||
name := string(nameB)
|
name := string(nameB)
|
||||||
mapped, ok := m[name]
|
mapped, ok := errorCodes[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("failed to find error code %s in IANA", name)
|
t.Errorf("failed to find error code %s in IANA", name)
|
||||||
}
|
}
|
||||||
|
46
integrity.go
46
integrity.go
@@ -22,6 +22,7 @@ func NewLongTermIntegrity(username, realm, password string) MessageIntegrity {
|
|||||||
k := strings.Join([]string{username, realm, password}, credentialsSep)
|
k := strings.Join([]string{username, realm, password}, credentialsSep)
|
||||||
h := md5.New() //nolint:gosec
|
h := md5.New() //nolint:gosec
|
||||||
fmt.Fprint(h, k) //nolint:errcheck
|
fmt.Fprint(h, k) //nolint:errcheck
|
||||||
|
|
||||||
return MessageIntegrity(h.Sum(nil))
|
return MessageIntegrity(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,13 +37,14 @@ func NewShortTermIntegrity(password string) MessageIntegrity {
|
|||||||
// AddTo and Check methods are using zero-allocation version of hmac, see
|
// AddTo and Check methods are using zero-allocation version of hmac, see
|
||||||
// newHMAC function and internal/hmac/pool.go.
|
// newHMAC function and internal/hmac/pool.go.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.4
|
// RFC 5389 Section 15.4.
|
||||||
type MessageIntegrity []byte
|
type MessageIntegrity []byte
|
||||||
|
|
||||||
func newHMAC(key, message, buf []byte) []byte {
|
func newHMAC(key, message, buf []byte) []byte {
|
||||||
mac := hmac.AcquireSHA1(key)
|
mac := hmac.AcquireSHA1(key)
|
||||||
writeOrPanic(mac, message)
|
writeOrPanic(mac, message)
|
||||||
defer hmac.PutSHA1(mac)
|
defer hmac.PutSHA1(mac)
|
||||||
|
|
||||||
return mac.Sum(buf)
|
return mac.Sum(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,8 +61,8 @@ var ErrFingerprintBeforeIntegrity = errors.New("FINGERPRINT before MESSAGE-INTEG
|
|||||||
// AddTo adds MESSAGE-INTEGRITY attribute to message.
|
// AddTo adds MESSAGE-INTEGRITY attribute to message.
|
||||||
//
|
//
|
||||||
// CPU costly, see BenchmarkMessageIntegrity_AddTo.
|
// CPU costly, see BenchmarkMessageIntegrity_AddTo.
|
||||||
func (i MessageIntegrity) AddTo(m *Message) error {
|
func (i MessageIntegrity) AddTo(msg *Message) error {
|
||||||
for _, a := range m.Attributes {
|
for _, a := range msg.Attributes {
|
||||||
// Message should not contain FINGERPRINT attribute
|
// Message should not contain FINGERPRINT attribute
|
||||||
// before MESSAGE-INTEGRITY.
|
// before MESSAGE-INTEGRITY.
|
||||||
if a.Type == AttrFingerprint {
|
if a.Type == AttrFingerprint {
|
||||||
@@ -70,19 +72,20 @@ func (i MessageIntegrity) AddTo(m *Message) error {
|
|||||||
// The text used as input to HMAC is the STUN message,
|
// The text used as input to HMAC is the STUN message,
|
||||||
// including the header, up to and including the attribute preceding the
|
// including the header, up to and including the attribute preceding the
|
||||||
// MESSAGE-INTEGRITY attribute.
|
// MESSAGE-INTEGRITY attribute.
|
||||||
length := m.Length
|
length := msg.Length
|
||||||
// Adjusting m.Length to contain MESSAGE-INTEGRITY TLV.
|
// Adjusting m.Length to contain MESSAGE-INTEGRITY TLV.
|
||||||
m.Length += messageIntegritySize + attributeHeaderSize
|
msg.Length += messageIntegritySize + attributeHeaderSize
|
||||||
m.WriteLength() // writing length to m.Raw
|
msg.WriteLength() // writing length to m.Raw
|
||||||
v := newHMAC(i, m.Raw, m.Raw[len(m.Raw):]) // calculating HMAC for adjusted m.Raw
|
v := newHMAC(i, msg.Raw, msg.Raw[len(msg.Raw):]) // calculating HMAC for adjusted m.Raw
|
||||||
m.Length = length // changing m.Length back
|
msg.Length = length // changing m.Length back
|
||||||
|
|
||||||
// Copy hmac value to temporary variable to protect it from resetting
|
// Copy hmac value to temporary variable to protect it from resetting
|
||||||
// while processing m.Add call.
|
// while processing m.Add call.
|
||||||
vBuf := make([]byte, sha1.Size)
|
vBuf := make([]byte, sha1.Size)
|
||||||
copy(vBuf, v)
|
copy(vBuf, v)
|
||||||
|
|
||||||
m.Add(AttrMessageIntegrity, vBuf)
|
msg.Add(AttrMessageIntegrity, vBuf)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,8 +95,8 @@ var ErrIntegrityMismatch = errors.New("integrity check failed")
|
|||||||
// Check checks MESSAGE-INTEGRITY attribute.
|
// Check checks MESSAGE-INTEGRITY attribute.
|
||||||
//
|
//
|
||||||
// CPU costly, see BenchmarkMessageIntegrity_Check.
|
// CPU costly, see BenchmarkMessageIntegrity_Check.
|
||||||
func (i MessageIntegrity) Check(m *Message) error {
|
func (i MessageIntegrity) Check(msg *Message) error {
|
||||||
v, err := m.Get(AttrMessageIntegrity)
|
val, err := msg.Get(AttrMessageIntegrity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -101,11 +104,11 @@ func (i MessageIntegrity) Check(m *Message) error {
|
|||||||
// Adjusting length in header to match m.Raw that was
|
// Adjusting length in header to match m.Raw that was
|
||||||
// used when computing HMAC.
|
// used when computing HMAC.
|
||||||
var (
|
var (
|
||||||
length = m.Length
|
length = msg.Length
|
||||||
afterIntegrity = false
|
afterIntegrity = false
|
||||||
sizeReduced int
|
sizeReduced int
|
||||||
)
|
)
|
||||||
for _, a := range m.Attributes {
|
for _, a := range msg.Attributes {
|
||||||
if afterIntegrity {
|
if afterIntegrity {
|
||||||
sizeReduced += nearestPaddedValueLength(int(a.Length))
|
sizeReduced += nearestPaddedValueLength(int(a.Length))
|
||||||
sizeReduced += attributeHeaderSize
|
sizeReduced += attributeHeaderSize
|
||||||
@@ -114,13 +117,14 @@ func (i MessageIntegrity) Check(m *Message) error {
|
|||||||
afterIntegrity = true
|
afterIntegrity = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.Length -= uint32(sizeReduced)
|
msg.Length -= uint32(sizeReduced) //nolint:gosec // G115
|
||||||
m.WriteLength()
|
msg.WriteLength()
|
||||||
// startOfHMAC should be first byte of integrity attribute.
|
// startOfHMAC should be first byte of integrity attribute.
|
||||||
startOfHMAC := messageHeaderSize + m.Length - (attributeHeaderSize + messageIntegritySize)
|
startOfHMAC := messageHeaderSize + msg.Length - (attributeHeaderSize + messageIntegritySize)
|
||||||
b := m.Raw[:startOfHMAC] // data before integrity attribute
|
b := msg.Raw[:startOfHMAC] // data before integrity attribute
|
||||||
expected := newHMAC(i, b, m.Raw[len(m.Raw):])
|
expected := newHMAC(i, b, msg.Raw[len(msg.Raw):])
|
||||||
m.Length = length
|
msg.Length = length
|
||||||
m.WriteLength() // writing length back
|
msg.WriteLength() // writing length back
|
||||||
return checkHMAC(v, expected)
|
|
||||||
|
return checkHMAC(val, expected)
|
||||||
}
|
}
|
||||||
|
@@ -10,18 +10,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
|
func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
|
||||||
i := NewLongTermIntegrity("user", "realm", "pass")
|
integrity := NewLongTermIntegrity("user", "realm", "pass")
|
||||||
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
|
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(expected, i) {
|
if !bytes.Equal(expected, integrity) {
|
||||||
t.Error(ErrIntegrityMismatch)
|
t.Error(ErrIntegrityMismatch)
|
||||||
}
|
}
|
||||||
t.Run("Check", func(t *testing.T) {
|
t.Run("Check", func(t *testing.T) {
|
||||||
m := new(Message)
|
m := new(Message)
|
||||||
m.WriteHeader()
|
m.WriteHeader()
|
||||||
if err := i.AddTo(m); err != nil {
|
if err := integrity.AddTo(m); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
|
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
|
||||||
@@ -31,39 +31,39 @@ func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
|
|||||||
if err := dM.Decode(); err != nil {
|
if err := dM.Decode(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := i.Check(dM); err != nil {
|
if err := integrity.Check(dM); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
dM.Raw[24] += 12 // HMAC now invalid
|
dM.Raw[24] += 12 // HMAC now invalid
|
||||||
if i.Check(dM) == nil {
|
if integrity.Check(dM) == nil {
|
||||||
t.Error("should be invalid")
|
t.Error("should be invalid")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessageIntegrityWithFingerprint(t *testing.T) {
|
func TestMessageIntegrityWithFingerprint(t *testing.T) {
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
m.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
|
msg.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
|
NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec
|
||||||
i := NewShortTermIntegrity("pwd")
|
integrity := NewShortTermIntegrity("pwd")
|
||||||
if i.String() != "KEY: 0x707764" {
|
if integrity.String() != "KEY: 0x707764" {
|
||||||
t.Error("bad string", i)
|
t.Error("bad string", integrity)
|
||||||
}
|
}
|
||||||
if err := i.Check(m); err == nil {
|
if err := integrity.Check(msg); err == nil {
|
||||||
t.Error("should error")
|
t.Error("should error")
|
||||||
}
|
}
|
||||||
if err := i.AddTo(m); err != nil {
|
if err := integrity.AddTo(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := Fingerprint.AddTo(m); err != nil {
|
if err := Fingerprint.AddTo(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := i.Check(m); err != nil {
|
if err := integrity.Check(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
m.Raw[24] = 33
|
msg.Raw[24] = 33
|
||||||
if err := i.Check(m); err == nil {
|
if err := integrity.Check(msg); err == nil {
|
||||||
t.Fatal("mismatch expected")
|
t.Fatal("mismatch expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -64,6 +64,7 @@ func (h *hmac) Sum(in []byte) []byte {
|
|||||||
h.outer.Write(h.opad) //nolint:errcheck,gosec
|
h.outer.Write(h.opad) //nolint:errcheck,gosec
|
||||||
}
|
}
|
||||||
h.outer.Write(in[origLen:]) //nolint:errcheck,gosec
|
h.outer.Write(in[origLen:]) //nolint:errcheck,gosec
|
||||||
|
|
||||||
return h.outer.Sum(in[:origLen])
|
return h.outer.Sum(in[:origLen])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,6 +80,7 @@ func (h *hmac) Reset() {
|
|||||||
if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert
|
if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert
|
||||||
panic(err) //nolint
|
panic(err) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -22,7 +22,7 @@ type hmacTest struct {
|
|||||||
blocksize int
|
blocksize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func hmacTests() []hmacTest {
|
func hmacTests() []hmacTest { //nolint:maintidx
|
||||||
return []hmacTest{
|
return []hmacTest{
|
||||||
// Tests from US FIPS 198
|
// Tests from US FIPS 198
|
||||||
// https://csrc.nist.gov/publications/fips/fips198/fips-198a.pdf
|
// https://csrc.nist.gov/publications/fips/fips198/fips-198a.pdf
|
||||||
@@ -523,41 +523,42 @@ func hmacTests() []hmacTest {
|
|||||||
|
|
||||||
func TestHMAC(t *testing.T) {
|
func TestHMAC(t *testing.T) {
|
||||||
for i, tt := range hmacTests() {
|
for i, tt := range hmacTests() {
|
||||||
h := New(tt.hash, tt.key)
|
hsh := New(tt.hash, tt.key)
|
||||||
if s := h.Size(); s != tt.size {
|
if s := hsh.Size(); s != tt.size {
|
||||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||||
}
|
}
|
||||||
if b := h.BlockSize(); b != tt.blocksize {
|
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||||
}
|
}
|
||||||
for j := 0; j < 4; j++ {
|
for j := 0; j < 4; j++ { //nolint:varnamelen
|
||||||
n, err := h.Write(tt.in)
|
n, err := hsh.Write(tt.in)
|
||||||
if n != len(tt.in) || err != nil {
|
if n != len(tt.in) || err != nil {
|
||||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repetitive Sum() calls should return the same value
|
// Repetitive Sum() calls should return the same value
|
||||||
for k := 0; k < 2; k++ {
|
for k := 0; k < 2; k++ {
|
||||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||||
if sum != tt.out {
|
if sum != tt.out {
|
||||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second iteration: make sure reset works.
|
// Second iteration: make sure reset works.
|
||||||
h.Reset()
|
hsh.Reset()
|
||||||
|
|
||||||
// Third and fourth iteration: make sure hmac works on
|
// Third and fourth iteration: make sure hmac works on
|
||||||
// hashes without MarshalBinary/UnmarshalBinary
|
// hashes without MarshalBinary/UnmarshalBinary
|
||||||
if j == 1 {
|
if j == 1 {
|
||||||
h = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key)
|
hsh = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// justHash implements just the hash.Hash methods and nothing else
|
// justHash implements just the hash.Hash methods and nothing else.
|
||||||
type justHash struct {
|
type justHash struct {
|
||||||
hash.Hash
|
hash.Hash
|
||||||
}
|
}
|
||||||
|
@@ -40,6 +40,7 @@ func (h *hmac) resetTo(key []byte) {
|
|||||||
var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals
|
var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
h := New(sha1.New, make([]byte, sha1.BlockSize))
|
h := New(sha1.New, make([]byte, sha1.BlockSize))
|
||||||
|
|
||||||
return h
|
return h
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -49,6 +50,7 @@ func AcquireSHA1(key []byte) hash.Hash {
|
|||||||
h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert
|
h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert
|
||||||
assertHMACSize(h, sha1.Size, sha1.BlockSize)
|
assertHMACSize(h, sha1.Size, sha1.BlockSize)
|
||||||
h.resetTo(key)
|
h.resetTo(key)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,7 @@ func PutSHA1(h hash.Hash) {
|
|||||||
var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals
|
var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
h := New(sha256.New, make([]byte, sha256.BlockSize))
|
h := New(sha256.New, make([]byte, sha256.BlockSize))
|
||||||
|
|
||||||
return h
|
return h
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -71,6 +74,7 @@ func AcquireSHA256(key []byte) hash.Hash {
|
|||||||
h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert
|
h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert
|
||||||
assertHMACSize(h, sha256.Size, sha256.BlockSize)
|
assertHMACSize(h, sha256.Size, sha256.BlockSize)
|
||||||
h.resetTo(key)
|
h.resetTo(key)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -42,100 +42,103 @@ func BenchmarkHMACSHA1_512_Pool(b *testing.B) {
|
|||||||
|
|
||||||
func TestHMACReset(t *testing.T) {
|
func TestHMACReset(t *testing.T) {
|
||||||
for i, tt := range hmacTests() {
|
for i, tt := range hmacTests() {
|
||||||
h := New(tt.hash, tt.key)
|
hsh := New(tt.hash, tt.key)
|
||||||
h.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
|
hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
|
||||||
if s := h.Size(); s != tt.size {
|
if s := hsh.Size(); s != tt.size {
|
||||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||||
}
|
}
|
||||||
if b := h.BlockSize(); b != tt.blocksize {
|
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||||
}
|
}
|
||||||
for j := 0; j < 2; j++ {
|
for j := 0; j < 2; j++ {
|
||||||
n, err := h.Write(tt.in)
|
n, err := hsh.Write(tt.in)
|
||||||
if n != len(tt.in) || err != nil {
|
if n != len(tt.in) || err != nil {
|
||||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repetitive Sum() calls should return the same value
|
// Repetitive Sum() calls should return the same value
|
||||||
for k := 0; k < 2; k++ {
|
for k := 0; k < 2; k++ {
|
||||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||||
if sum != tt.out {
|
if sum != tt.out {
|
||||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second iteration: make sure reset works.
|
// Second iteration: make sure reset works.
|
||||||
h.Reset()
|
hsh.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl
|
func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop
|
||||||
for i, tt := range hmacTests() {
|
for i, tt := range hmacTests() {
|
||||||
if tt.blocksize != sha1.BlockSize || tt.size != sha1.Size {
|
if tt.blocksize != sha1.BlockSize || tt.size != sha1.Size {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
h := AcquireSHA1(tt.key)
|
hsh := AcquireSHA1(tt.key)
|
||||||
if s := h.Size(); s != tt.size {
|
if s := hsh.Size(); s != tt.size {
|
||||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||||
}
|
}
|
||||||
if b := h.BlockSize(); b != tt.blocksize {
|
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||||
}
|
}
|
||||||
for j := 0; j < 2; j++ {
|
for j := 0; j < 2; j++ {
|
||||||
n, err := h.Write(tt.in)
|
n, err := hsh.Write(tt.in)
|
||||||
if n != len(tt.in) || err != nil {
|
if n != len(tt.in) || err != nil {
|
||||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repetitive Sum() calls should return the same value
|
// Repetitive Sum() calls should return the same value
|
||||||
for k := 0; k < 2; k++ {
|
for k := 0; k < 2; k++ {
|
||||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||||
if sum != tt.out {
|
if sum != tt.out {
|
||||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second iteration: make sure reset works.
|
// Second iteration: make sure reset works.
|
||||||
h.Reset()
|
hsh.Reset()
|
||||||
}
|
}
|
||||||
PutSHA1(h)
|
PutSHA1(hsh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl
|
func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop
|
||||||
for i, tt := range hmacTests() {
|
for i, tt := range hmacTests() {
|
||||||
if tt.blocksize != sha256.BlockSize || tt.size != sha256.Size {
|
if tt.blocksize != sha256.BlockSize || tt.size != sha256.Size {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
h := AcquireSHA256(tt.key)
|
hsh := AcquireSHA256(tt.key)
|
||||||
if s := h.Size(); s != tt.size {
|
if s := hsh.Size(); s != tt.size {
|
||||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||||
}
|
}
|
||||||
if b := h.BlockSize(); b != tt.blocksize {
|
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||||
}
|
}
|
||||||
for j := 0; j < 2; j++ {
|
for j := 0; j < 2; j++ {
|
||||||
n, err := h.Write(tt.in)
|
n, err := hsh.Write(tt.in)
|
||||||
if n != len(tt.in) || err != nil {
|
if n != len(tt.in) || err != nil {
|
||||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repetitive Sum() calls should return the same value
|
// Repetitive Sum() calls should return the same value
|
||||||
for k := 0; k < 2; k++ {
|
for k := 0; k < 2; k++ {
|
||||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||||
if sum != tt.out {
|
if sum != tt.out {
|
||||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second iteration: make sure reset works.
|
// Second iteration: make sure reset works.
|
||||||
h.Reset()
|
hsh.Reset()
|
||||||
}
|
}
|
||||||
PutSHA256(h)
|
PutSHA256(hsh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -10,8 +10,11 @@ import (
|
|||||||
|
|
||||||
// ShouldNotAllocate fails if f allocates.
|
// ShouldNotAllocate fails if f allocates.
|
||||||
func ShouldNotAllocate(t *testing.T, f func()) {
|
func ShouldNotAllocate(t *testing.T, f func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
if Race {
|
if Race {
|
||||||
t.Skip("skip while running with -race")
|
t.Skip("skip while running with -race")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if a := testing.AllocsPerRun(10, f); a > 0 {
|
if a := testing.AllocsPerRun(10, f); a > 0 {
|
||||||
|
115
message.go
115
message.go
@@ -32,6 +32,7 @@ const (
|
|||||||
// as source.
|
// as source.
|
||||||
func NewTransactionID() (b [TransactionIDSize]byte) {
|
func NewTransactionID() (b [TransactionIDSize]byte) {
|
||||||
readFullOrPanic(rand.Reader, b[:])
|
readFullOrPanic(rand.Reader, b[:])
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ func IsMessage(b []byte) bool {
|
|||||||
// New returns *Message with pre-allocated Raw.
|
// New returns *Message with pre-allocated Raw.
|
||||||
func New() *Message {
|
func New() *Message {
|
||||||
const defaultRawCapacity = 120
|
const defaultRawCapacity = 120
|
||||||
|
|
||||||
return &Message{
|
return &Message{
|
||||||
Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
|
Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
|
||||||
}
|
}
|
||||||
@@ -59,6 +61,7 @@ func Decode(data []byte, m *Message) error {
|
|||||||
return ErrDecodeToNil
|
return ErrDecodeToNil
|
||||||
}
|
}
|
||||||
m.Raw = append(m.Raw[:0], data...)
|
m.Raw = append(m.Raw[:0], data...)
|
||||||
|
|
||||||
return m.Decode()
|
return m.Decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +85,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
|
|||||||
// contract induced by other implementations.
|
// contract induced by other implementations.
|
||||||
b := make([]byte, len(m.Raw))
|
b := make([]byte, len(m.Raw))
|
||||||
copy(b, m.Raw)
|
copy(b, m.Raw)
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +93,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
|
|||||||
func (m *Message) UnmarshalBinary(data []byte) error {
|
func (m *Message) UnmarshalBinary(data []byte) error {
|
||||||
// We can't retain data, copy is expected by interface contract.
|
// We can't retain data, copy is expected by interface contract.
|
||||||
m.Raw = append(m.Raw[:0], data...)
|
m.Raw = append(m.Raw[:0], data...)
|
||||||
|
|
||||||
return m.Decode()
|
return m.Decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,6 +113,7 @@ func (m *Message) GobDecode(data []byte) error {
|
|||||||
func (m *Message) AddTo(b *Message) error {
|
func (m *Message) AddTo(b *Message) error {
|
||||||
b.TransactionID = m.TransactionID
|
b.TransactionID = m.TransactionID
|
||||||
b.WriteTransactionID()
|
b.WriteTransactionID()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,6 +124,7 @@ func (m *Message) NewTransactionID() error {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
m.WriteTransactionID()
|
m.WriteTransactionID()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,6 +134,7 @@ func (m *Message) String() string {
|
|||||||
for k, a := range m.Attributes {
|
for k, a := range m.Attributes {
|
||||||
aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type)
|
aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo)
|
return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,6 +152,7 @@ func (m *Message) grow(n int) {
|
|||||||
}
|
}
|
||||||
if cap(m.Raw) >= n {
|
if cap(m.Raw) >= n {
|
||||||
m.Raw = m.Raw[:n]
|
m.Raw = m.Raw[:n]
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
|
m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
|
||||||
@@ -153,7 +162,7 @@ func (m *Message) grow(n int) {
|
|||||||
//
|
//
|
||||||
// Value of attribute is copied to internal buffer so
|
// Value of attribute is copied to internal buffer so
|
||||||
// it is safe to reuse v.
|
// it is safe to reuse v.
|
||||||
func (m *Message) Add(t AttrType, v []byte) {
|
func (m *Message) Add(attrType AttrType, val []byte) {
|
||||||
// Allocating buffer for TLV (type-length-value).
|
// Allocating buffer for TLV (type-length-value).
|
||||||
// T = t, L = len(v), V = v.
|
// T = t, L = len(v), V = v.
|
||||||
// m.Raw will look like:
|
// m.Raw will look like:
|
||||||
@@ -163,31 +172,33 @@ func (m *Message) Add(t AttrType, v []byte) {
|
|||||||
// [first:last] <- same as previous
|
// [first:last] <- same as previous
|
||||||
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
|
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
|
||||||
// T L V
|
// T L V
|
||||||
allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V)
|
allocSize := attributeHeaderSize + len(val) // ~ len(TLV) = len(TL) + len(V)
|
||||||
first := messageHeaderSize + int(m.Length) // first byte number
|
first := messageHeaderSize + int(m.Length) // first byte number
|
||||||
last := first + allocSize // last byte number
|
last := first + allocSize // last byte number
|
||||||
m.grow(last) // growing cap(Raw) to fit TLV
|
m.grow(last) // growing cap(Raw) to fit TLV
|
||||||
m.Raw = m.Raw[:last] // now len(Raw) = last
|
m.Raw = m.Raw[:last] // now len(Raw) = last
|
||||||
m.Length += uint32(allocSize) // rendering length change
|
//nolint:gosec // G115
|
||||||
|
m.Length += uint32(allocSize) // rendering length change
|
||||||
|
|
||||||
// Sub-slicing internal buffer to simplify encoding.
|
// Sub-slicing internal buffer to simplify encoding.
|
||||||
buf := m.Raw[first:last] // slice for TLV
|
buf := m.Raw[first:last] // slice for TLV
|
||||||
value := buf[attributeHeaderSize:] // slice for V
|
value := buf[attributeHeaderSize:] // slice for V
|
||||||
attr := RawAttribute{
|
attr := RawAttribute{
|
||||||
Type: t, // T
|
Type: attrType, // T
|
||||||
Length: uint16(len(v)), // L
|
//nolint:gosec // G115
|
||||||
Value: value, // V
|
Length: uint16(len(val)), // L
|
||||||
|
Value: value, // V
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encoding attribute TLV to allocated buffer.
|
// Encoding attribute TLV to allocated buffer.
|
||||||
bin.PutUint16(buf[0:2], attr.Type.Value()) // T
|
bin.PutUint16(buf[0:2], attr.Type.Value()) // T
|
||||||
bin.PutUint16(buf[2:4], attr.Length) // L
|
bin.PutUint16(buf[2:4], attr.Length) // L
|
||||||
copy(value, v) // V
|
copy(value, val) // V
|
||||||
|
|
||||||
// Checking that attribute value needs padding.
|
// Checking that attribute value needs padding.
|
||||||
if attr.Length%padding != 0 {
|
if attr.Length%padding != 0 {
|
||||||
// Performing padding.
|
// Performing padding.
|
||||||
bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
|
bytesToAdd := nearestPaddedValueLength(len(val)) - len(val)
|
||||||
last += bytesToAdd
|
last += bytesToAdd
|
||||||
m.grow(last)
|
m.grow(last)
|
||||||
// setting all padding bytes to zero
|
// setting all padding bytes to zero
|
||||||
@@ -197,7 +208,8 @@ func (m *Message) Add(t AttrType, v []byte) {
|
|||||||
for i := range buf {
|
for i := range buf {
|
||||||
buf[i] = 0
|
buf[i] = 0
|
||||||
}
|
}
|
||||||
m.Raw = m.Raw[:last] // increasing buffer length
|
m.Raw = m.Raw[:last] // increasing buffer length
|
||||||
|
//nolint:gosec // G115
|
||||||
m.Length += uint32(bytesToAdd) // rendering length change
|
m.Length += uint32(bytesToAdd) // rendering length change
|
||||||
}
|
}
|
||||||
m.Attributes = append(m.Attributes, attr)
|
m.Attributes = append(m.Attributes, attr)
|
||||||
@@ -213,6 +225,7 @@ func attrSliceEqual(a, b Attributes) bool {
|
|||||||
}
|
}
|
||||||
if attrB.Equal(attr) {
|
if attrB.Equal(attr) {
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -220,56 +233,59 @@ func attrSliceEqual(a, b Attributes) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func attrEqual(a, b Attributes) bool {
|
func attrEqual(attrA, attrB Attributes) bool {
|
||||||
if a == nil && b == nil {
|
if attrA == nil && attrB == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if a == nil || b == nil {
|
if attrA == nil || attrB == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(a) != len(b) {
|
if len(attrA) != len(attrB) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if !attrSliceEqual(a, b) {
|
if !attrSliceEqual(attrA, attrB) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if !attrSliceEqual(b, a) {
|
if !attrSliceEqual(attrB, attrA) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Equal returns true if Message b equals to m.
|
// Equal returns true if Message msg equals to m.
|
||||||
// Ignores m.Raw.
|
// Ignores m.Raw.
|
||||||
func (m *Message) Equal(b *Message) bool {
|
func (m *Message) Equal(msg *Message) bool {
|
||||||
if m == nil && b == nil {
|
if m == nil && msg == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if m == nil || b == nil {
|
if m == nil || msg == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if m.Type != b.Type {
|
if m.Type != msg.Type {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if m.TransactionID != b.TransactionID {
|
if m.TransactionID != msg.TransactionID {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if m.Length != b.Length {
|
if m.Length != msg.Length {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if !attrEqual(m.Attributes, b.Attributes) {
|
if !attrEqual(m.Attributes, msg.Attributes) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteLength writes m.Length to m.Raw.
|
// WriteLength writes m.Length to m.Raw.
|
||||||
func (m *Message) WriteLength() {
|
func (m *Message) WriteLength() {
|
||||||
m.grow(4)
|
m.grow(4)
|
||||||
bin.PutUint16(m.Raw[2:4], uint16(m.Length))
|
bin.PutUint16(m.Raw[2:4], uint16(m.Length)) //nolint:gosec // G115
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
|
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
|
||||||
@@ -322,6 +338,7 @@ func (m *Message) Encode() {
|
|||||||
// call result.
|
// call result.
|
||||||
func (m *Message) WriteTo(w io.Writer) (int64, error) {
|
func (m *Message) WriteTo(w io.Writer) (int64, error) {
|
||||||
n, err := w.Write(m.Raw)
|
n, err := w.Write(m.Raw)
|
||||||
|
|
||||||
return int64(n), err
|
return int64(n), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,6 +357,7 @@ func (m *Message) ReadFrom(r io.Reader) (int64, error) {
|
|||||||
return int64(n), err
|
return int64(n), err
|
||||||
}
|
}
|
||||||
m.Raw = tBuf[:n]
|
m.Raw = tBuf[:n]
|
||||||
|
|
||||||
return int64(n), m.Decode()
|
return int64(n), m.Decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,22 +373,24 @@ func (m *Message) Decode() error {
|
|||||||
return ErrUnexpectedHeaderEOF
|
return ErrUnexpectedHeaderEOF
|
||||||
}
|
}
|
||||||
var (
|
var (
|
||||||
t = bin.Uint16(buf[0:2]) // first 2 bytes
|
msgType = bin.Uint16(buf[0:2]) // first 2 bytes
|
||||||
size = int(bin.Uint16(buf[2:4])) // second 2 bytes
|
size = int(bin.Uint16(buf[2:4])) // second 2 bytes
|
||||||
cookie = bin.Uint32(buf[4:8]) // last 4 bytes
|
cookie = bin.Uint32(buf[4:8]) // last 4 bytes
|
||||||
fullSize = messageHeaderSize + size // len(m.Raw)
|
fullSize = messageHeaderSize + size // len(m.Raw)
|
||||||
)
|
)
|
||||||
if cookie != magicCookie {
|
if cookie != magicCookie {
|
||||||
msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
|
msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
|
||||||
|
|
||||||
return newDecodeErr("message", "cookie", msg)
|
return newDecodeErr("message", "cookie", msg)
|
||||||
}
|
}
|
||||||
if len(buf) < fullSize {
|
if len(buf) < fullSize {
|
||||||
msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
|
msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
|
||||||
|
|
||||||
return newAttrDecodeErr("message", msg)
|
return newAttrDecodeErr("message", msg)
|
||||||
}
|
}
|
||||||
// saving header data
|
// saving header data
|
||||||
m.Type.ReadValue(t)
|
m.Type.ReadValue(msgType)
|
||||||
m.Length = uint32(size)
|
m.Length = uint32(size) //nolint:gosec // G115
|
||||||
copy(m.TransactionID[:], buf[8:messageHeaderSize])
|
copy(m.TransactionID[:], buf[8:messageHeaderSize])
|
||||||
|
|
||||||
m.Attributes = m.Attributes[:0]
|
m.Attributes = m.Attributes[:0]
|
||||||
@@ -382,28 +402,31 @@ func (m *Message) Decode() error {
|
|||||||
// checking that we have enough bytes to read header
|
// checking that we have enough bytes to read header
|
||||||
if len(b) < attributeHeaderSize {
|
if len(b) < attributeHeaderSize {
|
||||||
msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
|
msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
|
||||||
|
|
||||||
return newAttrDecodeErr("header", msg)
|
return newAttrDecodeErr("header", msg)
|
||||||
}
|
}
|
||||||
var (
|
var (
|
||||||
a = RawAttribute{
|
attr = RawAttribute{
|
||||||
Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
|
Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
|
||||||
Length: bin.Uint16(b[2:4]), // second 2 bytes
|
Length: bin.Uint16(b[2:4]), // second 2 bytes
|
||||||
}
|
}
|
||||||
aL = int(a.Length) // attribute length
|
aL = int(attr.Length) // attribute length
|
||||||
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
|
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
|
||||||
)
|
)
|
||||||
b = b[attributeHeaderSize:] // slicing again to simplify value read
|
b = b[attributeHeaderSize:] // slicing again to simplify value read
|
||||||
offset += attributeHeaderSize
|
offset += attributeHeaderSize
|
||||||
if len(b) < aBuffL { // checking size
|
if len(b) < aBuffL { // checking size
|
||||||
msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type)
|
msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, attr.Type)
|
||||||
|
|
||||||
return newAttrDecodeErr("value", msg)
|
return newAttrDecodeErr("value", msg)
|
||||||
}
|
}
|
||||||
a.Value = b[:aL]
|
attr.Value = b[:aL]
|
||||||
offset += aBuffL
|
offset += aBuffL
|
||||||
b = b[aBuffL:]
|
b = b[aBuffL:]
|
||||||
|
|
||||||
m.Attributes = append(m.Attributes, a)
|
m.Attributes = append(m.Attributes, attr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -412,12 +435,14 @@ func (m *Message) Decode() error {
|
|||||||
// Any error is unrecoverable, but message could be partially decoded.
|
// Any error is unrecoverable, but message could be partially decoded.
|
||||||
func (m *Message) Write(tBuf []byte) (int, error) {
|
func (m *Message) Write(tBuf []byte) (int, error) {
|
||||||
m.Raw = append(m.Raw[:0], tBuf...)
|
m.Raw = append(m.Raw[:0], tBuf...)
|
||||||
|
|
||||||
return len(tBuf), m.Decode()
|
return len(tBuf), m.Decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloneTo clones m to b securing any further m mutations.
|
// CloneTo clones m to b securing any further m mutations.
|
||||||
func (m *Message) CloneTo(b *Message) error {
|
func (m *Message) CloneTo(b *Message) error {
|
||||||
b.Raw = append(b.Raw[:0], m.Raw...)
|
b.Raw = append(b.Raw[:0], m.Raw...)
|
||||||
|
|
||||||
return b.Decode()
|
return b.Decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -436,7 +461,7 @@ const (
|
|||||||
var (
|
var (
|
||||||
// Binding request message type.
|
// Binding request message type.
|
||||||
BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
|
BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
|
||||||
// Binding success response message type
|
// Binding success response message type.
|
||||||
BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
|
BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
|
||||||
// Binding error response message type.
|
// Binding error response message type.
|
||||||
BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
|
BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
|
||||||
@@ -501,6 +526,7 @@ func (m Method) String() string {
|
|||||||
// Falling back to hex representation.
|
// Falling back to hex representation.
|
||||||
s = fmt.Sprintf("0x%x", uint16(m))
|
s = fmt.Sprintf("0x%x", uint16(m))
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -513,6 +539,7 @@ type MessageType struct {
|
|||||||
// AddTo sets m type to t.
|
// AddTo sets m type to t.
|
||||||
func (t MessageType) AddTo(m *Message) error {
|
func (t MessageType) AddTo(m *Message) error {
|
||||||
m.SetType(t)
|
m.SetType(t)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -554,13 +581,13 @@ func (t MessageType) Value() uint16 {
|
|||||||
|
|
||||||
// Warning: Abandon all hope ye who enter here.
|
// Warning: Abandon all hope ye who enter here.
|
||||||
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
|
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
|
||||||
m := uint16(t.Method)
|
msg := uint16(t.Method)
|
||||||
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
|
a := msg & methodABits // A = M * 0b0000000000001111 (right 4 bits)
|
||||||
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
|
b := msg & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
|
||||||
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
|
d := msg & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
|
||||||
|
|
||||||
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
|
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
|
||||||
m = a + (b << methodBShift) + (d << methodDShift)
|
msg = a + (b << methodBShift) + (d << methodDShift)
|
||||||
|
|
||||||
// C0 is zero bit of C, C1 is first bit.
|
// C0 is zero bit of C, C1 is first bit.
|
||||||
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
|
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
|
||||||
@@ -573,7 +600,7 @@ func (t MessageType) Value() uint16 {
|
|||||||
c1 := (c & c1Bit) << classC1Shift
|
c1 := (c & c1Bit) << classC1Shift
|
||||||
class := c0 + c1
|
class := c0 + c1
|
||||||
|
|
||||||
return m + class
|
return msg + class
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadValue decodes uint16 into MessageType.
|
// ReadValue decodes uint16 into MessageType.
|
||||||
@@ -604,6 +631,7 @@ func (m *Message) Contains(t AttrType) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -618,5 +646,6 @@ func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
|
|||||||
func (t transactionIDValueSetter) AddTo(m *Message) error {
|
func (t transactionIDValueSetter) AddTo(m *Message) error {
|
||||||
m.TransactionID = t
|
m.TransactionID = t
|
||||||
m.WriteTransactionID()
|
m.WriteTransactionID()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
222
message_test.go
222
message_test.go
@@ -27,9 +27,11 @@ type attributeEncoder interface {
|
|||||||
AddTo(m *Message) error
|
AddTo(m *Message) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func addAttr(t testing.TB, m *Message, a attributeEncoder) {
|
func addAttr(tb testing.TB, m *Message, a attributeEncoder) {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
if err := a.AddTo(m); err != nil {
|
if err := a.AddTo(m); err != nil {
|
||||||
t.Error(err)
|
tb.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,21 +131,21 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_WriteTo(t *testing.T) {
|
func TestMessage_WriteTo(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
m.TransactionID = NewTransactionID()
|
msg.TransactionID = NewTransactionID()
|
||||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
if _, err := m.WriteTo(buf); err != nil {
|
if _, err := msg.WriteTo(buf); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
mDecoded := New()
|
mDecoded := New()
|
||||||
if _, err := mDecoded.ReadFrom(buf); err != nil {
|
if _, err := mDecoded.ReadFrom(buf); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !mDecoded.Equal(m) {
|
if !mDecoded.Equal(msg) {
|
||||||
t.Error(mDecoded, "!", m)
|
t.Error(mDecoded, "!", msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,23 +277,23 @@ func BenchmarkMessage_WriteTo(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMessage_ReadFrom(b *testing.B) {
|
func BenchmarkMessage_ReadFrom(b *testing.B) {
|
||||||
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
m := &Message{
|
msg := &Message{
|
||||||
Type: mType,
|
Type: mType,
|
||||||
Length: 0,
|
Length: 0,
|
||||||
TransactionID: [TransactionIDSize]byte{
|
TransactionID: [TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.SetBytes(int64(len(m.Raw)))
|
b.SetBytes(int64(len(msg.Raw)))
|
||||||
reader := m.reader()
|
reader := msg.reader()
|
||||||
mRec := New()
|
mRec := New()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if _, err := mRec.ReadFrom(reader); err != nil {
|
if _, err := mRec.ReadFrom(reader); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
reader.Reset(m.Raw)
|
reader.Reset(msg.Raw)
|
||||||
mRec.Reset()
|
mRec.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -342,7 +344,7 @@ func TestMessageClass_String(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAttrType_String(t *testing.T) {
|
func TestAttrType_String(t *testing.T) {
|
||||||
v := [...]AttrType{
|
attrType := [...]AttrType{
|
||||||
AttrMappedAddress,
|
AttrMappedAddress,
|
||||||
AttrUsername,
|
AttrUsername,
|
||||||
AttrErrorCode,
|
AttrErrorCode,
|
||||||
@@ -355,7 +357,7 @@ func TestAttrType_String(t *testing.T) {
|
|||||||
AttrAlternateServer,
|
AttrAlternateServer,
|
||||||
AttrFingerprint,
|
AttrFingerprint,
|
||||||
}
|
}
|
||||||
for _, k := range v {
|
for _, k := range attrType {
|
||||||
if k.String() == "" {
|
if k.String() == "" {
|
||||||
t.Error(k, "bad stringer")
|
t.Error(k, "bad stringer")
|
||||||
}
|
}
|
||||||
@@ -379,80 +381,80 @@ func TestMethod_String(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAttribute_Equal(t *testing.T) {
|
func TestAttribute_Equal(t *testing.T) {
|
||||||
a := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||||
b := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||||
if !a.Equal(b) {
|
if !attr1.Equal(attr2) {
|
||||||
t.Error("should equal")
|
t.Error("should equal")
|
||||||
}
|
}
|
||||||
if a.Equal(RawAttribute{Type: 0x2}) {
|
if attr1.Equal(RawAttribute{Type: 0x2}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
if a.Equal(RawAttribute{Length: 0x2}) {
|
if attr1.Equal(RawAttribute{Length: 0x2}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
if a.Equal(RawAttribute{Length: 0x3}) {
|
if attr1.Equal(RawAttribute{Length: 0x3}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
if a.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
|
if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_Equal(t *testing.T) {
|
func TestMessage_Equal(t *testing.T) { //nolint:cyclop
|
||||||
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||||
attrs := Attributes{attr}
|
attrs := Attributes{attr}
|
||||||
a := &Message{Attributes: attrs, Length: 4 + 2}
|
msg1 := &Message{Attributes: attrs, Length: 4 + 2}
|
||||||
b := &Message{Attributes: attrs, Length: 4 + 2}
|
msg2 := &Message{Attributes: attrs, Length: 4 + 2}
|
||||||
if !a.Equal(b) {
|
if !msg1.Equal(msg2) {
|
||||||
t.Error("should equal")
|
t.Error("should equal")
|
||||||
}
|
}
|
||||||
if a.Equal(&Message{Type: MessageType{Class: 128}}) {
|
if msg1.Equal(&Message{Type: MessageType{Class: 128}}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
tID := [TransactionIDSize]byte{
|
tID := [TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||||
}
|
}
|
||||||
if a.Equal(&Message{TransactionID: tID}) {
|
if msg1.Equal(&Message{TransactionID: tID}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
if a.Equal(&Message{Length: 3}) {
|
if msg1.Equal(&Message{Length: 3}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
tAttrs := Attributes{
|
tAttrs := Attributes{
|
||||||
{Length: 1, Value: []byte{0x1}, Type: 0x1},
|
{Length: 1, Value: []byte{0x1}, Type: 0x1},
|
||||||
}
|
}
|
||||||
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
tAttrs = Attributes{
|
tAttrs = Attributes{
|
||||||
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
|
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
|
||||||
}
|
}
|
||||||
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
if !(*Message)(nil).Equal(nil) {
|
if !(*Message)(nil).Equal(nil) {
|
||||||
t.Error("nil should be equal to nil")
|
t.Error("nil should be equal to nil")
|
||||||
}
|
}
|
||||||
if a.Equal(nil) {
|
if msg1.Equal(nil) {
|
||||||
t.Error("non-nil should not be equal to nil")
|
t.Error("non-nil should not be equal to nil")
|
||||||
}
|
}
|
||||||
t.Run("Nil attributes", func(t *testing.T) {
|
t.Run("Nil attributes", func(t *testing.T) {
|
||||||
a := &Message{
|
msg1 := &Message{
|
||||||
Attributes: nil,
|
Attributes: nil,
|
||||||
Length: 4 + 2,
|
Length: 4 + 2,
|
||||||
}
|
}
|
||||||
b := &Message{
|
msg2 := &Message{
|
||||||
Attributes: attrs,
|
Attributes: attrs,
|
||||||
Length: 4 + 2,
|
Length: 4 + 2,
|
||||||
}
|
}
|
||||||
if a.Equal(b) {
|
if msg1.Equal(msg2) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
if b.Equal(a) {
|
if msg2.Equal(msg1) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
b.Attributes = nil
|
msg2.Attributes = nil
|
||||||
if !a.Equal(b) {
|
if !msg1.Equal(msg2) {
|
||||||
t.Error("should equal")
|
t.Error("should equal")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -547,6 +549,8 @@ func BenchmarkIsMessage(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func loadData(tb testing.TB, name string) []byte {
|
func loadData(tb testing.TB, name string) []byte {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
name = filepath.Join("testdata", name)
|
name = filepath.Join("testdata", name)
|
||||||
f, err := os.Open(name) //nolint:gosec
|
f, err := os.Open(name) //nolint:gosec
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -561,6 +565,7 @@ func loadData(tb testing.TB, name string) []byte {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
tb.Fatal(err)
|
tb.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -582,7 +587,7 @@ func TestMessageFromBrowsers(t *testing.T) {
|
|||||||
t.Fatal("failed to skip header of csv: ", err)
|
t.Fatal("failed to skip header of csv: ", err)
|
||||||
}
|
}
|
||||||
crcTable := crc64.MakeTable(crc64.ISO)
|
crcTable := crc64.MakeTable(crc64.ISO)
|
||||||
m := New()
|
msg := New()
|
||||||
for {
|
for {
|
||||||
line, err := reader.Read()
|
line, err := reader.Read()
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) {
|
||||||
@@ -602,10 +607,10 @@ func TestMessageFromBrowsers(t *testing.T) {
|
|||||||
if b != crc64.Checksum(data, crcTable) {
|
if b != crc64.Checksum(data, crcTable) {
|
||||||
t.Error("crc64 check failed for ", line[1])
|
t.Error("crc64 check failed for ", line[1])
|
||||||
}
|
}
|
||||||
if _, err = m.Write(data); err != nil {
|
if _, err = msg.Write(data); err != nil {
|
||||||
t.Error("failed to decode ", line[1], " as message: ", err)
|
t.Error("failed to decode ", line[1], " as message: ", err)
|
||||||
}
|
}
|
||||||
m.Reset()
|
msg.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -622,42 +627,42 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkMessageFull(b *testing.B) {
|
func BenchmarkMessageFull(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
s := NewSoftware("software")
|
s := NewSoftware("software")
|
||||||
addr := &XORMappedAddress{
|
addr := &XORMappedAddress{
|
||||||
IP: net.IPv4(213, 1, 223, 5),
|
IP: net.IPv4(213, 1, 223, 5),
|
||||||
}
|
}
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := addr.AddTo(m); err != nil {
|
if err := addr.AddTo(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := s.AddTo(m); err != nil {
|
if err := s.AddTo(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
m.WriteAttributes()
|
msg.WriteAttributes()
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
Fingerprint.AddTo(m) //nolint:errcheck,gosec
|
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
m.Reset()
|
msg.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkMessageFullHardcore(b *testing.B) {
|
func BenchmarkMessageFullHardcore(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
s := NewSoftware("software")
|
s := NewSoftware("software")
|
||||||
addr := &XORMappedAddress{
|
addr := &XORMappedAddress{
|
||||||
IP: net.IPv4(213, 1, 223, 5),
|
IP: net.IPv4(213, 1, 223, 5),
|
||||||
}
|
}
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := addr.AddTo(m); err != nil {
|
if err := addr.AddTo(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := s.AddTo(m); err != nil {
|
if err := s.AddTo(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
m.Reset()
|
msg.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -689,8 +694,8 @@ func TestMessage_Contains(t *testing.T) {
|
|||||||
|
|
||||||
func ExampleMessage() {
|
func ExampleMessage() {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
m.Build(BindingRequest, //nolint:errcheck,gosec
|
msg.Build(BindingRequest, //nolint:errcheck,gosec
|
||||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||||
}),
|
}),
|
||||||
@@ -707,8 +712,8 @@ func ExampleMessage() {
|
|||||||
// m.Build(&software) // no allocations
|
// m.Build(&software) // no allocations
|
||||||
// If you pass software as value, there will be 1 allocation.
|
// If you pass software as value, there will be 1 allocation.
|
||||||
// This rule is correct for all setters.
|
// This rule is correct for all setters.
|
||||||
fmt.Println(m, "buff length:", len(m.Raw))
|
fmt.Println(msg, "buff length:", len(msg.Raw))
|
||||||
n, err := m.WriteTo(buf)
|
n, err := msg.WriteTo(buf)
|
||||||
fmt.Println("wrote", n, "err", err)
|
fmt.Println("wrote", n, "err", err)
|
||||||
|
|
||||||
// Decoding from buf new *Message.
|
// Decoding from buf new *Message.
|
||||||
@@ -743,6 +748,7 @@ func ExampleMessage() {
|
|||||||
fmt.Println("fingerprint: failed")
|
fmt.Println("fingerprint: failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:lll
|
||||||
// Output:
|
// Output:
|
||||||
// Binding request l=48 attrs=3 id=AQIDBAUGBwgJAAEA, attr0=SOFTWARE attr1=MESSAGE-INTEGRITY attr2=FINGERPRINT buff length: 68
|
// Binding request l=48 attrs=3 id=AQIDBAUGBwgJAAEA, attr0=SOFTWARE attr1=MESSAGE-INTEGRITY attr2=FINGERPRINT buff length: 68
|
||||||
// wrote 68 err <nil>
|
// wrote 68 err <nil>
|
||||||
@@ -811,8 +817,8 @@ func TestAllocationsGetters(t *testing.T) {
|
|||||||
NewShortTermIntegrity("pwd"),
|
NewShortTermIntegrity("pwd"),
|
||||||
Fingerprint,
|
Fingerprint,
|
||||||
}
|
}
|
||||||
m := New()
|
msg := New()
|
||||||
if err := m.Build(setters...); err != nil {
|
if err := msg.Build(setters...); err != nil {
|
||||||
t.Error("failed to build", err)
|
t.Error("failed to build", err)
|
||||||
}
|
}
|
||||||
getters := []Getter{
|
getters := []Getter{
|
||||||
@@ -826,7 +832,7 @@ func TestAllocationsGetters(t *testing.T) {
|
|||||||
g := g
|
g := g
|
||||||
i := i
|
i := i
|
||||||
allocs := testing.AllocsPerRun(10, func() {
|
allocs := testing.AllocsPerRun(10, func() {
|
||||||
if err := g.GetFrom(m); err != nil {
|
if err := g.GetFrom(msg); err != nil {
|
||||||
t.Errorf("[%d] failed to get", i)
|
t.Errorf("[%d] failed to get", i)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -837,8 +843,8 @@ func TestAllocationsGetters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMessageFullSize(t *testing.T) {
|
func TestMessageFullSize(t *testing.T) {
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
if err := m.Build(BindingRequest,
|
if err := msg.Build(BindingRequest,
|
||||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||||
}),
|
}),
|
||||||
@@ -848,18 +854,18 @@ func TestMessageFullSize(t *testing.T) {
|
|||||||
); err != nil {
|
); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
m.Raw = m.Raw[:len(m.Raw)-10]
|
msg.Raw = msg.Raw[:len(msg.Raw)-10]
|
||||||
|
|
||||||
decoder := new(Message)
|
decoder := new(Message)
|
||||||
decoder.Raw = m.Raw[:len(m.Raw)-10]
|
decoder.Raw = msg.Raw[:len(msg.Raw)-10]
|
||||||
if err := decoder.Decode(); err == nil {
|
if err := decoder.Decode(); err == nil {
|
||||||
t.Error("decode on truncated buffer should error")
|
t.Error("decode on truncated buffer should error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_CloneTo(t *testing.T) {
|
func TestMessage_CloneTo(t *testing.T) {
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
if err := m.Build(BindingRequest,
|
if err := msg.Build(BindingRequest,
|
||||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||||
}),
|
}),
|
||||||
@@ -869,29 +875,29 @@ func TestMessage_CloneTo(t *testing.T) {
|
|||||||
); err != nil {
|
); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
m.Encode()
|
msg.Encode()
|
||||||
b := new(Message)
|
msg2 := new(Message)
|
||||||
if err := m.CloneTo(b); err != nil {
|
if err := msg.CloneTo(msg2); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !b.Equal(m) {
|
if !msg2.Equal(msg) {
|
||||||
t.Fatal("not equal")
|
t.Fatal("not equal")
|
||||||
}
|
}
|
||||||
// Corrupting m and checking that b is not corrupted.
|
// Corrupting m and checking that b is not corrupted.
|
||||||
s, ok := b.Attributes.Get(AttrSoftware)
|
s, ok := msg2.Attributes.Get(AttrSoftware)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("no software attribute")
|
t.Fatal("no software attribute")
|
||||||
}
|
}
|
||||||
s.Value[0] = 'k'
|
s.Value[0] = 'k'
|
||||||
if b.Equal(m) {
|
if msg2.Equal(msg) {
|
||||||
t.Fatal("should not be equal")
|
t.Fatal("should not be equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkMessage_CloneTo(b *testing.B) {
|
func BenchmarkMessage_CloneTo(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
if err := m.Build(BindingRequest,
|
if err := msg.Build(BindingRequest,
|
||||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||||
}),
|
}),
|
||||||
@@ -901,19 +907,19 @@ func BenchmarkMessage_CloneTo(b *testing.B) {
|
|||||||
); err != nil {
|
); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
b.SetBytes(int64(len(m.Raw)))
|
b.SetBytes(int64(len(msg.Raw)))
|
||||||
a := new(Message)
|
a := new(Message)
|
||||||
m.CloneTo(a) //nolint:errcheck,gosec
|
msg.CloneTo(a) //nolint:errcheck,gosec
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := m.CloneTo(a); err != nil {
|
if err := msg.CloneTo(a); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_AddTo(t *testing.T) {
|
func TestMessage_AddTo(t *testing.T) {
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
if err := m.Build(BindingRequest,
|
if err := msg.Build(BindingRequest,
|
||||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||||
}),
|
}),
|
||||||
@@ -921,19 +927,19 @@ func TestMessage_AddTo(t *testing.T) {
|
|||||||
); err != nil {
|
); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
m.Encode()
|
msg.Encode()
|
||||||
b := new(Message)
|
b := new(Message)
|
||||||
if err := m.CloneTo(b); err != nil {
|
if err := msg.CloneTo(b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
m.TransactionID = [TransactionIDSize]byte{
|
msg.TransactionID = [TransactionIDSize]byte{
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
|
||||||
}
|
}
|
||||||
if b.Equal(m) {
|
if b.Equal(msg) {
|
||||||
t.Fatal("should not be equal")
|
t.Fatal("should not be equal")
|
||||||
}
|
}
|
||||||
m.AddTo(b) //nolint:errcheck,gosec
|
msg.AddTo(b) //nolint:errcheck,gosec
|
||||||
if !b.Equal(m) {
|
if !b.Equal(msg) {
|
||||||
t.Fatal("should be equal")
|
t.Fatal("should be equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -964,22 +970,22 @@ func TestDecode(t *testing.T) {
|
|||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
m := New()
|
msg := New()
|
||||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
m.TransactionID = NewTransactionID()
|
msg.TransactionID = NewTransactionID()
|
||||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
mDecoded := New()
|
mDecoded := New()
|
||||||
if err := Decode(m.Raw, mDecoded); err != nil {
|
if err := Decode(msg.Raw, mDecoded); err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
if !mDecoded.Equal(m) {
|
if !mDecoded.Equal(msg) {
|
||||||
t.Error("decoded result is not equal to encoded message")
|
t.Error("decoded result is not equal to encoded message")
|
||||||
}
|
}
|
||||||
t.Run("ZeroAlloc", func(t *testing.T) {
|
t.Run("ZeroAlloc", func(t *testing.T) {
|
||||||
allocs := testing.AllocsPerRun(10, func() {
|
allocs := testing.AllocsPerRun(10, func() {
|
||||||
mDecoded.Reset()
|
mDecoded.Reset()
|
||||||
if err := Decode(m.Raw, mDecoded); err != nil {
|
if err := Decode(msg.Raw, mDecoded); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -1008,22 +1014,22 @@ func BenchmarkDecode(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_MarshalBinary(t *testing.T) {
|
func TestMessage_MarshalBinary(t *testing.T) {
|
||||||
m := MustBuild(
|
msg := MustBuild(
|
||||||
NewSoftware("software"),
|
NewSoftware("software"),
|
||||||
&XORMappedAddress{
|
&XORMappedAddress{
|
||||||
IP: net.IPv4(213, 1, 223, 5),
|
IP: net.IPv4(213, 1, 223, 5),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
data, err := m.MarshalBinary()
|
data, err := msg.MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset m.Raw to check retention.
|
// Reset m.Raw to check retention.
|
||||||
for i := range m.Raw {
|
for i := range msg.Raw {
|
||||||
m.Raw[i] = 0
|
msg.Raw[i] = 0
|
||||||
}
|
}
|
||||||
if err := m.UnmarshalBinary(data); err != nil {
|
if err := msg.UnmarshalBinary(data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1031,28 +1037,28 @@ func TestMessage_MarshalBinary(t *testing.T) {
|
|||||||
for i := range data {
|
for i := range data {
|
||||||
data[i] = 0
|
data[i] = 0
|
||||||
}
|
}
|
||||||
if err := m.Decode(); err != nil {
|
if err := msg.Decode(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_GobDecode(t *testing.T) {
|
func TestMessage_GobDecode(t *testing.T) {
|
||||||
m := MustBuild(
|
msg := MustBuild(
|
||||||
NewSoftware("software"),
|
NewSoftware("software"),
|
||||||
&XORMappedAddress{
|
&XORMappedAddress{
|
||||||
IP: net.IPv4(213, 1, 223, 5),
|
IP: net.IPv4(213, 1, 223, 5),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
data, err := m.GobEncode()
|
data, err := msg.GobEncode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset m.Raw to check retention.
|
// Reset m.Raw to check retention.
|
||||||
for i := range m.Raw {
|
for i := range msg.Raw {
|
||||||
m.Raw[i] = 0
|
msg.Raw[i] = 0
|
||||||
}
|
}
|
||||||
if err := m.GobDecode(data); err != nil {
|
if err := msg.GobDecode(data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1060,7 +1066,7 @@ func TestMessage_GobDecode(t *testing.T) {
|
|||||||
for i := range data {
|
for i := range data {
|
||||||
data[i] = 0
|
data[i] = 0
|
||||||
}
|
}
|
||||||
if err := m.Decode(); err != nil {
|
if err := msg.Decode(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -8,7 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRFC5769(t *testing.T) {
|
func TestRFC5769(t *testing.T) { //nolint:cyclop
|
||||||
// Test Vectors for Session Traversal Utilities for NAT (STUN)
|
// Test Vectors for Session Traversal Utilities for NAT (STUN)
|
||||||
// see https://tools.ietf.org/html/rfc5769
|
// see https://tools.ietf.org/html/rfc5769
|
||||||
t.Run("Request", func(t *testing.T) {
|
t.Run("Request", func(t *testing.T) {
|
||||||
@@ -46,7 +46,7 @@ func TestRFC5769(t *testing.T) {
|
|||||||
t.Error("check failed: ", err)
|
t.Error("check failed: ", err)
|
||||||
}
|
}
|
||||||
t.Run("Long-Term credentials", func(t *testing.T) {
|
t.Run("Long-Term credentials", func(t *testing.T) {
|
||||||
m := &Message{
|
msg := &Message{
|
||||||
Raw: []byte("\x00\x01\x00\x60" +
|
Raw: []byte("\x00\x01\x00\x60" +
|
||||||
"\x21\x12\xa4\x42" +
|
"\x21\x12\xa4\x42" +
|
||||||
"\x78\xad\x34\x33\xc6\xad\x72\xc0\x29\xda\x41\x2e" +
|
"\x78\xad\x34\x33\xc6\xad\x72\xc0\x29\xda\x41\x2e" +
|
||||||
@@ -64,11 +64,11 @@ func TestRFC5769(t *testing.T) {
|
|||||||
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
|
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if err := m.Decode(); err != nil {
|
if err := msg.Decode(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
u := new(Username)
|
u := new(Username)
|
||||||
if err := u.GetFrom(m); err != nil {
|
if err := u.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
|
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
|
||||||
@@ -76,14 +76,14 @@ func TestRFC5769(t *testing.T) {
|
|||||||
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
|
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
|
||||||
}
|
}
|
||||||
n := new(Nonce)
|
n := new(Nonce)
|
||||||
if err := n.GetFrom(m); err != nil {
|
if err := n.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
|
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
|
||||||
t.Error("bad nonce")
|
t.Error("bad nonce")
|
||||||
}
|
}
|
||||||
r := new(Realm)
|
r := new(Realm)
|
||||||
if err := r.GetFrom(m); err != nil {
|
if err := r.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if r.String() != "example.org" { //nolint:goconst
|
if r.String() != "example.org" { //nolint:goconst
|
||||||
@@ -95,14 +95,14 @@ func TestRFC5769(t *testing.T) {
|
|||||||
"example.org",
|
"example.org",
|
||||||
"TheMatrIX",
|
"TheMatrIX",
|
||||||
)
|
)
|
||||||
if err := i.Check(m); err != nil {
|
if err := i.Check(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
t.Run("Response", func(t *testing.T) {
|
t.Run("Response", func(t *testing.T) {
|
||||||
t.Run("IPv4", func(t *testing.T) {
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
m := &Message{
|
msg := &Message{
|
||||||
Raw: []byte("\x01\x01\x00\x3c" +
|
Raw: []byte("\x01\x01\x00\x3c" +
|
||||||
"\x21\x12\xa4\x42" +
|
"\x21\x12\xa4\x42" +
|
||||||
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
|
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
|
||||||
@@ -117,21 +117,21 @@ func TestRFC5769(t *testing.T) {
|
|||||||
"\xc0\x7d\x4c\x96",
|
"\xc0\x7d\x4c\x96",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if err := m.Decode(); err != nil {
|
if err := msg.Decode(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
software := new(Software)
|
software := new(Software)
|
||||||
if err := software.GetFrom(m); err != nil {
|
if err := software.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if software.String() != "test vector" {
|
if software.String() != "test vector" {
|
||||||
t.Error("bad software: ", software)
|
t.Error("bad software: ", software)
|
||||||
}
|
}
|
||||||
if err := Fingerprint.Check(m); err != nil {
|
if err := Fingerprint.Check(msg); err != nil {
|
||||||
t.Error("Check failed: ", err)
|
t.Error("Check failed: ", err)
|
||||||
}
|
}
|
||||||
addr := new(XORMappedAddress)
|
addr := new(XORMappedAddress)
|
||||||
if err := addr.GetFrom(m); err != nil {
|
if err := addr.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) {
|
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) {
|
||||||
@@ -140,12 +140,12 @@ func TestRFC5769(t *testing.T) {
|
|||||||
if addr.Port != 32853 {
|
if addr.Port != 32853 {
|
||||||
t.Error("bad Port")
|
t.Error("bad Port")
|
||||||
}
|
}
|
||||||
if err := Fingerprint.Check(m); err != nil {
|
if err := Fingerprint.Check(msg); err != nil {
|
||||||
t.Error("check failed: ", err)
|
t.Error("check failed: ", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("IPv6", func(t *testing.T) {
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
m := &Message{
|
msg := &Message{
|
||||||
Raw: []byte("\x01\x01\x00\x48" +
|
Raw: []byte("\x01\x01\x00\x48" +
|
||||||
"\x21\x12\xa4\x42" +
|
"\x21\x12\xa4\x42" +
|
||||||
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
|
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
|
||||||
@@ -162,21 +162,21 @@ func TestRFC5769(t *testing.T) {
|
|||||||
"\xc8\xfb\x0b\x4c",
|
"\xc8\xfb\x0b\x4c",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if err := m.Decode(); err != nil {
|
if err := msg.Decode(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
software := new(Software)
|
software := new(Software)
|
||||||
if err := software.GetFrom(m); err != nil {
|
if err := software.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if software.String() != "test vector" {
|
if software.String() != "test vector" {
|
||||||
t.Error("bad software: ", software)
|
t.Error("bad software: ", software)
|
||||||
}
|
}
|
||||||
if err := Fingerprint.Check(m); err != nil {
|
if err := Fingerprint.Check(msg); err != nil {
|
||||||
t.Error("Check failed: ", err)
|
t.Error("Check failed: ", err)
|
||||||
}
|
}
|
||||||
addr := new(XORMappedAddress)
|
addr := new(XORMappedAddress)
|
||||||
if err := addr.GetFrom(m); err != nil {
|
if err := addr.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) {
|
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) {
|
||||||
@@ -185,7 +185,7 @@ func TestRFC5769(t *testing.T) {
|
|||||||
if addr.Port != 32853 {
|
if addr.Port != 32853 {
|
||||||
t.Error("bad Port")
|
t.Error("bad Port")
|
||||||
}
|
}
|
||||||
if err := Fingerprint.Check(m); err != nil {
|
if err := Fingerprint.Check(msg); err != nil {
|
||||||
t.Error("check failed: ", err)
|
t.Error("check failed: ", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
2
stun.go
2
stun.go
@@ -27,6 +27,7 @@ func readFullOrPanic(r io.Reader, v []byte) int {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err) //nolint
|
panic(err) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ func writeOrPanic(w io.Writer, v []byte) int {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err) //nolint
|
panic(err) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -13,14 +13,19 @@ import (
|
|||||||
|
|
||||||
var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
|
var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
|
||||||
|
|
||||||
// NewUDPServer creates an udp server for testing. The supplied handler function will be called with the request
|
// NewUDPServer creates an udp server for testing.
|
||||||
|
// The supplied handler function will be called with the request
|
||||||
// and should be used to emulate the server behavior.
|
// and should be used to emulate the server behavior.
|
||||||
|
//
|
||||||
|
//nolint:cyclop
|
||||||
func NewUDPServer(
|
func NewUDPServer(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
network string,
|
network string,
|
||||||
maxMessageSize int,
|
maxMessageSize int,
|
||||||
handler func(req []byte) ([]byte, error),
|
handler func(req []byte) ([]byte, error),
|
||||||
) (net.Addr, func(t *testing.T), error) {
|
) (net.Addr, func(t *testing.T), error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
var ip string
|
var ip string
|
||||||
switch network {
|
switch network {
|
||||||
case "udp4":
|
case "udp4":
|
||||||
@@ -50,28 +55,34 @@ func NewUDPServer(
|
|||||||
n, addr, err := udpConn.ReadFrom(bs)
|
n, addr, err := udpConn.ReadFrom(bs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := handler(bs[:n])
|
resp, err := handler(bs[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = udpConn.WriteTo(resp, addr)
|
_, err = udpConn.WriteTo(resp, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return serverAddr, func(t *testing.T) {
|
return serverAddr, func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err) //nolint
|
t.Fatal(err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -79,7 +90,7 @@ func NewUDPServer(
|
|||||||
|
|
||||||
err := udpConn.Close()
|
err := udpConn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err) //nolint
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
<-errCh
|
<-errCh
|
||||||
|
10
textattrs.go
10
textattrs.go
@@ -10,7 +10,7 @@ func NewUsername(username string) Username {
|
|||||||
|
|
||||||
// Username represents USERNAME attribute.
|
// Username represents USERNAME attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.3
|
// RFC 5389 Section 15.3.
|
||||||
type Username []byte
|
type Username []byte
|
||||||
|
|
||||||
func (u Username) String() string {
|
func (u Username) String() string {
|
||||||
@@ -37,7 +37,7 @@ func NewRealm(realm string) Realm {
|
|||||||
|
|
||||||
// Realm represents REALM attribute.
|
// Realm represents REALM attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.7
|
// RFC 5389 Section 15.7.
|
||||||
type Realm []byte
|
type Realm []byte
|
||||||
|
|
||||||
func (n Realm) String() string {
|
func (n Realm) String() string {
|
||||||
@@ -60,7 +60,7 @@ const softwareRawMaxB = 763
|
|||||||
|
|
||||||
// Software is SOFTWARE attribute.
|
// Software is SOFTWARE attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.10
|
// RFC 5389 Section 15.10.
|
||||||
type Software []byte
|
type Software []byte
|
||||||
|
|
||||||
func (s Software) String() string {
|
func (s Software) String() string {
|
||||||
@@ -84,7 +84,7 @@ func (s *Software) GetFrom(m *Message) error {
|
|||||||
|
|
||||||
// Nonce represents NONCE attribute.
|
// Nonce represents NONCE attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.8
|
// RFC 5389 Section 15.8.
|
||||||
type Nonce []byte
|
type Nonce []byte
|
||||||
|
|
||||||
// NewNonce returns new Nonce from string.
|
// NewNonce returns new Nonce from string.
|
||||||
@@ -118,6 +118,7 @@ func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.Add(t, v)
|
m.Add(t, v)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,5 +129,6 @@ func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*v = a
|
*v = a
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -13,26 +13,26 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSoftware_GetFrom(t *testing.T) {
|
func TestSoftware_GetFrom(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
v := "Client v0.0.1"
|
val := "Client v0.0.1"
|
||||||
m.Add(AttrSoftware, []byte(v))
|
msg.Add(AttrSoftware, []byte(val))
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
|
|
||||||
m2 := &Message{
|
m2 := &Message{
|
||||||
Raw: make([]byte, 0, 256),
|
Raw: make([]byte, 0, 256),
|
||||||
}
|
}
|
||||||
software := new(Software)
|
software := new(Software)
|
||||||
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := software.GetFrom(m); err != nil {
|
if err := software.GetFrom(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if software.String() != v {
|
if software.String() != val {
|
||||||
t.Errorf("Expected %q, got %q.", v, software)
|
t.Errorf("Expected %q, got %q.", val, software)
|
||||||
}
|
}
|
||||||
|
|
||||||
sAttr, ok := m.Attributes.Get(AttrSoftware)
|
sAttr, ok := msg.Attributes.Get(AttrSoftware)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("software attribute should be found")
|
t.Error("software attribute should be found")
|
||||||
}
|
}
|
||||||
@@ -90,22 +90,22 @@ func BenchmarkUsername_GetFrom(b *testing.B) {
|
|||||||
|
|
||||||
func TestUsername(t *testing.T) {
|
func TestUsername(t *testing.T) {
|
||||||
username := "username"
|
username := "username"
|
||||||
u := NewUsername(username)
|
uName := NewUsername(username)
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
t.Run("Bad length", func(t *testing.T) {
|
t.Run("Bad length", func(t *testing.T) {
|
||||||
badU := make(Username, 600)
|
badU := make(Username, 600)
|
||||||
if err := badU.AddTo(m); !IsAttrSizeOverflow(err) {
|
if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) {
|
||||||
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("AddTo", func(t *testing.T) {
|
t.Run("AddTo", func(t *testing.T) {
|
||||||
if err := u.AddTo(m); err != nil {
|
if err := uName.AddTo(msg); err != nil {
|
||||||
t.Error("errored:", err)
|
t.Error("errored:", err)
|
||||||
}
|
}
|
||||||
t.Run("GetFrom", func(t *testing.T) {
|
t.Run("GetFrom", func(t *testing.T) {
|
||||||
got := new(Username)
|
got := new(Username)
|
||||||
if err := got.GetFrom(m); err != nil {
|
if err := got.GetFrom(msg); err != nil {
|
||||||
t.Error("errored:", err)
|
t.Error("errored:", err)
|
||||||
}
|
}
|
||||||
if got.String() != username {
|
if got.String() != username {
|
||||||
@@ -136,10 +136,10 @@ func TestUsername(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRealm_GetFrom(t *testing.T) {
|
func TestRealm_GetFrom(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
v := "realm"
|
val := "realm"
|
||||||
m.Add(AttrRealm, []byte(v))
|
msg.Add(AttrRealm, []byte(val))
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
|
|
||||||
m2 := &Message{
|
m2 := &Message{
|
||||||
Raw: make([]byte, 0, 256),
|
Raw: make([]byte, 0, 256),
|
||||||
@@ -148,17 +148,17 @@ func TestRealm_GetFrom(t *testing.T) {
|
|||||||
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) {
|
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) {
|
||||||
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
|
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
|
||||||
}
|
}
|
||||||
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := r.GetFrom(m); err != nil {
|
if err := r.GetFrom(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if r.String() != v {
|
if r.String() != val {
|
||||||
t.Errorf("Expected %q, got %q.", v, r)
|
t.Errorf("Expected %q, got %q.", val, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
rAttr, ok := m.Attributes.Get(AttrRealm)
|
rAttr, ok := msg.Attributes.Get(AttrRealm)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("realm attribute should be found")
|
t.Error("realm attribute should be found")
|
||||||
}
|
}
|
||||||
@@ -180,26 +180,26 @@ func TestRealm_AddTo_Invalid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNonce_GetFrom(t *testing.T) {
|
func TestNonce_GetFrom(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
v := "example.org"
|
val := "example.org"
|
||||||
m.Add(AttrNonce, []byte(v))
|
msg.Add(AttrNonce, []byte(val))
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
|
|
||||||
m2 := &Message{
|
m2 := &Message{
|
||||||
Raw: make([]byte, 0, 256),
|
Raw: make([]byte, 0, 256),
|
||||||
}
|
}
|
||||||
var nonce Nonce
|
var nonce Nonce
|
||||||
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if err := nonce.GetFrom(m); err != nil {
|
if err := nonce.GetFrom(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if nonce.String() != v {
|
if nonce.String() != val {
|
||||||
t.Errorf("Expected %q, got %q.", v, nonce)
|
t.Errorf("Expected %q, got %q.", val, nonce)
|
||||||
}
|
}
|
||||||
|
|
||||||
nAttr, ok := m.Attributes.Get(AttrNonce)
|
nAttr, ok := msg.Attributes.Get(AttrNonce)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("nonce attribute should be found")
|
t.Error("nonce attribute should be found")
|
||||||
}
|
}
|
||||||
|
@@ -7,7 +7,7 @@ import "errors"
|
|||||||
|
|
||||||
// UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute.
|
// UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.9
|
// RFC 5389 Section 15.9.
|
||||||
type UnknownAttributes []AttrType
|
type UnknownAttributes []AttrType
|
||||||
|
|
||||||
func (a UnknownAttributes) String() string {
|
func (a UnknownAttributes) String() string {
|
||||||
@@ -22,6 +22,7 @@ func (a UnknownAttributes) String() string {
|
|||||||
s += ", "
|
s += ", "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,6 +40,7 @@ func (a UnknownAttributes) AddTo(m *Message) error {
|
|||||||
bin.PutUint16(v[first:last], t.Value())
|
bin.PutUint16(v[first:last], t.Value())
|
||||||
}
|
}
|
||||||
m.Add(AttrUnknownAttributes, v)
|
m.Add(AttrUnknownAttributes, v)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,5 +64,6 @@ func (a *UnknownAttributes) GetFrom(m *Message) error {
|
|||||||
*a = append(*a, AttrType(bin.Uint16(v[first:last])))
|
*a = append(*a, AttrType(bin.Uint16(v[first:last])))
|
||||||
first = last
|
first = last
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -8,26 +8,26 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestUnknownAttributes(t *testing.T) {
|
func TestUnknownAttributes(t *testing.T) {
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
a := &UnknownAttributes{
|
attr := &UnknownAttributes{
|
||||||
AttrDontFragment,
|
AttrDontFragment,
|
||||||
AttrChannelNumber,
|
AttrChannelNumber,
|
||||||
}
|
}
|
||||||
if a.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
|
if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
|
||||||
t.Error("bad String:", a)
|
t.Error("bad String:", attr)
|
||||||
}
|
}
|
||||||
if (UnknownAttributes{}).String() != "<nil>" {
|
if (UnknownAttributes{}).String() != "<nil>" {
|
||||||
t.Error("bad blank string")
|
t.Error("bad blank string")
|
||||||
}
|
}
|
||||||
if err := a.AddTo(m); err != nil {
|
if err := attr.AddTo(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
t.Run("GetFrom", func(t *testing.T) {
|
t.Run("GetFrom", func(t *testing.T) {
|
||||||
attrs := make(UnknownAttributes, 10)
|
attrs := make(UnknownAttributes, 10)
|
||||||
if err := attrs.GetFrom(m); err != nil {
|
if err := attrs.GetFrom(msg); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
for i, at := range *a {
|
for i, at := range *attr {
|
||||||
if at != attrs[i] {
|
if at != attrs[i] {
|
||||||
t.Error("expected", at, "!=", attrs[i])
|
t.Error("expected", at, "!=", attrs[i])
|
||||||
}
|
}
|
||||||
@@ -44,8 +44,8 @@ func TestUnknownAttributes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkUnknownAttributes(b *testing.B) {
|
func BenchmarkUnknownAttributes(b *testing.B) {
|
||||||
m := new(Message)
|
msg := new(Message)
|
||||||
a := UnknownAttributes{
|
attr := UnknownAttributes{
|
||||||
AttrDontFragment,
|
AttrDontFragment,
|
||||||
AttrChannelNumber,
|
AttrChannelNumber,
|
||||||
AttrRealm,
|
AttrRealm,
|
||||||
@@ -54,20 +54,20 @@ func BenchmarkUnknownAttributes(b *testing.B) {
|
|||||||
b.Run("AddTo", func(b *testing.B) {
|
b.Run("AddTo", func(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := a.AddTo(m); err != nil {
|
if err := attr.AddTo(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
m.Reset()
|
msg.Reset()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("GetFrom", func(b *testing.B) {
|
b.Run("GetFrom", func(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
if err := a.AddTo(m); err != nil {
|
if err := attr.AddTo(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
attrs := make(UnknownAttributes, 0, 10)
|
attrs := make(UnknownAttributes, 0, 10)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := attrs.GetFrom(m); err != nil {
|
if err := attrs.GetFrom(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
attrs = attrs[:0]
|
attrs = attrs[:0]
|
||||||
|
47
uri.go
47
uri.go
@@ -124,7 +124,7 @@ func (t ProtoType) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// URI represents a STUN (rfc7064) or TURN (rfc7065) URI
|
// URI represents a STUN (rfc7064) or TURN (rfc7065) URI.
|
||||||
type URI struct {
|
type URI struct {
|
||||||
Scheme SchemeType
|
Scheme SchemeType
|
||||||
Host string
|
Host string
|
||||||
@@ -137,73 +137,76 @@ type URI struct {
|
|||||||
// ParseURI parses a STUN or TURN urls following the ABNF syntax described in
|
// ParseURI parses a STUN or TURN urls following the ABNF syntax described in
|
||||||
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
|
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
|
||||||
// respectively.
|
// respectively.
|
||||||
func ParseURI(raw string) (*URI, error) { //nolint:gocognit
|
func ParseURI(raw string) (*URI, error) { //nolint:gocognit,cyclop
|
||||||
rawParts, err := url.Parse(raw)
|
rawParts, err := url.Parse(raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var u URI
|
var uri URI
|
||||||
u.Scheme = NewSchemeType(rawParts.Scheme)
|
uri.Scheme = NewSchemeType(rawParts.Scheme)
|
||||||
if u.Scheme == SchemeTypeUnknown {
|
if uri.Scheme == SchemeTypeUnknown {
|
||||||
return nil, ErrSchemeType
|
return nil, ErrSchemeType
|
||||||
}
|
}
|
||||||
|
|
||||||
var rawPort string
|
var rawPort string
|
||||||
if u.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil {
|
if uri.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { //nolint:nestif
|
||||||
var e *net.AddrError
|
var e *net.AddrError
|
||||||
if errors.As(err, &e) {
|
if errors.As(err, &e) {
|
||||||
if e.Err == "missing port in address" {
|
if e.Err == "missing port in address" {
|
||||||
nextRawURL := u.Scheme.String() + ":" + rawParts.Opaque
|
nextRawURL := uri.Scheme.String() + ":" + rawParts.Opaque
|
||||||
switch {
|
switch {
|
||||||
case u.Scheme == SchemeTypeSTUN || u.Scheme == SchemeTypeTURN:
|
case uri.Scheme == SchemeTypeSTUN || uri.Scheme == SchemeTypeTURN:
|
||||||
nextRawURL += ":3478"
|
nextRawURL += ":3478"
|
||||||
if rawParts.RawQuery != "" {
|
if rawParts.RawQuery != "" {
|
||||||
nextRawURL += "?" + rawParts.RawQuery
|
nextRawURL += "?" + rawParts.RawQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
return ParseURI(nextRawURL)
|
return ParseURI(nextRawURL)
|
||||||
case u.Scheme == SchemeTypeSTUNS || u.Scheme == SchemeTypeTURNS:
|
case uri.Scheme == SchemeTypeSTUNS || uri.Scheme == SchemeTypeTURNS:
|
||||||
nextRawURL += ":5349"
|
nextRawURL += ":5349"
|
||||||
if rawParts.RawQuery != "" {
|
if rawParts.RawQuery != "" {
|
||||||
nextRawURL += "?" + rawParts.RawQuery
|
nextRawURL += "?" + rawParts.RawQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
return ParseURI(nextRawURL)
|
return ParseURI(nextRawURL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.Host == "" {
|
if uri.Host == "" {
|
||||||
return nil, ErrHost
|
return nil, ErrHost
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.Port, err = strconv.Atoi(rawPort); err != nil {
|
if uri.Port, err = strconv.Atoi(rawPort); err != nil {
|
||||||
return nil, ErrPort
|
return nil, ErrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
switch u.Scheme {
|
switch uri.Scheme {
|
||||||
case SchemeTypeSTUN:
|
case SchemeTypeSTUN:
|
||||||
qArgs, err := url.ParseQuery(rawParts.RawQuery)
|
qArgs, err := url.ParseQuery(rawParts.RawQuery)
|
||||||
if err != nil || len(qArgs) > 0 {
|
if err != nil || len(qArgs) > 0 {
|
||||||
return nil, ErrSTUNQuery
|
return nil, ErrSTUNQuery
|
||||||
}
|
}
|
||||||
u.Proto = ProtoTypeUDP
|
uri.Proto = ProtoTypeUDP
|
||||||
case SchemeTypeSTUNS:
|
case SchemeTypeSTUNS:
|
||||||
qArgs, err := url.ParseQuery(rawParts.RawQuery)
|
qArgs, err := url.ParseQuery(rawParts.RawQuery)
|
||||||
if err != nil || len(qArgs) > 0 {
|
if err != nil || len(qArgs) > 0 {
|
||||||
return nil, ErrSTUNQuery
|
return nil, ErrSTUNQuery
|
||||||
}
|
}
|
||||||
u.Proto = ProtoTypeTCP
|
uri.Proto = ProtoTypeTCP
|
||||||
case SchemeTypeTURN:
|
case SchemeTypeTURN:
|
||||||
proto, err := parseProto(rawParts.RawQuery)
|
proto, err := parseProto(rawParts.RawQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.Proto = proto
|
uri.Proto = proto
|
||||||
if u.Proto == ProtoTypeUnknown {
|
if uri.Proto == ProtoTypeUnknown {
|
||||||
u.Proto = ProtoTypeUDP
|
uri.Proto = ProtoTypeUDP
|
||||||
}
|
}
|
||||||
case SchemeTypeTURNS:
|
case SchemeTypeTURNS:
|
||||||
proto, err := parseProto(rawParts.RawQuery)
|
proto, err := parseProto(rawParts.RawQuery)
|
||||||
@@ -211,15 +214,15 @@ func ParseURI(raw string) (*URI, error) { //nolint:gocognit
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.Proto = proto
|
uri.Proto = proto
|
||||||
if u.Proto == ProtoTypeUnknown {
|
if uri.Proto == ProtoTypeUnknown {
|
||||||
u.Proto = ProtoTypeTCP
|
uri.Proto = ProtoTypeTCP
|
||||||
}
|
}
|
||||||
|
|
||||||
case SchemeTypeUnknown:
|
case SchemeTypeUnknown:
|
||||||
}
|
}
|
||||||
|
|
||||||
return &u, nil
|
return &uri, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseProto(raw string) (ProtoType, error) {
|
func parseProto(raw string) (ProtoType, error) {
|
||||||
@@ -233,6 +236,7 @@ func parseProto(raw string) (ProtoType, error) {
|
|||||||
if proto = NewProtoType(rawProto); proto == ProtoType(0) {
|
if proto = NewProtoType(rawProto); proto == ProtoType(0) {
|
||||||
return ProtoTypeUnknown, ErrProtoType
|
return ProtoTypeUnknown, ErrProtoType
|
||||||
}
|
}
|
||||||
|
|
||||||
return proto, nil
|
return proto, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,6 +252,7 @@ func (u URI) String() string {
|
|||||||
if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS {
|
if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS {
|
||||||
rawURL += "?transport=" + u.Proto.String()
|
rawURL += "?transport=" + u.Proto.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawURL
|
return rawURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
12
uri_test.go
12
uri_test.go
@@ -35,8 +35,16 @@ func TestParseURL(t *testing.T) {
|
|||||||
{"stun:[::1]:123", "stun:[::1]:123", SchemeTypeSTUN, false, "::1", 123, ProtoTypeUDP},
|
{"stun:[::1]:123", "stun:[::1]:123", SchemeTypeSTUN, false, "::1", 123, ProtoTypeUDP},
|
||||||
{"turn:google.de", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
|
{"turn:google.de", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
|
||||||
{"turns:google.de", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
|
{"turns:google.de", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
|
||||||
{"turn:google.de?transport=udp", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
|
{
|
||||||
{"turns:google.de?transport=tcp", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
|
"turn:google.de?transport=udp",
|
||||||
|
"turn:google.de:3478?transport=udp",
|
||||||
|
SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"turns:google.de?transport=tcp",
|
||||||
|
"turns:google.de:5349?transport=tcp",
|
||||||
|
SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, testCase := range testCases {
|
for i, testCase := range testCases {
|
||||||
|
33
xoraddr.go
33
xoraddr.go
@@ -20,7 +20,7 @@ const (
|
|||||||
|
|
||||||
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
|
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
|
||||||
//
|
//
|
||||||
// RFC 5389 Section 15.2
|
// RFC 5389 Section 15.2.
|
||||||
type XORMappedAddress struct {
|
type XORMappedAddress struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Port int
|
Port int
|
||||||
@@ -43,14 +43,15 @@ func isZeros(p net.IP) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
|
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
|
||||||
var ErrBadIPLength = errors.New("invalid length of IP value")
|
var ErrBadIPLength = errors.New("invalid length of IP value")
|
||||||
|
|
||||||
// AddToAs adds XOR-MAPPED-ADDRESS value to m as t attribute.
|
// AddToAs adds XOR-MAPPED-ADDRESS value to msg as attr attribute.
|
||||||
func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error {
|
func (a XORMappedAddress) AddToAs(msg *Message, attr AttrType) error {
|
||||||
var (
|
var (
|
||||||
family = familyIPv4
|
family = familyIPv4
|
||||||
ip = a.IP
|
ip = a.IP
|
||||||
@@ -67,12 +68,13 @@ func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error {
|
|||||||
value := make([]byte, 32+128)
|
value := make([]byte, 32+128)
|
||||||
value[0] = 0 // first 8 bits are zeroes
|
value[0] = 0 // first 8 bits are zeroes
|
||||||
xorValue := make([]byte, net.IPv6len)
|
xorValue := make([]byte, net.IPv6len)
|
||||||
copy(xorValue[4:], m.TransactionID[:])
|
copy(xorValue[4:], msg.TransactionID[:])
|
||||||
bin.PutUint32(xorValue[0:4], magicCookie)
|
bin.PutUint32(xorValue[0:4], magicCookie)
|
||||||
bin.PutUint16(value[0:2], family)
|
bin.PutUint16(value[0:2], family)
|
||||||
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
|
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) //nolint:gosec // G115, false positive, port
|
||||||
xor.XorBytes(value[4:4+len(ip)], ip, xorValue)
|
xor.XorBytes(value[4:4+len(ip)], ip, xorValue)
|
||||||
m.Add(t, value[:4+len(ip)])
|
msg.Add(attr, value[:4+len(ip)])
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,13 +85,13 @@ func (a XORMappedAddress) AddTo(m *Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message
|
// GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message
|
||||||
// getting it as for t type.
|
// getting it as for attr type.
|
||||||
func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
|
func (a *XORMappedAddress) GetFromAs(msg *Message, attr AttrType) error {
|
||||||
v, err := m.Get(t)
|
value, err := msg.Get(attr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
family := bin.Uint16(v[0:2])
|
family := bin.Uint16(value[0:2])
|
||||||
if family != familyIPv6 && family != familyIPv4 {
|
if family != familyIPv6 && family != familyIPv4 {
|
||||||
return newDecodeErr("xor-mapped address", "family",
|
return newDecodeErr("xor-mapped address", "family",
|
||||||
fmt.Sprintf("bad value %d", family),
|
fmt.Sprintf("bad value %d", family),
|
||||||
@@ -110,17 +112,18 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
|
|||||||
for i := range a.IP {
|
for i := range a.IP {
|
||||||
a.IP[i] = 0
|
a.IP[i] = 0
|
||||||
}
|
}
|
||||||
if len(v) <= 4 {
|
if len(value) <= 4 {
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
if err := CheckOverflow(t, len(v[4:]), len(a.IP)); err != nil {
|
if err := CheckOverflow(attr, len(value[4:]), len(a.IP)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
|
a.Port = int(bin.Uint16(value[2:4])) ^ (magicCookie >> 16)
|
||||||
xorValue := make([]byte, 4+TransactionIDSize)
|
xorValue := make([]byte, 4+TransactionIDSize)
|
||||||
bin.PutUint32(xorValue[0:4], magicCookie)
|
bin.PutUint32(xorValue[0:4], magicCookie)
|
||||||
copy(xorValue[4:], m.TransactionID[:])
|
copy(xorValue[4:], msg.TransactionID[:])
|
||||||
xor.XorBytes(a.IP, v[4:], xorValue)
|
xor.XorBytes(a.IP, value[4:], xorValue)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -29,21 +29,21 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
|
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
|
||||||
m := New()
|
msg := New()
|
||||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Error(err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
copy(m.TransactionID[:], transactionID)
|
copy(msg.TransactionID[:], transactionID)
|
||||||
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
|
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Error(err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
m.Add(AttrXORMappedAddress, addrValue)
|
msg.Add(AttrXORMappedAddress, addrValue)
|
||||||
addr := new(XORMappedAddress)
|
addr := new(XORMappedAddress)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
if err := addr.GetFrom(m); err != nil {
|
if err := addr.GetFrom(msg); err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -94,54 +94,54 @@ func TestXORMappedAddress_GetFrom(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
|
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
copy(m.TransactionID[:], transactionID)
|
copy(msg.TransactionID[:], transactionID)
|
||||||
expectedIP := net.ParseIP("213.141.156.236")
|
expectedIP := net.ParseIP("213.141.156.236")
|
||||||
expectedPort := 21254
|
expectedPort := 21254
|
||||||
addr := new(XORMappedAddress)
|
addr := new(XORMappedAddress)
|
||||||
|
|
||||||
if err = addr.GetFrom(m); err == nil {
|
if err = addr.GetFrom(msg); err == nil {
|
||||||
t.Fatal(err, "should be nil")
|
t.Fatal(err, "should be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
addr.IP = expectedIP
|
addr.IP = expectedIP
|
||||||
addr.Port = expectedPort
|
addr.Port = expectedPort
|
||||||
addr.AddTo(m) //nolint:errcheck,gosec
|
addr.AddTo(msg) //nolint:errcheck,gosec
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
|
|
||||||
mRes := New()
|
mRes := New()
|
||||||
binary.BigEndian.PutUint16(m.Raw[20+4:20+4+2], 0x21)
|
binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21)
|
||||||
if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
|
if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err = addr.GetFrom(m); err == nil {
|
if err = addr.GetFrom(msg); err == nil {
|
||||||
t.Fatal(err, "should not be nil")
|
t.Fatal(err, "should not be nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestXORMappedAddress_AddTo(t *testing.T) {
|
func TestXORMappedAddress_AddTo(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
copy(m.TransactionID[:], transactionID)
|
copy(msg.TransactionID[:], transactionID)
|
||||||
expectedIP := net.ParseIP("213.141.156.236")
|
expectedIP := net.ParseIP("213.141.156.236")
|
||||||
expectedPort := 21254
|
expectedPort := 21254
|
||||||
addr := &XORMappedAddress{
|
addr := &XORMappedAddress{
|
||||||
IP: net.ParseIP("213.141.156.236"),
|
IP: net.ParseIP("213.141.156.236"),
|
||||||
Port: expectedPort,
|
Port: expectedPort,
|
||||||
}
|
}
|
||||||
if err = addr.AddTo(m); err != nil {
|
if err = addr.AddTo(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
mRes := New()
|
mRes := New()
|
||||||
if _, err = mRes.Write(m.Raw); err != nil {
|
if _, err = mRes.Write(msg.Raw); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err = addr.GetFrom(mRes); err != nil {
|
if err = addr.GetFrom(mRes); err != nil {
|
||||||
@@ -156,27 +156,27 @@ func TestXORMappedAddress_AddTo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
|
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
|
||||||
m := New()
|
msg := New()
|
||||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
copy(m.TransactionID[:], transactionID)
|
copy(msg.TransactionID[:], transactionID)
|
||||||
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
|
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
|
||||||
expectedPort := 21254
|
expectedPort := 21254
|
||||||
addr := &XORMappedAddress{
|
addr := &XORMappedAddress{
|
||||||
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
|
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
|
||||||
Port: 21254,
|
Port: 21254,
|
||||||
}
|
}
|
||||||
addr.AddTo(m) //nolint:errcheck,gosec
|
addr.AddTo(msg) //nolint:errcheck,gosec
|
||||||
m.WriteHeader()
|
msg.WriteHeader()
|
||||||
|
|
||||||
mRes := New()
|
mRes := New()
|
||||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
if _, err = mRes.ReadFrom(msg.reader()); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
gotAddr := new(XORMappedAddress)
|
gotAddr := new(XORMappedAddress)
|
||||||
if err = gotAddr.GetFrom(m); err != nil {
|
if err = gotAddr.GetFrom(msg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !gotAddr.IP.Equal(expectedIP) {
|
if !gotAddr.IP.Equal(expectedIP) {
|
||||||
|
Reference in New Issue
Block a user