Upgrade golangci-lint, more linters

Introduces new linters, upgrade golangci-lint to version (v1.63.4)
This commit is contained in:
Joe Turki
2025-01-18 10:48:23 -06:00
parent be5e65e013
commit 0397f2187b
44 changed files with 883 additions and 656 deletions

View File

@@ -25,17 +25,32 @@ linters-settings:
- ^os.Exit$ - ^os.Exit$
- ^panic$ - ^panic$
- ^print(ln)?$ - ^print(ln)?$
varnamelen:
max-distance: 12
min-name-length: 2
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-decls:
- i int
- n int
- w io.Writer
- r io.Reader
- b []byte
linters: linters:
enable: enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences - bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully - bodyclose # checks whether HTTP response body is closed successfully
- containedctx # containedctx is a linter that detects struct contained context.Context field
- contextcheck # check the function whether use a non-inherited context - contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- decorder # check declaration order and count of types, constants, variables and functions - decorder # check declaration order and count of types, constants, variables and functions
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection - dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together - durationcheck # check for two durations multiplied together
- err113 # Golang linter to check the errors handling expressions
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
@@ -46,18 +61,17 @@ linters:
- forcetypeassert # finds forced type assertions - forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic. - gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code - gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions - gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant - goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter - gocritic # The most opinionated Go source code linter
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- godox # Tool for detection of FIXME, TODO and other comment keywords - godox # Tool for detection of FIXME, TODO and other comment keywords
- err113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed. - gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern - goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end - goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems - gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code - gosimple # Linter for Go source code that specializes in simplifying a code
@@ -65,9 +79,15 @@ linters:
- grouper # An analyzer to analyze expression groups. - grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases - importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used - ineffassign # Detects when assignments to existing variables are not used
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- misspell # Finds commonly misspelled English words in comments - misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- noctx # noctx finds sending http request without context.Context - noctx # noctx finds sending http request without context.Context
- predeclared # find code that shadows one of Go's predeclared identifiers - predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes - revive # golint replacement, finds style mistakes
@@ -75,28 +95,22 @@ linters:
- stylecheck # Stylecheck is a replacement for golint - stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags. - tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions - unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters - unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types - unused # Checks Go code for unused constants, variables, functions and types
- varnamelen # checks that the length of a variable's name matches its scope
- wastedassign # wastedassign finds wasted assignment statements - wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace - whitespace # Tool for detection of leading and trailing whitespace
disable: disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages - depguard # Go linter that checks if package imports are in a list of acceptable packages
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- funlen # Tool for detection of long functions - funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions - gochecknoinits # Checks that no init functions are present in Go code
- godot # Check if comments end in a period - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- gomnd # An analyzer to detect magic numbers. - interfacebloat # A linter that checks length of interface.
- ireturn # Accept Interfaces, Return Concrete Types - ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines - mnd # An analyzer to detect magic numbers
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- nolintlint # Reports ill-formed or insufficient nolint directives - nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated - prealloc # Finds slice declarations that could potentially be preallocated
@@ -104,8 +118,7 @@ linters:
- rowserrcheck # checks whether Err of rows is checked successfully - rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package - testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- varnamelen # checks that the length of a variable's name matches its scope
- wrapcheck # Checks that errors returned from external packages are wrapped - wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines! - wsl # Whitespace Linter - Forces you to use empty lines!
@@ -123,3 +136,4 @@ issues:
- path: cmd - path: cmd
linters: linters:
- forbidigo - forbidigo

32
addr.go
View File

@@ -15,7 +15,7 @@ import (
// This attribute is used only by servers for achieving backwards // This attribute is used only by servers for achieving backwards
// compatibility with RFC 3489 clients. // compatibility with RFC 3489 clients.
// //
// RFC 5389 Section 15.1 // RFC 5389 Section 15.1.
type MappedAddress struct { type MappedAddress struct {
IP net.IP IP net.IP
Port int Port int
@@ -23,7 +23,7 @@ type MappedAddress struct {
// AlternateServer represents ALTERNATE-SERVER attribute. // AlternateServer represents ALTERNATE-SERVER attribute.
// //
// RFC 5389 Section 15.11 // RFC 5389 Section 15.11.
type AlternateServer struct { type AlternateServer struct {
IP net.IP IP net.IP
Port int Port int
@@ -31,7 +31,7 @@ type AlternateServer struct {
// ResponseOrigin represents RESPONSE-ORIGIN attribute. // ResponseOrigin represents RESPONSE-ORIGIN attribute.
// //
// RFC 5780 Section 7.3 // RFC 5780 Section 7.3.
type ResponseOrigin struct { type ResponseOrigin struct {
IP net.IP IP net.IP
Port int Port int
@@ -39,7 +39,7 @@ type ResponseOrigin struct {
// OtherAddress represents OTHER-ADDRESS attribute. // OtherAddress represents OTHER-ADDRESS attribute.
// //
// RFC 5780 Section 7.4 // RFC 5780 Section 7.4.
type OtherAddress struct { type OtherAddress struct {
IP net.IP IP net.IP
Port int Port int
@@ -48,12 +48,14 @@ type OtherAddress struct {
// AddTo adds ALTERNATE-SERVER attribute to message. // AddTo adds ALTERNATE-SERVER attribute to message.
func (s *AlternateServer) AddTo(m *Message) error { func (s *AlternateServer) AddTo(m *Message) error {
a := (*MappedAddress)(s) a := (*MappedAddress)(s)
return a.AddToAs(m, AttrAlternateServer) return a.AddToAs(m, AttrAlternateServer)
} }
// GetFrom decodes ALTERNATE-SERVER from message. // GetFrom decodes ALTERNATE-SERVER from message.
func (s *AlternateServer) GetFrom(m *Message) error { func (s *AlternateServer) GetFrom(m *Message) error {
a := (*MappedAddress)(s) a := (*MappedAddress)(s)
return a.GetFromAs(m, AttrAlternateServer) return a.GetFromAs(m, AttrAlternateServer)
} }
@@ -63,14 +65,14 @@ func (a MappedAddress) String() string {
// GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t. // GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t.
func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error { func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
v, err := m.Get(t) value, err := m.Get(t)
if err != nil { if err != nil {
return err return err
} }
if len(v) <= 4 { if len(value) <= 4 {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
family := bin.Uint16(v[0:2]) family := bin.Uint16(value[0:2])
if family != familyIPv6 && family != familyIPv4 { if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family", return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family), fmt.Sprintf("bad value %d", family),
@@ -91,13 +93,14 @@ func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
for i := range a.IP { for i := range a.IP {
a.IP[i] = 0 a.IP[i] = 0
} }
a.Port = int(bin.Uint16(v[2:4])) a.Port = int(bin.Uint16(value[2:4]))
copy(a.IP, v[4:]) copy(a.IP, value[4:])
return nil return nil
} }
// AddToAs adds MAPPED-ADDRESS value to m as t attribute. // AddToAs adds MAPPED-ADDRESS value to m as t attribute.
func (a *MappedAddress) AddToAs(m *Message, t AttrType) error { func (a *MappedAddress) AddToAs(msg *Message, attrType AttrType) error {
var ( var (
family = familyIPv4 family = familyIPv4
ip = a.IP ip = a.IP
@@ -114,9 +117,10 @@ func (a *MappedAddress) AddToAs(m *Message, t AttrType) error {
value := make([]byte, 128) value := make([]byte, 128)
value[0] = 0 // first 8 bits are zeroes value[0] = 0 // first 8 bits are zeroes
bin.PutUint16(value[0:2], family) bin.PutUint16(value[0:2], family)
bin.PutUint16(value[2:4], uint16(a.Port)) bin.PutUint16(value[2:4], uint16(a.Port)) //nolint:gosec //G115
copy(value[4:], ip) copy(value[4:], ip)
m.Add(t, value[:4+len(ip)]) msg.Add(attrType, value[:4+len(ip)])
return nil return nil
} }
@@ -133,12 +137,14 @@ func (a *MappedAddress) GetFrom(m *Message) error {
// AddTo adds OTHER-ADDRESS attribute to message. // AddTo adds OTHER-ADDRESS attribute to message.
func (o *OtherAddress) AddTo(m *Message) error { func (o *OtherAddress) AddTo(m *Message) error {
a := (*MappedAddress)(o) a := (*MappedAddress)(o)
return a.AddToAs(m, AttrOtherAddress) return a.AddToAs(m, AttrOtherAddress)
} }
// GetFrom decodes OTHER-ADDRESS from message. // GetFrom decodes OTHER-ADDRESS from message.
func (o *OtherAddress) GetFrom(m *Message) error { func (o *OtherAddress) GetFrom(m *Message) error {
a := (*MappedAddress)(o) a := (*MappedAddress)(o)
return a.GetFromAs(m, AttrOtherAddress) return a.GetFromAs(m, AttrOtherAddress)
} }
@@ -149,12 +155,14 @@ func (o OtherAddress) String() string {
// AddTo adds RESPONSE-ORIGIN attribute to message. // AddTo adds RESPONSE-ORIGIN attribute to message.
func (o *ResponseOrigin) AddTo(m *Message) error { func (o *ResponseOrigin) AddTo(m *Message) error {
a := (*MappedAddress)(o) a := (*MappedAddress)(o)
return a.AddToAs(m, AttrResponseOrigin) return a.AddToAs(m, AttrResponseOrigin)
} }
// GetFrom decodes RESPONSE-ORIGIN from message. // GetFrom decodes RESPONSE-ORIGIN from message.
func (o *ResponseOrigin) GetFrom(m *Message) error { func (o *ResponseOrigin) GetFrom(m *Message) error {
a := (*MappedAddress)(o) a := (*MappedAddress)(o)
return a.GetFromAs(m, AttrResponseOrigin) return a.GetFromAs(m, AttrResponseOrigin)
} }

View File

@@ -11,7 +11,7 @@ import (
) )
func TestMappedAddress(t *testing.T) { func TestMappedAddress(t *testing.T) {
m := new(Message) msg := new(Message)
addr := &MappedAddress{ addr := &MappedAddress{
IP: net.ParseIP("122.12.34.5"), IP: net.ParseIP("122.12.34.5"),
Port: 5412, Port: 5412,
@@ -23,17 +23,17 @@ func TestMappedAddress(t *testing.T) {
badAddr := &MappedAddress{ badAddr := &MappedAddress{
IP: net.IP{1, 2, 3}, IP: net.IP{1, 2, 3},
} }
if err := badAddr.AddTo(m); err == nil { if err := badAddr.AddTo(msg); err == nil {
t.Error("should error") t.Error("should error")
} }
}) })
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil { if err := addr.AddTo(msg); err != nil {
t.Error(err) t.Error(err)
} }
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
got := new(MappedAddress) got := new(MappedAddress)
if err := got.GetFrom(m); err != nil { if err := got.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
if !got.IP.Equal(addr.IP) { if !got.IP.Equal(addr.IP) {
@@ -46,9 +46,9 @@ func TestMappedAddress(t *testing.T) {
} }
}) })
t.Run("Bad family", func(t *testing.T) { t.Run("Bad family", func(t *testing.T) {
v, _ := m.Attributes.Get(AttrMappedAddress) v, _ := msg.Attributes.Get(AttrMappedAddress)
v.Value[0] = 32 v.Value[0] = 32
if err := got.GetFrom(m); err == nil { if err := got.GetFrom(msg); err == nil {
t.Error("should error") t.Error("should error")
} }
}) })

View File

@@ -24,6 +24,7 @@ func NewAgent(h Handler) *Agent {
transactions: make(map[transactionID]agentTransaction), transactions: make(map[transactionID]agentTransaction),
handler: h, handler: h,
} }
return a return a
} }
@@ -80,6 +81,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
a.mux.Lock() a.mux.Lock()
if a.closed { if a.closed {
a.mux.Unlock() a.mux.Unlock()
return ErrAgentClosed return ErrAgentClosed
} }
t, exists := a.transactions[id] t, exists := a.transactions[id]
@@ -93,6 +95,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
TransactionID: t.id, TransactionID: t.id,
Error: err, Error: err,
}) })
return nil return nil
} }
@@ -124,6 +127,7 @@ func (a *Agent) Start(id [TransactionIDSize]byte, deadline time.Time) error {
id: id, id: id,
deadline: deadline, deadline: deadline,
} }
return nil return nil
} }
@@ -147,6 +151,7 @@ func (a *Agent) Collect(gcTime time.Time) error {
// All transactions should be already closed // All transactions should be already closed
// during Close() call. // during Close() call.
a.mux.Unlock() a.mux.Unlock()
return ErrAgentClosed return ErrAgentClosed
} }
// Adding all transactions with deadline before gcTime // Adding all transactions with deadline before gcTime
@@ -175,24 +180,27 @@ func (a *Agent) Collect(gcTime time.Time) error {
event.TransactionID = id event.TransactionID = id
h(event) h(event)
} }
return nil return nil
} }
// Process incoming message, synchronously passing it to handler. // Process incoming message, synchronously passing it to handler.
func (a *Agent) Process(m *Message) error { func (a *Agent) Process(m *Message) error {
e := Event{ event := Event{
TransactionID: m.TransactionID, TransactionID: m.TransactionID,
Message: m, Message: m,
} }
a.mux.Lock() a.mux.Lock()
if a.closed { if a.closed {
a.mux.Unlock() a.mux.Unlock()
return ErrAgentClosed return ErrAgentClosed
} }
h := a.handler h := a.handler
delete(a.transactions, m.TransactionID) delete(a.transactions, m.TransactionID)
a.mux.Unlock() a.mux.Unlock()
h(e) h(event)
return nil return nil
} }
@@ -201,10 +209,12 @@ func (a *Agent) SetHandler(h Handler) error {
a.mux.Lock() a.mux.Lock()
if a.closed { if a.closed {
a.mux.Unlock() a.mux.Unlock()
return ErrAgentClosed return ErrAgentClosed
} }
a.handler = h a.handler = h
a.mux.Unlock() a.mux.Unlock()
return nil return nil
} }
@@ -217,6 +227,7 @@ func (a *Agent) Close() error {
a.mux.Lock() a.mux.Lock()
if a.closed { if a.closed {
a.mux.Unlock() a.mux.Unlock()
return ErrAgentClosed return ErrAgentClosed
} }
for _, t := range a.transactions { for _, t := range a.transactions {
@@ -227,6 +238,7 @@ func (a *Agent) Close() error {
a.closed = true a.closed = true
a.handler = nil a.handler = nil
a.mux.Unlock() a.mux.Unlock()
return nil return nil
} }

View File

@@ -10,49 +10,49 @@ import (
) )
func TestAgent_ProcessInTransaction(t *testing.T) { func TestAgent_ProcessInTransaction(t *testing.T) {
m := New() msg := New()
a := NewAgent(func(e Event) { agent := NewAgent(func(e Event) {
if e.Error != nil { if e.Error != nil {
t.Errorf("got error: %s", e.Error) t.Errorf("got error: %s", e.Error)
} }
if !e.Message.Equal(m) { if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, m) t.Errorf("%s (got) != %s (expected)", e.Message, msg)
} }
}) })
if err := m.NewTransactionID(); err != nil { if err := msg.NewTransactionID(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := a.Start(m.TransactionID, time.Time{}); err != nil { if err := agent.Start(msg.TransactionID, time.Time{}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := a.Process(m); err != nil { if err := agent.Process(msg); err != nil {
t.Error(err) t.Error(err)
} }
if err := a.Close(); err != nil { if err := agent.Close(); err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestAgent_Process(t *testing.T) { func TestAgent_Process(t *testing.T) {
m := New() msg := New()
a := NewAgent(func(e Event) { agent := NewAgent(func(e Event) {
if e.Error != nil { if e.Error != nil {
t.Errorf("got error: %s", e.Error) t.Errorf("got error: %s", e.Error)
} }
if !e.Message.Equal(m) { if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, m) t.Errorf("%s (got) != %s (expected)", e.Message, msg)
} }
}) })
if err := m.NewTransactionID(); err != nil { if err := msg.NewTransactionID(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := a.Process(m); err != nil { if err := agent.Process(msg); err != nil {
t.Error(err) t.Error(err)
} }
if err := a.Close(); err != nil { if err := agent.Close(); err != nil {
t.Error(err) t.Error(err)
} }
if err := a.Process(m); !errors.Is(err, ErrAgentClosed) { if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) {
t.Errorf("closed agent should return <%s>, but got <%s>", t.Errorf("closed agent should return <%s>, but got <%s>",
ErrAgentClosed, err, ErrAgentClosed, err,
) )
@@ -60,27 +60,27 @@ func TestAgent_Process(t *testing.T) {
} }
func TestAgent_Start(t *testing.T) { func TestAgent_Start(t *testing.T) {
a := NewAgent(nil) agent := NewAgent(nil)
id := NewTransactionID() id := NewTransactionID()
deadline := time.Now().AddDate(0, 0, 1) deadline := time.Now().AddDate(0, 0, 1)
if err := a.Start(id, deadline); err != nil { if err := agent.Start(id, deadline); err != nil {
t.Errorf("failed to statt transaction: %s", err) t.Errorf("failed to statt transaction: %s", err)
} }
if err := a.Start(id, deadline); !errors.Is(err, ErrTransactionExists) { if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
t.Errorf("duplicate start should return <%s>, got <%s>", t.Errorf("duplicate start should return <%s>, got <%s>",
ErrTransactionExists, err, ErrTransactionExists, err,
) )
} }
if err := a.Close(); err != nil { if err := agent.Close(); err != nil {
t.Error(err) t.Error(err)
} }
id = NewTransactionID() id = NewTransactionID()
if err := a.Start(id, deadline); !errors.Is(err, ErrAgentClosed) { if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
t.Errorf("start on closed agent should return <%s>, got <%s>", t.Errorf("start on closed agent should return <%s>, got <%s>",
ErrAgentClosed, err, ErrAgentClosed, err,
) )
} }
if err := a.SetHandler(nil); !errors.Is(err, ErrAgentClosed) { if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>", t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
ErrAgentClosed, err, ErrAgentClosed, err,
) )
@@ -89,18 +89,18 @@ func TestAgent_Start(t *testing.T) {
func TestAgent_Stop(t *testing.T) { func TestAgent_Stop(t *testing.T) {
called := make(chan Event, 1) called := make(chan Event, 1)
a := NewAgent(func(e Event) { agent := NewAgent(func(e Event) {
called <- e called <- e
}) })
if err := a.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) { if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists) t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
} }
id := NewTransactionID() id := NewTransactionID()
timeout := time.Millisecond * 200 timeout := time.Millisecond * 200
if err := a.Start(id, time.Now().Add(timeout)); err != nil { if err := agent.Start(id, time.Now().Add(timeout)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := a.Stop(id); err != nil { if err := agent.Stop(id); err != nil {
t.Fatal(err) t.Fatal(err)
} }
select { select {
@@ -113,19 +113,19 @@ func TestAgent_Stop(t *testing.T) {
case <-time.After(timeout * 2): case <-time.After(timeout * 2):
t.Fatal("timed out") t.Fatal("timed out")
} }
if err := a.Close(); err != nil { if err := agent.Close(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := a.Close(); !errors.Is(err, ErrAgentClosed) { if err := agent.Close(); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed) t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
} }
if err := a.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) { if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed) t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
} }
} }
func TestAgent_GC(t *testing.T) { func TestAgent_GC(t *testing.T) { //nolint:cyclop
a := NewAgent(nil) agent := NewAgent(nil)
shouldTimeOutID := make(map[transactionID]bool) shouldTimeOutID := make(map[transactionID]bool)
deadline := time.Date(2027, time.November, 21, deadline := time.Date(2027, time.November, 21,
23, 0, 0, 0, 23, 0, 0, 0,
@@ -133,7 +133,7 @@ func TestAgent_GC(t *testing.T) {
) )
gcDeadline := deadline.Add(-time.Second) gcDeadline := deadline.Add(-time.Second)
deadlineNotGC := gcDeadline.AddDate(0, 0, -1) deadlineNotGC := gcDeadline.AddDate(0, 0, -1)
a.SetHandler(func(e Event) { //nolint:errcheck,gosec agent.SetHandler(func(e Event) { //nolint:errcheck,gosec
id := e.TransactionID id := e.TransactionID
shouldTimeOut, found := shouldTimeOutID[id] shouldTimeOut, found := shouldTimeOutID[id]
if !found { if !found {
@@ -149,67 +149,67 @@ func TestAgent_GC(t *testing.T) {
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
id := NewTransactionID() id := NewTransactionID()
shouldTimeOutID[id] = false shouldTimeOutID[id] = false
if err := a.Start(id, deadline); err != nil { if err := agent.Start(id, deadline); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
id := NewTransactionID() id := NewTransactionID()
shouldTimeOutID[id] = true shouldTimeOutID[id] = true
if err := a.Start(id, deadlineNotGC); err != nil { if err := agent.Start(id, deadlineNotGC); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
if err := a.Collect(gcDeadline); err != nil { if err := agent.Collect(gcDeadline); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := a.Close(); err != nil { if err := agent.Close(); err != nil {
t.Error(err) t.Error(err)
} }
if err := a.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) { if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err) t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
} }
} }
func BenchmarkAgent_GC(b *testing.B) { func BenchmarkAgent_GC(b *testing.B) {
a := NewAgent(nil) agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1) deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < agentCollectCap; i++ { for i := 0; i < agentCollectCap; i++ {
if err := a.Start(NewTransactionID(), deadline); err != nil { if err := agent.Start(NewTransactionID(), deadline); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
defer func() { defer func() {
if err := a.Close(); err != nil { if err := agent.Close(); err != nil {
b.Error(err) b.Error(err)
} }
}() }()
b.ReportAllocs() b.ReportAllocs()
gcDeadline := deadline.Add(-time.Second) gcDeadline := deadline.Add(-time.Second)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := a.Collect(gcDeadline); err != nil { if err := agent.Collect(gcDeadline); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
} }
func BenchmarkAgent_Process(b *testing.B) { func BenchmarkAgent_Process(b *testing.B) {
a := NewAgent(nil) agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1) deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
if err := a.Start(NewTransactionID(), deadline); err != nil { if err := agent.Start(NewTransactionID(), deadline); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
defer func() { defer func() {
if err := a.Close(); err != nil { if err := agent.Close(); err != nil {
b.Error(err) b.Error(err)
} }
}() }()
b.ReportAllocs() b.ReportAllocs()
m := MustBuild(TransactionID) m := MustBuild(TransactionID)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := a.Process(m); err != nil { if err := agent.Process(m); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }

View File

@@ -21,6 +21,7 @@ func (a Attributes) Get(t AttrType) (RawAttribute, bool) {
return candidate, true return candidate, true
} }
} }
return RawAttribute{}, false return RawAttribute{}, false
} }
@@ -77,7 +78,7 @@ const (
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
) )
// Attributes from RFC 5780 NAT Behavior Discovery // Attributes from RFC 5780 NAT Behavior Discovery.
const ( const (
AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST
AttrPadding AttrType = 0x0026 // PADDING AttrPadding AttrType = 0x0026 // PADDING
@@ -166,6 +167,7 @@ func (t AttrType) String() string {
// Just return hex representation of unknown attribute type. // Just return hex representation of unknown attribute type.
return fmt.Sprintf("0x%x", uint16(t)) return fmt.Sprintf("0x%x", uint16(t))
} }
return s return s
} }
@@ -186,6 +188,7 @@ type RawAttribute struct {
// the Length field. // the Length field.
func (a RawAttribute) AddTo(m *Message) error { func (a RawAttribute) AddTo(m *Message) error {
m.Add(a.Type, a.Value) m.Add(a.Type, a.Value)
return nil return nil
} }
@@ -205,6 +208,7 @@ func (a RawAttribute) Equal(b RawAttribute) bool {
return false return false
} }
} }
return true return true
} }
@@ -224,6 +228,7 @@ func (m *Message) Get(t AttrType) ([]byte, error) {
if !ok { if !ok {
return nil, ErrAttributeNotFound return nil, ErrAttributeNotFound
} }
return v.Value, nil return v.Value, nil
} }
@@ -240,6 +245,7 @@ func nearestPaddedValueLength(l int) int {
if n < l { if n < l {
n += padding n += padding
} }
return n return n
} }
@@ -250,5 +256,6 @@ func compatAttrType(val uint16) AttrType {
if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN
return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on) return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on)
} }
return AttrType(val) return AttrType(val)
} }

View File

@@ -44,13 +44,13 @@ func TestRawAttribute_AddTo(t *testing.T) {
} }
func TestMessage_GetNoAllocs(t *testing.T) { func TestMessage_GetNoAllocs(t *testing.T) {
m := New() msg := New()
NewSoftware("c").AddTo(m) //nolint:errcheck,gosec NewSoftware("c").AddTo(msg) //nolint:errcheck,gosec
m.WriteHeader() msg.WriteHeader()
t.Run("Default", func(t *testing.T) { t.Run("Default", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
m.Get(AttrSoftware) //nolint:errcheck,gosec msg.Get(AttrSoftware) //nolint:errcheck,gosec
}) })
if allocs > 0 { if allocs > 0 {
t.Error("allocated memory, but should not") t.Error("allocated memory, but should not")
@@ -58,7 +58,7 @@ func TestMessage_GetNoAllocs(t *testing.T) {
}) })
t.Run("Not found", func(t *testing.T) { t.Run("Not found", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
m.Get(AttrOrigin) //nolint:errcheck,gosec msg.Get(AttrOrigin) //nolint:errcheck,gosec
}) })
if allocs > 0 { if allocs > 0 {
t.Error("allocated memory, but should not") t.Error("allocated memory, but should not")

View File

@@ -17,6 +17,7 @@ func CheckSize(_ AttrType, got, expected int) error {
if got == expected { if got == expected {
return nil return nil
} }
return ErrAttributeSizeInvalid return ErrAttributeSizeInvalid
} }
@@ -24,6 +25,7 @@ func checkHMAC(got, expected []byte) error {
if hmac.Equal(got, expected) { if hmac.Equal(got, expected) {
return nil return nil
} }
return ErrIntegrityMismatch return ErrIntegrityMismatch
} }
@@ -31,6 +33,7 @@ func checkFingerprint(got, expected uint32) error {
if got == expected { if got == expected {
return nil return nil
} }
return ErrFingerprintMismatch return ErrFingerprintMismatch
} }
@@ -44,6 +47,7 @@ func CheckOverflow(_ AttrType, got, maxVal int) error {
if got <= maxVal { if got <= maxVal {
return nil return nil
} }
return ErrAttributeSizeOverflow return ErrAttributeSizeOverflow
} }

126
client.go
View File

@@ -21,7 +21,7 @@ import (
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
) )
// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI // ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI.
var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport") var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")
// Dial connects to the address on the named network and then // Dial connects to the address on the named network and then
@@ -31,10 +31,11 @@ func Dial(network, address string) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewClient(conn) return NewClient(conn)
} }
// DialConfig is used to pass configuration to DialURI() // DialConfig is used to pass configuration to DialURI().
type DialConfig struct { type DialConfig struct {
DTLSConfig dtls.Config DTLSConfig dtls.Config
TLSConfig tls.Config TLSConfig tls.Config
@@ -44,7 +45,7 @@ type DialConfig struct {
// DialURI connect to the STUN/TURN URI and then // DialURI connect to the STUN/TURN URI and then
// initializes Client on that connection, returning error if any. // initializes Client on that connection, returning error if any.
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop
var conn Connection var conn Connection
var err error var err error
@@ -203,7 +204,7 @@ const (
// provide any API for it, so if you need to read application data, wrap the // provide any API for it, so if you need to read application data, wrap the
// connection with your (de-)multiplexer and pass the wrapper as conn. // connection with your (de-)multiplexer and pass the wrapper as conn.
func NewClient(conn Connection, options ...ClientOption) (*Client, error) { func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
c := &Client{ client := &Client{
close: make(chan struct{}), close: make(chan struct{}),
c: conn, c: conn,
clock: systemClock(), clock: systemClock(),
@@ -214,32 +215,33 @@ func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
closeConn: true, closeConn: true,
} }
for _, o := range options { for _, o := range options {
o(c) o(client)
} }
if c.c == nil { if client.c == nil {
return nil, ErrNoConnection return nil, ErrNoConnection
} }
if c.a == nil { if client.a == nil {
c.a = NewAgent(nil) client.a = NewAgent(nil)
} }
if err := c.a.SetHandler(c.handleAgentCallback); err != nil { if err := client.a.SetHandler(client.handleAgentCallback); err != nil {
return nil, err return nil, err
} }
if c.collector == nil { if client.collector == nil {
c.collector = &tickerCollector{ client.collector = &tickerCollector{
close: make(chan struct{}), close: make(chan struct{}),
clock: c.clock, clock: client.clock,
} }
} }
if err := c.collector.Start(c.rtoRate, func(t time.Time) { if err := client.collector.Start(client.rtoRate, func(t time.Time) {
closedOrPanic(c.a.Collect(t)) closedOrPanic(client.a.Collect(t))
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
c.wg.Add(1) client.wg.Add(1)
go c.readUntilClosed() go client.readUntilClosed()
runtime.SetFinalizer(c, clientFinalizer) runtime.SetFinalizer(client, clientFinalizer)
return c, nil
return client, nil
} }
func clientFinalizer(c *Client) { func clientFinalizer(c *Client) {
@@ -252,6 +254,7 @@ func clientFinalizer(c *Client) {
} }
if err == nil { if err == nil {
log.Println("client: called finalizer on non-closed client") // nolint log.Println("client: called finalizer on non-closed client") // nolint
return return
} }
log.Println("client: called finalizer on non-closed client:", err) // nolint log.Println("client: called finalizer on non-closed client:", err) // nolint
@@ -353,6 +356,7 @@ func (c *Client) start(t *clientTransaction) error {
return ErrTransactionExists return ErrTransactionExists
} }
c.t[t.id] = t c.t[t.id] = t
return nil return nil
} }
@@ -399,6 +403,7 @@ func sprintErr(err error) string {
if err == nil { if err == nil {
return "<nil>" //nolint:goconst return "<nil>" //nolint:goconst
} }
return err.Error() return err.Error()
} }
@@ -455,18 +460,21 @@ func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error
select { select {
case <-a.close: case <-a.close:
t.Stop() t.Stop()
return return
case <-t.C: case <-t.C:
f(a.clock.Now()) f(a.clock.Now())
} }
} }
}() }()
return nil return nil
} }
func (a *tickerCollector) Close() error { func (a *tickerCollector) Close() error {
close(a.close) close(a.close)
a.wg.Wait() a.wg.Wait()
return nil return nil
} }
@@ -481,6 +489,7 @@ func (c *Client) Close() error {
c.mux.Lock() c.mux.Lock()
if c.closed { if c.closed {
c.mux.Unlock() c.mux.Unlock()
return ErrClientClosed return ErrClientClosed
} }
c.closed = true c.closed = true
@@ -498,6 +507,7 @@ func (c *Client) Close() error {
if agentErr == nil && connErr == nil { if agentErr == nil && connErr == nil {
return nil return nil
} }
return CloseErr{ return CloseErr{
AgentErr: agentErr, AgentErr: agentErr,
ConnectionErr: connErr, ConnectionErr: connErr,
@@ -566,6 +576,7 @@ func (c *Client) checkInit() error {
if c == nil || c.c == nil || c.a == nil || c.close == nil { if c == nil || c.c == nil || c.a == nil || c.close == nil {
return ErrClientNotInitialized return ErrClientNotInitialized
} }
return nil return nil
} }
@@ -590,6 +601,7 @@ func (c *Client) Do(m *Message, f func(Event)) error {
return err return err
} }
h.wait() h.wait()
return nil return nil
} }
@@ -611,80 +623,85 @@ var bufferPool = &sync.Pool{ //nolint:gochecknoglobals
}, },
} }
func (c *Client) handleAgentCallback(e Event) { func (c *Client) handleAgentCallback(event Event) { //nolint:cyclop
c.mux.Lock() c.mux.Lock()
if c.closed { if c.closed {
c.mux.Unlock() c.mux.Unlock()
return return
} }
t, found := c.t[e.TransactionID] transaction, found := c.t[event.TransactionID]
if found { if found {
delete(c.t, t.id) delete(c.t, transaction.id)
} }
c.mux.Unlock() c.mux.Unlock()
if !found { if !found {
if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) { if c.handler != nil && !errors.Is(event.Error, ErrTransactionStopped) {
c.handler(e) c.handler(event)
} }
// Ignoring. // Ignoring.
return return
} }
if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil { if atomic.LoadInt32(&c.maxAttempts) <= transaction.attempt || event.Error == nil {
// Transaction completed. // Transaction completed.
t.handle(e) transaction.handle(event)
putClientTransaction(t) putClientTransaction(transaction)
return return
} }
// Doing re-transmission. // Doing re-transmission.
t.attempt++ transaction.attempt++
b := bufferPool.Get().(*buffer) //nolint:forcetypeassert buff := bufferPool.Get().(*buffer) //nolint:forcetypeassert
b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)] buff.buf = buff.buf[:copy(buff.buf[:cap(buff.buf)], transaction.raw)]
defer bufferPool.Put(b) defer bufferPool.Put(buff)
var ( var (
now = c.clock.Now() now = c.clock.Now()
timeOut = t.nextTimeout(now) timeOut = transaction.nextTimeout(now)
id = t.id id = transaction.id
) )
// Starting client transaction. // Starting client transaction.
if startErr := c.start(t); startErr != nil { if startErr := c.start(transaction); startErr != nil {
c.delete(id) c.delete(id)
e.Error = startErr event.Error = startErr
t.handle(e) transaction.handle(event)
putClientTransaction(t) putClientTransaction(transaction)
return return
} }
// Starting agent transaction. // Starting agent transaction.
if startErr := c.a.Start(id, timeOut); startErr != nil { if startErr := c.a.Start(id, timeOut); startErr != nil {
c.delete(id) c.delete(id)
e.Error = startErr event.Error = startErr
t.handle(e) transaction.handle(event)
putClientTransaction(t) putClientTransaction(transaction)
return return
} }
// Writing message to connection again. // Writing message to connection again.
_, writeErr := c.c.Write(b.buf) _, writeErr := c.c.Write(buff.buf)
if writeErr != nil { if writeErr != nil {
c.delete(id) c.delete(id)
e.Error = writeErr event.Error = writeErr
// Stopping agent transaction instead of waiting until it's deadline. // Stopping agent transaction instead of waiting until it's deadline.
// This will call handleAgentCallback with "ErrTransactionStopped" error // This will call handleAgentCallback with "ErrTransactionStopped" error
// which will be ignored. // which will be ignored.
if stopErr := c.a.Stop(id); stopErr != nil { if stopErr := c.a.Stop(id); stopErr != nil {
// Failed to stop agent transaction. Wrapping the error in StopError. // Failed to stop agent transaction. Wrapping the error in StopError.
e.Error = StopErr{ event.Error = StopErr{
Err: stopErr, Err: stopErr,
Cause: writeErr, Cause: writeErr,
} }
} }
t.handle(e) transaction.handle(event)
putClientTransaction(t) putClientTransaction(transaction)
return return
} }
} }
// Start starts transaction (if h set) and writes message to server, handler // Start starts transaction (if h set) and writes message to server, handler
// is called asynchronously. // is called asynchronously.
func (c *Client) Start(m *Message, h Handler) error { func (c *Client) Start(msg *Message, handler Handler) error {
if err := c.checkInit(); err != nil { if err := c.checkInit(); err != nil {
return err return err
} }
@@ -694,34 +711,35 @@ func (c *Client) Start(m *Message, h Handler) error {
if closed { if closed {
return ErrClientClosed return ErrClientClosed
} }
if h != nil { if handler != nil {
// Starting transaction only if h is set. Useful for indications. // Starting transaction only if h is set. Useful for indications.
t := acquireClientTransaction() t := acquireClientTransaction()
t.id = m.TransactionID t.id = msg.TransactionID
t.start = c.clock.Now() t.start = c.clock.Now()
t.h = h t.h = handler
t.rto = time.Duration(atomic.LoadInt64(&c.rto)) t.rto = time.Duration(atomic.LoadInt64(&c.rto))
t.attempt = 0 t.attempt = 0
t.raw = append(t.raw[:0], m.Raw...) t.raw = append(t.raw[:0], msg.Raw...)
t.calls = 0 t.calls = 0
d := t.nextTimeout(t.start) d := t.nextTimeout(t.start)
if err := c.start(t); err != nil { if err := c.start(t); err != nil {
return err return err
} }
if err := c.a.Start(m.TransactionID, d); err != nil { if err := c.a.Start(msg.TransactionID, d); err != nil {
return err return err
} }
} }
_, err := m.WriteTo(c.c) _, err := msg.WriteTo(c.c)
if err != nil && h != nil { if err != nil && handler != nil {
c.delete(m.TransactionID) c.delete(msg.TransactionID)
// Stopping transaction instead of waiting until deadline. // Stopping transaction instead of waiting until deadline.
if stopErr := c.a.Stop(m.TransactionID); stopErr != nil { if stopErr := c.a.Stop(msg.TransactionID); stopErr != nil {
return StopErr{ return StopErr{
Err: stopErr, Err: stopErr,
Cause: err, Cause: err,
} }
} }
} }
return err return err
} }

View File

@@ -38,11 +38,13 @@ type TestAgent struct {
func (n *TestAgent) SetHandler(h Handler) error { func (n *TestAgent) SetHandler(h Handler) error {
n.h = h n.h = h
return nil return nil
} }
func (n *TestAgent) Close() error { func (n *TestAgent) Close() error {
close(n.e) close(n.e)
return nil return nil
} }
@@ -54,6 +56,7 @@ func (n *TestAgent) Start(id [TransactionIDSize]byte, _ time.Time) error {
n.e <- Event{ n.e <- Event{
TransactionID: id, TransactionID: id,
} }
return nil return nil
} }
@@ -69,6 +72,7 @@ func (noopConnection) Write(b []byte) (int, error) {
func (noopConnection) Read([]byte) (int, error) { func (noopConnection) Read([]byte) (int, error) {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
return 0, io.EOF return 0, io.EOF
} }
@@ -136,6 +140,7 @@ func (t *testConnection) Close() error {
return errClientAlreadyStopped return errClientAlreadyStopped
} }
t.stopped = true t.stopped = true
return nil return nil
} }
@@ -148,6 +153,7 @@ func (t *testConnection) Read(b []byte) (int, error) {
if t.read != nil { if t.read != nil {
return t.read(b) return t.read(b)
} }
return copy(b, t.b), nil return copy(b, t.b), nil
} }
@@ -165,7 +171,7 @@ func TestClosedOrPanic(t *testing.T) {
}() }()
} }
func TestClient_Start(t *testing.T) { func TestClient_Start(t *testing.T) { //nolint:cyclop
response := MustBuild(TransactionID, BindingSuccess) response := MustBuild(TransactionID, BindingSuccess)
response.Encode() response.Encode()
write := make(chan struct{}, 1) write := make(chan struct{}, 1)
@@ -178,6 +184,7 @@ func TestClient_Start(t *testing.T) {
case <-read: case <-read:
t.Log("reading") t.Log("reading")
copy(i, response.Raw) copy(i, response.Raw)
return len(response.Raw), nil return len(response.Raw), nil
case <-time.After(time.Millisecond * 10): case <-time.After(time.Millisecond * 10):
return 0, errClientReadTimedOut return 0, errClientReadTimedOut
@@ -188,33 +195,34 @@ func TestClient_Start(t *testing.T) {
select { select {
case <-write: case <-write:
t.Log("writing") t.Log("writing")
return len(bytes), nil return len(bytes), nil
case <-time.After(time.Millisecond * 10): case <-time.After(time.Millisecond * 10):
return 0, errClientWriteTimedOut return 0, errClientWriteTimedOut
} }
}, },
} }
c, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer func() { defer func() {
if err := c.Close(); err != nil { if err := client.Close(); err != nil {
t.Error(err) t.Error(err)
} }
if err := c.Close(); err == nil { if err := client.Close(); err == nil {
t.Error("second close should fail") t.Error("second close should fail")
} }
if err := c.Do(MustBuild(TransactionID), nil); err == nil { if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail") t.Error("Do after Close should fail")
} }
}() }()
m := MustBuild(response, BindingRequest) msg := MustBuild(response, BindingRequest)
t.Log("init") t.Log("init")
got := make(chan struct{}) got := make(chan struct{})
write <- struct{}{} write <- struct{}{}
t.Log("starting the first transaction") t.Log("starting the first transaction")
if err := c.Start(m, func(event Event) { if err := client.Start(msg, func(event Event) {
t.Log("got first transaction callback") t.Log("got first transaction callback")
if event.Error != nil { if event.Error != nil {
t.Error(event.Error) t.Error(event.Error)
@@ -224,7 +232,7 @@ func TestClient_Start(t *testing.T) {
t.Error(err) t.Error(err)
} }
t.Log("starting the second transaction") t.Log("starting the second transaction")
if err := c.Start(m, func(Event) { if err := client.Start(msg, func(Event) {
t.Error("should not be called") t.Error("should not be called")
}); !errors.Is(err, ErrTransactionExists) { }); !errors.Is(err, ErrTransactionExists) {
t.Errorf("unexpected error %v", err) t.Errorf("unexpected error %v", err)
@@ -247,25 +255,25 @@ func TestClient_Do(t *testing.T) {
return len(bytes), nil return len(bytes), nil
}, },
} }
c, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer func() { defer func() {
if err := c.Close(); err != nil { if err := client.Close(); err != nil {
t.Error(err) t.Error(err)
} }
if err := c.Close(); err == nil { if err := client.Close(); err == nil {
t.Error("second close should fail") t.Error("second close should fail")
} }
if err := c.Do(MustBuild(TransactionID), nil); err == nil { if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail") t.Error("Do after Close should fail")
} }
}() }()
m := MustBuild( m := MustBuild(
NewTransactionIDSetter(response.TransactionID), NewTransactionIDSetter(response.TransactionID),
) )
if err := c.Do(m, func(event Event) { if err := client.Do(m, func(event Event) {
if event.Error != nil { if event.Error != nil {
t.Error(event.Error) t.Error(event.Error)
} }
@@ -273,13 +281,13 @@ func TestClient_Do(t *testing.T) {
t.Error(err) t.Error(err)
} }
m = MustBuild(TransactionID) m = MustBuild(TransactionID)
if err := c.Do(m, nil); err != nil { if err := client.Do(m, nil); err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestCloseErr_Error(t *testing.T) { func TestCloseErr_Error(t *testing.T) {
for id, c := range []struct { for id, testCase := range []struct {
Err CloseErr Err CloseErr
Out string Out string
}{ }{
@@ -291,16 +299,16 @@ func TestCloseErr_Error(t *testing.T) {
ConnectionErr: io.ErrUnexpectedEOF, ConnectionErr: io.ErrUnexpectedEOF,
}, "failed to close: unexpected EOF (connection), <nil> (agent)"}, }, "failed to close: unexpected EOF (connection), <nil> (agent)"},
} { } {
if out := c.Err.Error(); out != c.Out { if out := testCase.Err.Error(); out != testCase.Out {
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)", t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, c.Err, out, c.Out, id, testCase.Err, out, testCase.Out,
) )
} }
} }
} }
func TestStopErr_Error(t *testing.T) { func TestStopErr_Error(t *testing.T) {
for id, c := range []struct { for id, testcase := range []struct {
Err StopErr Err StopErr
Out string Out string
}{ }{
@@ -312,9 +320,9 @@ func TestStopErr_Error(t *testing.T) {
Cause: io.ErrUnexpectedEOF, Cause: io.ErrUnexpectedEOF,
}, "error while stopping due to unexpected EOF: <nil>"}, }, "error while stopping due to unexpected EOF: <nil>"},
} { } {
if out := c.Err.Error(); out != c.Out { if out := testcase.Err.Error(); out != testcase.Out {
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)", t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, c.Err, out, c.Out, id, testcase.Err, out, testcase.Out,
) )
} }
} }
@@ -352,7 +360,7 @@ func TestClientAgentError(t *testing.T) {
return len(bytes), nil return len(bytes), nil
}, },
} }
c, err := NewClient(conn, client, err := NewClient(conn,
WithAgent(errorAgent{ WithAgent(errorAgent{
startErr: io.ErrUnexpectedEOF, startErr: io.ErrUnexpectedEOF,
}), }),
@@ -361,15 +369,15 @@ func TestClientAgentError(t *testing.T) {
log.Fatal(err) log.Fatal(err)
} }
defer func() { defer func() {
if err := c.Close(); err != nil { if err := client.Close(); err != nil {
t.Error(err) t.Error(err)
} }
}() }()
m := MustBuild(NewTransactionIDSetter(response.TransactionID)) m := MustBuild(NewTransactionIDSetter(response.TransactionID))
if err := c.Do(m, nil); err != nil { if err := client.Do(m, nil); err != nil {
t.Error(err) t.Error(err)
} }
if err := c.Do(m, func(event Event) { if err := client.Do(m, func(event Event) {
if event.Error == nil { if event.Error == nil {
t.Error("error expected") t.Error("error expected")
} }
@@ -384,20 +392,20 @@ func TestClientConnErr(t *testing.T) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
}, },
} }
c, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer func() { defer func() {
if err := c.Close(); err != nil { if err := client.Close(); err != nil {
t.Error(err) t.Error(err)
} }
}() }()
m := MustBuild(TransactionID) m := MustBuild(TransactionID)
if err := c.Do(m, nil); err == nil { if err := client.Do(m, nil); err == nil {
t.Error("error expected") t.Error("error expected")
} }
if err := c.Do(m, NoopHandler()); err == nil { if err := client.Do(m, NoopHandler()); err == nil {
t.Error("error expected") t.Error("error expected")
} }
} }
@@ -408,7 +416,7 @@ func TestClientConnErrStopErr(t *testing.T) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
}, },
} }
c, err := NewClient(conn, client, err := NewClient(conn,
WithAgent(errorAgent{ WithAgent(errorAgent{
stopErr: io.ErrUnexpectedEOF, stopErr: io.ErrUnexpectedEOF,
}), }),
@@ -417,12 +425,12 @@ func TestClientConnErrStopErr(t *testing.T) {
log.Fatal(err) log.Fatal(err)
} }
defer func() { defer func() {
if err := c.Close(); err != nil { if err := client.Close(); err != nil {
t.Error(err) t.Error(err)
} }
}() }()
m := MustBuild(TransactionID) m := MustBuild(TransactionID)
if err := c.Do(m, NoopHandler()); err == nil { if err := client.Do(m, NoopHandler()); err == nil {
t.Error("error expected") t.Error("error expected")
} }
} }
@@ -556,11 +564,13 @@ func (a *gcWaitAgent) Stop([TransactionIDSize]byte) error {
func (a *gcWaitAgent) Close() error { func (a *gcWaitAgent) Close() error {
close(a.gc) close(a.gc)
return nil return nil
} }
func (a *gcWaitAgent) Collect(time.Time) error { func (a *gcWaitAgent) Collect(time.Time) error {
a.gc <- struct{}{} a.gc <- struct{}{}
return nil return nil
} }
@@ -617,6 +627,7 @@ func captureLog() (*bytes.Buffer, func()) {
log.SetOutput(&buf) log.SetOutput(&buf)
f := log.Flags() f := log.Flags()
log.SetFlags(0) log.SetFlags(0)
return &buf, func() { return &buf, func() {
log.SetFlags(f) log.SetFlags(f)
log.SetOutput(os.Stderr) log.SetOutput(os.Stderr)
@@ -633,12 +644,12 @@ func TestClientFinalizer(t *testing.T) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
}, },
} }
c, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
clientFinalizer(c) clientFinalizer(client)
clientFinalizer(c) clientFinalizer(client)
response := MustBuild(TransactionID, BindingSuccess) response := MustBuild(TransactionID, BindingSuccess)
response.Encode() response.Encode()
conn = &testConnection{ conn = &testConnection{
@@ -647,7 +658,7 @@ func TestClientFinalizer(t *testing.T) {
return len(bytes), nil return len(bytes), nil
}, },
} }
c, err = NewClient(conn, client, err = NewClient(conn,
WithAgent(errorAgent{ WithAgent(errorAgent{
closeErr: io.ErrUnexpectedEOF, closeErr: io.ErrUnexpectedEOF,
}), }),
@@ -655,7 +666,7 @@ func TestClientFinalizer(t *testing.T) {
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
clientFinalizer(c) clientFinalizer(client)
reader := bufio.NewScanner(buf) reader := bufio.NewScanner(buf)
var lines int var lines int
expectedLines := []string{ expectedLines := []string{
@@ -700,6 +711,7 @@ func (m *manualCollector) Collect(t time.Time) {
func (m *manualCollector) Start(_ time.Duration, f func(t time.Time)) error { func (m *manualCollector) Start(_ time.Duration, f func(t time.Time)) error {
m.f = f m.f = f
return nil return nil
} }
@@ -717,12 +729,14 @@ func (m *manualClock) Add(d time.Duration) time.Time {
v := m.current.Add(d) v := m.current.Add(d)
m.current = v m.current = v
m.mux.Unlock() m.mux.Unlock()
return v return v
} }
func (m *manualClock) Now() time.Time { func (m *manualClock) Now() time.Time {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
return m.current return m.current
} }
@@ -735,6 +749,7 @@ type manualAgent struct {
func (n *manualAgent) SetHandler(h Handler) error { func (n *manualAgent) SetHandler(h Handler) error {
n.h = h n.h = h
return nil return nil
} }
@@ -748,6 +763,7 @@ func (n *manualAgent) Process(m *Message) error {
if n.process != nil { if n.process != nil {
return n.process(m) return n.process(m)
} }
return nil return nil
} }
@@ -759,6 +775,7 @@ func (n *manualAgent) Stop(id [TransactionIDSize]byte) error {
if n.stop != nil { if n.stop != nil {
return n.stop(id) return n.stop(id)
} }
return nil return nil
} }
@@ -788,9 +805,10 @@ func TestClientRetransmission(t *testing.T) {
Message: response, Message: response,
}) })
} }
return nil return nil
} }
c, err := NewClient(connR, client, err := NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -799,7 +817,7 @@ func TestClientRetransmission(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.SetRTO(time.Second) client.SetRTO(time.Second)
gotReads := make(chan struct{}) gotReads := make(chan struct{})
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
@@ -819,7 +837,7 @@ func TestClientRetransmission(t *testing.T) {
} }
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if event.Error != nil { if event.Error != nil {
t.Error("failed") t.Error("failed")
} }
@@ -829,7 +847,9 @@ func TestClientRetransmission(t *testing.T) {
<-gotReads <-gotReads
} }
func testClientDoConcurrent(t *testing.T, concurrency int) { func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
t.Helper()
response := MustBuild(TransactionID, BindingSuccess) response := MustBuild(TransactionID, BindingSuccess)
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
@@ -846,9 +866,10 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
TransactionID: id, TransactionID: id,
Message: response, Message: response,
}) })
return nil return nil
} }
c, err := NewClient(connR, client, err := NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -856,7 +877,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.SetRTO(time.Second) client.SetRTO(time.Second)
conns := new(sync.WaitGroup) conns := new(sync.WaitGroup)
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
for i := 0; i < concurrency; i++ { for i := 0; i < concurrency; i++ {
@@ -880,7 +901,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if doErr := c.Do(MustBuild(TransactionID, BindingRequest), func(event Event) { if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
if event.Error != nil { if event.Error != nil {
t.Error("failed") t.Error("failed")
} }
@@ -962,14 +983,14 @@ func TestClient_Close(t *testing.T) {
} }
func TestClientDefaultHandler(t *testing.T) { func TestClientDefaultHandler(t *testing.T) {
a := &TestAgent{ agent := &TestAgent{
e: make(chan Event), e: make(chan Event),
} }
id := NewTransactionID() id := NewTransactionID()
handlerCalled := make(chan struct{}) handlerCalled := make(chan struct{})
called := false called := false
c, createErr := NewClient(noopConnection{}, client, createErr := NewClient(noopConnection{},
WithAgent(a), WithAgent(agent),
WithHandler(func(e Event) { WithHandler(func(e Event) {
if called { if called {
t.Error("should not be called twice") t.Error("should not be called twice")
@@ -985,7 +1006,7 @@ func TestClientDefaultHandler(t *testing.T) {
t.Fatal(createErr) t.Fatal(createErr)
} }
go func() { go func() {
a.h(Event{ agent.h(Event{
TransactionID: id, TransactionID: id,
}) })
}() }()
@@ -995,11 +1016,11 @@ func TestClientDefaultHandler(t *testing.T) {
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
t.Fatal("timed out") t.Fatal("timed out")
} }
if closeErr := c.Close(); closeErr != nil { if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr) t.Error(closeErr)
} }
// Handler call should be ignored. // Handler call should be ignored.
a.h(Event{}) agent.h(Event{})
} }
func TestClientClosedStart(t *testing.T) { func TestClientClosedStart(t *testing.T) {
@@ -1047,9 +1068,10 @@ func TestWithNoRetransmit(t *testing.T) {
Error: ErrTransactionTimeOut, Error: ErrTransactionTimeOut,
}) })
} }
return nil return nil
} }
c, err := NewClient(connR, client, err := NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -1071,7 +1093,7 @@ func TestWithNoRetransmit(t *testing.T) {
} }
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrTransactionTimeOut) { if !errors.Is(event.Error, ErrTransactionTimeOut) {
t.Error("unexpected error") t.Error("unexpected error")
} }
@@ -1087,7 +1109,7 @@ func (c callbackClock) Now() time.Time {
return c() return c()
} }
func TestClientRTOStartErr(t *testing.T) { func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
response := MustBuild(TransactionID, BindingSuccess) response := MustBuild(TransactionID, BindingSuccess)
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
@@ -1116,13 +1138,14 @@ func TestClientRTOStartErr(t *testing.T) {
} else { } else {
t.Log("clock returned") t.Log("clock returned")
} }
return time.Now() return time.Now()
}) })
agent := &manualAgent{} agent := &manualAgent{}
attempt := 0 attempt := 0
gotReads := make(chan struct{}) gotReads := make(chan struct{})
var ( var (
c *Client client *Client
startClientErr error startClientErr error
) )
agent.start = func(id [TransactionIDSize]byte, _ time.Time) error { agent.start = func(id [TransactionIDSize]byte, _ time.Time) error {
@@ -1146,7 +1169,7 @@ func TestClientRTOStartErr(t *testing.T) {
t.Log("clock locked") t.Log("clock locked")
<-clockLocked <-clockLocked
t.Log("closing client") t.Log("closing client")
if closeErr := c.Close(); closeErr != nil { if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr) t.Error(closeErr)
} }
t.Log("client closed, unlocking clock") t.Log("client closed, unlocking clock")
@@ -1154,9 +1177,10 @@ func TestClientRTOStartErr(t *testing.T) {
t.Log("clock unlocked") t.Log("clock unlocked")
}() }()
} }
return nil return nil
} }
c, startClientErr = NewClient(connR, client, startClientErr = NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -1186,7 +1210,7 @@ func TestClientRTOStartErr(t *testing.T) {
t.Log("starting") t.Log("starting")
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrClientClosed) { if !errors.Is(event.Error, ErrClientClosed) {
t.Error(event.Error) t.Error(event.Error)
} }
@@ -1203,7 +1227,7 @@ func TestClientRTOStartErr(t *testing.T) {
} }
} }
func TestClientRTOWriteErr(t *testing.T) { func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
response := MustBuild(TransactionID, BindingSuccess) response := MustBuild(TransactionID, BindingSuccess)
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
@@ -1232,13 +1256,14 @@ func TestClientRTOWriteErr(t *testing.T) {
} else { } else {
t.Log("clock returned") t.Log("clock returned")
} }
return time.Now() return time.Now()
}) })
agent := &manualAgent{} agent := &manualAgent{}
attempt := 0 attempt := 0
gotReads := make(chan struct{}) gotReads := make(chan struct{})
var ( var (
c *Client client *Client
startClientErr error startClientErr error
) )
agentStopErr := errClientAgentCantStop agentStopErr := errClientAgentCantStop
@@ -1274,9 +1299,10 @@ func TestClientRTOWriteErr(t *testing.T) {
t.Log("clock unlocked") t.Log("clock unlocked")
}() }()
} }
return nil return nil
} }
c, startClientErr = NewClient(connR, client, startClientErr = NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -1306,7 +1332,7 @@ func TestClientRTOWriteErr(t *testing.T) {
t.Log("starting") t.Log("starting")
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
var e StopErr var e StopErr
if !errors.As(event.Error, &e) { if !errors.As(event.Error, &e) {
t.Error(event.Error) t.Error(event.Error)
@@ -1346,7 +1372,7 @@ func TestClientRTOAgentErr(t *testing.T) {
attempt := 0 attempt := 0
gotReads := make(chan struct{}) gotReads := make(chan struct{})
var ( var (
c *Client client *Client
startClientErr error startClientErr error
) )
agentStartErr := errClientStartRefused agentStartErr := errClientStartRefused
@@ -1361,9 +1387,10 @@ func TestClientRTOAgentErr(t *testing.T) {
} else { } else {
return agentStartErr return agentStartErr
} }
return nil return nil
} }
c, startClientErr = NewClient(connR, client, startClientErr = NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -1384,7 +1411,7 @@ func TestClientRTOAgentErr(t *testing.T) {
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
t.Log("starting") t.Log("starting")
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, agentStartErr) { if !errors.Is(event.Error, agentStartErr) {
t.Error(event.Error) t.Error(event.Error)
} }
@@ -1415,9 +1442,10 @@ func TestClient_HandleProcessError(t *testing.T) {
processCalled := make(chan struct{}, 1) processCalled := make(chan struct{}, 1)
agent.process = func(*Message) error { agent.process = func(*Message) error {
processCalled <- struct{}{} processCalled <- struct{}{}
return ErrAgentClosed return ErrAgentClosed
} }
c, startClientErr := NewClient(connR, client, startClientErr := NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -1440,7 +1468,7 @@ func TestClient_HandleProcessError(t *testing.T) {
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
t.Error("reads timeout") t.Error("reads timeout")
} }
if closeErr := c.Close(); closeErr != nil { if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr) t.Error(closeErr)
} }
} }
@@ -1475,9 +1503,10 @@ func TestClientImmediateTimeout(t *testing.T) {
Error: ErrTransactionTimeOut, Error: ErrTransactionTimeOut,
}) })
} }
return nil return nil
} }
c, err := NewClient(connR, client, err := NewClient(connR,
WithAgent(agent), WithAgent(agent),
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
@@ -1498,7 +1527,7 @@ func TestClientImmediateTimeout(t *testing.T) {
} }
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
c.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
if errors.Is(e.Error, ErrTransactionTimeOut) { if errors.Is(e.Error, ErrTransactionTimeOut) {
t.Error("unexpected error") t.Error("unexpected error")
} }

View File

@@ -30,7 +30,7 @@ var (
realRand = flag.Bool("crypt", false, "use crypto/rand as random source") //nolint:gochecknoglobals realRand = flag.Bool("crypt", false, "use crypto/rand as random source") //nolint:gochecknoglobals
) )
func main() { //nolint:gocognit func main() { //nolint:gocognit,cyclop
flag.Parse() flag.Parse()
uri, err := stun.ParseURI(*uriStr) uri, err := stun.ParseURI(*uriStr)
if err != nil { if err != nil {
@@ -88,7 +88,7 @@ func main() { //nolint:gocognit
log.Print("Using crypto/rand as random source for transaction id") log.Print("Using crypto/rand as random source for transaction id")
} }
for i := 0; i < *workers; i++ { for i := 0; i < *workers; i++ {
c, clientErr := stun.DialURI(uri, &stun.DialConfig{}) client, clientErr := stun.DialURI(uri, &stun.DialConfig{})
if clientErr != nil { if clientErr != nil {
log.Panicf("Failed to create client: %s", clientErr) log.Panicf("Failed to create client: %s", clientErr)
} }
@@ -105,12 +105,13 @@ func main() { //nolint:gocognit
req.Type = stun.BindingRequest req.Type = stun.BindingRequest
req.WriteHeader() req.WriteHeader()
atomic.AddInt64(&request, 1) atomic.AddInt64(&request, 1)
if doErr := c.Do(req, func(event stun.Event) { if doErr := client.Do(req, func(event stun.Event) {
if event.Error != nil { if event.Error != nil {
if !errors.Is(event.Error, stun.ErrTransactionTimeOut) { if !errors.Is(event.Error, stun.ErrTransactionTimeOut) {
log.Printf("Failed STUN transaction: %s", event.Error) log.Printf("Failed STUN transaction: %s", event.Error)
} }
atomic.AddInt64(&requestErr, 1) atomic.AddInt64(&requestErr, 1)
return return
} }
atomic.AddInt64(&requestOK, 1) atomic.AddInt64(&requestOK, 1)

View File

@@ -31,11 +31,11 @@ func main() {
} }
// we only try the first address, so restrict ourselves to IPv4 // we only try the first address, so restrict ourselves to IPv4
c, err := stun.DialURI(uri, &stun.DialConfig{}) client, err := stun.DialURI(uri, &stun.DialConfig{})
if err != nil { if err != nil {
log.Fatalf("Failed to dial: %s", err) log.Fatalf("Failed to dial: %s", err)
} }
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
if res.Error != nil { if res.Error != nil {
log.Fatalf("Failed STUN transaction: %s", res.Error) log.Fatalf("Failed STUN transaction: %s", res.Error)
} }
@@ -49,7 +49,7 @@ func main() {
}); err != nil { }); err != nil {
log.Fatal("Do:", err) log.Fatal("Do:", err)
} }
if err := c.Close(); err != nil { if err := client.Close(); err != nil {
log.Fatalf("Failed to close connection: %s", err) log.Fatalf("Failed to close connection: %s", err)
} }
} }

View File

@@ -84,7 +84,7 @@ func multiplex(conn *net.UDPConn, stunAddr net.Addr, stunConn io.Reader) {
var stunServer = flag.String("stun", "stun.l.google.com:19302", "STUN Server to use") //nolint:gochecknoglobals var stunServer = flag.String("stun", "stun.l.google.com:19302", "STUN Server to use") //nolint:gochecknoglobals
func main() { func main() { //nolint:cyclop
flag.Parse() flag.Parse()
isServer := flag.Arg(0) == "" isServer := flag.Arg(0) == ""
@@ -112,7 +112,7 @@ func main() {
stunL, stunR := net.Pipe() stunL, stunR := net.Pipe()
c, err := stun.NewClient(stunR) client, err := stun.NewClient(stunR)
if err != nil { if err != nil {
log.Panicf("Failed to create client: %s", err) log.Panicf("Failed to create client: %s", err)
} }
@@ -134,7 +134,7 @@ func main() {
// This can fail if your NAT Server is strict and will use separate ports // This can fail if your NAT Server is strict and will use separate ports
// for application data and STUN // for application data and STUN
var gotAddr stun.XORMappedAddress var gotAddr stun.XORMappedAddress
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
if res.Error != nil { if res.Error != nil {
log.Panicf("Failed STUN transaction: %s", res.Error) log.Panicf("Failed STUN transaction: %s", res.Error)
} }
@@ -153,7 +153,7 @@ func main() {
// Any ping-pong will work, but we are just making binding requests. // Any ping-pong will work, but we are just making binding requests.
// Note that STUN Server is not mandatory for keep alive, application // Note that STUN Server is not mandatory for keep alive, application
// data will keep alive that binding too. // data will keep alive that binding too.
go keepAlive(c) go keepAlive(client)
notify := make(chan os.Signal, 1) notify := make(chan os.Signal, 1)
signal.Notify(notify, os.Interrupt, syscall.SIGTERM) signal.Notify(notify, os.Interrupt, syscall.SIGTERM)
@@ -168,6 +168,7 @@ func main() {
} }
case <-notify: case <-notify:
log.Println("Stopping") log.Println("Stopping")
return return
} }
} }
@@ -203,10 +204,12 @@ func main() {
case m := <-messages: case m := <-messages:
log.Printf("Got response from %s: %s", m.addr, m.text) log.Printf("Got response from %s: %s", m.addr, m.text)
return return
case <-notify: case <-notify:
log.Print("Stopping") log.Print("Stopping")
return return
} }
} }

View File

@@ -30,10 +30,14 @@ func (c *stunServerConn) Close() error {
} }
var ( var (
addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address") //nolint:gochecknoglobals //nolint:gochecknoglobals
timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response") //nolint:gochecknoglobals addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address")
verbose = flag.Int("verbose", 1, "the verbosity level") //nolint:gochecknoglobals //nolint:gochecknoglobals
log logging.LeveledLogger //nolint:gochecknoglobals timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response")
//nolint:gochecknoglobals
verbose = flag.Int("verbose", 1, "the verbosity level")
//nolint:gochecknoglobals
log logging.LeveledLogger
) )
const ( const (
@@ -70,11 +74,12 @@ func main() {
} }
} }
// RFC5780: 4.3. Determining NAT Mapping Behavior // RFC5780: 4.3. Determining NAT Mapping Behavior.
func mappingTests(addrStr string) error { func mappingTests(addrStr string) error { //nolint:cyclop
mapTestConn, err := connect(addrStr) mapTestConn, err := connect(addrStr)
if err != nil { if err != nil {
log.Warnf("Error creating STUN connection: %s", err) log.Warnf("Error creating STUN connection: %s", err)
return err return err
} }
@@ -91,11 +96,13 @@ func mappingTests(addrStr string) error {
resps1 := parse(resp) resps1 := parse(resp)
if resps1.xorAddr == nil || resps1.otherAddr == nil { if resps1.xorAddr == nil || resps1.otherAddr == nil {
log.Info("Error: NAT discovery feature not supported by this server") log.Info("Error: NAT discovery feature not supported by this server")
return errNoOtherAddress return errNoOtherAddress
} }
addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String()) addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String())
if err != nil { if err != nil {
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps1.otherAddr) log.Infof("Failed resolving OTHER-ADDRESS: %v", resps1.otherAddr)
return err return err
} }
mapTestConn.OtherAddr = addr mapTestConn.OtherAddr = addr
@@ -104,6 +111,7 @@ func mappingTests(addrStr string) error {
// Assert mapping behavior // Assert mapping behavior
if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() { if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() {
log.Warn("=> NAT mapping behavior: endpoint independent (no NAT)") log.Warn("=> NAT mapping behavior: endpoint independent (no NAT)")
return nil return nil
} }
@@ -121,6 +129,7 @@ func mappingTests(addrStr string) error {
log.Infof("Received XOR-MAPPED-ADDRESS: %v", resps2.xorAddr) log.Infof("Received XOR-MAPPED-ADDRESS: %v", resps2.xorAddr)
if resps2.xorAddr.String() == resps1.xorAddr.String() { if resps2.xorAddr.String() == resps1.xorAddr.String() {
log.Warn("=> NAT mapping behavior: endpoint independent") log.Warn("=> NAT mapping behavior: endpoint independent")
return nil return nil
} }
@@ -143,11 +152,12 @@ func mappingTests(addrStr string) error {
return mapTestConn.Close() return mapTestConn.Close()
} }
// RFC5780: 4.4. Determining NAT Filtering Behavior // RFC5780: 4.4. Determining NAT Filtering Behavior.
func filteringTests(addrStr string) error { func filteringTests(addrStr string) error { //nolint:cyclop
mapTestConn, err := connect(addrStr) mapTestConn, err := connect(addrStr)
if err != nil { if err != nil {
log.Warnf("Error creating STUN connection: %s", err) log.Warnf("Error creating STUN connection: %s", err)
return err return err
} }
@@ -162,11 +172,13 @@ func filteringTests(addrStr string) error {
resps := parse(resp) resps := parse(resp)
if resps.xorAddr == nil || resps.otherAddr == nil { if resps.xorAddr == nil || resps.otherAddr == nil {
log.Warn("Error: NAT discovery feature not supported by this server") log.Warn("Error: NAT discovery feature not supported by this server")
return errNoOtherAddress return errNoOtherAddress
} }
addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String()) addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String())
if err != nil { if err != nil {
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps.otherAddr) log.Infof("Failed resolving OTHER-ADDRESS: %v", resps.otherAddr)
return err return err
} }
mapTestConn.OtherAddr = addr mapTestConn.OtherAddr = addr
@@ -180,6 +192,7 @@ func filteringTests(addrStr string) error {
if err == nil { if err == nil {
parse(resp) // just to print out the resp parse(resp) // just to print out the resp
log.Warn("=> NAT filtering behavior: endpoint independent") log.Warn("=> NAT filtering behavior: endpoint independent")
return nil return nil
} else if !errors.Is(err, errTimedOut) { } else if !errors.Is(err, errTimedOut) {
return err // something else went wrong return err // something else went wrong
@@ -201,7 +214,7 @@ func filteringTests(addrStr string) error {
return mapTestConn.Close() return mapTestConn.Close()
} }
// Parse a STUN message // Parse a STUN message.
func parse(msg *stun.Message) (ret struct { func parse(msg *stun.Message) (ret struct {
xorAddr *stun.XORMappedAddress xorAddr *stun.XORMappedAddress
otherAddr *stun.OtherAddress otherAddr *stun.OtherAddress
@@ -249,15 +262,17 @@ func parse(msg *stun.Message) (ret struct {
log.Debugf("\t%v (l=%v)", attr, attr.Length) log.Debugf("\t%v (l=%v)", attr, attr.Length)
} }
} }
return ret return ret
} }
// Given an address string, returns a StunServerConn // Given an address string, returns a StunServerConn.
func connect(addrStr string) (*stunServerConn, error) { func connect(addrStr string) (*stunServerConn, error) {
log.Infof("Connecting to STUN server: %s", addrStr) log.Infof("Connecting to STUN server: %s", addrStr)
addr, err := net.ResolveUDPAddr("udp4", addrStr) addr, err := net.ResolveUDPAddr("udp4", addrStr)
if err != nil { if err != nil {
log.Warnf("Error resolving address: %s", err) log.Warnf("Error resolving address: %s", err)
return nil, err return nil, err
} }
@@ -278,7 +293,7 @@ func connect(addrStr string) (*stunServerConn, error) {
}, nil }, nil
} }
// Send request and wait for response or timeout // Send request and wait for response or timeout.
func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Message, error) { func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Message, error) {
_ = msg.NewTransactionID() _ = msg.NewTransactionID()
log.Infof("Sending to %v: (%v bytes)", addr, msg.Length+messageHeaderSize) log.Infof("Sending to %v: (%v bytes)", addr, msg.Length+messageHeaderSize)
@@ -289,6 +304,7 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
_, err := c.conn.WriteTo(msg.Raw, addr) _, err := c.conn.WriteTo(msg.Raw, addr)
if err != nil { if err != nil {
log.Warnf("Error sending request to %v", addr) log.Warnf("Error sending request to %v", addr)
return nil, err return nil, err
} }
@@ -298,9 +314,11 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
if !ok { if !ok {
return nil, errResponseMessage return nil, errResponseMessage
} }
return m, nil return m, nil
case <-time.After(time.Duration(*timeoutPtr) * time.Second): case <-time.After(time.Duration(*timeoutPtr) * time.Second):
log.Infof("Timed out waiting for response from server %v", addr) log.Infof("Timed out waiting for response from server %v", addr)
return nil, errTimedOut return nil, errTimedOut
} }
} }
@@ -315,6 +333,7 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
n, addr, err := conn.ReadFromUDP(buf) n, addr, err := conn.ReadFromUDP(buf)
if err != nil { if err != nil {
close(messages) close(messages)
return return
} }
log.Infof("Response from %v: (%v bytes)", addr, n) log.Infof("Response from %v: (%v bytes)", addr, n)
@@ -326,11 +345,13 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
if err != nil { if err != nil {
log.Infof("Error decoding message: %v", err) log.Infof("Error decoding message: %v", err)
close(messages) close(messages)
return return
} }
messages <- m messages <- m
} }
}() }()
return return
} }

View File

@@ -26,7 +26,7 @@ const (
timeoutMillis = 500 timeoutMillis = 500
) )
func main() { //nolint:gocognit func main() { //nolint:gocognit,cyclop
flag.Parse() flag.Parse()
srvAddr, err := net.ResolveUDPAddr(udp, *server) srvAddr, err := net.ResolveUDPAddr(udp, *server)
@@ -87,11 +87,13 @@ func main() { //nolint:gocognit
decErr := m.Decode() decErr := m.Decode()
if decErr != nil { if decErr != nil {
log.Println("decode:", decErr) log.Println("decode:", decErr)
break break
} }
var xorAddr stun.XORMappedAddress var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil { if getErr := xorAddr.GetFrom(m); getErr != nil {
log.Println("getFrom:", getErr) log.Println("getFrom:", getErr)
break break
} }
@@ -160,6 +162,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
n, _, err := conn.ReadFromUDP(buf) n, _, err := conn.ReadFromUDP(buf)
if err != nil { if err != nil {
close(messages) close(messages)
return return
} }
buf = buf[:n] buf = buf[:n]
@@ -167,6 +170,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
messages <- buf messages <- buf
} }
}() }()
return messages return messages
} }

View File

@@ -14,7 +14,7 @@ import (
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
) )
func test(network string) { func test(network string) { //nolint:cyclop
addr := resolve(network) addr := resolve(network)
fmt.Println("START", strings.ToUpper(addr.Network())) //nolint fmt.Println("START", strings.ToUpper(addr.Network())) //nolint
var ( var (

View File

@@ -11,7 +11,7 @@ import (
// ErrorCodeAttribute represents ERROR-CODE attribute. // ErrorCodeAttribute represents ERROR-CODE attribute.
// //
// RFC 5389 Section 15.6 // RFC 5389 Section 15.6.
type ErrorCodeAttribute struct { type ErrorCodeAttribute struct {
Code ErrorCode Code ErrorCode
Reason []byte Reason []byte
@@ -31,7 +31,7 @@ const (
) )
// AddTo adds ERROR-CODE to m. // AddTo adds ERROR-CODE to m.
func (c ErrorCodeAttribute) AddTo(m *Message) error { func (c ErrorCodeAttribute) AddTo(msg *Message) error {
value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB) value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB)
if err := CheckOverflow(AttrErrorCode, if err := CheckOverflow(AttrErrorCode,
len(c.Reason)+errorCodeReasonStart, len(c.Reason)+errorCodeReasonStart,
@@ -45,26 +45,28 @@ func (c ErrorCodeAttribute) AddTo(m *Message) error {
value[errorCodeClassByte] = class value[errorCodeClassByte] = class
value[errorCodeNumberByte] = number value[errorCodeNumberByte] = number
copy(value[errorCodeReasonStart:], c.Reason) copy(value[errorCodeReasonStart:], c.Reason)
m.Add(AttrErrorCode, value) msg.Add(AttrErrorCode, value)
return nil return nil
} }
// GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid. // GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid.
func (c *ErrorCodeAttribute) GetFrom(m *Message) error { func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
v, err := m.Get(AttrErrorCode) value, err := m.Get(AttrErrorCode)
if err != nil { if err != nil {
return err return err
} }
if len(v) < errorCodeReasonStart { if len(value) < errorCodeReasonStart {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
var ( var (
class = uint16(v[errorCodeClassByte]) class = uint16(value[errorCodeClassByte])
number = uint16(v[errorCodeNumberByte]) number = uint16(value[errorCodeNumberByte])
code = int(class*errorCodeModulo + number) code = int(class*errorCodeModulo + number)
) )
c.Code = ErrorCode(code) c.Code = ErrorCode(code)
c.Reason = v[errorCodeReasonStart:] c.Reason = value[errorCodeReasonStart:]
return nil return nil
} }
@@ -86,6 +88,7 @@ func (c ErrorCode) AddTo(m *Message) error {
Code: c, Code: c,
Reason: reason, Reason: reason,
} }
return a.AddTo(m) return a.AddTo(m)
} }
@@ -108,7 +111,7 @@ const (
// Error codes from RFC 5766. // Error codes from RFC 5766.
// //
// RFC 5766 Section 15 // RFC 5766 Section 15.
const ( const (
CodeForbidden ErrorCode = 403 // Forbidden CodeForbidden ErrorCode = 403 // Forbidden
CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch
@@ -120,7 +123,7 @@ const (
// Error codes from RFC 6062. // Error codes from RFC 6062.
// //
// RFC 6062 Section 6.3 // RFC 6062 Section 6.3.
const ( const (
CodeConnAlreadyExists ErrorCode = 446 CodeConnAlreadyExists ErrorCode = 446
CodeConnTimeoutOrFailure ErrorCode = 447 CodeConnTimeoutOrFailure ErrorCode = 447
@@ -128,7 +131,7 @@ const (
// Error codes from RFC 6156. // Error codes from RFC 6156.
// //
// RFC 6156 Section 10.2 // RFC 6156 Section 10.2.
const ( const (
CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported
CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch

View File

@@ -92,23 +92,23 @@ func TestMessage_AddErrorCode(t *testing.T) {
} }
func TestErrorCode(t *testing.T) { func TestErrorCode(t *testing.T) {
a := &ErrorCodeAttribute{ attr := &ErrorCodeAttribute{
Code: 404, Code: 404,
Reason: []byte("not found!"), Reason: []byte("not found!"),
} }
if a.String() != "404: not found!" { if attr.String() != "404: not found!" {
t.Error("bad string", a) t.Error("bad string", attr)
} }
m := New() m := New()
cod := ErrorCode(666) cod := ErrorCode(666)
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) { if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) {
t.Error("should be ErrNoDefaultReason", err) t.Error("should be ErrNoDefaultReason", err)
} }
if err := a.GetFrom(m); err == nil { if err := attr.GetFrom(m); err == nil {
t.Error("attr should not be in message") t.Error("attr should not be in message")
} }
a.Reason = make([]byte, 2048) attr.Reason = make([]byte, 2048)
if err := a.AddTo(m); err == nil { if err := attr.AddTo(m); err == nil {
t.Error("should error") t.Error("should error")
} }
} }

View File

@@ -10,7 +10,7 @@ import (
// FingerprintAttr represents FINGERPRINT attribute. // FingerprintAttr represents FINGERPRINT attribute.
// //
// RFC 5389 Section 15.5 // RFC 5389 Section 15.5.
type FingerprintAttr struct{} type FingerprintAttr struct{}
// ErrFingerprintMismatch means that computed fingerprint differs from expected. // ErrFingerprintMismatch means that computed fingerprint differs from expected.
@@ -50,6 +50,7 @@ func (FingerprintAttr) AddTo(m *Message) error {
bin.PutUint32(b, val) bin.PutUint32(b, val)
m.Length = l m.Length = l
m.Add(AttrFingerprint, b) m.Add(AttrFingerprint, b)
return nil return nil
} }
@@ -66,5 +67,6 @@ func (FingerprintAttr) Check(m *Message) error {
val := bin.Uint32(b) val := bin.Uint32(b)
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize) attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
expected := FingerprintValue(m.Raw[:attrStart]) expected := FingerprintValue(m.Raw[:attrStart])
return checkFingerprint(val, expected) return checkFingerprint(val, expected)
} }

View File

@@ -13,20 +13,20 @@ import (
func BenchmarkFingerprint_AddTo(b *testing.B) { func BenchmarkFingerprint_AddTo(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m := new(Message) msg := new(Message)
s := NewSoftware("software") s := NewSoftware("software")
addr := &XORMappedAddress{ addr := &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
addAttr(b, m, addr) addAttr(b, msg, addr)
addAttr(b, m, s) addAttr(b, msg, s)
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(msg.Raw)))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Fingerprint.AddTo(m) //nolint:errcheck,gosec Fingerprint.AddTo(msg) //nolint:errcheck,gosec
m.WriteLength() msg.WriteLength()
m.Length -= attributeHeaderSize + fingerprintSize msg.Length -= attributeHeaderSize + fingerprintSize
m.Raw = m.Raw[:m.Length+messageHeaderSize] msg.Raw = msg.Raw[:msg.Length+messageHeaderSize]
m.Attributes = m.Attributes[:len(m.Attributes)-1] msg.Attributes = msg.Attributes[:len(msg.Attributes)-1]
} }
} }

View File

@@ -109,6 +109,7 @@ func FuzzSetters(f *testing.F) {
if !IsAttrSizeOverflow(err) { if !IsAttrSizeOverflow(err) {
t.Fatal(err) t.Fatal(err)
} }
return return
} }
@@ -150,5 +151,6 @@ func (a attributes) pick(v byte) struct {
t AttrType t AttrType
} { } {
idx := int(v) % len(a) idx := int(v) % len(a)
return a[idx] return a[idx]
} }

View File

@@ -44,6 +44,7 @@ func (m *Message) Build(setters ...Setter) error {
return err return err
} }
} }
return nil return nil
} }
@@ -54,6 +55,7 @@ func (m *Message) Check(checkers ...Checker) error {
return err return err
} }
} }
return nil return nil
} }
@@ -64,6 +66,7 @@ func (m *Message) Parse(getters ...Getter) error {
return err return err
} }
} }
return nil return nil
} }
@@ -73,6 +76,7 @@ func MustBuild(setters ...Setter) *Message {
if err != nil { if err != nil {
panic(err) //nolint panic(err) //nolint
} }
return m return m
} }
@@ -82,6 +86,7 @@ func Build(setters ...Setter) (*Message, error) {
if err := m.Build(setters...); err != nil { if err := m.Build(setters...); err != nil {
return nil, err return nil, err
} }
return m, nil return m, nil
} }
@@ -105,5 +110,6 @@ func (m *Message) ForEach(t AttrType, f func(m *Message) error) error {
return err return err
} }
} }
return nil return nil
} }

View File

@@ -13,7 +13,7 @@ import (
func BenchmarkBuildOverhead(b *testing.B) { func BenchmarkBuildOverhead(b *testing.B) {
var ( var (
t = BindingRequest msgType = BindingRequest
username = NewUsername("username") username = NewUsername("username")
nonce = NewNonce("nonce") nonce = NewNonce("nonce")
realm = NewRealm("example.org") realm = NewRealm("example.org")
@@ -22,14 +22,14 @@ func BenchmarkBuildOverhead(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m := new(Message) m := new(Message)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.Build(&t, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec m.Build(&msgType, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec
} }
}) })
b.Run("BuildNonPointer", func(b *testing.B) { b.Run("BuildNonPointer", func(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m := new(Message) m := new(Message)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.Build(t, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec m.Build(msgType, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec
} }
}) })
b.Run("Raw", func(b *testing.B) { b.Run("Raw", func(b *testing.B) {
@@ -38,7 +38,7 @@ func BenchmarkBuildOverhead(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.Reset() m.Reset()
m.WriteHeader() m.WriteHeader()
m.SetType(t) m.SetType(msgType)
username.AddTo(m) //nolint:errcheck,gosec username.AddTo(m) //nolint:errcheck,gosec
nonce.AddTo(m) //nolint:errcheck,gosec nonce.AddTo(m) //nolint:errcheck,gosec
realm.AddTo(m) //nolint:errcheck,gosec realm.AddTo(m) //nolint:errcheck,gosec
@@ -52,7 +52,7 @@ func TestMessage_Apply(t *testing.T) {
integrity = NewShortTermIntegrity("password") integrity = NewShortTermIntegrity("password")
decoded = new(Message) decoded = new(Message)
) )
m, err := Build(BindingRequest, TransactionID, msg, err := Build(BindingRequest, TransactionID,
NewUsername("username"), NewUsername("username"),
NewNonce("nonce"), NewNonce("nonce"),
NewRealm("example.org"), NewRealm("example.org"),
@@ -62,13 +62,13 @@ func TestMessage_Apply(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("failed to build:", err) t.Fatal("failed to build:", err)
} }
if err = m.Check(Fingerprint, integrity); err != nil { if err = msg.Check(Fingerprint, integrity); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, err := decoded.Write(m.Raw); err != nil { if _, err := decoded.Write(msg.Raw); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !decoded.Equal(m) { if !decoded.Equal(msg) {
t.Error("not equal") t.Error("not equal")
} }
if err := integrity.Check(decoded); err != nil { if err := integrity.Check(decoded); err != nil {
@@ -96,32 +96,32 @@ func (e errReturner) GetFrom(*Message) error {
func TestHelpersErrorHandling(t *testing.T) { func TestHelpersErrorHandling(t *testing.T) {
m := New() m := New()
e := errReturner{Err: errTError} errReturn := errReturner{Err: errTError}
if err := m.Build(e); !errors.Is(err, e.Err) { if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", e.Err) t.Error(err, "!=", errReturn.Err)
} }
if err := m.Check(e); !errors.Is(err, e.Err) { if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", e.Err) t.Error(err, "!=", errReturn.Err)
} }
if err := m.Parse(e); !errors.Is(err, e.Err) { if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", e.Err) t.Error(err, "!=", errReturn.Err)
} }
t.Run("MustBuild", func(t *testing.T) { t.Run("MustBuild", func(t *testing.T) {
t.Run("Positive", func(*testing.T) { t.Run("Positive", func(*testing.T) {
MustBuild(NewTransactionIDSetter(transactionID{})) MustBuild(NewTransactionIDSetter(transactionID{}))
}) })
defer func() { defer func() {
if p, ok := recover().(error); !ok || !errors.Is(p, e.Err) { if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) {
t.Errorf("%s != %s", t.Errorf("%s != %s",
p, e.Err, p, errReturn.Err,
) )
} }
}() }()
MustBuild(e) MustBuild(errReturn)
}) })
} }
func TestMessage_ForEach(t *testing.T) { func TestMessage_ForEach(t *testing.T) { //nolint:cyclop
initial := New() initial := New()
if err := initial.Build( if err := initial.Build(
NewRealm("realm1"), NewRealm("realm2"), NewRealm("realm1"), NewRealm("realm2"),
@@ -135,6 +135,7 @@ func TestMessage_ForEach(t *testing.T) {
); err != nil { ); err != nil {
t.Fatal(err) t.Fatal(err)
} }
return m return m
} }
t.Run("NoResults", func(t *testing.T) { t.Run("NoResults", func(t *testing.T) {
@@ -144,6 +145,7 @@ func TestMessage_ForEach(t *testing.T) {
} }
if err := m.ForEach(AttrUsername, func(*Message) error { if err := m.ForEach(AttrUsername, func(*Message) error {
t.Error("should not be called") t.Error("should not be called")
return nil return nil
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -160,6 +162,7 @@ func TestMessage_ForEach(t *testing.T) {
t.Error("called multiple times") t.Error("called multiple times")
} }
calls++ calls++
return ErrAttributeNotFound return ErrAttributeNotFound
}); !errors.Is(err, ErrAttributeNotFound) { }); !errors.Is(err, ErrAttributeNotFound) {
t.Fatal(err) t.Fatal(err)
@@ -169,14 +172,15 @@ func TestMessage_ForEach(t *testing.T) {
} }
}) })
t.Run("Positive", func(t *testing.T) { t.Run("Positive", func(t *testing.T) {
m := newMessage() msg := newMessage()
var realms []string var realms []string
if err := m.ForEach(AttrRealm, func(m *Message) error { if err := msg.ForEach(AttrRealm, func(m *Message) error {
var realm Realm var realm Realm
if err := realm.GetFrom(m); err != nil { if err := realm.GetFrom(m); err != nil {
return err return err
} }
realms = append(realms, realm.String()) realms = append(realms, realm.String())
return nil return nil
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -190,18 +194,18 @@ func TestMessage_ForEach(t *testing.T) {
if realms[1] != "realm2" { if realms[1] != "realm2" {
t.Error("bad value for 2 realm") t.Error("bad value for 2 realm")
} }
if !m.Equal(initial) { if !msg.Equal(initial) {
t.Error("m should be equal to initial") t.Error("m should be equal to initial")
} }
t.Run("ZeroAlloc", func(t *testing.T) { t.Run("ZeroAlloc", func(t *testing.T) {
m = newMessage() msg = newMessage()
var realm Realm var realm Realm
testutil.ShouldNotAllocate(t, func() { testutil.ShouldNotAllocate(t, func() {
if err := m.ForEach(AttrRealm, realm.GetFrom); err != nil { if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}) })
if !m.Equal(initial) { if !msg.Equal(initial) {
t.Error("m should be equal to initial") t.Error("m should be equal to initial")
} }
}) })
@@ -216,6 +220,7 @@ func ExampleMessage_ForEach() {
return err return err
} }
fmt.Println(r) fmt.Println(r)
return nil return nil
}); err != nil { }); err != nil {
fmt.Println("error:", err) fmt.Println("error:", err)

View File

@@ -14,22 +14,24 @@ import (
"testing" "testing"
) )
func loadCSV(t testing.TB, name string) [][]string { func loadCSV(tb testing.TB, name string) [][]string {
t.Helper() tb.Helper()
data := loadData(t, name)
data := loadData(tb, name)
r := csv.NewReader(bytes.NewReader(data)) r := csv.NewReader(bytes.NewReader(data))
r.Comment = '#' r.Comment = '#'
records, err := r.ReadAll() records, err := r.ReadAll()
if err != nil { if err != nil {
t.Fatal(err) tb.Fatal(err)
} }
return records return records
} }
func TestIANA(t *testing.T) { func TestIANA(t *testing.T) { //nolint:cyclop
t.Run("Methods", func(t *testing.T) { t.Run("Methods", func(t *testing.T) {
records := loadCSV(t, "stun-parameters-2.csv") records := loadCSV(t, "stun-parameters-2.csv")
m := make(map[string]Method) methods := make(map[string]Method)
for _, r := range records[1:] { for _, r := range records[1:] {
var ( var (
v = r[0] v = r[0]
@@ -43,10 +45,10 @@ func TestIANA(t *testing.T) {
t.Fatal(parseErr) t.Fatal(parseErr)
} }
t.Logf("value: 0x%x, name: %s", val, name) t.Logf("value: 0x%x, name: %s", val, name)
m[name] = Method(val) methods[name] = Method(val) //nolint:gosec // G115
} }
for val, name := range methodName() { for val, name := range methodName() {
mapped, ok := m[name] mapped, ok := methods[name]
if !ok { if !ok {
t.Errorf("failed to find method %s in IANA", name) t.Errorf("failed to find method %s in IANA", name)
} }
@@ -57,7 +59,7 @@ func TestIANA(t *testing.T) {
}) })
t.Run("Attributes", func(t *testing.T) { t.Run("Attributes", func(t *testing.T) {
records := loadCSV(t, "stun-parameters-4.csv") records := loadCSV(t, "stun-parameters-4.csv")
m := map[string]AttrType{} attrTypes := map[string]AttrType{}
for _, r := range records[1:] { for _, r := range records[1:] {
var ( var (
v = r[0] v = r[0]
@@ -71,16 +73,16 @@ func TestIANA(t *testing.T) {
t.Fatal(parseErr) t.Fatal(parseErr)
} }
t.Logf("value: 0x%x, name: %s", val, name) t.Logf("value: 0x%x, name: %s", val, name)
m[name] = AttrType(val) attrTypes[name] = AttrType(val) //nolint:gosec // G115
} }
// Not registered in IANA. // Not registered in IANA.
for k, v := range map[string]AttrType{ for k, v := range map[string]AttrType{
"ORIGIN": 0x802F, "ORIGIN": 0x802F,
} { } {
m[k] = v attrTypes[k] = v
} }
for val, name := range attrNames() { for val, name := range attrNames() {
mapped, ok := m[name] mapped, ok := attrTypes[name]
if !ok { if !ok {
t.Errorf("failed to find attribute %s in IANA", name) t.Errorf("failed to find attribute %s in IANA", name)
} }
@@ -91,7 +93,7 @@ func TestIANA(t *testing.T) {
}) })
t.Run("ErrorCodes", func(t *testing.T) { t.Run("ErrorCodes", func(t *testing.T) {
records := loadCSV(t, "stun-parameters-6.csv") records := loadCSV(t, "stun-parameters-6.csv")
m := map[string]ErrorCode{} errorCodes := map[string]ErrorCode{}
for _, r := range records[1:] { for _, r := range records[1:] {
var ( var (
v = r[0] v = r[0]
@@ -105,11 +107,11 @@ func TestIANA(t *testing.T) {
t.Fatal(parseErr) t.Fatal(parseErr)
} }
t.Logf("value: 0x%x, name: %s", val, name) t.Logf("value: 0x%x, name: %s", val, name)
m[name] = ErrorCode(val) errorCodes[name] = ErrorCode(val)
} }
for val, nameB := range errorReasons { for val, nameB := range errorReasons {
name := string(nameB) name := string(nameB)
mapped, ok := m[name] mapped, ok := errorCodes[name]
if !ok { if !ok {
t.Errorf("failed to find error code %s in IANA", name) t.Errorf("failed to find error code %s in IANA", name)
} }

View File

@@ -22,6 +22,7 @@ func NewLongTermIntegrity(username, realm, password string) MessageIntegrity {
k := strings.Join([]string{username, realm, password}, credentialsSep) k := strings.Join([]string{username, realm, password}, credentialsSep)
h := md5.New() //nolint:gosec h := md5.New() //nolint:gosec
fmt.Fprint(h, k) //nolint:errcheck fmt.Fprint(h, k) //nolint:errcheck
return MessageIntegrity(h.Sum(nil)) return MessageIntegrity(h.Sum(nil))
} }
@@ -36,13 +37,14 @@ func NewShortTermIntegrity(password string) MessageIntegrity {
// AddTo and Check methods are using zero-allocation version of hmac, see // AddTo and Check methods are using zero-allocation version of hmac, see
// newHMAC function and internal/hmac/pool.go. // newHMAC function and internal/hmac/pool.go.
// //
// RFC 5389 Section 15.4 // RFC 5389 Section 15.4.
type MessageIntegrity []byte type MessageIntegrity []byte
func newHMAC(key, message, buf []byte) []byte { func newHMAC(key, message, buf []byte) []byte {
mac := hmac.AcquireSHA1(key) mac := hmac.AcquireSHA1(key)
writeOrPanic(mac, message) writeOrPanic(mac, message)
defer hmac.PutSHA1(mac) defer hmac.PutSHA1(mac)
return mac.Sum(buf) return mac.Sum(buf)
} }
@@ -59,8 +61,8 @@ var ErrFingerprintBeforeIntegrity = errors.New("FINGERPRINT before MESSAGE-INTEG
// AddTo adds MESSAGE-INTEGRITY attribute to message. // AddTo adds MESSAGE-INTEGRITY attribute to message.
// //
// CPU costly, see BenchmarkMessageIntegrity_AddTo. // CPU costly, see BenchmarkMessageIntegrity_AddTo.
func (i MessageIntegrity) AddTo(m *Message) error { func (i MessageIntegrity) AddTo(msg *Message) error {
for _, a := range m.Attributes { for _, a := range msg.Attributes {
// Message should not contain FINGERPRINT attribute // Message should not contain FINGERPRINT attribute
// before MESSAGE-INTEGRITY. // before MESSAGE-INTEGRITY.
if a.Type == AttrFingerprint { if a.Type == AttrFingerprint {
@@ -70,19 +72,20 @@ func (i MessageIntegrity) AddTo(m *Message) error {
// The text used as input to HMAC is the STUN message, // The text used as input to HMAC is the STUN message,
// including the header, up to and including the attribute preceding the // including the header, up to and including the attribute preceding the
// MESSAGE-INTEGRITY attribute. // MESSAGE-INTEGRITY attribute.
length := m.Length length := msg.Length
// Adjusting m.Length to contain MESSAGE-INTEGRITY TLV. // Adjusting m.Length to contain MESSAGE-INTEGRITY TLV.
m.Length += messageIntegritySize + attributeHeaderSize msg.Length += messageIntegritySize + attributeHeaderSize
m.WriteLength() // writing length to m.Raw msg.WriteLength() // writing length to m.Raw
v := newHMAC(i, m.Raw, m.Raw[len(m.Raw):]) // calculating HMAC for adjusted m.Raw v := newHMAC(i, msg.Raw, msg.Raw[len(msg.Raw):]) // calculating HMAC for adjusted m.Raw
m.Length = length // changing m.Length back msg.Length = length // changing m.Length back
// Copy hmac value to temporary variable to protect it from resetting // Copy hmac value to temporary variable to protect it from resetting
// while processing m.Add call. // while processing m.Add call.
vBuf := make([]byte, sha1.Size) vBuf := make([]byte, sha1.Size)
copy(vBuf, v) copy(vBuf, v)
m.Add(AttrMessageIntegrity, vBuf) msg.Add(AttrMessageIntegrity, vBuf)
return nil return nil
} }
@@ -92,8 +95,8 @@ var ErrIntegrityMismatch = errors.New("integrity check failed")
// Check checks MESSAGE-INTEGRITY attribute. // Check checks MESSAGE-INTEGRITY attribute.
// //
// CPU costly, see BenchmarkMessageIntegrity_Check. // CPU costly, see BenchmarkMessageIntegrity_Check.
func (i MessageIntegrity) Check(m *Message) error { func (i MessageIntegrity) Check(msg *Message) error {
v, err := m.Get(AttrMessageIntegrity) val, err := msg.Get(AttrMessageIntegrity)
if err != nil { if err != nil {
return err return err
} }
@@ -101,11 +104,11 @@ func (i MessageIntegrity) Check(m *Message) error {
// Adjusting length in header to match m.Raw that was // Adjusting length in header to match m.Raw that was
// used when computing HMAC. // used when computing HMAC.
var ( var (
length = m.Length length = msg.Length
afterIntegrity = false afterIntegrity = false
sizeReduced int sizeReduced int
) )
for _, a := range m.Attributes { for _, a := range msg.Attributes {
if afterIntegrity { if afterIntegrity {
sizeReduced += nearestPaddedValueLength(int(a.Length)) sizeReduced += nearestPaddedValueLength(int(a.Length))
sizeReduced += attributeHeaderSize sizeReduced += attributeHeaderSize
@@ -114,13 +117,14 @@ func (i MessageIntegrity) Check(m *Message) error {
afterIntegrity = true afterIntegrity = true
} }
} }
m.Length -= uint32(sizeReduced) msg.Length -= uint32(sizeReduced) //nolint:gosec // G115
m.WriteLength() msg.WriteLength()
// startOfHMAC should be first byte of integrity attribute. // startOfHMAC should be first byte of integrity attribute.
startOfHMAC := messageHeaderSize + m.Length - (attributeHeaderSize + messageIntegritySize) startOfHMAC := messageHeaderSize + msg.Length - (attributeHeaderSize + messageIntegritySize)
b := m.Raw[:startOfHMAC] // data before integrity attribute b := msg.Raw[:startOfHMAC] // data before integrity attribute
expected := newHMAC(i, b, m.Raw[len(m.Raw):]) expected := newHMAC(i, b, msg.Raw[len(msg.Raw):])
m.Length = length msg.Length = length
m.WriteLength() // writing length back msg.WriteLength() // writing length back
return checkHMAC(v, expected)
return checkHMAC(val, expected)
} }

View File

@@ -10,18 +10,18 @@ import (
) )
func TestMessageIntegrity_AddTo_Simple(t *testing.T) { func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
i := NewLongTermIntegrity("user", "realm", "pass") integrity := NewLongTermIntegrity("user", "realm", "pass")
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb") expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(expected, i) { if !bytes.Equal(expected, integrity) {
t.Error(ErrIntegrityMismatch) t.Error(ErrIntegrityMismatch)
} }
t.Run("Check", func(t *testing.T) { t.Run("Check", func(t *testing.T) {
m := new(Message) m := new(Message)
m.WriteHeader() m.WriteHeader()
if err := i.AddTo(m); err != nil { if err := integrity.AddTo(m); err != nil {
t.Error(err) t.Error(err)
} }
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
@@ -31,39 +31,39 @@ func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
if err := dM.Decode(); err != nil { if err := dM.Decode(); err != nil {
t.Error(err) t.Error(err)
} }
if err := i.Check(dM); err != nil { if err := integrity.Check(dM); err != nil {
t.Error(err) t.Error(err)
} }
dM.Raw[24] += 12 // HMAC now invalid dM.Raw[24] += 12 // HMAC now invalid
if i.Check(dM) == nil { if integrity.Check(dM) == nil {
t.Error("should be invalid") t.Error("should be invalid")
} }
}) })
} }
func TestMessageIntegrityWithFingerprint(t *testing.T) { func TestMessageIntegrityWithFingerprint(t *testing.T) {
m := new(Message) msg := new(Message)
m.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} msg.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
m.WriteHeader() msg.WriteHeader()
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec
i := NewShortTermIntegrity("pwd") integrity := NewShortTermIntegrity("pwd")
if i.String() != "KEY: 0x707764" { if integrity.String() != "KEY: 0x707764" {
t.Error("bad string", i) t.Error("bad string", integrity)
} }
if err := i.Check(m); err == nil { if err := integrity.Check(msg); err == nil {
t.Error("should error") t.Error("should error")
} }
if err := i.AddTo(m); err != nil { if err := integrity.AddTo(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := Fingerprint.AddTo(m); err != nil { if err := Fingerprint.AddTo(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := i.Check(m); err != nil { if err := integrity.Check(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
m.Raw[24] = 33 msg.Raw[24] = 33
if err := i.Check(m); err == nil { if err := integrity.Check(msg); err == nil {
t.Fatal("mismatch expected") t.Fatal("mismatch expected")
} }
} }

View File

@@ -64,6 +64,7 @@ func (h *hmac) Sum(in []byte) []byte {
h.outer.Write(h.opad) //nolint:errcheck,gosec h.outer.Write(h.opad) //nolint:errcheck,gosec
} }
h.outer.Write(in[origLen:]) //nolint:errcheck,gosec h.outer.Write(in[origLen:]) //nolint:errcheck,gosec
return h.outer.Sum(in[:origLen]) return h.outer.Sum(in[:origLen])
} }
@@ -79,6 +80,7 @@ func (h *hmac) Reset() {
if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert
panic(err) //nolint panic(err) //nolint
} }
return return
} }

View File

@@ -22,7 +22,7 @@ type hmacTest struct {
blocksize int blocksize int
} }
func hmacTests() []hmacTest { func hmacTests() []hmacTest { //nolint:maintidx
return []hmacTest{ return []hmacTest{
// Tests from US FIPS 198 // Tests from US FIPS 198
// https://csrc.nist.gov/publications/fips/fips198/fips-198a.pdf // https://csrc.nist.gov/publications/fips/fips198/fips-198a.pdf
@@ -523,41 +523,42 @@ func hmacTests() []hmacTest {
func TestHMAC(t *testing.T) { func TestHMAC(t *testing.T) {
for i, tt := range hmacTests() { for i, tt := range hmacTests() {
h := New(tt.hash, tt.key) hsh := New(tt.hash, tt.key)
if s := h.Size(); s != tt.size { if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size) t.Errorf("Size: got %v, want %v", s, tt.size)
} }
if b := h.BlockSize(); b != tt.blocksize { if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
} }
for j := 0; j < 4; j++ { for j := 0; j < 4; j++ { //nolint:varnamelen
n, err := h.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue continue
} }
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
} }
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
h.Reset() hsh.Reset()
// Third and fourth iteration: make sure hmac works on // Third and fourth iteration: make sure hmac works on
// hashes without MarshalBinary/UnmarshalBinary // hashes without MarshalBinary/UnmarshalBinary
if j == 1 { if j == 1 {
h = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key) hsh = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key)
} }
} }
} }
} }
// justHash implements just the hash.Hash methods and nothing else // justHash implements just the hash.Hash methods and nothing else.
type justHash struct { type justHash struct {
hash.Hash hash.Hash
} }

View File

@@ -40,6 +40,7 @@ func (h *hmac) resetTo(key []byte) {
var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals
New: func() interface{} { New: func() interface{} {
h := New(sha1.New, make([]byte, sha1.BlockSize)) h := New(sha1.New, make([]byte, sha1.BlockSize))
return h return h
}, },
} }
@@ -49,6 +50,7 @@ func AcquireSHA1(key []byte) hash.Hash {
h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert
assertHMACSize(h, sha1.Size, sha1.BlockSize) assertHMACSize(h, sha1.Size, sha1.BlockSize)
h.resetTo(key) h.resetTo(key)
return h return h
} }
@@ -62,6 +64,7 @@ func PutSHA1(h hash.Hash) {
var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals
New: func() interface{} { New: func() interface{} {
h := New(sha256.New, make([]byte, sha256.BlockSize)) h := New(sha256.New, make([]byte, sha256.BlockSize))
return h return h
}, },
} }
@@ -71,6 +74,7 @@ func AcquireSHA256(key []byte) hash.Hash {
h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert
assertHMACSize(h, sha256.Size, sha256.BlockSize) assertHMACSize(h, sha256.Size, sha256.BlockSize)
h.resetTo(key) h.resetTo(key)
return h return h
} }

View File

@@ -42,100 +42,103 @@ func BenchmarkHMACSHA1_512_Pool(b *testing.B) {
func TestHMACReset(t *testing.T) { func TestHMACReset(t *testing.T) {
for i, tt := range hmacTests() { for i, tt := range hmacTests() {
h := New(tt.hash, tt.key) hsh := New(tt.hash, tt.key)
h.(*hmac).resetTo(tt.key) //nolint:forcetypeassert hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
if s := h.Size(); s != tt.size { if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size) t.Errorf("Size: got %v, want %v", s, tt.size)
} }
if b := h.BlockSize(); b != tt.blocksize { if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
} }
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
n, err := h.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue continue
} }
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
} }
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
h.Reset() hsh.Reset()
} }
} }
} }
func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop
for i, tt := range hmacTests() { for i, tt := range hmacTests() {
if tt.blocksize != sha1.BlockSize || tt.size != sha1.Size { if tt.blocksize != sha1.BlockSize || tt.size != sha1.Size {
continue continue
} }
h := AcquireSHA1(tt.key) hsh := AcquireSHA1(tt.key)
if s := h.Size(); s != tt.size { if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size) t.Errorf("Size: got %v, want %v", s, tt.size)
} }
if b := h.BlockSize(); b != tt.blocksize { if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
} }
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
n, err := h.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue continue
} }
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
} }
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
h.Reset() hsh.Reset()
} }
PutSHA1(h) PutSHA1(hsh)
} }
} }
func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop
for i, tt := range hmacTests() { for i, tt := range hmacTests() {
if tt.blocksize != sha256.BlockSize || tt.size != sha256.Size { if tt.blocksize != sha256.BlockSize || tt.size != sha256.Size {
continue continue
} }
h := AcquireSHA256(tt.key) hsh := AcquireSHA256(tt.key)
if s := h.Size(); s != tt.size { if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size) t.Errorf("Size: got %v, want %v", s, tt.size)
} }
if b := h.BlockSize(); b != tt.blocksize { if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
} }
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
n, err := h.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue continue
} }
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
} }
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
h.Reset() hsh.Reset()
} }
PutSHA256(h) PutSHA256(hsh)
} }
} }

View File

@@ -10,8 +10,11 @@ import (
// ShouldNotAllocate fails if f allocates. // ShouldNotAllocate fails if f allocates.
func ShouldNotAllocate(t *testing.T, f func()) { func ShouldNotAllocate(t *testing.T, f func()) {
t.Helper()
if Race { if Race {
t.Skip("skip while running with -race") t.Skip("skip while running with -race")
return return
} }
if a := testing.AllocsPerRun(10, f); a > 0 { if a := testing.AllocsPerRun(10, f); a > 0 {

View File

@@ -32,6 +32,7 @@ const (
// as source. // as source.
func NewTransactionID() (b [TransactionIDSize]byte) { func NewTransactionID() (b [TransactionIDSize]byte) {
readFullOrPanic(rand.Reader, b[:]) readFullOrPanic(rand.Reader, b[:])
return b return b
} }
@@ -45,6 +46,7 @@ func IsMessage(b []byte) bool {
// New returns *Message with pre-allocated Raw. // New returns *Message with pre-allocated Raw.
func New() *Message { func New() *Message {
const defaultRawCapacity = 120 const defaultRawCapacity = 120
return &Message{ return &Message{
Raw: make([]byte, messageHeaderSize, defaultRawCapacity), Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
} }
@@ -59,6 +61,7 @@ func Decode(data []byte, m *Message) error {
return ErrDecodeToNil return ErrDecodeToNil
} }
m.Raw = append(m.Raw[:0], data...) m.Raw = append(m.Raw[:0], data...)
return m.Decode() return m.Decode()
} }
@@ -82,6 +85,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
// contract induced by other implementations. // contract induced by other implementations.
b := make([]byte, len(m.Raw)) b := make([]byte, len(m.Raw))
copy(b, m.Raw) copy(b, m.Raw)
return b, nil return b, nil
} }
@@ -89,6 +93,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
func (m *Message) UnmarshalBinary(data []byte) error { func (m *Message) UnmarshalBinary(data []byte) error {
// We can't retain data, copy is expected by interface contract. // We can't retain data, copy is expected by interface contract.
m.Raw = append(m.Raw[:0], data...) m.Raw = append(m.Raw[:0], data...)
return m.Decode() return m.Decode()
} }
@@ -108,6 +113,7 @@ func (m *Message) GobDecode(data []byte) error {
func (m *Message) AddTo(b *Message) error { func (m *Message) AddTo(b *Message) error {
b.TransactionID = m.TransactionID b.TransactionID = m.TransactionID
b.WriteTransactionID() b.WriteTransactionID()
return nil return nil
} }
@@ -118,6 +124,7 @@ func (m *Message) NewTransactionID() error {
if err == nil { if err == nil {
m.WriteTransactionID() m.WriteTransactionID()
} }
return err return err
} }
@@ -127,6 +134,7 @@ func (m *Message) String() string {
for k, a := range m.Attributes { for k, a := range m.Attributes {
aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type) aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type)
} }
return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo) return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo)
} }
@@ -144,6 +152,7 @@ func (m *Message) grow(n int) {
} }
if cap(m.Raw) >= n { if cap(m.Raw) >= n {
m.Raw = m.Raw[:n] m.Raw = m.Raw[:n]
return return
} }
m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...) m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
@@ -153,7 +162,7 @@ func (m *Message) grow(n int) {
// //
// Value of attribute is copied to internal buffer so // Value of attribute is copied to internal buffer so
// it is safe to reuse v. // it is safe to reuse v.
func (m *Message) Add(t AttrType, v []byte) { func (m *Message) Add(attrType AttrType, val []byte) {
// Allocating buffer for TLV (type-length-value). // Allocating buffer for TLV (type-length-value).
// T = t, L = len(v), V = v. // T = t, L = len(v), V = v.
// m.Raw will look like: // m.Raw will look like:
@@ -163,31 +172,33 @@ func (m *Message) Add(t AttrType, v []byte) {
// [first:last] <- same as previous // [first:last] <- same as previous
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
// T L V // T L V
allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V) allocSize := attributeHeaderSize + len(val) // ~ len(TLV) = len(TL) + len(V)
first := messageHeaderSize + int(m.Length) // first byte number first := messageHeaderSize + int(m.Length) // first byte number
last := first + allocSize // last byte number last := first + allocSize // last byte number
m.grow(last) // growing cap(Raw) to fit TLV m.grow(last) // growing cap(Raw) to fit TLV
m.Raw = m.Raw[:last] // now len(Raw) = last m.Raw = m.Raw[:last] // now len(Raw) = last
m.Length += uint32(allocSize) // rendering length change //nolint:gosec // G115
m.Length += uint32(allocSize) // rendering length change
// Sub-slicing internal buffer to simplify encoding. // Sub-slicing internal buffer to simplify encoding.
buf := m.Raw[first:last] // slice for TLV buf := m.Raw[first:last] // slice for TLV
value := buf[attributeHeaderSize:] // slice for V value := buf[attributeHeaderSize:] // slice for V
attr := RawAttribute{ attr := RawAttribute{
Type: t, // T Type: attrType, // T
Length: uint16(len(v)), // L //nolint:gosec // G115
Value: value, // V Length: uint16(len(val)), // L
Value: value, // V
} }
// Encoding attribute TLV to allocated buffer. // Encoding attribute TLV to allocated buffer.
bin.PutUint16(buf[0:2], attr.Type.Value()) // T bin.PutUint16(buf[0:2], attr.Type.Value()) // T
bin.PutUint16(buf[2:4], attr.Length) // L bin.PutUint16(buf[2:4], attr.Length) // L
copy(value, v) // V copy(value, val) // V
// Checking that attribute value needs padding. // Checking that attribute value needs padding.
if attr.Length%padding != 0 { if attr.Length%padding != 0 {
// Performing padding. // Performing padding.
bytesToAdd := nearestPaddedValueLength(len(v)) - len(v) bytesToAdd := nearestPaddedValueLength(len(val)) - len(val)
last += bytesToAdd last += bytesToAdd
m.grow(last) m.grow(last)
// setting all padding bytes to zero // setting all padding bytes to zero
@@ -197,7 +208,8 @@ func (m *Message) Add(t AttrType, v []byte) {
for i := range buf { for i := range buf {
buf[i] = 0 buf[i] = 0
} }
m.Raw = m.Raw[:last] // increasing buffer length m.Raw = m.Raw[:last] // increasing buffer length
//nolint:gosec // G115
m.Length += uint32(bytesToAdd) // rendering length change m.Length += uint32(bytesToAdd) // rendering length change
} }
m.Attributes = append(m.Attributes, attr) m.Attributes = append(m.Attributes, attr)
@@ -213,6 +225,7 @@ func attrSliceEqual(a, b Attributes) bool {
} }
if attrB.Equal(attr) { if attrB.Equal(attr) {
found = true found = true
break break
} }
} }
@@ -220,56 +233,59 @@ func attrSliceEqual(a, b Attributes) bool {
return false return false
} }
} }
return true return true
} }
func attrEqual(a, b Attributes) bool { func attrEqual(attrA, attrB Attributes) bool {
if a == nil && b == nil { if attrA == nil && attrB == nil {
return true return true
} }
if a == nil || b == nil { if attrA == nil || attrB == nil {
return false return false
} }
if len(a) != len(b) { if len(attrA) != len(attrB) {
return false return false
} }
if !attrSliceEqual(a, b) { if !attrSliceEqual(attrA, attrB) {
return false return false
} }
if !attrSliceEqual(b, a) { if !attrSliceEqual(attrB, attrA) {
return false return false
} }
return true return true
} }
// Equal returns true if Message b equals to m. // Equal returns true if Message msg equals to m.
// Ignores m.Raw. // Ignores m.Raw.
func (m *Message) Equal(b *Message) bool { func (m *Message) Equal(msg *Message) bool {
if m == nil && b == nil { if m == nil && msg == nil {
return true return true
} }
if m == nil || b == nil { if m == nil || msg == nil {
return false return false
} }
if m.Type != b.Type { if m.Type != msg.Type {
return false return false
} }
if m.TransactionID != b.TransactionID { if m.TransactionID != msg.TransactionID {
return false return false
} }
if m.Length != b.Length { if m.Length != msg.Length {
return false return false
} }
if !attrEqual(m.Attributes, b.Attributes) { if !attrEqual(m.Attributes, msg.Attributes) {
return false return false
} }
return true return true
} }
// WriteLength writes m.Length to m.Raw. // WriteLength writes m.Length to m.Raw.
func (m *Message) WriteLength() { func (m *Message) WriteLength() {
m.grow(4) m.grow(4)
bin.PutUint16(m.Raw[2:4], uint16(m.Length)) bin.PutUint16(m.Raw[2:4], uint16(m.Length)) //nolint:gosec // G115
} }
// WriteHeader writes header to underlying buffer. Not goroutine-safe. // WriteHeader writes header to underlying buffer. Not goroutine-safe.
@@ -322,6 +338,7 @@ func (m *Message) Encode() {
// call result. // call result.
func (m *Message) WriteTo(w io.Writer) (int64, error) { func (m *Message) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.Raw) n, err := w.Write(m.Raw)
return int64(n), err return int64(n), err
} }
@@ -340,6 +357,7 @@ func (m *Message) ReadFrom(r io.Reader) (int64, error) {
return int64(n), err return int64(n), err
} }
m.Raw = tBuf[:n] m.Raw = tBuf[:n]
return int64(n), m.Decode() return int64(n), m.Decode()
} }
@@ -355,22 +373,24 @@ func (m *Message) Decode() error {
return ErrUnexpectedHeaderEOF return ErrUnexpectedHeaderEOF
} }
var ( var (
t = bin.Uint16(buf[0:2]) // first 2 bytes msgType = bin.Uint16(buf[0:2]) // first 2 bytes
size = int(bin.Uint16(buf[2:4])) // second 2 bytes size = int(bin.Uint16(buf[2:4])) // second 2 bytes
cookie = bin.Uint32(buf[4:8]) // last 4 bytes cookie = bin.Uint32(buf[4:8]) // last 4 bytes
fullSize = messageHeaderSize + size // len(m.Raw) fullSize = messageHeaderSize + size // len(m.Raw)
) )
if cookie != magicCookie { if cookie != magicCookie {
msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie) msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
return newDecodeErr("message", "cookie", msg) return newDecodeErr("message", "cookie", msg)
} }
if len(buf) < fullSize { if len(buf) < fullSize {
msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize) msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
return newAttrDecodeErr("message", msg) return newAttrDecodeErr("message", msg)
} }
// saving header data // saving header data
m.Type.ReadValue(t) m.Type.ReadValue(msgType)
m.Length = uint32(size) m.Length = uint32(size) //nolint:gosec // G115
copy(m.TransactionID[:], buf[8:messageHeaderSize]) copy(m.TransactionID[:], buf[8:messageHeaderSize])
m.Attributes = m.Attributes[:0] m.Attributes = m.Attributes[:0]
@@ -382,28 +402,31 @@ func (m *Message) Decode() error {
// checking that we have enough bytes to read header // checking that we have enough bytes to read header
if len(b) < attributeHeaderSize { if len(b) < attributeHeaderSize {
msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize) msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
return newAttrDecodeErr("header", msg) return newAttrDecodeErr("header", msg)
} }
var ( var (
a = RawAttribute{ attr = RawAttribute{
Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
Length: bin.Uint16(b[2:4]), // second 2 bytes Length: bin.Uint16(b[2:4]), // second 2 bytes
} }
aL = int(a.Length) // attribute length aL = int(attr.Length) // attribute length
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding) aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
) )
b = b[attributeHeaderSize:] // slicing again to simplify value read b = b[attributeHeaderSize:] // slicing again to simplify value read
offset += attributeHeaderSize offset += attributeHeaderSize
if len(b) < aBuffL { // checking size if len(b) < aBuffL { // checking size
msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type) msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, attr.Type)
return newAttrDecodeErr("value", msg) return newAttrDecodeErr("value", msg)
} }
a.Value = b[:aL] attr.Value = b[:aL]
offset += aBuffL offset += aBuffL
b = b[aBuffL:] b = b[aBuffL:]
m.Attributes = append(m.Attributes, a) m.Attributes = append(m.Attributes, attr)
} }
return nil return nil
} }
@@ -412,12 +435,14 @@ func (m *Message) Decode() error {
// Any error is unrecoverable, but message could be partially decoded. // Any error is unrecoverable, but message could be partially decoded.
func (m *Message) Write(tBuf []byte) (int, error) { func (m *Message) Write(tBuf []byte) (int, error) {
m.Raw = append(m.Raw[:0], tBuf...) m.Raw = append(m.Raw[:0], tBuf...)
return len(tBuf), m.Decode() return len(tBuf), m.Decode()
} }
// CloneTo clones m to b securing any further m mutations. // CloneTo clones m to b securing any further m mutations.
func (m *Message) CloneTo(b *Message) error { func (m *Message) CloneTo(b *Message) error {
b.Raw = append(b.Raw[:0], m.Raw...) b.Raw = append(b.Raw[:0], m.Raw...)
return b.Decode() return b.Decode()
} }
@@ -436,7 +461,7 @@ const (
var ( var (
// Binding request message type. // Binding request message type.
BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
// Binding success response message type // Binding success response message type.
BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
// Binding error response message type. // Binding error response message type.
BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
@@ -501,6 +526,7 @@ func (m Method) String() string {
// Falling back to hex representation. // Falling back to hex representation.
s = fmt.Sprintf("0x%x", uint16(m)) s = fmt.Sprintf("0x%x", uint16(m))
} }
return s return s
} }
@@ -513,6 +539,7 @@ type MessageType struct {
// AddTo sets m type to t. // AddTo sets m type to t.
func (t MessageType) AddTo(m *Message) error { func (t MessageType) AddTo(m *Message) error {
m.SetType(t) m.SetType(t)
return nil return nil
} }
@@ -554,13 +581,13 @@ func (t MessageType) Value() uint16 {
// Warning: Abandon all hope ye who enter here. // Warning: Abandon all hope ye who enter here.
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11). // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
m := uint16(t.Method) msg := uint16(t.Method)
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits) a := msg & methodABits // A = M * 0b0000000000001111 (right 4 bits)
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A) b := msg & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B) d := msg & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
m = a + (b << methodBShift) + (d << methodDShift) msg = a + (b << methodBShift) + (d << methodDShift)
// C0 is zero bit of C, C1 is first bit. // C0 is zero bit of C, C1 is first bit.
// C0 = C * 0b01, C1 = (C * 0b10) >> 1 // C0 = C * 0b01, C1 = (C * 0b10) >> 1
@@ -573,7 +600,7 @@ func (t MessageType) Value() uint16 {
c1 := (c & c1Bit) << classC1Shift c1 := (c & c1Bit) << classC1Shift
class := c0 + c1 class := c0 + c1
return m + class return msg + class
} }
// ReadValue decodes uint16 into MessageType. // ReadValue decodes uint16 into MessageType.
@@ -604,6 +631,7 @@ func (m *Message) Contains(t AttrType) bool {
return true return true
} }
} }
return false return false
} }
@@ -618,5 +646,6 @@ func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
func (t transactionIDValueSetter) AddTo(m *Message) error { func (t transactionIDValueSetter) AddTo(m *Message) error {
m.TransactionID = t m.TransactionID = t
m.WriteTransactionID() m.WriteTransactionID()
return nil return nil
} }

View File

@@ -27,9 +27,11 @@ type attributeEncoder interface {
AddTo(m *Message) error AddTo(m *Message) error
} }
func addAttr(t testing.TB, m *Message, a attributeEncoder) { func addAttr(tb testing.TB, m *Message, a attributeEncoder) {
tb.Helper()
if err := a.AddTo(m); err != nil { if err := a.AddTo(m); err != nil {
t.Error(err) tb.Error(err)
} }
} }
@@ -129,21 +131,21 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
} }
func TestMessage_WriteTo(t *testing.T) { func TestMessage_WriteTo(t *testing.T) {
m := New() msg := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID() msg.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader() msg.WriteHeader()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if _, err := m.WriteTo(buf); err != nil { if _, err := msg.WriteTo(buf); err != nil {
t.Fatal(err) t.Fatal(err)
} }
mDecoded := New() mDecoded := New()
if _, err := mDecoded.ReadFrom(buf); err != nil { if _, err := mDecoded.ReadFrom(buf); err != nil {
t.Error(err) t.Error(err)
} }
if !mDecoded.Equal(m) { if !mDecoded.Equal(msg) {
t.Error(mDecoded, "!", m) t.Error(mDecoded, "!", msg)
} }
} }
@@ -275,23 +277,23 @@ func BenchmarkMessage_WriteTo(b *testing.B) {
func BenchmarkMessage_ReadFrom(b *testing.B) { func BenchmarkMessage_ReadFrom(b *testing.B) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest} mType := MessageType{Method: MethodBinding, Class: ClassRequest}
m := &Message{ msg := &Message{
Type: mType, Type: mType,
Length: 0, Length: 0,
TransactionID: [TransactionIDSize]byte{ TransactionID: [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
}, },
} }
m.WriteHeader() msg.WriteHeader()
b.ReportAllocs() b.ReportAllocs()
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(msg.Raw)))
reader := m.reader() reader := msg.reader()
mRec := New() mRec := New()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if _, err := mRec.ReadFrom(reader); err != nil { if _, err := mRec.ReadFrom(reader); err != nil {
b.Fatal(err) b.Fatal(err)
} }
reader.Reset(m.Raw) reader.Reset(msg.Raw)
mRec.Reset() mRec.Reset()
} }
} }
@@ -342,7 +344,7 @@ func TestMessageClass_String(t *testing.T) {
} }
func TestAttrType_String(t *testing.T) { func TestAttrType_String(t *testing.T) {
v := [...]AttrType{ attrType := [...]AttrType{
AttrMappedAddress, AttrMappedAddress,
AttrUsername, AttrUsername,
AttrErrorCode, AttrErrorCode,
@@ -355,7 +357,7 @@ func TestAttrType_String(t *testing.T) {
AttrAlternateServer, AttrAlternateServer,
AttrFingerprint, AttrFingerprint,
} }
for _, k := range v { for _, k := range attrType {
if k.String() == "" { if k.String() == "" {
t.Error(k, "bad stringer") t.Error(k, "bad stringer")
} }
@@ -379,80 +381,80 @@ func TestMethod_String(t *testing.T) {
} }
func TestAttribute_Equal(t *testing.T) { func TestAttribute_Equal(t *testing.T) {
a := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
b := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
if !a.Equal(b) { if !attr1.Equal(attr2) {
t.Error("should equal") t.Error("should equal")
} }
if a.Equal(RawAttribute{Type: 0x2}) { if attr1.Equal(RawAttribute{Type: 0x2}) {
t.Error("should not equal") t.Error("should not equal")
} }
if a.Equal(RawAttribute{Length: 0x2}) { if attr1.Equal(RawAttribute{Length: 0x2}) {
t.Error("should not equal") t.Error("should not equal")
} }
if a.Equal(RawAttribute{Length: 0x3}) { if attr1.Equal(RawAttribute{Length: 0x3}) {
t.Error("should not equal") t.Error("should not equal")
} }
if a.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) { if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
t.Error("should not equal") t.Error("should not equal")
} }
} }
func TestMessage_Equal(t *testing.T) { func TestMessage_Equal(t *testing.T) { //nolint:cyclop
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attrs := Attributes{attr} attrs := Attributes{attr}
a := &Message{Attributes: attrs, Length: 4 + 2} msg1 := &Message{Attributes: attrs, Length: 4 + 2}
b := &Message{Attributes: attrs, Length: 4 + 2} msg2 := &Message{Attributes: attrs, Length: 4 + 2}
if !a.Equal(b) { if !msg1.Equal(msg2) {
t.Error("should equal") t.Error("should equal")
} }
if a.Equal(&Message{Type: MessageType{Class: 128}}) { if msg1.Equal(&Message{Type: MessageType{Class: 128}}) {
t.Error("should not equal") t.Error("should not equal")
} }
tID := [TransactionIDSize]byte{ tID := [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
} }
if a.Equal(&Message{TransactionID: tID}) { if msg1.Equal(&Message{TransactionID: tID}) {
t.Error("should not equal") t.Error("should not equal")
} }
if a.Equal(&Message{Length: 3}) { if msg1.Equal(&Message{Length: 3}) {
t.Error("should not equal") t.Error("should not equal")
} }
tAttrs := Attributes{ tAttrs := Attributes{
{Length: 1, Value: []byte{0x1}, Type: 0x1}, {Length: 1, Value: []byte{0x1}, Type: 0x1},
} }
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal") t.Error("should not equal")
} }
tAttrs = Attributes{ tAttrs = Attributes{
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2}, {Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
} }
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal") t.Error("should not equal")
} }
if !(*Message)(nil).Equal(nil) { if !(*Message)(nil).Equal(nil) {
t.Error("nil should be equal to nil") t.Error("nil should be equal to nil")
} }
if a.Equal(nil) { if msg1.Equal(nil) {
t.Error("non-nil should not be equal to nil") t.Error("non-nil should not be equal to nil")
} }
t.Run("Nil attributes", func(t *testing.T) { t.Run("Nil attributes", func(t *testing.T) {
a := &Message{ msg1 := &Message{
Attributes: nil, Attributes: nil,
Length: 4 + 2, Length: 4 + 2,
} }
b := &Message{ msg2 := &Message{
Attributes: attrs, Attributes: attrs,
Length: 4 + 2, Length: 4 + 2,
} }
if a.Equal(b) { if msg1.Equal(msg2) {
t.Error("should not equal") t.Error("should not equal")
} }
if b.Equal(a) { if msg2.Equal(msg1) {
t.Error("should not equal") t.Error("should not equal")
} }
b.Attributes = nil msg2.Attributes = nil
if !a.Equal(b) { if !msg1.Equal(msg2) {
t.Error("should equal") t.Error("should equal")
} }
}) })
@@ -547,6 +549,8 @@ func BenchmarkIsMessage(b *testing.B) {
} }
func loadData(tb testing.TB, name string) []byte { func loadData(tb testing.TB, name string) []byte {
tb.Helper()
name = filepath.Join("testdata", name) name = filepath.Join("testdata", name)
f, err := os.Open(name) //nolint:gosec f, err := os.Open(name) //nolint:gosec
if err != nil { if err != nil {
@@ -561,6 +565,7 @@ func loadData(tb testing.TB, name string) []byte {
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }
return v return v
} }
@@ -582,7 +587,7 @@ func TestMessageFromBrowsers(t *testing.T) {
t.Fatal("failed to skip header of csv: ", err) t.Fatal("failed to skip header of csv: ", err)
} }
crcTable := crc64.MakeTable(crc64.ISO) crcTable := crc64.MakeTable(crc64.ISO)
m := New() msg := New()
for { for {
line, err := reader.Read() line, err := reader.Read()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@@ -602,10 +607,10 @@ func TestMessageFromBrowsers(t *testing.T) {
if b != crc64.Checksum(data, crcTable) { if b != crc64.Checksum(data, crcTable) {
t.Error("crc64 check failed for ", line[1]) t.Error("crc64 check failed for ", line[1])
} }
if _, err = m.Write(data); err != nil { if _, err = msg.Write(data); err != nil {
t.Error("failed to decode ", line[1], " as message: ", err) t.Error("failed to decode ", line[1], " as message: ", err)
} }
m.Reset() msg.Reset()
} }
} }
@@ -622,42 +627,42 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) {
func BenchmarkMessageFull(b *testing.B) { func BenchmarkMessageFull(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m := new(Message) msg := new(Message)
s := NewSoftware("software") s := NewSoftware("software")
addr := &XORMappedAddress{ addr := &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := addr.AddTo(m); err != nil { if err := addr.AddTo(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
if err := s.AddTo(m); err != nil { if err := s.AddTo(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
m.WriteAttributes() msg.WriteAttributes()
m.WriteHeader() msg.WriteHeader()
Fingerprint.AddTo(m) //nolint:errcheck,gosec Fingerprint.AddTo(msg) //nolint:errcheck,gosec
m.WriteHeader() msg.WriteHeader()
m.Reset() msg.Reset()
} }
} }
func BenchmarkMessageFullHardcore(b *testing.B) { func BenchmarkMessageFullHardcore(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m := new(Message) msg := new(Message)
s := NewSoftware("software") s := NewSoftware("software")
addr := &XORMappedAddress{ addr := &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := addr.AddTo(m); err != nil { if err := addr.AddTo(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
if err := s.AddTo(m); err != nil { if err := s.AddTo(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
m.WriteHeader() msg.WriteHeader()
m.Reset() msg.Reset()
} }
} }
@@ -689,8 +694,8 @@ func TestMessage_Contains(t *testing.T) {
func ExampleMessage() { func ExampleMessage() {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
m := new(Message) msg := new(Message)
m.Build(BindingRequest, //nolint:errcheck,gosec msg.Build(BindingRequest, //nolint:errcheck,gosec
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
@@ -707,8 +712,8 @@ func ExampleMessage() {
// m.Build(&software) // no allocations // m.Build(&software) // no allocations
// If you pass software as value, there will be 1 allocation. // If you pass software as value, there will be 1 allocation.
// This rule is correct for all setters. // This rule is correct for all setters.
fmt.Println(m, "buff length:", len(m.Raw)) fmt.Println(msg, "buff length:", len(msg.Raw))
n, err := m.WriteTo(buf) n, err := msg.WriteTo(buf)
fmt.Println("wrote", n, "err", err) fmt.Println("wrote", n, "err", err)
// Decoding from buf new *Message. // Decoding from buf new *Message.
@@ -743,6 +748,7 @@ func ExampleMessage() {
fmt.Println("fingerprint: failed") fmt.Println("fingerprint: failed")
} }
//nolint:lll
// Output: // Output:
// Binding request l=48 attrs=3 id=AQIDBAUGBwgJAAEA, attr0=SOFTWARE attr1=MESSAGE-INTEGRITY attr2=FINGERPRINT buff length: 68 // Binding request l=48 attrs=3 id=AQIDBAUGBwgJAAEA, attr0=SOFTWARE attr1=MESSAGE-INTEGRITY attr2=FINGERPRINT buff length: 68
// wrote 68 err <nil> // wrote 68 err <nil>
@@ -811,8 +817,8 @@ func TestAllocationsGetters(t *testing.T) {
NewShortTermIntegrity("pwd"), NewShortTermIntegrity("pwd"),
Fingerprint, Fingerprint,
} }
m := New() msg := New()
if err := m.Build(setters...); err != nil { if err := msg.Build(setters...); err != nil {
t.Error("failed to build", err) t.Error("failed to build", err)
} }
getters := []Getter{ getters := []Getter{
@@ -826,7 +832,7 @@ func TestAllocationsGetters(t *testing.T) {
g := g g := g
i := i i := i
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
if err := g.GetFrom(m); err != nil { if err := g.GetFrom(msg); err != nil {
t.Errorf("[%d] failed to get", i) t.Errorf("[%d] failed to get", i)
} }
}) })
@@ -837,8 +843,8 @@ func TestAllocationsGetters(t *testing.T) {
} }
func TestMessageFullSize(t *testing.T) { func TestMessageFullSize(t *testing.T) {
m := new(Message) msg := new(Message)
if err := m.Build(BindingRequest, if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
@@ -848,18 +854,18 @@ func TestMessageFullSize(t *testing.T) {
); err != nil { ); err != nil {
t.Fatal(err) t.Fatal(err)
} }
m.Raw = m.Raw[:len(m.Raw)-10] msg.Raw = msg.Raw[:len(msg.Raw)-10]
decoder := new(Message) decoder := new(Message)
decoder.Raw = m.Raw[:len(m.Raw)-10] decoder.Raw = msg.Raw[:len(msg.Raw)-10]
if err := decoder.Decode(); err == nil { if err := decoder.Decode(); err == nil {
t.Error("decode on truncated buffer should error") t.Error("decode on truncated buffer should error")
} }
} }
func TestMessage_CloneTo(t *testing.T) { func TestMessage_CloneTo(t *testing.T) {
m := new(Message) msg := new(Message)
if err := m.Build(BindingRequest, if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
@@ -869,29 +875,29 @@ func TestMessage_CloneTo(t *testing.T) {
); err != nil { ); err != nil {
t.Fatal(err) t.Fatal(err)
} }
m.Encode() msg.Encode()
b := new(Message) msg2 := new(Message)
if err := m.CloneTo(b); err != nil { if err := msg.CloneTo(msg2); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !b.Equal(m) { if !msg2.Equal(msg) {
t.Fatal("not equal") t.Fatal("not equal")
} }
// Corrupting m and checking that b is not corrupted. // Corrupting m and checking that b is not corrupted.
s, ok := b.Attributes.Get(AttrSoftware) s, ok := msg2.Attributes.Get(AttrSoftware)
if !ok { if !ok {
t.Fatal("no software attribute") t.Fatal("no software attribute")
} }
s.Value[0] = 'k' s.Value[0] = 'k'
if b.Equal(m) { if msg2.Equal(msg) {
t.Fatal("should not be equal") t.Fatal("should not be equal")
} }
} }
func BenchmarkMessage_CloneTo(b *testing.B) { func BenchmarkMessage_CloneTo(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m := new(Message) msg := new(Message)
if err := m.Build(BindingRequest, if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
@@ -901,19 +907,19 @@ func BenchmarkMessage_CloneTo(b *testing.B) {
); err != nil { ); err != nil {
b.Fatal(err) b.Fatal(err)
} }
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(msg.Raw)))
a := new(Message) a := new(Message)
m.CloneTo(a) //nolint:errcheck,gosec msg.CloneTo(a) //nolint:errcheck,gosec
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := m.CloneTo(a); err != nil { if err := msg.CloneTo(a); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
} }
func TestMessage_AddTo(t *testing.T) { func TestMessage_AddTo(t *testing.T) {
m := new(Message) msg := new(Message)
if err := m.Build(BindingRequest, if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
@@ -921,19 +927,19 @@ func TestMessage_AddTo(t *testing.T) {
); err != nil { ); err != nil {
t.Fatal(err) t.Fatal(err)
} }
m.Encode() msg.Encode()
b := new(Message) b := new(Message)
if err := m.CloneTo(b); err != nil { if err := msg.CloneTo(b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
m.TransactionID = [TransactionIDSize]byte{ msg.TransactionID = [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
} }
if b.Equal(m) { if b.Equal(msg) {
t.Fatal("should not be equal") t.Fatal("should not be equal")
} }
m.AddTo(b) //nolint:errcheck,gosec msg.AddTo(b) //nolint:errcheck,gosec
if !b.Equal(m) { if !b.Equal(msg) {
t.Fatal("should be equal") t.Fatal("should be equal")
} }
} }
@@ -964,22 +970,22 @@ func TestDecode(t *testing.T) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
}) })
m := New() msg := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID() msg.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader() msg.WriteHeader()
mDecoded := New() mDecoded := New()
if err := Decode(m.Raw, mDecoded); err != nil { if err := Decode(msg.Raw, mDecoded); err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
if !mDecoded.Equal(m) { if !mDecoded.Equal(msg) {
t.Error("decoded result is not equal to encoded message") t.Error("decoded result is not equal to encoded message")
} }
t.Run("ZeroAlloc", func(t *testing.T) { t.Run("ZeroAlloc", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
mDecoded.Reset() mDecoded.Reset()
if err := Decode(m.Raw, mDecoded); err != nil { if err := Decode(msg.Raw, mDecoded); err != nil {
t.Error(err) t.Error(err)
} }
}) })
@@ -1008,22 +1014,22 @@ func BenchmarkDecode(b *testing.B) {
} }
func TestMessage_MarshalBinary(t *testing.T) { func TestMessage_MarshalBinary(t *testing.T) {
m := MustBuild( msg := MustBuild(
NewSoftware("software"), NewSoftware("software"),
&XORMappedAddress{ &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
}, },
) )
data, err := m.MarshalBinary() data, err := msg.MarshalBinary()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Reset m.Raw to check retention. // Reset m.Raw to check retention.
for i := range m.Raw { for i := range msg.Raw {
m.Raw[i] = 0 msg.Raw[i] = 0
} }
if err := m.UnmarshalBinary(data); err != nil { if err := msg.UnmarshalBinary(data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1031,28 +1037,28 @@ func TestMessage_MarshalBinary(t *testing.T) {
for i := range data { for i := range data {
data[i] = 0 data[i] = 0
} }
if err := m.Decode(); err != nil { if err := msg.Decode(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
func TestMessage_GobDecode(t *testing.T) { func TestMessage_GobDecode(t *testing.T) {
m := MustBuild( msg := MustBuild(
NewSoftware("software"), NewSoftware("software"),
&XORMappedAddress{ &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
}, },
) )
data, err := m.GobEncode() data, err := msg.GobEncode()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Reset m.Raw to check retention. // Reset m.Raw to check retention.
for i := range m.Raw { for i := range msg.Raw {
m.Raw[i] = 0 msg.Raw[i] = 0
} }
if err := m.GobDecode(data); err != nil { if err := msg.GobDecode(data); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1060,7 +1066,7 @@ func TestMessage_GobDecode(t *testing.T) {
for i := range data { for i := range data {
data[i] = 0 data[i] = 0
} }
if err := m.Decode(); err != nil { if err := msg.Decode(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@@ -8,7 +8,7 @@ import (
"testing" "testing"
) )
func TestRFC5769(t *testing.T) { func TestRFC5769(t *testing.T) { //nolint:cyclop
// Test Vectors for Session Traversal Utilities for NAT (STUN) // Test Vectors for Session Traversal Utilities for NAT (STUN)
// see https://tools.ietf.org/html/rfc5769 // see https://tools.ietf.org/html/rfc5769
t.Run("Request", func(t *testing.T) { t.Run("Request", func(t *testing.T) {
@@ -46,7 +46,7 @@ func TestRFC5769(t *testing.T) {
t.Error("check failed: ", err) t.Error("check failed: ", err)
} }
t.Run("Long-Term credentials", func(t *testing.T) { t.Run("Long-Term credentials", func(t *testing.T) {
m := &Message{ msg := &Message{
Raw: []byte("\x00\x01\x00\x60" + Raw: []byte("\x00\x01\x00\x60" +
"\x21\x12\xa4\x42" + "\x21\x12\xa4\x42" +
"\x78\xad\x34\x33\xc6\xad\x72\xc0\x29\xda\x41\x2e" + "\x78\xad\x34\x33\xc6\xad\x72\xc0\x29\xda\x41\x2e" +
@@ -64,11 +64,11 @@ func TestRFC5769(t *testing.T) {
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66", "\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
), ),
} }
if err := m.Decode(); err != nil { if err := msg.Decode(); err != nil {
t.Error(err) t.Error(err)
} }
u := new(Username) u := new(Username)
if err := u.GetFrom(m); err != nil { if err := u.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9" expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
@@ -76,14 +76,14 @@ func TestRFC5769(t *testing.T) {
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername) t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
} }
n := new(Nonce) n := new(Nonce)
if err := n.GetFrom(m); err != nil { if err := n.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" { if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
t.Error("bad nonce") t.Error("bad nonce")
} }
r := new(Realm) r := new(Realm)
if err := r.GetFrom(m); err != nil { if err := r.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
if r.String() != "example.org" { //nolint:goconst if r.String() != "example.org" { //nolint:goconst
@@ -95,14 +95,14 @@ func TestRFC5769(t *testing.T) {
"example.org", "example.org",
"TheMatrIX", "TheMatrIX",
) )
if err := i.Check(m); err != nil { if err := i.Check(msg); err != nil {
t.Error(err) t.Error(err)
} }
}) })
}) })
t.Run("Response", func(t *testing.T) { t.Run("Response", func(t *testing.T) {
t.Run("IPv4", func(t *testing.T) { t.Run("IPv4", func(t *testing.T) {
m := &Message{ msg := &Message{
Raw: []byte("\x01\x01\x00\x3c" + Raw: []byte("\x01\x01\x00\x3c" +
"\x21\x12\xa4\x42" + "\x21\x12\xa4\x42" +
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" + "\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
@@ -117,21 +117,21 @@ func TestRFC5769(t *testing.T) {
"\xc0\x7d\x4c\x96", "\xc0\x7d\x4c\x96",
), ),
} }
if err := m.Decode(); err != nil { if err := msg.Decode(); err != nil {
t.Error(err) t.Error(err)
} }
software := new(Software) software := new(Software)
if err := software.GetFrom(m); err != nil { if err := software.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
if software.String() != "test vector" { if software.String() != "test vector" {
t.Error("bad software: ", software) t.Error("bad software: ", software)
} }
if err := Fingerprint.Check(m); err != nil { if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err) t.Error("Check failed: ", err)
} }
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err := addr.GetFrom(m); err != nil { if err := addr.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) { if !addr.IP.Equal(net.ParseIP("192.0.2.1")) {
@@ -140,12 +140,12 @@ func TestRFC5769(t *testing.T) {
if addr.Port != 32853 { if addr.Port != 32853 {
t.Error("bad Port") t.Error("bad Port")
} }
if err := Fingerprint.Check(m); err != nil { if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err) t.Error("check failed: ", err)
} }
}) })
t.Run("IPv6", func(t *testing.T) { t.Run("IPv6", func(t *testing.T) {
m := &Message{ msg := &Message{
Raw: []byte("\x01\x01\x00\x48" + Raw: []byte("\x01\x01\x00\x48" +
"\x21\x12\xa4\x42" + "\x21\x12\xa4\x42" +
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" + "\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
@@ -162,21 +162,21 @@ func TestRFC5769(t *testing.T) {
"\xc8\xfb\x0b\x4c", "\xc8\xfb\x0b\x4c",
), ),
} }
if err := m.Decode(); err != nil { if err := msg.Decode(); err != nil {
t.Error(err) t.Error(err)
} }
software := new(Software) software := new(Software)
if err := software.GetFrom(m); err != nil { if err := software.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
if software.String() != "test vector" { if software.String() != "test vector" {
t.Error("bad software: ", software) t.Error("bad software: ", software)
} }
if err := Fingerprint.Check(m); err != nil { if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err) t.Error("Check failed: ", err)
} }
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err := addr.GetFrom(m); err != nil { if err := addr.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) { if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) {
@@ -185,7 +185,7 @@ func TestRFC5769(t *testing.T) {
if addr.Port != 32853 { if addr.Port != 32853 {
t.Error("bad Port") t.Error("bad Port")
} }
if err := Fingerprint.Check(m); err != nil { if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err) t.Error("check failed: ", err)
} }
}) })

View File

@@ -27,6 +27,7 @@ func readFullOrPanic(r io.Reader, v []byte) int {
if err != nil { if err != nil {
panic(err) //nolint panic(err) //nolint
} }
return n return n
} }
@@ -35,6 +36,7 @@ func writeOrPanic(w io.Writer, v []byte) int {
if err != nil { if err != nil {
panic(err) //nolint panic(err) //nolint
} }
return n return n
} }

View File

@@ -13,14 +13,19 @@ import (
var errUDPServerUnsupportedNetwork = errors.New("unsupported network") var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
// NewUDPServer creates an udp server for testing. The supplied handler function will be called with the request // NewUDPServer creates an udp server for testing.
// The supplied handler function will be called with the request
// and should be used to emulate the server behavior. // and should be used to emulate the server behavior.
//
//nolint:cyclop
func NewUDPServer( func NewUDPServer(
t *testing.T, t *testing.T,
network string, network string,
maxMessageSize int, maxMessageSize int,
handler func(req []byte) ([]byte, error), handler func(req []byte) ([]byte, error),
) (net.Addr, func(t *testing.T), error) { ) (net.Addr, func(t *testing.T), error) {
t.Helper()
var ip string var ip string
switch network { switch network {
case "udp4": case "udp4":
@@ -50,28 +55,34 @@ func NewUDPServer(
n, addr, err := udpConn.ReadFrom(bs) n, addr, err := udpConn.ReadFrom(bs)
if err != nil { if err != nil {
errCh <- err errCh <- err
return return
} }
resp, err := handler(bs[:n]) resp, err := handler(bs[:n])
if err != nil { if err != nil {
errCh <- err errCh <- err
return return
} }
_, err = udpConn.WriteTo(resp, addr) _, err = udpConn.WriteTo(resp, addr)
if err != nil { if err != nil {
errCh <- err errCh <- err
return return
} }
} }
}() }()
return serverAddr, func(t *testing.T) { return serverAddr, func(t *testing.T) {
t.Helper()
select { select {
case err := <-errCh: case err := <-errCh:
if err != nil { if err != nil {
t.Fatal(err) //nolint t.Fatal(err)
return return
} }
default: default:
@@ -79,7 +90,7 @@ func NewUDPServer(
err := udpConn.Close() err := udpConn.Close()
if err != nil { if err != nil {
t.Fatal(err) //nolint t.Fatal(err)
} }
<-errCh <-errCh

View File

@@ -10,7 +10,7 @@ func NewUsername(username string) Username {
// Username represents USERNAME attribute. // Username represents USERNAME attribute.
// //
// RFC 5389 Section 15.3 // RFC 5389 Section 15.3.
type Username []byte type Username []byte
func (u Username) String() string { func (u Username) String() string {
@@ -37,7 +37,7 @@ func NewRealm(realm string) Realm {
// Realm represents REALM attribute. // Realm represents REALM attribute.
// //
// RFC 5389 Section 15.7 // RFC 5389 Section 15.7.
type Realm []byte type Realm []byte
func (n Realm) String() string { func (n Realm) String() string {
@@ -60,7 +60,7 @@ const softwareRawMaxB = 763
// Software is SOFTWARE attribute. // Software is SOFTWARE attribute.
// //
// RFC 5389 Section 15.10 // RFC 5389 Section 15.10.
type Software []byte type Software []byte
func (s Software) String() string { func (s Software) String() string {
@@ -84,7 +84,7 @@ func (s *Software) GetFrom(m *Message) error {
// Nonce represents NONCE attribute. // Nonce represents NONCE attribute.
// //
// RFC 5389 Section 15.8 // RFC 5389 Section 15.8.
type Nonce []byte type Nonce []byte
// NewNonce returns new Nonce from string. // NewNonce returns new Nonce from string.
@@ -118,6 +118,7 @@ func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error {
return err return err
} }
m.Add(t, v) m.Add(t, v)
return nil return nil
} }
@@ -128,5 +129,6 @@ func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error {
return err return err
} }
*v = a *v = a
return nil return nil
} }

View File

@@ -13,26 +13,26 @@ import (
) )
func TestSoftware_GetFrom(t *testing.T) { func TestSoftware_GetFrom(t *testing.T) {
m := New() msg := New()
v := "Client v0.0.1" val := "Client v0.0.1"
m.Add(AttrSoftware, []byte(v)) msg.Add(AttrSoftware, []byte(val))
m.WriteHeader() msg.WriteHeader()
m2 := &Message{ m2 := &Message{
Raw: make([]byte, 0, 256), Raw: make([]byte, 0, 256),
} }
software := new(Software) software := new(Software)
if _, err := m2.ReadFrom(m.reader()); err != nil { if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err) t.Error(err)
} }
if err := software.GetFrom(m); err != nil { if err := software.GetFrom(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if software.String() != v { if software.String() != val {
t.Errorf("Expected %q, got %q.", v, software) t.Errorf("Expected %q, got %q.", val, software)
} }
sAttr, ok := m.Attributes.Get(AttrSoftware) sAttr, ok := msg.Attributes.Get(AttrSoftware)
if !ok { if !ok {
t.Error("software attribute should be found") t.Error("software attribute should be found")
} }
@@ -90,22 +90,22 @@ func BenchmarkUsername_GetFrom(b *testing.B) {
func TestUsername(t *testing.T) { func TestUsername(t *testing.T) {
username := "username" username := "username"
u := NewUsername(username) uName := NewUsername(username)
m := new(Message) msg := new(Message)
m.WriteHeader() msg.WriteHeader()
t.Run("Bad length", func(t *testing.T) { t.Run("Bad length", func(t *testing.T) {
badU := make(Username, 600) badU := make(Username, 600)
if err := badU.AddTo(m); !IsAttrSizeOverflow(err) { if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
} }
}) })
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
if err := u.AddTo(m); err != nil { if err := uName.AddTo(msg); err != nil {
t.Error("errored:", err) t.Error("errored:", err)
} }
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
got := new(Username) got := new(Username)
if err := got.GetFrom(m); err != nil { if err := got.GetFrom(msg); err != nil {
t.Error("errored:", err) t.Error("errored:", err)
} }
if got.String() != username { if got.String() != username {
@@ -136,10 +136,10 @@ func TestUsername(t *testing.T) {
} }
func TestRealm_GetFrom(t *testing.T) { func TestRealm_GetFrom(t *testing.T) {
m := New() msg := New()
v := "realm" val := "realm"
m.Add(AttrRealm, []byte(v)) msg.Add(AttrRealm, []byte(val))
m.WriteHeader() msg.WriteHeader()
m2 := &Message{ m2 := &Message{
Raw: make([]byte, 0, 256), Raw: make([]byte, 0, 256),
@@ -148,17 +148,17 @@ func TestRealm_GetFrom(t *testing.T) {
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) { if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
} }
if _, err := m2.ReadFrom(m.reader()); err != nil { if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err) t.Error(err)
} }
if err := r.GetFrom(m); err != nil { if err := r.GetFrom(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if r.String() != v { if r.String() != val {
t.Errorf("Expected %q, got %q.", v, r) t.Errorf("Expected %q, got %q.", val, r)
} }
rAttr, ok := m.Attributes.Get(AttrRealm) rAttr, ok := msg.Attributes.Get(AttrRealm)
if !ok { if !ok {
t.Error("realm attribute should be found") t.Error("realm attribute should be found")
} }
@@ -180,26 +180,26 @@ func TestRealm_AddTo_Invalid(t *testing.T) {
} }
func TestNonce_GetFrom(t *testing.T) { func TestNonce_GetFrom(t *testing.T) {
m := New() msg := New()
v := "example.org" val := "example.org"
m.Add(AttrNonce, []byte(v)) msg.Add(AttrNonce, []byte(val))
m.WriteHeader() msg.WriteHeader()
m2 := &Message{ m2 := &Message{
Raw: make([]byte, 0, 256), Raw: make([]byte, 0, 256),
} }
var nonce Nonce var nonce Nonce
if _, err := m2.ReadFrom(m.reader()); err != nil { if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err) t.Error(err)
} }
if err := nonce.GetFrom(m); err != nil { if err := nonce.GetFrom(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if nonce.String() != v { if nonce.String() != val {
t.Errorf("Expected %q, got %q.", v, nonce) t.Errorf("Expected %q, got %q.", val, nonce)
} }
nAttr, ok := m.Attributes.Get(AttrNonce) nAttr, ok := msg.Attributes.Get(AttrNonce)
if !ok { if !ok {
t.Error("nonce attribute should be found") t.Error("nonce attribute should be found")
} }

View File

@@ -7,7 +7,7 @@ import "errors"
// UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute. // UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute.
// //
// RFC 5389 Section 15.9 // RFC 5389 Section 15.9.
type UnknownAttributes []AttrType type UnknownAttributes []AttrType
func (a UnknownAttributes) String() string { func (a UnknownAttributes) String() string {
@@ -22,6 +22,7 @@ func (a UnknownAttributes) String() string {
s += ", " s += ", "
} }
} }
return s return s
} }
@@ -39,6 +40,7 @@ func (a UnknownAttributes) AddTo(m *Message) error {
bin.PutUint16(v[first:last], t.Value()) bin.PutUint16(v[first:last], t.Value())
} }
m.Add(AttrUnknownAttributes, v) m.Add(AttrUnknownAttributes, v)
return nil return nil
} }
@@ -62,5 +64,6 @@ func (a *UnknownAttributes) GetFrom(m *Message) error {
*a = append(*a, AttrType(bin.Uint16(v[first:last]))) *a = append(*a, AttrType(bin.Uint16(v[first:last])))
first = last first = last
} }
return nil return nil
} }

View File

@@ -8,26 +8,26 @@ import (
) )
func TestUnknownAttributes(t *testing.T) { func TestUnknownAttributes(t *testing.T) {
m := new(Message) msg := new(Message)
a := &UnknownAttributes{ attr := &UnknownAttributes{
AttrDontFragment, AttrDontFragment,
AttrChannelNumber, AttrChannelNumber,
} }
if a.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" { if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
t.Error("bad String:", a) t.Error("bad String:", attr)
} }
if (UnknownAttributes{}).String() != "<nil>" { if (UnknownAttributes{}).String() != "<nil>" {
t.Error("bad blank string") t.Error("bad blank string")
} }
if err := a.AddTo(m); err != nil { if err := attr.AddTo(msg); err != nil {
t.Error(err) t.Error(err)
} }
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
attrs := make(UnknownAttributes, 10) attrs := make(UnknownAttributes, 10)
if err := attrs.GetFrom(m); err != nil { if err := attrs.GetFrom(msg); err != nil {
t.Error(err) t.Error(err)
} }
for i, at := range *a { for i, at := range *attr {
if at != attrs[i] { if at != attrs[i] {
t.Error("expected", at, "!=", attrs[i]) t.Error("expected", at, "!=", attrs[i])
} }
@@ -44,8 +44,8 @@ func TestUnknownAttributes(t *testing.T) {
} }
func BenchmarkUnknownAttributes(b *testing.B) { func BenchmarkUnknownAttributes(b *testing.B) {
m := new(Message) msg := new(Message)
a := UnknownAttributes{ attr := UnknownAttributes{
AttrDontFragment, AttrDontFragment,
AttrChannelNumber, AttrChannelNumber,
AttrRealm, AttrRealm,
@@ -54,20 +54,20 @@ func BenchmarkUnknownAttributes(b *testing.B) {
b.Run("AddTo", func(b *testing.B) { b.Run("AddTo", func(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := a.AddTo(m); err != nil { if err := attr.AddTo(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
m.Reset() msg.Reset()
} }
}) })
b.Run("GetFrom", func(b *testing.B) { b.Run("GetFrom", func(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
if err := a.AddTo(m); err != nil { if err := attr.AddTo(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
attrs := make(UnknownAttributes, 0, 10) attrs := make(UnknownAttributes, 0, 10)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := attrs.GetFrom(m); err != nil { if err := attrs.GetFrom(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
attrs = attrs[:0] attrs = attrs[:0]

47
uri.go
View File

@@ -124,7 +124,7 @@ func (t ProtoType) String() string {
} }
} }
// URI represents a STUN (rfc7064) or TURN (rfc7065) URI // URI represents a STUN (rfc7064) or TURN (rfc7065) URI.
type URI struct { type URI struct {
Scheme SchemeType Scheme SchemeType
Host string Host string
@@ -137,73 +137,76 @@ type URI struct {
// ParseURI parses a STUN or TURN urls following the ABNF syntax described in // ParseURI parses a STUN or TURN urls following the ABNF syntax described in
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065 // https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
// respectively. // respectively.
func ParseURI(raw string) (*URI, error) { //nolint:gocognit func ParseURI(raw string) (*URI, error) { //nolint:gocognit,cyclop
rawParts, err := url.Parse(raw) rawParts, err := url.Parse(raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var u URI var uri URI
u.Scheme = NewSchemeType(rawParts.Scheme) uri.Scheme = NewSchemeType(rawParts.Scheme)
if u.Scheme == SchemeTypeUnknown { if uri.Scheme == SchemeTypeUnknown {
return nil, ErrSchemeType return nil, ErrSchemeType
} }
var rawPort string var rawPort string
if u.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { if uri.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { //nolint:nestif
var e *net.AddrError var e *net.AddrError
if errors.As(err, &e) { if errors.As(err, &e) {
if e.Err == "missing port in address" { if e.Err == "missing port in address" {
nextRawURL := u.Scheme.String() + ":" + rawParts.Opaque nextRawURL := uri.Scheme.String() + ":" + rawParts.Opaque
switch { switch {
case u.Scheme == SchemeTypeSTUN || u.Scheme == SchemeTypeTURN: case uri.Scheme == SchemeTypeSTUN || uri.Scheme == SchemeTypeTURN:
nextRawURL += ":3478" nextRawURL += ":3478"
if rawParts.RawQuery != "" { if rawParts.RawQuery != "" {
nextRawURL += "?" + rawParts.RawQuery nextRawURL += "?" + rawParts.RawQuery
} }
return ParseURI(nextRawURL) return ParseURI(nextRawURL)
case u.Scheme == SchemeTypeSTUNS || u.Scheme == SchemeTypeTURNS: case uri.Scheme == SchemeTypeSTUNS || uri.Scheme == SchemeTypeTURNS:
nextRawURL += ":5349" nextRawURL += ":5349"
if rawParts.RawQuery != "" { if rawParts.RawQuery != "" {
nextRawURL += "?" + rawParts.RawQuery nextRawURL += "?" + rawParts.RawQuery
} }
return ParseURI(nextRawURL) return ParseURI(nextRawURL)
} }
} }
} }
return nil, err return nil, err
} }
if u.Host == "" { if uri.Host == "" {
return nil, ErrHost return nil, ErrHost
} }
if u.Port, err = strconv.Atoi(rawPort); err != nil { if uri.Port, err = strconv.Atoi(rawPort); err != nil {
return nil, ErrPort return nil, ErrPort
} }
switch u.Scheme { switch uri.Scheme {
case SchemeTypeSTUN: case SchemeTypeSTUN:
qArgs, err := url.ParseQuery(rawParts.RawQuery) qArgs, err := url.ParseQuery(rawParts.RawQuery)
if err != nil || len(qArgs) > 0 { if err != nil || len(qArgs) > 0 {
return nil, ErrSTUNQuery return nil, ErrSTUNQuery
} }
u.Proto = ProtoTypeUDP uri.Proto = ProtoTypeUDP
case SchemeTypeSTUNS: case SchemeTypeSTUNS:
qArgs, err := url.ParseQuery(rawParts.RawQuery) qArgs, err := url.ParseQuery(rawParts.RawQuery)
if err != nil || len(qArgs) > 0 { if err != nil || len(qArgs) > 0 {
return nil, ErrSTUNQuery return nil, ErrSTUNQuery
} }
u.Proto = ProtoTypeTCP uri.Proto = ProtoTypeTCP
case SchemeTypeTURN: case SchemeTypeTURN:
proto, err := parseProto(rawParts.RawQuery) proto, err := parseProto(rawParts.RawQuery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u.Proto = proto uri.Proto = proto
if u.Proto == ProtoTypeUnknown { if uri.Proto == ProtoTypeUnknown {
u.Proto = ProtoTypeUDP uri.Proto = ProtoTypeUDP
} }
case SchemeTypeTURNS: case SchemeTypeTURNS:
proto, err := parseProto(rawParts.RawQuery) proto, err := parseProto(rawParts.RawQuery)
@@ -211,15 +214,15 @@ func ParseURI(raw string) (*URI, error) { //nolint:gocognit
return nil, err return nil, err
} }
u.Proto = proto uri.Proto = proto
if u.Proto == ProtoTypeUnknown { if uri.Proto == ProtoTypeUnknown {
u.Proto = ProtoTypeTCP uri.Proto = ProtoTypeTCP
} }
case SchemeTypeUnknown: case SchemeTypeUnknown:
} }
return &u, nil return &uri, nil
} }
func parseProto(raw string) (ProtoType, error) { func parseProto(raw string) (ProtoType, error) {
@@ -233,6 +236,7 @@ func parseProto(raw string) (ProtoType, error) {
if proto = NewProtoType(rawProto); proto == ProtoType(0) { if proto = NewProtoType(rawProto); proto == ProtoType(0) {
return ProtoTypeUnknown, ErrProtoType return ProtoTypeUnknown, ErrProtoType
} }
return proto, nil return proto, nil
} }
@@ -248,6 +252,7 @@ func (u URI) String() string {
if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS { if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS {
rawURL += "?transport=" + u.Proto.String() rawURL += "?transport=" + u.Proto.String()
} }
return rawURL return rawURL
} }

View File

@@ -35,8 +35,16 @@ func TestParseURL(t *testing.T) {
{"stun:[::1]:123", "stun:[::1]:123", SchemeTypeSTUN, false, "::1", 123, ProtoTypeUDP}, {"stun:[::1]:123", "stun:[::1]:123", SchemeTypeSTUN, false, "::1", 123, ProtoTypeUDP},
{"turn:google.de", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP}, {"turn:google.de", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
{"turns:google.de", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP}, {"turns:google.de", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
{"turn:google.de?transport=udp", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP}, {
{"turns:google.de?transport=tcp", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP}, "turn:google.de?transport=udp",
"turn:google.de:3478?transport=udp",
SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP,
},
{
"turns:google.de?transport=tcp",
"turns:google.de:5349?transport=tcp",
SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP,
},
} }
for i, testCase := range testCases { for i, testCase := range testCases {

View File

@@ -20,7 +20,7 @@ const (
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. // XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
// //
// RFC 5389 Section 15.2 // RFC 5389 Section 15.2.
type XORMappedAddress struct { type XORMappedAddress struct {
IP net.IP IP net.IP
Port int Port int
@@ -43,14 +43,15 @@ func isZeros(p net.IP) bool {
return false return false
} }
} }
return true return true
} }
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}. // ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
var ErrBadIPLength = errors.New("invalid length of IP value") var ErrBadIPLength = errors.New("invalid length of IP value")
// AddToAs adds XOR-MAPPED-ADDRESS value to m as t attribute. // AddToAs adds XOR-MAPPED-ADDRESS value to msg as attr attribute.
func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error { func (a XORMappedAddress) AddToAs(msg *Message, attr AttrType) error {
var ( var (
family = familyIPv4 family = familyIPv4
ip = a.IP ip = a.IP
@@ -67,12 +68,13 @@ func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error {
value := make([]byte, 32+128) value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len) xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:]) copy(xorValue[4:], msg.TransactionID[:])
bin.PutUint32(xorValue[0:4], magicCookie) bin.PutUint32(xorValue[0:4], magicCookie)
bin.PutUint16(value[0:2], family) bin.PutUint16(value[0:2], family)
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) //nolint:gosec // G115, false positive, port
xor.XorBytes(value[4:4+len(ip)], ip, xorValue) xor.XorBytes(value[4:4+len(ip)], ip, xorValue)
m.Add(t, value[:4+len(ip)]) msg.Add(attr, value[:4+len(ip)])
return nil return nil
} }
@@ -83,13 +85,13 @@ func (a XORMappedAddress) AddTo(m *Message) error {
} }
// GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message // GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message
// getting it as for t type. // getting it as for attr type.
func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error { func (a *XORMappedAddress) GetFromAs(msg *Message, attr AttrType) error {
v, err := m.Get(t) value, err := msg.Get(attr)
if err != nil { if err != nil {
return err return err
} }
family := bin.Uint16(v[0:2]) family := bin.Uint16(value[0:2])
if family != familyIPv6 && family != familyIPv4 { if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family", return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family), fmt.Sprintf("bad value %d", family),
@@ -110,17 +112,18 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
for i := range a.IP { for i := range a.IP {
a.IP[i] = 0 a.IP[i] = 0
} }
if len(v) <= 4 { if len(value) <= 4 {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
if err := CheckOverflow(t, len(v[4:]), len(a.IP)); err != nil { if err := CheckOverflow(attr, len(value[4:]), len(a.IP)); err != nil {
return err return err
} }
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16) a.Port = int(bin.Uint16(value[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 4+TransactionIDSize) xorValue := make([]byte, 4+TransactionIDSize)
bin.PutUint32(xorValue[0:4], magicCookie) bin.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:]) copy(xorValue[4:], msg.TransactionID[:])
xor.XorBytes(a.IP, v[4:], xorValue) xor.XorBytes(a.IP, value[4:], xorValue)
return nil return nil
} }

View File

@@ -29,21 +29,21 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
} }
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) { func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
m := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
copy(m.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
addrValue, err := hex.DecodeString("00019cd5f49f38ae") addrValue, err := hex.DecodeString("00019cd5f49f38ae")
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
m.Add(AttrXORMappedAddress, addrValue) msg.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := addr.GetFrom(m); err != nil { if err := addr.GetFrom(msg); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }
@@ -94,54 +94,54 @@ func TestXORMappedAddress_GetFrom(t *testing.T) {
} }
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) { func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
m := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
copy(m.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236") expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254 expectedPort := 21254
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err = addr.GetFrom(m); err == nil { if err = addr.GetFrom(msg); err == nil {
t.Fatal(err, "should be nil") t.Fatal(err, "should be nil")
} }
addr.IP = expectedIP addr.IP = expectedIP
addr.Port = expectedPort addr.Port = expectedPort
addr.AddTo(m) //nolint:errcheck,gosec addr.AddTo(msg) //nolint:errcheck,gosec
m.WriteHeader() msg.WriteHeader()
mRes := New() mRes := New()
binary.BigEndian.PutUint16(m.Raw[20+4:20+4+2], 0x21) binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21)
if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil { if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = addr.GetFrom(m); err == nil { if err = addr.GetFrom(msg); err == nil {
t.Fatal(err, "should not be nil") t.Fatal(err, "should not be nil")
} }
} }
func TestXORMappedAddress_AddTo(t *testing.T) { func TestXORMappedAddress_AddTo(t *testing.T) {
m := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
copy(m.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236") expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254 expectedPort := 21254
addr := &XORMappedAddress{ addr := &XORMappedAddress{
IP: net.ParseIP("213.141.156.236"), IP: net.ParseIP("213.141.156.236"),
Port: expectedPort, Port: expectedPort,
} }
if err = addr.AddTo(m); err != nil { if err = addr.AddTo(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
m.WriteHeader() msg.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.Write(m.Raw); err != nil { if _, err = mRes.Write(msg.Raw); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err = addr.GetFrom(mRes); err != nil { if err = addr.GetFrom(mRes); err != nil {
@@ -156,27 +156,27 @@ func TestXORMappedAddress_AddTo(t *testing.T) {
} }
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) { func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
m := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
copy(m.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009") expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
expectedPort := 21254 expectedPort := 21254
addr := &XORMappedAddress{ addr := &XORMappedAddress{
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"), IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
Port: 21254, Port: 21254,
} }
addr.AddTo(m) //nolint:errcheck,gosec addr.AddTo(msg) //nolint:errcheck,gosec
m.WriteHeader() msg.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil { if _, err = mRes.ReadFrom(msg.reader()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
gotAddr := new(XORMappedAddress) gotAddr := new(XORMappedAddress)
if err = gotAddr.GetFrom(m); err != nil { if err = gotAddr.GetFrom(msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !gotAddr.IP.Equal(expectedIP) { if !gotAddr.IP.Equal(expectedIP) {