mirror of
https://github.com/pion/stun.git
synced 2025-09-26 20:01:18 +08:00
Upgrade golangci-lint, more linters
Introduces new linters, upgrade golangci-lint to version (v1.63.4)
This commit is contained in:
@@ -25,17 +25,32 @@ linters-settings:
|
||||
- ^os.Exit$
|
||||
- ^panic$
|
||||
- ^print(ln)?$
|
||||
varnamelen:
|
||||
max-distance: 12
|
||||
min-name-length: 2
|
||||
ignore-type-assert-ok: true
|
||||
ignore-map-index-ok: true
|
||||
ignore-chan-recv-ok: true
|
||||
ignore-decls:
|
||||
- i int
|
||||
- n int
|
||||
- w io.Writer
|
||||
- r io.Reader
|
||||
- b []byte
|
||||
|
||||
linters:
|
||||
enable:
|
||||
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
||||
- bidichk # Checks for dangerous unicode character sequences
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||
- contextcheck # check the function whether use a non-inherited context
|
||||
- cyclop # checks function and package cyclomatic complexity
|
||||
- decorder # check declaration order and count of types, constants, variables and functions
|
||||
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||
- dupl # Tool for code clone detection
|
||||
- durationcheck # check for two durations multiplied together
|
||||
- err113 # Golang linter to check the errors handling expressions
|
||||
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
||||
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
|
||||
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
|
||||
@@ -46,18 +61,17 @@ linters:
|
||||
- forcetypeassert # finds forced type assertions
|
||||
- gci # Gci control golang package import order and make it always deterministic.
|
||||
- gochecknoglobals # Checks that no globals are present in Go code
|
||||
- gochecknoinits # Checks that no init functions are present in Go code
|
||||
- gocognit # Computes and checks the cognitive complexity of functions
|
||||
- goconst # Finds repeated strings that could be replaced by a constant
|
||||
- gocritic # The most opinionated Go source code linter
|
||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||
- godot # Check if comments end in a period
|
||||
- godox # Tool for detection of FIXME, TODO and other comment keywords
|
||||
- err113 # Golang linter to check the errors handling expressions
|
||||
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
||||
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
|
||||
- goheader # Checks is file header matches to pattern
|
||||
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
|
||||
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
|
||||
- gosec # Inspects source code for security problems
|
||||
- gosimple # Linter for Go source code that specializes in simplifying a code
|
||||
@@ -65,9 +79,15 @@ linters:
|
||||
- grouper # An analyzer to analyze expression groups.
|
||||
- importas # Enforces consistent import aliases
|
||||
- ineffassign # Detects when assignments to existing variables are not used
|
||||
- lll # Reports long lines
|
||||
- maintidx # maintidx measures the maintainability index of each function.
|
||||
- makezero # Finds slice declarations with non-zero initial length
|
||||
- misspell # Finds commonly misspelled English words in comments
|
||||
- nakedret # Finds naked returns in functions greater than a specified function length
|
||||
- nestif # Reports deeply nested if statements
|
||||
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
||||
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
|
||||
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||
- noctx # noctx finds sending http request without context.Context
|
||||
- predeclared # find code that shadows one of Go's predeclared identifiers
|
||||
- revive # golint replacement, finds style mistakes
|
||||
@@ -75,28 +95,22 @@ linters:
|
||||
- stylecheck # Stylecheck is a replacement for golint
|
||||
- tagliatelle # Checks the struct tags.
|
||||
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unconvert # Remove unnecessary type conversions
|
||||
- unparam # Reports unused function parameters
|
||||
- unused # Checks Go code for unused constants, variables, functions and types
|
||||
- varnamelen # checks that the length of a variable's name matches its scope
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
- whitespace # Tool for detection of leading and trailing whitespace
|
||||
disable:
|
||||
- depguard # Go linter that checks if package imports are in a list of acceptable packages
|
||||
- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||
- cyclop # checks function and package cyclomatic complexity
|
||||
- funlen # Tool for detection of long functions
|
||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||
- godot # Check if comments end in a period
|
||||
- gomnd # An analyzer to detect magic numbers.
|
||||
- gochecknoinits # Checks that no init functions are present in Go code
|
||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||
- interfacebloat # A linter that checks length of interface.
|
||||
- ireturn # Accept Interfaces, Return Concrete Types
|
||||
- lll # Reports long lines
|
||||
- maintidx # maintidx measures the maintainability index of each function.
|
||||
- makezero # Finds slice declarations with non-zero initial length
|
||||
- nakedret # Finds naked returns in functions greater than a specified function length
|
||||
- nestif # Reports deeply nested if statements
|
||||
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||
- mnd # An analyzer to detect magic numbers
|
||||
- nolintlint # Reports ill-formed or insufficient nolint directives
|
||||
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
|
||||
- prealloc # Finds slice declarations that could potentially be preallocated
|
||||
@@ -104,8 +118,7 @@ linters:
|
||||
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
||||
- testpackage # linter that makes you use a separate _test package
|
||||
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||
- varnamelen # checks that the length of a variable's name matches its scope
|
||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||
- wrapcheck # Checks that errors returned from external packages are wrapped
|
||||
- wsl # Whitespace Linter - Forces you to use empty lines!
|
||||
|
||||
@@ -123,3 +136,4 @@ issues:
|
||||
- path: cmd
|
||||
linters:
|
||||
- forbidigo
|
||||
|
||||
|
32
addr.go
32
addr.go
@@ -15,7 +15,7 @@ import (
|
||||
// This attribute is used only by servers for achieving backwards
|
||||
// compatibility with RFC 3489 clients.
|
||||
//
|
||||
// RFC 5389 Section 15.1
|
||||
// RFC 5389 Section 15.1.
|
||||
type MappedAddress struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
@@ -23,7 +23,7 @@ type MappedAddress struct {
|
||||
|
||||
// AlternateServer represents ALTERNATE-SERVER attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.11
|
||||
// RFC 5389 Section 15.11.
|
||||
type AlternateServer struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
@@ -31,7 +31,7 @@ type AlternateServer struct {
|
||||
|
||||
// ResponseOrigin represents RESPONSE-ORIGIN attribute.
|
||||
//
|
||||
// RFC 5780 Section 7.3
|
||||
// RFC 5780 Section 7.3.
|
||||
type ResponseOrigin struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
@@ -39,7 +39,7 @@ type ResponseOrigin struct {
|
||||
|
||||
// OtherAddress represents OTHER-ADDRESS attribute.
|
||||
//
|
||||
// RFC 5780 Section 7.4
|
||||
// RFC 5780 Section 7.4.
|
||||
type OtherAddress struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
@@ -48,12 +48,14 @@ type OtherAddress struct {
|
||||
// AddTo adds ALTERNATE-SERVER attribute to message.
|
||||
func (s *AlternateServer) AddTo(m *Message) error {
|
||||
a := (*MappedAddress)(s)
|
||||
|
||||
return a.AddToAs(m, AttrAlternateServer)
|
||||
}
|
||||
|
||||
// GetFrom decodes ALTERNATE-SERVER from message.
|
||||
func (s *AlternateServer) GetFrom(m *Message) error {
|
||||
a := (*MappedAddress)(s)
|
||||
|
||||
return a.GetFromAs(m, AttrAlternateServer)
|
||||
}
|
||||
|
||||
@@ -63,14 +65,14 @@ func (a MappedAddress) String() string {
|
||||
|
||||
// GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t.
|
||||
func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
|
||||
v, err := m.Get(t)
|
||||
value, err := m.Get(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(v) <= 4 {
|
||||
if len(value) <= 4 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
family := bin.Uint16(v[0:2])
|
||||
family := bin.Uint16(value[0:2])
|
||||
if family != familyIPv6 && family != familyIPv4 {
|
||||
return newDecodeErr("xor-mapped address", "family",
|
||||
fmt.Sprintf("bad value %d", family),
|
||||
@@ -91,13 +93,14 @@ func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
|
||||
for i := range a.IP {
|
||||
a.IP[i] = 0
|
||||
}
|
||||
a.Port = int(bin.Uint16(v[2:4]))
|
||||
copy(a.IP, v[4:])
|
||||
a.Port = int(bin.Uint16(value[2:4]))
|
||||
copy(a.IP, value[4:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddToAs adds MAPPED-ADDRESS value to m as t attribute.
|
||||
func (a *MappedAddress) AddToAs(m *Message, t AttrType) error {
|
||||
func (a *MappedAddress) AddToAs(msg *Message, attrType AttrType) error {
|
||||
var (
|
||||
family = familyIPv4
|
||||
ip = a.IP
|
||||
@@ -114,9 +117,10 @@ func (a *MappedAddress) AddToAs(m *Message, t AttrType) error {
|
||||
value := make([]byte, 128)
|
||||
value[0] = 0 // first 8 bits are zeroes
|
||||
bin.PutUint16(value[0:2], family)
|
||||
bin.PutUint16(value[2:4], uint16(a.Port))
|
||||
bin.PutUint16(value[2:4], uint16(a.Port)) //nolint:gosec //G115
|
||||
copy(value[4:], ip)
|
||||
m.Add(t, value[:4+len(ip)])
|
||||
msg.Add(attrType, value[:4+len(ip)])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -133,12 +137,14 @@ func (a *MappedAddress) GetFrom(m *Message) error {
|
||||
// AddTo adds OTHER-ADDRESS attribute to message.
|
||||
func (o *OtherAddress) AddTo(m *Message) error {
|
||||
a := (*MappedAddress)(o)
|
||||
|
||||
return a.AddToAs(m, AttrOtherAddress)
|
||||
}
|
||||
|
||||
// GetFrom decodes OTHER-ADDRESS from message.
|
||||
func (o *OtherAddress) GetFrom(m *Message) error {
|
||||
a := (*MappedAddress)(o)
|
||||
|
||||
return a.GetFromAs(m, AttrOtherAddress)
|
||||
}
|
||||
|
||||
@@ -149,12 +155,14 @@ func (o OtherAddress) String() string {
|
||||
// AddTo adds RESPONSE-ORIGIN attribute to message.
|
||||
func (o *ResponseOrigin) AddTo(m *Message) error {
|
||||
a := (*MappedAddress)(o)
|
||||
|
||||
return a.AddToAs(m, AttrResponseOrigin)
|
||||
}
|
||||
|
||||
// GetFrom decodes RESPONSE-ORIGIN from message.
|
||||
func (o *ResponseOrigin) GetFrom(m *Message) error {
|
||||
a := (*MappedAddress)(o)
|
||||
|
||||
return a.GetFromAs(m, AttrResponseOrigin)
|
||||
}
|
||||
|
||||
|
12
addr_test.go
12
addr_test.go
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func TestMappedAddress(t *testing.T) {
|
||||
m := new(Message)
|
||||
msg := new(Message)
|
||||
addr := &MappedAddress{
|
||||
IP: net.ParseIP("122.12.34.5"),
|
||||
Port: 5412,
|
||||
@@ -23,17 +23,17 @@ func TestMappedAddress(t *testing.T) {
|
||||
badAddr := &MappedAddress{
|
||||
IP: net.IP{1, 2, 3},
|
||||
}
|
||||
if err := badAddr.AddTo(m); err == nil {
|
||||
if err := badAddr.AddTo(msg); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
})
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
if err := addr.AddTo(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(MappedAddress)
|
||||
if err := got.GetFrom(m); err != nil {
|
||||
if err := got.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !got.IP.Equal(addr.IP) {
|
||||
@@ -46,9 +46,9 @@ func TestMappedAddress(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("Bad family", func(t *testing.T) {
|
||||
v, _ := m.Attributes.Get(AttrMappedAddress)
|
||||
v, _ := msg.Attributes.Get(AttrMappedAddress)
|
||||
v.Value[0] = 32
|
||||
if err := got.GetFrom(m); err == nil {
|
||||
if err := got.GetFrom(msg); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
})
|
||||
|
16
agent.go
16
agent.go
@@ -24,6 +24,7 @@ func NewAgent(h Handler) *Agent {
|
||||
transactions: make(map[transactionID]agentTransaction),
|
||||
handler: h,
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -80,6 +81,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
|
||||
a.mux.Lock()
|
||||
if a.closed {
|
||||
a.mux.Unlock()
|
||||
|
||||
return ErrAgentClosed
|
||||
}
|
||||
t, exists := a.transactions[id]
|
||||
@@ -93,6 +95,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
|
||||
TransactionID: t.id,
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -124,6 +127,7 @@ func (a *Agent) Start(id [TransactionIDSize]byte, deadline time.Time) error {
|
||||
id: id,
|
||||
deadline: deadline,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -147,6 +151,7 @@ func (a *Agent) Collect(gcTime time.Time) error {
|
||||
// All transactions should be already closed
|
||||
// during Close() call.
|
||||
a.mux.Unlock()
|
||||
|
||||
return ErrAgentClosed
|
||||
}
|
||||
// Adding all transactions with deadline before gcTime
|
||||
@@ -175,24 +180,27 @@ func (a *Agent) Collect(gcTime time.Time) error {
|
||||
event.TransactionID = id
|
||||
h(event)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process incoming message, synchronously passing it to handler.
|
||||
func (a *Agent) Process(m *Message) error {
|
||||
e := Event{
|
||||
event := Event{
|
||||
TransactionID: m.TransactionID,
|
||||
Message: m,
|
||||
}
|
||||
a.mux.Lock()
|
||||
if a.closed {
|
||||
a.mux.Unlock()
|
||||
|
||||
return ErrAgentClosed
|
||||
}
|
||||
h := a.handler
|
||||
delete(a.transactions, m.TransactionID)
|
||||
a.mux.Unlock()
|
||||
h(e)
|
||||
h(event)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -201,10 +209,12 @@ func (a *Agent) SetHandler(h Handler) error {
|
||||
a.mux.Lock()
|
||||
if a.closed {
|
||||
a.mux.Unlock()
|
||||
|
||||
return ErrAgentClosed
|
||||
}
|
||||
a.handler = h
|
||||
a.mux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -217,6 +227,7 @@ func (a *Agent) Close() error {
|
||||
a.mux.Lock()
|
||||
if a.closed {
|
||||
a.mux.Unlock()
|
||||
|
||||
return ErrAgentClosed
|
||||
}
|
||||
for _, t := range a.transactions {
|
||||
@@ -227,6 +238,7 @@ func (a *Agent) Close() error {
|
||||
a.closed = true
|
||||
a.handler = nil
|
||||
a.mux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -10,49 +10,49 @@ import (
|
||||
)
|
||||
|
||||
func TestAgent_ProcessInTransaction(t *testing.T) {
|
||||
m := New()
|
||||
a := NewAgent(func(e Event) {
|
||||
msg := New()
|
||||
agent := NewAgent(func(e Event) {
|
||||
if e.Error != nil {
|
||||
t.Errorf("got error: %s", e.Error)
|
||||
}
|
||||
if !e.Message.Equal(m) {
|
||||
t.Errorf("%s (got) != %s (expected)", e.Message, m)
|
||||
if !e.Message.Equal(msg) {
|
||||
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
|
||||
}
|
||||
})
|
||||
if err := m.NewTransactionID(); err != nil {
|
||||
if err := msg.NewTransactionID(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := a.Start(m.TransactionID, time.Time{}); err != nil {
|
||||
if err := agent.Start(msg.TransactionID, time.Time{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := a.Process(m); err != nil {
|
||||
if err := agent.Process(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := a.Close(); err != nil {
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_Process(t *testing.T) {
|
||||
m := New()
|
||||
a := NewAgent(func(e Event) {
|
||||
msg := New()
|
||||
agent := NewAgent(func(e Event) {
|
||||
if e.Error != nil {
|
||||
t.Errorf("got error: %s", e.Error)
|
||||
}
|
||||
if !e.Message.Equal(m) {
|
||||
t.Errorf("%s (got) != %s (expected)", e.Message, m)
|
||||
if !e.Message.Equal(msg) {
|
||||
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
|
||||
}
|
||||
})
|
||||
if err := m.NewTransactionID(); err != nil {
|
||||
if err := msg.NewTransactionID(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := a.Process(m); err != nil {
|
||||
if err := agent.Process(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := a.Close(); err != nil {
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := a.Process(m); !errors.Is(err, ErrAgentClosed) {
|
||||
if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("closed agent should return <%s>, but got <%s>",
|
||||
ErrAgentClosed, err,
|
||||
)
|
||||
@@ -60,27 +60,27 @@ func TestAgent_Process(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgent_Start(t *testing.T) {
|
||||
a := NewAgent(nil)
|
||||
agent := NewAgent(nil)
|
||||
id := NewTransactionID()
|
||||
deadline := time.Now().AddDate(0, 0, 1)
|
||||
if err := a.Start(id, deadline); err != nil {
|
||||
if err := agent.Start(id, deadline); err != nil {
|
||||
t.Errorf("failed to statt transaction: %s", err)
|
||||
}
|
||||
if err := a.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
|
||||
if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
|
||||
t.Errorf("duplicate start should return <%s>, got <%s>",
|
||||
ErrTransactionExists, err,
|
||||
)
|
||||
}
|
||||
if err := a.Close(); err != nil {
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
id = NewTransactionID()
|
||||
if err := a.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
|
||||
if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("start on closed agent should return <%s>, got <%s>",
|
||||
ErrAgentClosed, err,
|
||||
)
|
||||
}
|
||||
if err := a.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
|
||||
if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
|
||||
ErrAgentClosed, err,
|
||||
)
|
||||
@@ -89,18 +89,18 @@ func TestAgent_Start(t *testing.T) {
|
||||
|
||||
func TestAgent_Stop(t *testing.T) {
|
||||
called := make(chan Event, 1)
|
||||
a := NewAgent(func(e Event) {
|
||||
agent := NewAgent(func(e Event) {
|
||||
called <- e
|
||||
})
|
||||
if err := a.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
|
||||
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
|
||||
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
|
||||
}
|
||||
id := NewTransactionID()
|
||||
timeout := time.Millisecond * 200
|
||||
if err := a.Start(id, time.Now().Add(timeout)); err != nil {
|
||||
if err := agent.Start(id, time.Now().Add(timeout)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := a.Stop(id); err != nil {
|
||||
if err := agent.Stop(id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
select {
|
||||
@@ -113,19 +113,19 @@ func TestAgent_Stop(t *testing.T) {
|
||||
case <-time.After(timeout * 2):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
if err := a.Close(); err != nil {
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := a.Close(); !errors.Is(err, ErrAgentClosed) {
|
||||
if err := agent.Close(); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
|
||||
}
|
||||
if err := a.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
|
||||
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_GC(t *testing.T) {
|
||||
a := NewAgent(nil)
|
||||
func TestAgent_GC(t *testing.T) { //nolint:cyclop
|
||||
agent := NewAgent(nil)
|
||||
shouldTimeOutID := make(map[transactionID]bool)
|
||||
deadline := time.Date(2027, time.November, 21,
|
||||
23, 0, 0, 0,
|
||||
@@ -133,7 +133,7 @@ func TestAgent_GC(t *testing.T) {
|
||||
)
|
||||
gcDeadline := deadline.Add(-time.Second)
|
||||
deadlineNotGC := gcDeadline.AddDate(0, 0, -1)
|
||||
a.SetHandler(func(e Event) { //nolint:errcheck,gosec
|
||||
agent.SetHandler(func(e Event) { //nolint:errcheck,gosec
|
||||
id := e.TransactionID
|
||||
shouldTimeOut, found := shouldTimeOutID[id]
|
||||
if !found {
|
||||
@@ -149,67 +149,67 @@ func TestAgent_GC(t *testing.T) {
|
||||
for i := 0; i < 5; i++ {
|
||||
id := NewTransactionID()
|
||||
shouldTimeOutID[id] = false
|
||||
if err := a.Start(id, deadline); err != nil {
|
||||
if err := agent.Start(id, deadline); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
id := NewTransactionID()
|
||||
shouldTimeOutID[id] = true
|
||||
if err := a.Start(id, deadlineNotGC); err != nil {
|
||||
if err := agent.Start(id, deadlineNotGC); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := a.Collect(gcDeadline); err != nil {
|
||||
if err := agent.Collect(gcDeadline); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := a.Close(); err != nil {
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := a.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
|
||||
if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAgent_GC(b *testing.B) {
|
||||
a := NewAgent(nil)
|
||||
agent := NewAgent(nil)
|
||||
deadline := time.Now().AddDate(0, 0, 1)
|
||||
for i := 0; i < agentCollectCap; i++ {
|
||||
if err := a.Start(NewTransactionID(), deadline); err != nil {
|
||||
if err := agent.Start(NewTransactionID(), deadline); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if err := a.Close(); err != nil {
|
||||
if err := agent.Close(); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}()
|
||||
b.ReportAllocs()
|
||||
gcDeadline := deadline.Add(-time.Second)
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := a.Collect(gcDeadline); err != nil {
|
||||
if err := agent.Collect(gcDeadline); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAgent_Process(b *testing.B) {
|
||||
a := NewAgent(nil)
|
||||
agent := NewAgent(nil)
|
||||
deadline := time.Now().AddDate(0, 0, 1)
|
||||
for i := 0; i < 1000; i++ {
|
||||
if err := a.Start(NewTransactionID(), deadline); err != nil {
|
||||
if err := agent.Start(NewTransactionID(), deadline); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if err := a.Close(); err != nil {
|
||||
if err := agent.Close(); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}()
|
||||
b.ReportAllocs()
|
||||
m := MustBuild(TransactionID)
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := a.Process(m); err != nil {
|
||||
if err := agent.Process(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@@ -21,6 +21,7 @@ func (a Attributes) Get(t AttrType) (RawAttribute, bool) {
|
||||
return candidate, true
|
||||
}
|
||||
}
|
||||
|
||||
return RawAttribute{}, false
|
||||
}
|
||||
|
||||
@@ -77,7 +78,7 @@ const (
|
||||
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
|
||||
)
|
||||
|
||||
// Attributes from RFC 5780 NAT Behavior Discovery
|
||||
// Attributes from RFC 5780 NAT Behavior Discovery.
|
||||
const (
|
||||
AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST
|
||||
AttrPadding AttrType = 0x0026 // PADDING
|
||||
@@ -166,6 +167,7 @@ func (t AttrType) String() string {
|
||||
// Just return hex representation of unknown attribute type.
|
||||
return fmt.Sprintf("0x%x", uint16(t))
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -186,6 +188,7 @@ type RawAttribute struct {
|
||||
// the Length field.
|
||||
func (a RawAttribute) AddTo(m *Message) error {
|
||||
m.Add(a.Type, a.Value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -205,6 +208,7 @@ func (a RawAttribute) Equal(b RawAttribute) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -224,6 +228,7 @@ func (m *Message) Get(t AttrType) ([]byte, error) {
|
||||
if !ok {
|
||||
return nil, ErrAttributeNotFound
|
||||
}
|
||||
|
||||
return v.Value, nil
|
||||
}
|
||||
|
||||
@@ -240,6 +245,7 @@ func nearestPaddedValueLength(l int) int {
|
||||
if n < l {
|
||||
n += padding
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -250,5 +256,6 @@ func compatAttrType(val uint16) AttrType {
|
||||
if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN
|
||||
return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on)
|
||||
}
|
||||
|
||||
return AttrType(val)
|
||||
}
|
||||
|
@@ -44,13 +44,13 @@ func TestRawAttribute_AddTo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMessage_GetNoAllocs(t *testing.T) {
|
||||
m := New()
|
||||
NewSoftware("c").AddTo(m) //nolint:errcheck,gosec
|
||||
m.WriteHeader()
|
||||
msg := New()
|
||||
NewSoftware("c").AddTo(msg) //nolint:errcheck,gosec
|
||||
msg.WriteHeader()
|
||||
|
||||
t.Run("Default", func(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
m.Get(AttrSoftware) //nolint:errcheck,gosec
|
||||
msg.Get(AttrSoftware) //nolint:errcheck,gosec
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Error("allocated memory, but should not")
|
||||
@@ -58,7 +58,7 @@ func TestMessage_GetNoAllocs(t *testing.T) {
|
||||
})
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
m.Get(AttrOrigin) //nolint:errcheck,gosec
|
||||
msg.Get(AttrOrigin) //nolint:errcheck,gosec
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Error("allocated memory, but should not")
|
||||
|
@@ -17,6 +17,7 @@ func CheckSize(_ AttrType, got, expected int) error {
|
||||
if got == expected {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrAttributeSizeInvalid
|
||||
}
|
||||
|
||||
@@ -24,6 +25,7 @@ func checkHMAC(got, expected []byte) error {
|
||||
if hmac.Equal(got, expected) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrIntegrityMismatch
|
||||
}
|
||||
|
||||
@@ -31,6 +33,7 @@ func checkFingerprint(got, expected uint32) error {
|
||||
if got == expected {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrFingerprintMismatch
|
||||
}
|
||||
|
||||
@@ -44,6 +47,7 @@ func CheckOverflow(_ AttrType, got, maxVal int) error {
|
||||
if got <= maxVal {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrAttributeSizeOverflow
|
||||
}
|
||||
|
||||
|
126
client.go
126
client.go
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
)
|
||||
|
||||
// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI
|
||||
// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI.
|
||||
var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")
|
||||
|
||||
// Dial connects to the address on the named network and then
|
||||
@@ -31,10 +31,11 @@ func Dial(network, address string) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClient(conn)
|
||||
}
|
||||
|
||||
// DialConfig is used to pass configuration to DialURI()
|
||||
// DialConfig is used to pass configuration to DialURI().
|
||||
type DialConfig struct {
|
||||
DTLSConfig dtls.Config
|
||||
TLSConfig tls.Config
|
||||
@@ -44,7 +45,7 @@ type DialConfig struct {
|
||||
|
||||
// DialURI connect to the STUN/TURN URI and then
|
||||
// initializes Client on that connection, returning error if any.
|
||||
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) {
|
||||
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop
|
||||
var conn Connection
|
||||
var err error
|
||||
|
||||
@@ -203,7 +204,7 @@ const (
|
||||
// provide any API for it, so if you need to read application data, wrap the
|
||||
// connection with your (de-)multiplexer and pass the wrapper as conn.
|
||||
func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
|
||||
c := &Client{
|
||||
client := &Client{
|
||||
close: make(chan struct{}),
|
||||
c: conn,
|
||||
clock: systemClock(),
|
||||
@@ -214,32 +215,33 @@ func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
|
||||
closeConn: true,
|
||||
}
|
||||
for _, o := range options {
|
||||
o(c)
|
||||
o(client)
|
||||
}
|
||||
if c.c == nil {
|
||||
if client.c == nil {
|
||||
return nil, ErrNoConnection
|
||||
}
|
||||
if c.a == nil {
|
||||
c.a = NewAgent(nil)
|
||||
if client.a == nil {
|
||||
client.a = NewAgent(nil)
|
||||
}
|
||||
if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
|
||||
if err := client.a.SetHandler(client.handleAgentCallback); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.collector == nil {
|
||||
c.collector = &tickerCollector{
|
||||
if client.collector == nil {
|
||||
client.collector = &tickerCollector{
|
||||
close: make(chan struct{}),
|
||||
clock: c.clock,
|
||||
clock: client.clock,
|
||||
}
|
||||
}
|
||||
if err := c.collector.Start(c.rtoRate, func(t time.Time) {
|
||||
closedOrPanic(c.a.Collect(t))
|
||||
if err := client.collector.Start(client.rtoRate, func(t time.Time) {
|
||||
closedOrPanic(client.a.Collect(t))
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.wg.Add(1)
|
||||
go c.readUntilClosed()
|
||||
runtime.SetFinalizer(c, clientFinalizer)
|
||||
return c, nil
|
||||
client.wg.Add(1)
|
||||
go client.readUntilClosed()
|
||||
runtime.SetFinalizer(client, clientFinalizer)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func clientFinalizer(c *Client) {
|
||||
@@ -252,6 +254,7 @@ func clientFinalizer(c *Client) {
|
||||
}
|
||||
if err == nil {
|
||||
log.Println("client: called finalizer on non-closed client") // nolint
|
||||
|
||||
return
|
||||
}
|
||||
log.Println("client: called finalizer on non-closed client:", err) // nolint
|
||||
@@ -353,6 +356,7 @@ func (c *Client) start(t *clientTransaction) error {
|
||||
return ErrTransactionExists
|
||||
}
|
||||
c.t[t.id] = t
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -399,6 +403,7 @@ func sprintErr(err error) string {
|
||||
if err == nil {
|
||||
return "<nil>" //nolint:goconst
|
||||
}
|
||||
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
@@ -455,18 +460,21 @@ func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error
|
||||
select {
|
||||
case <-a.close:
|
||||
t.Stop()
|
||||
|
||||
return
|
||||
case <-t.C:
|
||||
f(a.clock.Now())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *tickerCollector) Close() error {
|
||||
close(a.close)
|
||||
a.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -481,6 +489,7 @@ func (c *Client) Close() error {
|
||||
c.mux.Lock()
|
||||
if c.closed {
|
||||
c.mux.Unlock()
|
||||
|
||||
return ErrClientClosed
|
||||
}
|
||||
c.closed = true
|
||||
@@ -498,6 +507,7 @@ func (c *Client) Close() error {
|
||||
if agentErr == nil && connErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return CloseErr{
|
||||
AgentErr: agentErr,
|
||||
ConnectionErr: connErr,
|
||||
@@ -566,6 +576,7 @@ func (c *Client) checkInit() error {
|
||||
if c == nil || c.c == nil || c.a == nil || c.close == nil {
|
||||
return ErrClientNotInitialized
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -590,6 +601,7 @@ func (c *Client) Do(m *Message, f func(Event)) error {
|
||||
return err
|
||||
}
|
||||
h.wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -611,80 +623,85 @@ var bufferPool = &sync.Pool{ //nolint:gochecknoglobals
|
||||
},
|
||||
}
|
||||
|
||||
func (c *Client) handleAgentCallback(e Event) {
|
||||
func (c *Client) handleAgentCallback(event Event) { //nolint:cyclop
|
||||
c.mux.Lock()
|
||||
if c.closed {
|
||||
c.mux.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
t, found := c.t[e.TransactionID]
|
||||
transaction, found := c.t[event.TransactionID]
|
||||
if found {
|
||||
delete(c.t, t.id)
|
||||
delete(c.t, transaction.id)
|
||||
}
|
||||
c.mux.Unlock()
|
||||
if !found {
|
||||
if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
|
||||
c.handler(e)
|
||||
if c.handler != nil && !errors.Is(event.Error, ErrTransactionStopped) {
|
||||
c.handler(event)
|
||||
}
|
||||
// Ignoring.
|
||||
return
|
||||
}
|
||||
if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
|
||||
if atomic.LoadInt32(&c.maxAttempts) <= transaction.attempt || event.Error == nil {
|
||||
// Transaction completed.
|
||||
t.handle(e)
|
||||
putClientTransaction(t)
|
||||
transaction.handle(event)
|
||||
putClientTransaction(transaction)
|
||||
|
||||
return
|
||||
}
|
||||
// Doing re-transmission.
|
||||
t.attempt++
|
||||
b := bufferPool.Get().(*buffer) //nolint:forcetypeassert
|
||||
b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
|
||||
defer bufferPool.Put(b)
|
||||
transaction.attempt++
|
||||
buff := bufferPool.Get().(*buffer) //nolint:forcetypeassert
|
||||
buff.buf = buff.buf[:copy(buff.buf[:cap(buff.buf)], transaction.raw)]
|
||||
defer bufferPool.Put(buff)
|
||||
var (
|
||||
now = c.clock.Now()
|
||||
timeOut = t.nextTimeout(now)
|
||||
id = t.id
|
||||
timeOut = transaction.nextTimeout(now)
|
||||
id = transaction.id
|
||||
)
|
||||
// Starting client transaction.
|
||||
if startErr := c.start(t); startErr != nil {
|
||||
if startErr := c.start(transaction); startErr != nil {
|
||||
c.delete(id)
|
||||
e.Error = startErr
|
||||
t.handle(e)
|
||||
putClientTransaction(t)
|
||||
event.Error = startErr
|
||||
transaction.handle(event)
|
||||
putClientTransaction(transaction)
|
||||
|
||||
return
|
||||
}
|
||||
// Starting agent transaction.
|
||||
if startErr := c.a.Start(id, timeOut); startErr != nil {
|
||||
c.delete(id)
|
||||
e.Error = startErr
|
||||
t.handle(e)
|
||||
putClientTransaction(t)
|
||||
event.Error = startErr
|
||||
transaction.handle(event)
|
||||
putClientTransaction(transaction)
|
||||
|
||||
return
|
||||
}
|
||||
// Writing message to connection again.
|
||||
_, writeErr := c.c.Write(b.buf)
|
||||
_, writeErr := c.c.Write(buff.buf)
|
||||
if writeErr != nil {
|
||||
c.delete(id)
|
||||
e.Error = writeErr
|
||||
event.Error = writeErr
|
||||
// Stopping agent transaction instead of waiting until it's deadline.
|
||||
// This will call handleAgentCallback with "ErrTransactionStopped" error
|
||||
// which will be ignored.
|
||||
if stopErr := c.a.Stop(id); stopErr != nil {
|
||||
// Failed to stop agent transaction. Wrapping the error in StopError.
|
||||
e.Error = StopErr{
|
||||
event.Error = StopErr{
|
||||
Err: stopErr,
|
||||
Cause: writeErr,
|
||||
}
|
||||
}
|
||||
t.handle(e)
|
||||
putClientTransaction(t)
|
||||
transaction.handle(event)
|
||||
putClientTransaction(transaction)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts transaction (if h set) and writes message to server, handler
|
||||
// is called asynchronously.
|
||||
func (c *Client) Start(m *Message, h Handler) error {
|
||||
func (c *Client) Start(msg *Message, handler Handler) error {
|
||||
if err := c.checkInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -694,34 +711,35 @@ func (c *Client) Start(m *Message, h Handler) error {
|
||||
if closed {
|
||||
return ErrClientClosed
|
||||
}
|
||||
if h != nil {
|
||||
if handler != nil {
|
||||
// Starting transaction only if h is set. Useful for indications.
|
||||
t := acquireClientTransaction()
|
||||
t.id = m.TransactionID
|
||||
t.id = msg.TransactionID
|
||||
t.start = c.clock.Now()
|
||||
t.h = h
|
||||
t.h = handler
|
||||
t.rto = time.Duration(atomic.LoadInt64(&c.rto))
|
||||
t.attempt = 0
|
||||
t.raw = append(t.raw[:0], m.Raw...)
|
||||
t.raw = append(t.raw[:0], msg.Raw...)
|
||||
t.calls = 0
|
||||
d := t.nextTimeout(t.start)
|
||||
if err := c.start(t); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.a.Start(m.TransactionID, d); err != nil {
|
||||
if err := c.a.Start(msg.TransactionID, d); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := m.WriteTo(c.c)
|
||||
if err != nil && h != nil {
|
||||
c.delete(m.TransactionID)
|
||||
_, err := msg.WriteTo(c.c)
|
||||
if err != nil && handler != nil {
|
||||
c.delete(msg.TransactionID)
|
||||
// Stopping transaction instead of waiting until deadline.
|
||||
if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
|
||||
if stopErr := c.a.Stop(msg.TransactionID); stopErr != nil {
|
||||
return StopErr{
|
||||
Err: stopErr,
|
||||
Cause: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
163
client_test.go
163
client_test.go
@@ -38,11 +38,13 @@ type TestAgent struct {
|
||||
|
||||
func (n *TestAgent) SetHandler(h Handler) error {
|
||||
n.h = h
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *TestAgent) Close() error {
|
||||
close(n.e)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,6 +56,7 @@ func (n *TestAgent) Start(id [TransactionIDSize]byte, _ time.Time) error {
|
||||
n.e <- Event{
|
||||
TransactionID: id,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -69,6 +72,7 @@ func (noopConnection) Write(b []byte) (int, error) {
|
||||
|
||||
func (noopConnection) Read([]byte) (int, error) {
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
@@ -136,6 +140,7 @@ func (t *testConnection) Close() error {
|
||||
return errClientAlreadyStopped
|
||||
}
|
||||
t.stopped = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -148,6 +153,7 @@ func (t *testConnection) Read(b []byte) (int, error) {
|
||||
if t.read != nil {
|
||||
return t.read(b)
|
||||
}
|
||||
|
||||
return copy(b, t.b), nil
|
||||
}
|
||||
|
||||
@@ -165,7 +171,7 @@ func TestClosedOrPanic(t *testing.T) {
|
||||
}()
|
||||
}
|
||||
|
||||
func TestClient_Start(t *testing.T) {
|
||||
func TestClient_Start(t *testing.T) { //nolint:cyclop
|
||||
response := MustBuild(TransactionID, BindingSuccess)
|
||||
response.Encode()
|
||||
write := make(chan struct{}, 1)
|
||||
@@ -178,6 +184,7 @@ func TestClient_Start(t *testing.T) {
|
||||
case <-read:
|
||||
t.Log("reading")
|
||||
copy(i, response.Raw)
|
||||
|
||||
return len(response.Raw), nil
|
||||
case <-time.After(time.Millisecond * 10):
|
||||
return 0, errClientReadTimedOut
|
||||
@@ -188,33 +195,34 @@ func TestClient_Start(t *testing.T) {
|
||||
select {
|
||||
case <-write:
|
||||
t.Log("writing")
|
||||
|
||||
return len(bytes), nil
|
||||
case <-time.After(time.Millisecond * 10):
|
||||
return 0, errClientWriteTimedOut
|
||||
}
|
||||
},
|
||||
}
|
||||
c, err := NewClient(conn)
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := c.Close(); err == nil {
|
||||
if err := client.Close(); err == nil {
|
||||
t.Error("second close should fail")
|
||||
}
|
||||
if err := c.Do(MustBuild(TransactionID), nil); err == nil {
|
||||
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
|
||||
t.Error("Do after Close should fail")
|
||||
}
|
||||
}()
|
||||
m := MustBuild(response, BindingRequest)
|
||||
msg := MustBuild(response, BindingRequest)
|
||||
t.Log("init")
|
||||
got := make(chan struct{})
|
||||
write <- struct{}{}
|
||||
t.Log("starting the first transaction")
|
||||
if err := c.Start(m, func(event Event) {
|
||||
if err := client.Start(msg, func(event Event) {
|
||||
t.Log("got first transaction callback")
|
||||
if event.Error != nil {
|
||||
t.Error(event.Error)
|
||||
@@ -224,7 +232,7 @@ func TestClient_Start(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Log("starting the second transaction")
|
||||
if err := c.Start(m, func(Event) {
|
||||
if err := client.Start(msg, func(Event) {
|
||||
t.Error("should not be called")
|
||||
}); !errors.Is(err, ErrTransactionExists) {
|
||||
t.Errorf("unexpected error %v", err)
|
||||
@@ -247,25 +255,25 @@ func TestClient_Do(t *testing.T) {
|
||||
return len(bytes), nil
|
||||
},
|
||||
}
|
||||
c, err := NewClient(conn)
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := c.Close(); err == nil {
|
||||
if err := client.Close(); err == nil {
|
||||
t.Error("second close should fail")
|
||||
}
|
||||
if err := c.Do(MustBuild(TransactionID), nil); err == nil {
|
||||
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
|
||||
t.Error("Do after Close should fail")
|
||||
}
|
||||
}()
|
||||
m := MustBuild(
|
||||
NewTransactionIDSetter(response.TransactionID),
|
||||
)
|
||||
if err := c.Do(m, func(event Event) {
|
||||
if err := client.Do(m, func(event Event) {
|
||||
if event.Error != nil {
|
||||
t.Error(event.Error)
|
||||
}
|
||||
@@ -273,13 +281,13 @@ func TestClient_Do(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
m = MustBuild(TransactionID)
|
||||
if err := c.Do(m, nil); err != nil {
|
||||
if err := client.Do(m, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseErr_Error(t *testing.T) {
|
||||
for id, c := range []struct {
|
||||
for id, testCase := range []struct {
|
||||
Err CloseErr
|
||||
Out string
|
||||
}{
|
||||
@@ -291,16 +299,16 @@ func TestCloseErr_Error(t *testing.T) {
|
||||
ConnectionErr: io.ErrUnexpectedEOF,
|
||||
}, "failed to close: unexpected EOF (connection), <nil> (agent)"},
|
||||
} {
|
||||
if out := c.Err.Error(); out != c.Out {
|
||||
if out := testCase.Err.Error(); out != testCase.Out {
|
||||
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
||||
id, c.Err, out, c.Out,
|
||||
id, testCase.Err, out, testCase.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopErr_Error(t *testing.T) {
|
||||
for id, c := range []struct {
|
||||
for id, testcase := range []struct {
|
||||
Err StopErr
|
||||
Out string
|
||||
}{
|
||||
@@ -312,9 +320,9 @@ func TestStopErr_Error(t *testing.T) {
|
||||
Cause: io.ErrUnexpectedEOF,
|
||||
}, "error while stopping due to unexpected EOF: <nil>"},
|
||||
} {
|
||||
if out := c.Err.Error(); out != c.Out {
|
||||
if out := testcase.Err.Error(); out != testcase.Out {
|
||||
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
||||
id, c.Err, out, c.Out,
|
||||
id, testcase.Err, out, testcase.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -352,7 +360,7 @@ func TestClientAgentError(t *testing.T) {
|
||||
return len(bytes), nil
|
||||
},
|
||||
}
|
||||
c, err := NewClient(conn,
|
||||
client, err := NewClient(conn,
|
||||
WithAgent(errorAgent{
|
||||
startErr: io.ErrUnexpectedEOF,
|
||||
}),
|
||||
@@ -361,15 +369,15 @@ func TestClientAgentError(t *testing.T) {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m := MustBuild(NewTransactionIDSetter(response.TransactionID))
|
||||
if err := c.Do(m, nil); err != nil {
|
||||
if err := client.Do(m, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := c.Do(m, func(event Event) {
|
||||
if err := client.Do(m, func(event Event) {
|
||||
if event.Error == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
@@ -384,20 +392,20 @@ func TestClientConnErr(t *testing.T) {
|
||||
return 0, io.ErrClosedPipe
|
||||
},
|
||||
}
|
||||
c, err := NewClient(conn)
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m := MustBuild(TransactionID)
|
||||
if err := c.Do(m, nil); err == nil {
|
||||
if err := client.Do(m, nil); err == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
if err := c.Do(m, NoopHandler()); err == nil {
|
||||
if err := client.Do(m, NoopHandler()); err == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
}
|
||||
@@ -408,7 +416,7 @@ func TestClientConnErrStopErr(t *testing.T) {
|
||||
return 0, io.ErrClosedPipe
|
||||
},
|
||||
}
|
||||
c, err := NewClient(conn,
|
||||
client, err := NewClient(conn,
|
||||
WithAgent(errorAgent{
|
||||
stopErr: io.ErrUnexpectedEOF,
|
||||
}),
|
||||
@@ -417,12 +425,12 @@ func TestClientConnErrStopErr(t *testing.T) {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m := MustBuild(TransactionID)
|
||||
if err := c.Do(m, NoopHandler()); err == nil {
|
||||
if err := client.Do(m, NoopHandler()); err == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
}
|
||||
@@ -556,11 +564,13 @@ func (a *gcWaitAgent) Stop([TransactionIDSize]byte) error {
|
||||
|
||||
func (a *gcWaitAgent) Close() error {
|
||||
close(a.gc)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *gcWaitAgent) Collect(time.Time) error {
|
||||
a.gc <- struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -617,6 +627,7 @@ func captureLog() (*bytes.Buffer, func()) {
|
||||
log.SetOutput(&buf)
|
||||
f := log.Flags()
|
||||
log.SetFlags(0)
|
||||
|
||||
return &buf, func() {
|
||||
log.SetFlags(f)
|
||||
log.SetOutput(os.Stderr)
|
||||
@@ -633,12 +644,12 @@ func TestClientFinalizer(t *testing.T) {
|
||||
return 0, io.ErrClosedPipe
|
||||
},
|
||||
}
|
||||
c, err := NewClient(conn)
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
clientFinalizer(c)
|
||||
clientFinalizer(c)
|
||||
clientFinalizer(client)
|
||||
clientFinalizer(client)
|
||||
response := MustBuild(TransactionID, BindingSuccess)
|
||||
response.Encode()
|
||||
conn = &testConnection{
|
||||
@@ -647,7 +658,7 @@ func TestClientFinalizer(t *testing.T) {
|
||||
return len(bytes), nil
|
||||
},
|
||||
}
|
||||
c, err = NewClient(conn,
|
||||
client, err = NewClient(conn,
|
||||
WithAgent(errorAgent{
|
||||
closeErr: io.ErrUnexpectedEOF,
|
||||
}),
|
||||
@@ -655,7 +666,7 @@ func TestClientFinalizer(t *testing.T) {
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
clientFinalizer(c)
|
||||
clientFinalizer(client)
|
||||
reader := bufio.NewScanner(buf)
|
||||
var lines int
|
||||
expectedLines := []string{
|
||||
@@ -700,6 +711,7 @@ func (m *manualCollector) Collect(t time.Time) {
|
||||
|
||||
func (m *manualCollector) Start(_ time.Duration, f func(t time.Time)) error {
|
||||
m.f = f
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -717,12 +729,14 @@ func (m *manualClock) Add(d time.Duration) time.Time {
|
||||
v := m.current.Add(d)
|
||||
m.current = v
|
||||
m.mux.Unlock()
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (m *manualClock) Now() time.Time {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
return m.current
|
||||
}
|
||||
|
||||
@@ -735,6 +749,7 @@ type manualAgent struct {
|
||||
|
||||
func (n *manualAgent) SetHandler(h Handler) error {
|
||||
n.h = h
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -748,6 +763,7 @@ func (n *manualAgent) Process(m *Message) error {
|
||||
if n.process != nil {
|
||||
return n.process(m)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -759,6 +775,7 @@ func (n *manualAgent) Stop(id [TransactionIDSize]byte) error {
|
||||
if n.stop != nil {
|
||||
return n.stop(id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -788,9 +805,10 @@ func TestClientRetransmission(t *testing.T) {
|
||||
Message: response,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
c, err := NewClient(connR,
|
||||
client, err := NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -799,7 +817,7 @@ func TestClientRetransmission(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.SetRTO(time.Second)
|
||||
client.SetRTO(time.Second)
|
||||
gotReads := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
@@ -819,7 +837,7 @@ func TestClientRetransmission(t *testing.T) {
|
||||
}
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if event.Error != nil {
|
||||
t.Error("failed")
|
||||
}
|
||||
@@ -829,7 +847,9 @@ func TestClientRetransmission(t *testing.T) {
|
||||
<-gotReads
|
||||
}
|
||||
|
||||
func testClientDoConcurrent(t *testing.T, concurrency int) {
|
||||
func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
|
||||
t.Helper()
|
||||
|
||||
response := MustBuild(TransactionID, BindingSuccess)
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
@@ -846,9 +866,10 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
|
||||
TransactionID: id,
|
||||
Message: response,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
c, err := NewClient(connR,
|
||||
client, err := NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -856,7 +877,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.SetRTO(time.Second)
|
||||
client.SetRTO(time.Second)
|
||||
conns := new(sync.WaitGroup)
|
||||
wg := new(sync.WaitGroup)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
@@ -880,7 +901,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if doErr := c.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
|
||||
if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
|
||||
if event.Error != nil {
|
||||
t.Error("failed")
|
||||
}
|
||||
@@ -962,14 +983,14 @@ func TestClient_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientDefaultHandler(t *testing.T) {
|
||||
a := &TestAgent{
|
||||
agent := &TestAgent{
|
||||
e: make(chan Event),
|
||||
}
|
||||
id := NewTransactionID()
|
||||
handlerCalled := make(chan struct{})
|
||||
called := false
|
||||
c, createErr := NewClient(noopConnection{},
|
||||
WithAgent(a),
|
||||
client, createErr := NewClient(noopConnection{},
|
||||
WithAgent(agent),
|
||||
WithHandler(func(e Event) {
|
||||
if called {
|
||||
t.Error("should not be called twice")
|
||||
@@ -985,7 +1006,7 @@ func TestClientDefaultHandler(t *testing.T) {
|
||||
t.Fatal(createErr)
|
||||
}
|
||||
go func() {
|
||||
a.h(Event{
|
||||
agent.h(Event{
|
||||
TransactionID: id,
|
||||
})
|
||||
}()
|
||||
@@ -995,11 +1016,11 @@ func TestClientDefaultHandler(t *testing.T) {
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
if closeErr := c.Close(); closeErr != nil {
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
t.Error(closeErr)
|
||||
}
|
||||
// Handler call should be ignored.
|
||||
a.h(Event{})
|
||||
agent.h(Event{})
|
||||
}
|
||||
|
||||
func TestClientClosedStart(t *testing.T) {
|
||||
@@ -1047,9 +1068,10 @@ func TestWithNoRetransmit(t *testing.T) {
|
||||
Error: ErrTransactionTimeOut,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
c, err := NewClient(connR,
|
||||
client, err := NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -1071,7 +1093,7 @@ func TestWithNoRetransmit(t *testing.T) {
|
||||
}
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if !errors.Is(event.Error, ErrTransactionTimeOut) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
@@ -1087,7 +1109,7 @@ func (c callbackClock) Now() time.Time {
|
||||
return c()
|
||||
}
|
||||
|
||||
func TestClientRTOStartErr(t *testing.T) {
|
||||
func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
|
||||
response := MustBuild(TransactionID, BindingSuccess)
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
@@ -1116,13 +1138,14 @@ func TestClientRTOStartErr(t *testing.T) {
|
||||
} else {
|
||||
t.Log("clock returned")
|
||||
}
|
||||
|
||||
return time.Now()
|
||||
})
|
||||
agent := &manualAgent{}
|
||||
attempt := 0
|
||||
gotReads := make(chan struct{})
|
||||
var (
|
||||
c *Client
|
||||
client *Client
|
||||
startClientErr error
|
||||
)
|
||||
agent.start = func(id [TransactionIDSize]byte, _ time.Time) error {
|
||||
@@ -1146,7 +1169,7 @@ func TestClientRTOStartErr(t *testing.T) {
|
||||
t.Log("clock locked")
|
||||
<-clockLocked
|
||||
t.Log("closing client")
|
||||
if closeErr := c.Close(); closeErr != nil {
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
t.Error(closeErr)
|
||||
}
|
||||
t.Log("client closed, unlocking clock")
|
||||
@@ -1154,9 +1177,10 @@ func TestClientRTOStartErr(t *testing.T) {
|
||||
t.Log("clock unlocked")
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
c, startClientErr = NewClient(connR,
|
||||
client, startClientErr = NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -1186,7 +1210,7 @@ func TestClientRTOStartErr(t *testing.T) {
|
||||
t.Log("starting")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if !errors.Is(event.Error, ErrClientClosed) {
|
||||
t.Error(event.Error)
|
||||
}
|
||||
@@ -1203,7 +1227,7 @@ func TestClientRTOStartErr(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRTOWriteErr(t *testing.T) {
|
||||
func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
|
||||
response := MustBuild(TransactionID, BindingSuccess)
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
@@ -1232,13 +1256,14 @@ func TestClientRTOWriteErr(t *testing.T) {
|
||||
} else {
|
||||
t.Log("clock returned")
|
||||
}
|
||||
|
||||
return time.Now()
|
||||
})
|
||||
agent := &manualAgent{}
|
||||
attempt := 0
|
||||
gotReads := make(chan struct{})
|
||||
var (
|
||||
c *Client
|
||||
client *Client
|
||||
startClientErr error
|
||||
)
|
||||
agentStopErr := errClientAgentCantStop
|
||||
@@ -1274,9 +1299,10 @@ func TestClientRTOWriteErr(t *testing.T) {
|
||||
t.Log("clock unlocked")
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
c, startClientErr = NewClient(connR,
|
||||
client, startClientErr = NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -1306,7 +1332,7 @@ func TestClientRTOWriteErr(t *testing.T) {
|
||||
t.Log("starting")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
var e StopErr
|
||||
if !errors.As(event.Error, &e) {
|
||||
t.Error(event.Error)
|
||||
@@ -1346,7 +1372,7 @@ func TestClientRTOAgentErr(t *testing.T) {
|
||||
attempt := 0
|
||||
gotReads := make(chan struct{})
|
||||
var (
|
||||
c *Client
|
||||
client *Client
|
||||
startClientErr error
|
||||
)
|
||||
agentStartErr := errClientStartRefused
|
||||
@@ -1361,9 +1387,10 @@ func TestClientRTOAgentErr(t *testing.T) {
|
||||
} else {
|
||||
return agentStartErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
c, startClientErr = NewClient(connR,
|
||||
client, startClientErr = NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -1384,7 +1411,7 @@ func TestClientRTOAgentErr(t *testing.T) {
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
t.Log("starting")
|
||||
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if !errors.Is(event.Error, agentStartErr) {
|
||||
t.Error(event.Error)
|
||||
}
|
||||
@@ -1415,9 +1442,10 @@ func TestClient_HandleProcessError(t *testing.T) {
|
||||
processCalled := make(chan struct{}, 1)
|
||||
agent.process = func(*Message) error {
|
||||
processCalled <- struct{}{}
|
||||
|
||||
return ErrAgentClosed
|
||||
}
|
||||
c, startClientErr := NewClient(connR,
|
||||
client, startClientErr := NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -1440,7 +1468,7 @@ func TestClient_HandleProcessError(t *testing.T) {
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Error("reads timeout")
|
||||
}
|
||||
if closeErr := c.Close(); closeErr != nil {
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
t.Error(closeErr)
|
||||
}
|
||||
}
|
||||
@@ -1475,9 +1503,10 @@ func TestClientImmediateTimeout(t *testing.T) {
|
||||
Error: ErrTransactionTimeOut,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
c, err := NewClient(connR,
|
||||
client, err := NewClient(connR,
|
||||
WithAgent(agent),
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
@@ -1498,7 +1527,7 @@ func TestClientImmediateTimeout(t *testing.T) {
|
||||
}
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
c.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
|
||||
client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
|
||||
if errors.Is(e.Error, ErrTransactionTimeOut) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ var (
|
||||
realRand = flag.Bool("crypt", false, "use crypto/rand as random source") //nolint:gochecknoglobals
|
||||
)
|
||||
|
||||
func main() { //nolint:gocognit
|
||||
func main() { //nolint:gocognit,cyclop
|
||||
flag.Parse()
|
||||
uri, err := stun.ParseURI(*uriStr)
|
||||
if err != nil {
|
||||
@@ -88,7 +88,7 @@ func main() { //nolint:gocognit
|
||||
log.Print("Using crypto/rand as random source for transaction id")
|
||||
}
|
||||
for i := 0; i < *workers; i++ {
|
||||
c, clientErr := stun.DialURI(uri, &stun.DialConfig{})
|
||||
client, clientErr := stun.DialURI(uri, &stun.DialConfig{})
|
||||
if clientErr != nil {
|
||||
log.Panicf("Failed to create client: %s", clientErr)
|
||||
}
|
||||
@@ -105,12 +105,13 @@ func main() { //nolint:gocognit
|
||||
req.Type = stun.BindingRequest
|
||||
req.WriteHeader()
|
||||
atomic.AddInt64(&request, 1)
|
||||
if doErr := c.Do(req, func(event stun.Event) {
|
||||
if doErr := client.Do(req, func(event stun.Event) {
|
||||
if event.Error != nil {
|
||||
if !errors.Is(event.Error, stun.ErrTransactionTimeOut) {
|
||||
log.Printf("Failed STUN transaction: %s", event.Error)
|
||||
}
|
||||
atomic.AddInt64(&requestErr, 1)
|
||||
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&requestOK, 1)
|
||||
|
@@ -31,11 +31,11 @@ func main() {
|
||||
}
|
||||
|
||||
// we only try the first address, so restrict ourselves to IPv4
|
||||
c, err := stun.DialURI(uri, &stun.DialConfig{})
|
||||
client, err := stun.DialURI(uri, &stun.DialConfig{})
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to dial: %s", err)
|
||||
}
|
||||
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||||
if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||||
if res.Error != nil {
|
||||
log.Fatalf("Failed STUN transaction: %s", res.Error)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func main() {
|
||||
}); err != nil {
|
||||
log.Fatal("Do:", err)
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
if err := client.Close(); err != nil {
|
||||
log.Fatalf("Failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@@ -84,7 +84,7 @@ func multiplex(conn *net.UDPConn, stunAddr net.Addr, stunConn io.Reader) {
|
||||
|
||||
var stunServer = flag.String("stun", "stun.l.google.com:19302", "STUN Server to use") //nolint:gochecknoglobals
|
||||
|
||||
func main() {
|
||||
func main() { //nolint:cyclop
|
||||
flag.Parse()
|
||||
|
||||
isServer := flag.Arg(0) == ""
|
||||
@@ -112,7 +112,7 @@ func main() {
|
||||
|
||||
stunL, stunR := net.Pipe()
|
||||
|
||||
c, err := stun.NewClient(stunR)
|
||||
client, err := stun.NewClient(stunR)
|
||||
if err != nil {
|
||||
log.Panicf("Failed to create client: %s", err)
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func main() {
|
||||
// This can fail if your NAT Server is strict and will use separate ports
|
||||
// for application data and STUN
|
||||
var gotAddr stun.XORMappedAddress
|
||||
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||||
if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
|
||||
if res.Error != nil {
|
||||
log.Panicf("Failed STUN transaction: %s", res.Error)
|
||||
}
|
||||
@@ -153,7 +153,7 @@ func main() {
|
||||
// Any ping-pong will work, but we are just making binding requests.
|
||||
// Note that STUN Server is not mandatory for keep alive, application
|
||||
// data will keep alive that binding too.
|
||||
go keepAlive(c)
|
||||
go keepAlive(client)
|
||||
|
||||
notify := make(chan os.Signal, 1)
|
||||
signal.Notify(notify, os.Interrupt, syscall.SIGTERM)
|
||||
@@ -168,6 +168,7 @@ func main() {
|
||||
}
|
||||
case <-notify:
|
||||
log.Println("Stopping")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -203,10 +204,12 @@ func main() {
|
||||
|
||||
case m := <-messages:
|
||||
log.Printf("Got response from %s: %s", m.addr, m.text)
|
||||
|
||||
return
|
||||
|
||||
case <-notify:
|
||||
log.Print("Stopping")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@@ -30,10 +30,14 @@ func (c *stunServerConn) Close() error {
|
||||
}
|
||||
|
||||
var (
|
||||
addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address") //nolint:gochecknoglobals
|
||||
timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response") //nolint:gochecknoglobals
|
||||
verbose = flag.Int("verbose", 1, "the verbosity level") //nolint:gochecknoglobals
|
||||
log logging.LeveledLogger //nolint:gochecknoglobals
|
||||
//nolint:gochecknoglobals
|
||||
addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address")
|
||||
//nolint:gochecknoglobals
|
||||
timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response")
|
||||
//nolint:gochecknoglobals
|
||||
verbose = flag.Int("verbose", 1, "the verbosity level")
|
||||
//nolint:gochecknoglobals
|
||||
log logging.LeveledLogger
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -70,11 +74,12 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// RFC5780: 4.3. Determining NAT Mapping Behavior
|
||||
func mappingTests(addrStr string) error {
|
||||
// RFC5780: 4.3. Determining NAT Mapping Behavior.
|
||||
func mappingTests(addrStr string) error { //nolint:cyclop
|
||||
mapTestConn, err := connect(addrStr)
|
||||
if err != nil {
|
||||
log.Warnf("Error creating STUN connection: %s", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -91,11 +96,13 @@ func mappingTests(addrStr string) error {
|
||||
resps1 := parse(resp)
|
||||
if resps1.xorAddr == nil || resps1.otherAddr == nil {
|
||||
log.Info("Error: NAT discovery feature not supported by this server")
|
||||
|
||||
return errNoOtherAddress
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String())
|
||||
if err != nil {
|
||||
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps1.otherAddr)
|
||||
|
||||
return err
|
||||
}
|
||||
mapTestConn.OtherAddr = addr
|
||||
@@ -104,6 +111,7 @@ func mappingTests(addrStr string) error {
|
||||
// Assert mapping behavior
|
||||
if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() {
|
||||
log.Warn("=> NAT mapping behavior: endpoint independent (no NAT)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -121,6 +129,7 @@ func mappingTests(addrStr string) error {
|
||||
log.Infof("Received XOR-MAPPED-ADDRESS: %v", resps2.xorAddr)
|
||||
if resps2.xorAddr.String() == resps1.xorAddr.String() {
|
||||
log.Warn("=> NAT mapping behavior: endpoint independent")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -143,11 +152,12 @@ func mappingTests(addrStr string) error {
|
||||
return mapTestConn.Close()
|
||||
}
|
||||
|
||||
// RFC5780: 4.4. Determining NAT Filtering Behavior
|
||||
func filteringTests(addrStr string) error {
|
||||
// RFC5780: 4.4. Determining NAT Filtering Behavior.
|
||||
func filteringTests(addrStr string) error { //nolint:cyclop
|
||||
mapTestConn, err := connect(addrStr)
|
||||
if err != nil {
|
||||
log.Warnf("Error creating STUN connection: %s", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -162,11 +172,13 @@ func filteringTests(addrStr string) error {
|
||||
resps := parse(resp)
|
||||
if resps.xorAddr == nil || resps.otherAddr == nil {
|
||||
log.Warn("Error: NAT discovery feature not supported by this server")
|
||||
|
||||
return errNoOtherAddress
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String())
|
||||
if err != nil {
|
||||
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps.otherAddr)
|
||||
|
||||
return err
|
||||
}
|
||||
mapTestConn.OtherAddr = addr
|
||||
@@ -180,6 +192,7 @@ func filteringTests(addrStr string) error {
|
||||
if err == nil {
|
||||
parse(resp) // just to print out the resp
|
||||
log.Warn("=> NAT filtering behavior: endpoint independent")
|
||||
|
||||
return nil
|
||||
} else if !errors.Is(err, errTimedOut) {
|
||||
return err // something else went wrong
|
||||
@@ -201,7 +214,7 @@ func filteringTests(addrStr string) error {
|
||||
return mapTestConn.Close()
|
||||
}
|
||||
|
||||
// Parse a STUN message
|
||||
// Parse a STUN message.
|
||||
func parse(msg *stun.Message) (ret struct {
|
||||
xorAddr *stun.XORMappedAddress
|
||||
otherAddr *stun.OtherAddress
|
||||
@@ -249,15 +262,17 @@ func parse(msg *stun.Message) (ret struct {
|
||||
log.Debugf("\t%v (l=%v)", attr, attr.Length)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Given an address string, returns a StunServerConn
|
||||
// Given an address string, returns a StunServerConn.
|
||||
func connect(addrStr string) (*stunServerConn, error) {
|
||||
log.Infof("Connecting to STUN server: %s", addrStr)
|
||||
addr, err := net.ResolveUDPAddr("udp4", addrStr)
|
||||
if err != nil {
|
||||
log.Warnf("Error resolving address: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -278,7 +293,7 @@ func connect(addrStr string) (*stunServerConn, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Send request and wait for response or timeout
|
||||
// Send request and wait for response or timeout.
|
||||
func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Message, error) {
|
||||
_ = msg.NewTransactionID()
|
||||
log.Infof("Sending to %v: (%v bytes)", addr, msg.Length+messageHeaderSize)
|
||||
@@ -289,6 +304,7 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
|
||||
_, err := c.conn.WriteTo(msg.Raw, addr)
|
||||
if err != nil {
|
||||
log.Warnf("Error sending request to %v", addr)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -298,9 +314,11 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
|
||||
if !ok {
|
||||
return nil, errResponseMessage
|
||||
}
|
||||
|
||||
return m, nil
|
||||
case <-time.After(time.Duration(*timeoutPtr) * time.Second):
|
||||
log.Infof("Timed out waiting for response from server %v", addr)
|
||||
|
||||
return nil, errTimedOut
|
||||
}
|
||||
}
|
||||
@@ -315,6 +333,7 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
|
||||
n, addr, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
close(messages)
|
||||
|
||||
return
|
||||
}
|
||||
log.Infof("Response from %v: (%v bytes)", addr, n)
|
||||
@@ -326,11 +345,13 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
|
||||
if err != nil {
|
||||
log.Infof("Error decoding message: %v", err)
|
||||
close(messages)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
messages <- m
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
@@ -26,7 +26,7 @@ const (
|
||||
timeoutMillis = 500
|
||||
)
|
||||
|
||||
func main() { //nolint:gocognit
|
||||
func main() { //nolint:gocognit,cyclop
|
||||
flag.Parse()
|
||||
|
||||
srvAddr, err := net.ResolveUDPAddr(udp, *server)
|
||||
@@ -87,11 +87,13 @@ func main() { //nolint:gocognit
|
||||
decErr := m.Decode()
|
||||
if decErr != nil {
|
||||
log.Println("decode:", decErr)
|
||||
|
||||
break
|
||||
}
|
||||
var xorAddr stun.XORMappedAddress
|
||||
if getErr := xorAddr.GetFrom(m); getErr != nil {
|
||||
log.Println("getFrom:", getErr)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
@@ -160,6 +162,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
|
||||
n, _, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
close(messages)
|
||||
|
||||
return
|
||||
}
|
||||
buf = buf[:n]
|
||||
@@ -167,6 +170,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
|
||||
messages <- buf
|
||||
}
|
||||
}()
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
|
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/pion/stun/v3"
|
||||
)
|
||||
|
||||
func test(network string) {
|
||||
func test(network string) { //nolint:cyclop
|
||||
addr := resolve(network)
|
||||
fmt.Println("START", strings.ToUpper(addr.Network())) //nolint
|
||||
var (
|
||||
|
25
errorcode.go
25
errorcode.go
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
// ErrorCodeAttribute represents ERROR-CODE attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.6
|
||||
// RFC 5389 Section 15.6.
|
||||
type ErrorCodeAttribute struct {
|
||||
Code ErrorCode
|
||||
Reason []byte
|
||||
@@ -31,7 +31,7 @@ const (
|
||||
)
|
||||
|
||||
// AddTo adds ERROR-CODE to m.
|
||||
func (c ErrorCodeAttribute) AddTo(m *Message) error {
|
||||
func (c ErrorCodeAttribute) AddTo(msg *Message) error {
|
||||
value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB)
|
||||
if err := CheckOverflow(AttrErrorCode,
|
||||
len(c.Reason)+errorCodeReasonStart,
|
||||
@@ -45,26 +45,28 @@ func (c ErrorCodeAttribute) AddTo(m *Message) error {
|
||||
value[errorCodeClassByte] = class
|
||||
value[errorCodeNumberByte] = number
|
||||
copy(value[errorCodeReasonStart:], c.Reason)
|
||||
m.Add(AttrErrorCode, value)
|
||||
msg.Add(AttrErrorCode, value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid.
|
||||
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
|
||||
v, err := m.Get(AttrErrorCode)
|
||||
value, err := m.Get(AttrErrorCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(v) < errorCodeReasonStart {
|
||||
if len(value) < errorCodeReasonStart {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
var (
|
||||
class = uint16(v[errorCodeClassByte])
|
||||
number = uint16(v[errorCodeNumberByte])
|
||||
class = uint16(value[errorCodeClassByte])
|
||||
number = uint16(value[errorCodeNumberByte])
|
||||
code = int(class*errorCodeModulo + number)
|
||||
)
|
||||
c.Code = ErrorCode(code)
|
||||
c.Reason = v[errorCodeReasonStart:]
|
||||
c.Reason = value[errorCodeReasonStart:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -86,6 +88,7 @@ func (c ErrorCode) AddTo(m *Message) error {
|
||||
Code: c,
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
return a.AddTo(m)
|
||||
}
|
||||
|
||||
@@ -108,7 +111,7 @@ const (
|
||||
|
||||
// Error codes from RFC 5766.
|
||||
//
|
||||
// RFC 5766 Section 15
|
||||
// RFC 5766 Section 15.
|
||||
const (
|
||||
CodeForbidden ErrorCode = 403 // Forbidden
|
||||
CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch
|
||||
@@ -120,7 +123,7 @@ const (
|
||||
|
||||
// Error codes from RFC 6062.
|
||||
//
|
||||
// RFC 6062 Section 6.3
|
||||
// RFC 6062 Section 6.3.
|
||||
const (
|
||||
CodeConnAlreadyExists ErrorCode = 446
|
||||
CodeConnTimeoutOrFailure ErrorCode = 447
|
||||
@@ -128,7 +131,7 @@ const (
|
||||
|
||||
// Error codes from RFC 6156.
|
||||
//
|
||||
// RFC 6156 Section 10.2
|
||||
// RFC 6156 Section 10.2.
|
||||
const (
|
||||
CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported
|
||||
CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch
|
||||
|
@@ -92,23 +92,23 @@ func TestMessage_AddErrorCode(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestErrorCode(t *testing.T) {
|
||||
a := &ErrorCodeAttribute{
|
||||
attr := &ErrorCodeAttribute{
|
||||
Code: 404,
|
||||
Reason: []byte("not found!"),
|
||||
}
|
||||
if a.String() != "404: not found!" {
|
||||
t.Error("bad string", a)
|
||||
if attr.String() != "404: not found!" {
|
||||
t.Error("bad string", attr)
|
||||
}
|
||||
m := New()
|
||||
cod := ErrorCode(666)
|
||||
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) {
|
||||
t.Error("should be ErrNoDefaultReason", err)
|
||||
}
|
||||
if err := a.GetFrom(m); err == nil {
|
||||
if err := attr.GetFrom(m); err == nil {
|
||||
t.Error("attr should not be in message")
|
||||
}
|
||||
a.Reason = make([]byte, 2048)
|
||||
if err := a.AddTo(m); err == nil {
|
||||
attr.Reason = make([]byte, 2048)
|
||||
if err := attr.AddTo(m); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
// FingerprintAttr represents FINGERPRINT attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.5
|
||||
// RFC 5389 Section 15.5.
|
||||
type FingerprintAttr struct{}
|
||||
|
||||
// ErrFingerprintMismatch means that computed fingerprint differs from expected.
|
||||
@@ -50,6 +50,7 @@ func (FingerprintAttr) AddTo(m *Message) error {
|
||||
bin.PutUint32(b, val)
|
||||
m.Length = l
|
||||
m.Add(AttrFingerprint, b)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -66,5 +67,6 @@ func (FingerprintAttr) Check(m *Message) error {
|
||||
val := bin.Uint32(b)
|
||||
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
|
||||
expected := FingerprintValue(m.Raw[:attrStart])
|
||||
|
||||
return checkFingerprint(val, expected)
|
||||
}
|
||||
|
@@ -13,20 +13,20 @@ import (
|
||||
|
||||
func BenchmarkFingerprint_AddTo(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
msg := new(Message)
|
||||
s := NewSoftware("software")
|
||||
addr := &XORMappedAddress{
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
addAttr(b, m, addr)
|
||||
addAttr(b, m, s)
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
addAttr(b, msg, addr)
|
||||
addAttr(b, msg, s)
|
||||
b.SetBytes(int64(len(msg.Raw)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
Fingerprint.AddTo(m) //nolint:errcheck,gosec
|
||||
m.WriteLength()
|
||||
m.Length -= attributeHeaderSize + fingerprintSize
|
||||
m.Raw = m.Raw[:m.Length+messageHeaderSize]
|
||||
m.Attributes = m.Attributes[:len(m.Attributes)-1]
|
||||
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
|
||||
msg.WriteLength()
|
||||
msg.Length -= attributeHeaderSize + fingerprintSize
|
||||
msg.Raw = msg.Raw[:msg.Length+messageHeaderSize]
|
||||
msg.Attributes = msg.Attributes[:len(msg.Attributes)-1]
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -109,6 +109,7 @@ func FuzzSetters(f *testing.F) {
|
||||
if !IsAttrSizeOverflow(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,5 +151,6 @@ func (a attributes) pick(v byte) struct {
|
||||
t AttrType
|
||||
} {
|
||||
idx := int(v) % len(a)
|
||||
|
||||
return a[idx]
|
||||
}
|
||||
|
@@ -44,6 +44,7 @@ func (m *Message) Build(setters ...Setter) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,6 +55,7 @@ func (m *Message) Check(checkers ...Checker) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,6 +66,7 @@ func (m *Message) Parse(getters ...Getter) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -73,6 +76,7 @@ func MustBuild(setters ...Setter) *Message {
|
||||
if err != nil {
|
||||
panic(err) //nolint
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -82,6 +86,7 @@ func Build(setters ...Setter) (*Message, error) {
|
||||
if err := m.Build(setters...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -105,5 +110,6 @@ func (m *Message) ForEach(t AttrType, f func(m *Message) error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func BenchmarkBuildOverhead(b *testing.B) {
|
||||
var (
|
||||
t = BindingRequest
|
||||
msgType = BindingRequest
|
||||
username = NewUsername("username")
|
||||
nonce = NewNonce("nonce")
|
||||
realm = NewRealm("example.org")
|
||||
@@ -22,14 +22,14 @@ func BenchmarkBuildOverhead(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.Build(&t, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec
|
||||
m.Build(&msgType, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec
|
||||
}
|
||||
})
|
||||
b.Run("BuildNonPointer", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.Build(t, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec
|
||||
m.Build(msgType, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec
|
||||
}
|
||||
})
|
||||
b.Run("Raw", func(b *testing.B) {
|
||||
@@ -38,7 +38,7 @@ func BenchmarkBuildOverhead(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.Reset()
|
||||
m.WriteHeader()
|
||||
m.SetType(t)
|
||||
m.SetType(msgType)
|
||||
username.AddTo(m) //nolint:errcheck,gosec
|
||||
nonce.AddTo(m) //nolint:errcheck,gosec
|
||||
realm.AddTo(m) //nolint:errcheck,gosec
|
||||
@@ -52,7 +52,7 @@ func TestMessage_Apply(t *testing.T) {
|
||||
integrity = NewShortTermIntegrity("password")
|
||||
decoded = new(Message)
|
||||
)
|
||||
m, err := Build(BindingRequest, TransactionID,
|
||||
msg, err := Build(BindingRequest, TransactionID,
|
||||
NewUsername("username"),
|
||||
NewNonce("nonce"),
|
||||
NewRealm("example.org"),
|
||||
@@ -62,13 +62,13 @@ func TestMessage_Apply(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal("failed to build:", err)
|
||||
}
|
||||
if err = m.Check(Fingerprint, integrity); err != nil {
|
||||
if err = msg.Check(Fingerprint, integrity); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := decoded.Write(m.Raw); err != nil {
|
||||
if _, err := decoded.Write(msg.Raw); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !decoded.Equal(m) {
|
||||
if !decoded.Equal(msg) {
|
||||
t.Error("not equal")
|
||||
}
|
||||
if err := integrity.Check(decoded); err != nil {
|
||||
@@ -96,32 +96,32 @@ func (e errReturner) GetFrom(*Message) error {
|
||||
|
||||
func TestHelpersErrorHandling(t *testing.T) {
|
||||
m := New()
|
||||
e := errReturner{Err: errTError}
|
||||
if err := m.Build(e); !errors.Is(err, e.Err) {
|
||||
t.Error(err, "!=", e.Err)
|
||||
errReturn := errReturner{Err: errTError}
|
||||
if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) {
|
||||
t.Error(err, "!=", errReturn.Err)
|
||||
}
|
||||
if err := m.Check(e); !errors.Is(err, e.Err) {
|
||||
t.Error(err, "!=", e.Err)
|
||||
if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) {
|
||||
t.Error(err, "!=", errReturn.Err)
|
||||
}
|
||||
if err := m.Parse(e); !errors.Is(err, e.Err) {
|
||||
t.Error(err, "!=", e.Err)
|
||||
if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) {
|
||||
t.Error(err, "!=", errReturn.Err)
|
||||
}
|
||||
t.Run("MustBuild", func(t *testing.T) {
|
||||
t.Run("Positive", func(*testing.T) {
|
||||
MustBuild(NewTransactionIDSetter(transactionID{}))
|
||||
})
|
||||
defer func() {
|
||||
if p, ok := recover().(error); !ok || !errors.Is(p, e.Err) {
|
||||
if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) {
|
||||
t.Errorf("%s != %s",
|
||||
p, e.Err,
|
||||
p, errReturn.Err,
|
||||
)
|
||||
}
|
||||
}()
|
||||
MustBuild(e)
|
||||
MustBuild(errReturn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessage_ForEach(t *testing.T) {
|
||||
func TestMessage_ForEach(t *testing.T) { //nolint:cyclop
|
||||
initial := New()
|
||||
if err := initial.Build(
|
||||
NewRealm("realm1"), NewRealm("realm2"),
|
||||
@@ -135,6 +135,7 @@ func TestMessage_ForEach(t *testing.T) {
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
t.Run("NoResults", func(t *testing.T) {
|
||||
@@ -144,6 +145,7 @@ func TestMessage_ForEach(t *testing.T) {
|
||||
}
|
||||
if err := m.ForEach(AttrUsername, func(*Message) error {
|
||||
t.Error("should not be called")
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -160,6 +162,7 @@ func TestMessage_ForEach(t *testing.T) {
|
||||
t.Error("called multiple times")
|
||||
}
|
||||
calls++
|
||||
|
||||
return ErrAttributeNotFound
|
||||
}); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Fatal(err)
|
||||
@@ -169,14 +172,15 @@ func TestMessage_ForEach(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("Positive", func(t *testing.T) {
|
||||
m := newMessage()
|
||||
msg := newMessage()
|
||||
var realms []string
|
||||
if err := m.ForEach(AttrRealm, func(m *Message) error {
|
||||
if err := msg.ForEach(AttrRealm, func(m *Message) error {
|
||||
var realm Realm
|
||||
if err := realm.GetFrom(m); err != nil {
|
||||
return err
|
||||
}
|
||||
realms = append(realms, realm.String())
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -190,18 +194,18 @@ func TestMessage_ForEach(t *testing.T) {
|
||||
if realms[1] != "realm2" {
|
||||
t.Error("bad value for 2 realm")
|
||||
}
|
||||
if !m.Equal(initial) {
|
||||
if !msg.Equal(initial) {
|
||||
t.Error("m should be equal to initial")
|
||||
}
|
||||
t.Run("ZeroAlloc", func(t *testing.T) {
|
||||
m = newMessage()
|
||||
msg = newMessage()
|
||||
var realm Realm
|
||||
testutil.ShouldNotAllocate(t, func() {
|
||||
if err := m.ForEach(AttrRealm, realm.GetFrom); err != nil {
|
||||
if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
if !m.Equal(initial) {
|
||||
if !msg.Equal(initial) {
|
||||
t.Error("m should be equal to initial")
|
||||
}
|
||||
})
|
||||
@@ -216,6 +220,7 @@ func ExampleMessage_ForEach() {
|
||||
return err
|
||||
}
|
||||
fmt.Println(r)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
fmt.Println("error:", err)
|
||||
|
32
iana_test.go
32
iana_test.go
@@ -14,22 +14,24 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func loadCSV(t testing.TB, name string) [][]string {
|
||||
t.Helper()
|
||||
data := loadData(t, name)
|
||||
func loadCSV(tb testing.TB, name string) [][]string {
|
||||
tb.Helper()
|
||||
|
||||
data := loadData(tb, name)
|
||||
r := csv.NewReader(bytes.NewReader(data))
|
||||
r.Comment = '#'
|
||||
records, err := r.ReadAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
func TestIANA(t *testing.T) {
|
||||
func TestIANA(t *testing.T) { //nolint:cyclop
|
||||
t.Run("Methods", func(t *testing.T) {
|
||||
records := loadCSV(t, "stun-parameters-2.csv")
|
||||
m := make(map[string]Method)
|
||||
methods := make(map[string]Method)
|
||||
for _, r := range records[1:] {
|
||||
var (
|
||||
v = r[0]
|
||||
@@ -43,10 +45,10 @@ func TestIANA(t *testing.T) {
|
||||
t.Fatal(parseErr)
|
||||
}
|
||||
t.Logf("value: 0x%x, name: %s", val, name)
|
||||
m[name] = Method(val)
|
||||
methods[name] = Method(val) //nolint:gosec // G115
|
||||
}
|
||||
for val, name := range methodName() {
|
||||
mapped, ok := m[name]
|
||||
mapped, ok := methods[name]
|
||||
if !ok {
|
||||
t.Errorf("failed to find method %s in IANA", name)
|
||||
}
|
||||
@@ -57,7 +59,7 @@ func TestIANA(t *testing.T) {
|
||||
})
|
||||
t.Run("Attributes", func(t *testing.T) {
|
||||
records := loadCSV(t, "stun-parameters-4.csv")
|
||||
m := map[string]AttrType{}
|
||||
attrTypes := map[string]AttrType{}
|
||||
for _, r := range records[1:] {
|
||||
var (
|
||||
v = r[0]
|
||||
@@ -71,16 +73,16 @@ func TestIANA(t *testing.T) {
|
||||
t.Fatal(parseErr)
|
||||
}
|
||||
t.Logf("value: 0x%x, name: %s", val, name)
|
||||
m[name] = AttrType(val)
|
||||
attrTypes[name] = AttrType(val) //nolint:gosec // G115
|
||||
}
|
||||
// Not registered in IANA.
|
||||
for k, v := range map[string]AttrType{
|
||||
"ORIGIN": 0x802F,
|
||||
} {
|
||||
m[k] = v
|
||||
attrTypes[k] = v
|
||||
}
|
||||
for val, name := range attrNames() {
|
||||
mapped, ok := m[name]
|
||||
mapped, ok := attrTypes[name]
|
||||
if !ok {
|
||||
t.Errorf("failed to find attribute %s in IANA", name)
|
||||
}
|
||||
@@ -91,7 +93,7 @@ func TestIANA(t *testing.T) {
|
||||
})
|
||||
t.Run("ErrorCodes", func(t *testing.T) {
|
||||
records := loadCSV(t, "stun-parameters-6.csv")
|
||||
m := map[string]ErrorCode{}
|
||||
errorCodes := map[string]ErrorCode{}
|
||||
for _, r := range records[1:] {
|
||||
var (
|
||||
v = r[0]
|
||||
@@ -105,11 +107,11 @@ func TestIANA(t *testing.T) {
|
||||
t.Fatal(parseErr)
|
||||
}
|
||||
t.Logf("value: 0x%x, name: %s", val, name)
|
||||
m[name] = ErrorCode(val)
|
||||
errorCodes[name] = ErrorCode(val)
|
||||
}
|
||||
for val, nameB := range errorReasons {
|
||||
name := string(nameB)
|
||||
mapped, ok := m[name]
|
||||
mapped, ok := errorCodes[name]
|
||||
if !ok {
|
||||
t.Errorf("failed to find error code %s in IANA", name)
|
||||
}
|
||||
|
46
integrity.go
46
integrity.go
@@ -22,6 +22,7 @@ func NewLongTermIntegrity(username, realm, password string) MessageIntegrity {
|
||||
k := strings.Join([]string{username, realm, password}, credentialsSep)
|
||||
h := md5.New() //nolint:gosec
|
||||
fmt.Fprint(h, k) //nolint:errcheck
|
||||
|
||||
return MessageIntegrity(h.Sum(nil))
|
||||
}
|
||||
|
||||
@@ -36,13 +37,14 @@ func NewShortTermIntegrity(password string) MessageIntegrity {
|
||||
// AddTo and Check methods are using zero-allocation version of hmac, see
|
||||
// newHMAC function and internal/hmac/pool.go.
|
||||
//
|
||||
// RFC 5389 Section 15.4
|
||||
// RFC 5389 Section 15.4.
|
||||
type MessageIntegrity []byte
|
||||
|
||||
func newHMAC(key, message, buf []byte) []byte {
|
||||
mac := hmac.AcquireSHA1(key)
|
||||
writeOrPanic(mac, message)
|
||||
defer hmac.PutSHA1(mac)
|
||||
|
||||
return mac.Sum(buf)
|
||||
}
|
||||
|
||||
@@ -59,8 +61,8 @@ var ErrFingerprintBeforeIntegrity = errors.New("FINGERPRINT before MESSAGE-INTEG
|
||||
// AddTo adds MESSAGE-INTEGRITY attribute to message.
|
||||
//
|
||||
// CPU costly, see BenchmarkMessageIntegrity_AddTo.
|
||||
func (i MessageIntegrity) AddTo(m *Message) error {
|
||||
for _, a := range m.Attributes {
|
||||
func (i MessageIntegrity) AddTo(msg *Message) error {
|
||||
for _, a := range msg.Attributes {
|
||||
// Message should not contain FINGERPRINT attribute
|
||||
// before MESSAGE-INTEGRITY.
|
||||
if a.Type == AttrFingerprint {
|
||||
@@ -70,19 +72,20 @@ func (i MessageIntegrity) AddTo(m *Message) error {
|
||||
// The text used as input to HMAC is the STUN message,
|
||||
// including the header, up to and including the attribute preceding the
|
||||
// MESSAGE-INTEGRITY attribute.
|
||||
length := m.Length
|
||||
length := msg.Length
|
||||
// Adjusting m.Length to contain MESSAGE-INTEGRITY TLV.
|
||||
m.Length += messageIntegritySize + attributeHeaderSize
|
||||
m.WriteLength() // writing length to m.Raw
|
||||
v := newHMAC(i, m.Raw, m.Raw[len(m.Raw):]) // calculating HMAC for adjusted m.Raw
|
||||
m.Length = length // changing m.Length back
|
||||
msg.Length += messageIntegritySize + attributeHeaderSize
|
||||
msg.WriteLength() // writing length to m.Raw
|
||||
v := newHMAC(i, msg.Raw, msg.Raw[len(msg.Raw):]) // calculating HMAC for adjusted m.Raw
|
||||
msg.Length = length // changing m.Length back
|
||||
|
||||
// Copy hmac value to temporary variable to protect it from resetting
|
||||
// while processing m.Add call.
|
||||
vBuf := make([]byte, sha1.Size)
|
||||
copy(vBuf, v)
|
||||
|
||||
m.Add(AttrMessageIntegrity, vBuf)
|
||||
msg.Add(AttrMessageIntegrity, vBuf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -92,8 +95,8 @@ var ErrIntegrityMismatch = errors.New("integrity check failed")
|
||||
// Check checks MESSAGE-INTEGRITY attribute.
|
||||
//
|
||||
// CPU costly, see BenchmarkMessageIntegrity_Check.
|
||||
func (i MessageIntegrity) Check(m *Message) error {
|
||||
v, err := m.Get(AttrMessageIntegrity)
|
||||
func (i MessageIntegrity) Check(msg *Message) error {
|
||||
val, err := msg.Get(AttrMessageIntegrity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -101,11 +104,11 @@ func (i MessageIntegrity) Check(m *Message) error {
|
||||
// Adjusting length in header to match m.Raw that was
|
||||
// used when computing HMAC.
|
||||
var (
|
||||
length = m.Length
|
||||
length = msg.Length
|
||||
afterIntegrity = false
|
||||
sizeReduced int
|
||||
)
|
||||
for _, a := range m.Attributes {
|
||||
for _, a := range msg.Attributes {
|
||||
if afterIntegrity {
|
||||
sizeReduced += nearestPaddedValueLength(int(a.Length))
|
||||
sizeReduced += attributeHeaderSize
|
||||
@@ -114,13 +117,14 @@ func (i MessageIntegrity) Check(m *Message) error {
|
||||
afterIntegrity = true
|
||||
}
|
||||
}
|
||||
m.Length -= uint32(sizeReduced)
|
||||
m.WriteLength()
|
||||
msg.Length -= uint32(sizeReduced) //nolint:gosec // G115
|
||||
msg.WriteLength()
|
||||
// startOfHMAC should be first byte of integrity attribute.
|
||||
startOfHMAC := messageHeaderSize + m.Length - (attributeHeaderSize + messageIntegritySize)
|
||||
b := m.Raw[:startOfHMAC] // data before integrity attribute
|
||||
expected := newHMAC(i, b, m.Raw[len(m.Raw):])
|
||||
m.Length = length
|
||||
m.WriteLength() // writing length back
|
||||
return checkHMAC(v, expected)
|
||||
startOfHMAC := messageHeaderSize + msg.Length - (attributeHeaderSize + messageIntegritySize)
|
||||
b := msg.Raw[:startOfHMAC] // data before integrity attribute
|
||||
expected := newHMAC(i, b, msg.Raw[len(msg.Raw):])
|
||||
msg.Length = length
|
||||
msg.WriteLength() // writing length back
|
||||
|
||||
return checkHMAC(val, expected)
|
||||
}
|
||||
|
@@ -10,18 +10,18 @@ import (
|
||||
)
|
||||
|
||||
func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
|
||||
i := NewLongTermIntegrity("user", "realm", "pass")
|
||||
integrity := NewLongTermIntegrity("user", "realm", "pass")
|
||||
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(expected, i) {
|
||||
if !bytes.Equal(expected, integrity) {
|
||||
t.Error(ErrIntegrityMismatch)
|
||||
}
|
||||
t.Run("Check", func(t *testing.T) {
|
||||
m := new(Message)
|
||||
m.WriteHeader()
|
||||
if err := i.AddTo(m); err != nil {
|
||||
if err := integrity.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
|
||||
@@ -31,39 +31,39 @@ func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
|
||||
if err := dM.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := i.Check(dM); err != nil {
|
||||
if err := integrity.Check(dM); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
dM.Raw[24] += 12 // HMAC now invalid
|
||||
if i.Check(dM) == nil {
|
||||
if integrity.Check(dM) == nil {
|
||||
t.Error("should be invalid")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessageIntegrityWithFingerprint(t *testing.T) {
|
||||
m := new(Message)
|
||||
m.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
|
||||
m.WriteHeader()
|
||||
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
|
||||
i := NewShortTermIntegrity("pwd")
|
||||
if i.String() != "KEY: 0x707764" {
|
||||
t.Error("bad string", i)
|
||||
msg := new(Message)
|
||||
msg.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
|
||||
msg.WriteHeader()
|
||||
NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec
|
||||
integrity := NewShortTermIntegrity("pwd")
|
||||
if integrity.String() != "KEY: 0x707764" {
|
||||
t.Error("bad string", integrity)
|
||||
}
|
||||
if err := i.Check(m); err == nil {
|
||||
if err := integrity.Check(msg); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
if err := i.AddTo(m); err != nil {
|
||||
if err := integrity.AddTo(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := Fingerprint.AddTo(m); err != nil {
|
||||
if err := Fingerprint.AddTo(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := i.Check(m); err != nil {
|
||||
if err := integrity.Check(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.Raw[24] = 33
|
||||
if err := i.Check(m); err == nil {
|
||||
msg.Raw[24] = 33
|
||||
if err := integrity.Check(msg); err == nil {
|
||||
t.Fatal("mismatch expected")
|
||||
}
|
||||
}
|
||||
|
@@ -64,6 +64,7 @@ func (h *hmac) Sum(in []byte) []byte {
|
||||
h.outer.Write(h.opad) //nolint:errcheck,gosec
|
||||
}
|
||||
h.outer.Write(in[origLen:]) //nolint:errcheck,gosec
|
||||
|
||||
return h.outer.Sum(in[:origLen])
|
||||
}
|
||||
|
||||
@@ -79,6 +80,7 @@ func (h *hmac) Reset() {
|
||||
if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert
|
||||
panic(err) //nolint
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -22,7 +22,7 @@ type hmacTest struct {
|
||||
blocksize int
|
||||
}
|
||||
|
||||
func hmacTests() []hmacTest {
|
||||
func hmacTests() []hmacTest { //nolint:maintidx
|
||||
return []hmacTest{
|
||||
// Tests from US FIPS 198
|
||||
// https://csrc.nist.gov/publications/fips/fips198/fips-198a.pdf
|
||||
@@ -523,41 +523,42 @@ func hmacTests() []hmacTest {
|
||||
|
||||
func TestHMAC(t *testing.T) {
|
||||
for i, tt := range hmacTests() {
|
||||
h := New(tt.hash, tt.key)
|
||||
if s := h.Size(); s != tt.size {
|
||||
hsh := New(tt.hash, tt.key)
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := h.BlockSize(); b != tt.blocksize {
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
for j := 0; j < 4; j++ {
|
||||
n, err := h.Write(tt.in)
|
||||
for j := 0; j < 4; j++ { //nolint:varnamelen
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
h.Reset()
|
||||
hsh.Reset()
|
||||
|
||||
// Third and fourth iteration: make sure hmac works on
|
||||
// hashes without MarshalBinary/UnmarshalBinary
|
||||
if j == 1 {
|
||||
h = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key)
|
||||
hsh = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// justHash implements just the hash.Hash methods and nothing else
|
||||
// justHash implements just the hash.Hash methods and nothing else.
|
||||
type justHash struct {
|
||||
hash.Hash
|
||||
}
|
||||
|
@@ -40,6 +40,7 @@ func (h *hmac) resetTo(key []byte) {
|
||||
var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals
|
||||
New: func() interface{} {
|
||||
h := New(sha1.New, make([]byte, sha1.BlockSize))
|
||||
|
||||
return h
|
||||
},
|
||||
}
|
||||
@@ -49,6 +50,7 @@ func AcquireSHA1(key []byte) hash.Hash {
|
||||
h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert
|
||||
assertHMACSize(h, sha1.Size, sha1.BlockSize)
|
||||
h.resetTo(key)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -62,6 +64,7 @@ func PutSHA1(h hash.Hash) {
|
||||
var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals
|
||||
New: func() interface{} {
|
||||
h := New(sha256.New, make([]byte, sha256.BlockSize))
|
||||
|
||||
return h
|
||||
},
|
||||
}
|
||||
@@ -71,6 +74,7 @@ func AcquireSHA256(key []byte) hash.Hash {
|
||||
h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert
|
||||
assertHMACSize(h, sha256.Size, sha256.BlockSize)
|
||||
h.resetTo(key)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
|
@@ -42,100 +42,103 @@ func BenchmarkHMACSHA1_512_Pool(b *testing.B) {
|
||||
|
||||
func TestHMACReset(t *testing.T) {
|
||||
for i, tt := range hmacTests() {
|
||||
h := New(tt.hash, tt.key)
|
||||
h.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
|
||||
if s := h.Size(); s != tt.size {
|
||||
hsh := New(tt.hash, tt.key)
|
||||
hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := h.BlockSize(); b != tt.blocksize {
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
for j := 0; j < 2; j++ {
|
||||
n, err := h.Write(tt.in)
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
h.Reset()
|
||||
hsh.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl
|
||||
func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop
|
||||
for i, tt := range hmacTests() {
|
||||
if tt.blocksize != sha1.BlockSize || tt.size != sha1.Size {
|
||||
continue
|
||||
}
|
||||
h := AcquireSHA1(tt.key)
|
||||
if s := h.Size(); s != tt.size {
|
||||
hsh := AcquireSHA1(tt.key)
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := h.BlockSize(); b != tt.blocksize {
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
for j := 0; j < 2; j++ {
|
||||
n, err := h.Write(tt.in)
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
h.Reset()
|
||||
hsh.Reset()
|
||||
}
|
||||
PutSHA1(h)
|
||||
PutSHA1(hsh)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl
|
||||
func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop
|
||||
for i, tt := range hmacTests() {
|
||||
if tt.blocksize != sha256.BlockSize || tt.size != sha256.Size {
|
||||
continue
|
||||
}
|
||||
h := AcquireSHA256(tt.key)
|
||||
if s := h.Size(); s != tt.size {
|
||||
hsh := AcquireSHA256(tt.key)
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := h.BlockSize(); b != tt.blocksize {
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
for j := 0; j < 2; j++ {
|
||||
n, err := h.Write(tt.in)
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", h.Sum(nil))
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
h.Reset()
|
||||
hsh.Reset()
|
||||
}
|
||||
PutSHA256(h)
|
||||
PutSHA256(hsh)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -10,8 +10,11 @@ import (
|
||||
|
||||
// ShouldNotAllocate fails if f allocates.
|
||||
func ShouldNotAllocate(t *testing.T, f func()) {
|
||||
t.Helper()
|
||||
|
||||
if Race {
|
||||
t.Skip("skip while running with -race")
|
||||
|
||||
return
|
||||
}
|
||||
if a := testing.AllocsPerRun(10, f); a > 0 {
|
||||
|
115
message.go
115
message.go
@@ -32,6 +32,7 @@ const (
|
||||
// as source.
|
||||
func NewTransactionID() (b [TransactionIDSize]byte) {
|
||||
readFullOrPanic(rand.Reader, b[:])
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -45,6 +46,7 @@ func IsMessage(b []byte) bool {
|
||||
// New returns *Message with pre-allocated Raw.
|
||||
func New() *Message {
|
||||
const defaultRawCapacity = 120
|
||||
|
||||
return &Message{
|
||||
Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
|
||||
}
|
||||
@@ -59,6 +61,7 @@ func Decode(data []byte, m *Message) error {
|
||||
return ErrDecodeToNil
|
||||
}
|
||||
m.Raw = append(m.Raw[:0], data...)
|
||||
|
||||
return m.Decode()
|
||||
}
|
||||
|
||||
@@ -82,6 +85,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
|
||||
// contract induced by other implementations.
|
||||
b := make([]byte, len(m.Raw))
|
||||
copy(b, m.Raw)
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
@@ -89,6 +93,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
|
||||
func (m *Message) UnmarshalBinary(data []byte) error {
|
||||
// We can't retain data, copy is expected by interface contract.
|
||||
m.Raw = append(m.Raw[:0], data...)
|
||||
|
||||
return m.Decode()
|
||||
}
|
||||
|
||||
@@ -108,6 +113,7 @@ func (m *Message) GobDecode(data []byte) error {
|
||||
func (m *Message) AddTo(b *Message) error {
|
||||
b.TransactionID = m.TransactionID
|
||||
b.WriteTransactionID()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -118,6 +124,7 @@ func (m *Message) NewTransactionID() error {
|
||||
if err == nil {
|
||||
m.WriteTransactionID()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -127,6 +134,7 @@ func (m *Message) String() string {
|
||||
for k, a := range m.Attributes {
|
||||
aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo)
|
||||
}
|
||||
|
||||
@@ -144,6 +152,7 @@ func (m *Message) grow(n int) {
|
||||
}
|
||||
if cap(m.Raw) >= n {
|
||||
m.Raw = m.Raw[:n]
|
||||
|
||||
return
|
||||
}
|
||||
m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
|
||||
@@ -153,7 +162,7 @@ func (m *Message) grow(n int) {
|
||||
//
|
||||
// Value of attribute is copied to internal buffer so
|
||||
// it is safe to reuse v.
|
||||
func (m *Message) Add(t AttrType, v []byte) {
|
||||
func (m *Message) Add(attrType AttrType, val []byte) {
|
||||
// Allocating buffer for TLV (type-length-value).
|
||||
// T = t, L = len(v), V = v.
|
||||
// m.Raw will look like:
|
||||
@@ -163,31 +172,33 @@ func (m *Message) Add(t AttrType, v []byte) {
|
||||
// [first:last] <- same as previous
|
||||
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
|
||||
// T L V
|
||||
allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V)
|
||||
first := messageHeaderSize + int(m.Length) // first byte number
|
||||
last := first + allocSize // last byte number
|
||||
m.grow(last) // growing cap(Raw) to fit TLV
|
||||
m.Raw = m.Raw[:last] // now len(Raw) = last
|
||||
m.Length += uint32(allocSize) // rendering length change
|
||||
allocSize := attributeHeaderSize + len(val) // ~ len(TLV) = len(TL) + len(V)
|
||||
first := messageHeaderSize + int(m.Length) // first byte number
|
||||
last := first + allocSize // last byte number
|
||||
m.grow(last) // growing cap(Raw) to fit TLV
|
||||
m.Raw = m.Raw[:last] // now len(Raw) = last
|
||||
//nolint:gosec // G115
|
||||
m.Length += uint32(allocSize) // rendering length change
|
||||
|
||||
// Sub-slicing internal buffer to simplify encoding.
|
||||
buf := m.Raw[first:last] // slice for TLV
|
||||
value := buf[attributeHeaderSize:] // slice for V
|
||||
attr := RawAttribute{
|
||||
Type: t, // T
|
||||
Length: uint16(len(v)), // L
|
||||
Value: value, // V
|
||||
Type: attrType, // T
|
||||
//nolint:gosec // G115
|
||||
Length: uint16(len(val)), // L
|
||||
Value: value, // V
|
||||
}
|
||||
|
||||
// Encoding attribute TLV to allocated buffer.
|
||||
bin.PutUint16(buf[0:2], attr.Type.Value()) // T
|
||||
bin.PutUint16(buf[2:4], attr.Length) // L
|
||||
copy(value, v) // V
|
||||
copy(value, val) // V
|
||||
|
||||
// Checking that attribute value needs padding.
|
||||
if attr.Length%padding != 0 {
|
||||
// Performing padding.
|
||||
bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
|
||||
bytesToAdd := nearestPaddedValueLength(len(val)) - len(val)
|
||||
last += bytesToAdd
|
||||
m.grow(last)
|
||||
// setting all padding bytes to zero
|
||||
@@ -197,7 +208,8 @@ func (m *Message) Add(t AttrType, v []byte) {
|
||||
for i := range buf {
|
||||
buf[i] = 0
|
||||
}
|
||||
m.Raw = m.Raw[:last] // increasing buffer length
|
||||
m.Raw = m.Raw[:last] // increasing buffer length
|
||||
//nolint:gosec // G115
|
||||
m.Length += uint32(bytesToAdd) // rendering length change
|
||||
}
|
||||
m.Attributes = append(m.Attributes, attr)
|
||||
@@ -213,6 +225,7 @@ func attrSliceEqual(a, b Attributes) bool {
|
||||
}
|
||||
if attrB.Equal(attr) {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -220,56 +233,59 @@ func attrSliceEqual(a, b Attributes) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func attrEqual(a, b Attributes) bool {
|
||||
if a == nil && b == nil {
|
||||
func attrEqual(attrA, attrB Attributes) bool {
|
||||
if attrA == nil && attrB == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
if attrA == nil || attrB == nil {
|
||||
return false
|
||||
}
|
||||
if len(a) != len(b) {
|
||||
if len(attrA) != len(attrB) {
|
||||
return false
|
||||
}
|
||||
if !attrSliceEqual(a, b) {
|
||||
if !attrSliceEqual(attrA, attrB) {
|
||||
return false
|
||||
}
|
||||
if !attrSliceEqual(b, a) {
|
||||
if !attrSliceEqual(attrB, attrA) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Equal returns true if Message b equals to m.
|
||||
// Equal returns true if Message msg equals to m.
|
||||
// Ignores m.Raw.
|
||||
func (m *Message) Equal(b *Message) bool {
|
||||
if m == nil && b == nil {
|
||||
func (m *Message) Equal(msg *Message) bool {
|
||||
if m == nil && msg == nil {
|
||||
return true
|
||||
}
|
||||
if m == nil || b == nil {
|
||||
if m == nil || msg == nil {
|
||||
return false
|
||||
}
|
||||
if m.Type != b.Type {
|
||||
if m.Type != msg.Type {
|
||||
return false
|
||||
}
|
||||
if m.TransactionID != b.TransactionID {
|
||||
if m.TransactionID != msg.TransactionID {
|
||||
return false
|
||||
}
|
||||
if m.Length != b.Length {
|
||||
if m.Length != msg.Length {
|
||||
return false
|
||||
}
|
||||
if !attrEqual(m.Attributes, b.Attributes) {
|
||||
if !attrEqual(m.Attributes, msg.Attributes) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// WriteLength writes m.Length to m.Raw.
|
||||
func (m *Message) WriteLength() {
|
||||
m.grow(4)
|
||||
bin.PutUint16(m.Raw[2:4], uint16(m.Length))
|
||||
bin.PutUint16(m.Raw[2:4], uint16(m.Length)) //nolint:gosec // G115
|
||||
}
|
||||
|
||||
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
|
||||
@@ -322,6 +338,7 @@ func (m *Message) Encode() {
|
||||
// call result.
|
||||
func (m *Message) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := w.Write(m.Raw)
|
||||
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
@@ -340,6 +357,7 @@ func (m *Message) ReadFrom(r io.Reader) (int64, error) {
|
||||
return int64(n), err
|
||||
}
|
||||
m.Raw = tBuf[:n]
|
||||
|
||||
return int64(n), m.Decode()
|
||||
}
|
||||
|
||||
@@ -355,22 +373,24 @@ func (m *Message) Decode() error {
|
||||
return ErrUnexpectedHeaderEOF
|
||||
}
|
||||
var (
|
||||
t = bin.Uint16(buf[0:2]) // first 2 bytes
|
||||
msgType = bin.Uint16(buf[0:2]) // first 2 bytes
|
||||
size = int(bin.Uint16(buf[2:4])) // second 2 bytes
|
||||
cookie = bin.Uint32(buf[4:8]) // last 4 bytes
|
||||
fullSize = messageHeaderSize + size // len(m.Raw)
|
||||
)
|
||||
if cookie != magicCookie {
|
||||
msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
|
||||
|
||||
return newDecodeErr("message", "cookie", msg)
|
||||
}
|
||||
if len(buf) < fullSize {
|
||||
msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
|
||||
|
||||
return newAttrDecodeErr("message", msg)
|
||||
}
|
||||
// saving header data
|
||||
m.Type.ReadValue(t)
|
||||
m.Length = uint32(size)
|
||||
m.Type.ReadValue(msgType)
|
||||
m.Length = uint32(size) //nolint:gosec // G115
|
||||
copy(m.TransactionID[:], buf[8:messageHeaderSize])
|
||||
|
||||
m.Attributes = m.Attributes[:0]
|
||||
@@ -382,28 +402,31 @@ func (m *Message) Decode() error {
|
||||
// checking that we have enough bytes to read header
|
||||
if len(b) < attributeHeaderSize {
|
||||
msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
|
||||
|
||||
return newAttrDecodeErr("header", msg)
|
||||
}
|
||||
var (
|
||||
a = RawAttribute{
|
||||
attr = RawAttribute{
|
||||
Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
|
||||
Length: bin.Uint16(b[2:4]), // second 2 bytes
|
||||
}
|
||||
aL = int(a.Length) // attribute length
|
||||
aL = int(attr.Length) // attribute length
|
||||
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
|
||||
)
|
||||
b = b[attributeHeaderSize:] // slicing again to simplify value read
|
||||
offset += attributeHeaderSize
|
||||
if len(b) < aBuffL { // checking size
|
||||
msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type)
|
||||
msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, attr.Type)
|
||||
|
||||
return newAttrDecodeErr("value", msg)
|
||||
}
|
||||
a.Value = b[:aL]
|
||||
attr.Value = b[:aL]
|
||||
offset += aBuffL
|
||||
b = b[aBuffL:]
|
||||
|
||||
m.Attributes = append(m.Attributes, a)
|
||||
m.Attributes = append(m.Attributes, attr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -412,12 +435,14 @@ func (m *Message) Decode() error {
|
||||
// Any error is unrecoverable, but message could be partially decoded.
|
||||
func (m *Message) Write(tBuf []byte) (int, error) {
|
||||
m.Raw = append(m.Raw[:0], tBuf...)
|
||||
|
||||
return len(tBuf), m.Decode()
|
||||
}
|
||||
|
||||
// CloneTo clones m to b securing any further m mutations.
|
||||
func (m *Message) CloneTo(b *Message) error {
|
||||
b.Raw = append(b.Raw[:0], m.Raw...)
|
||||
|
||||
return b.Decode()
|
||||
}
|
||||
|
||||
@@ -436,7 +461,7 @@ const (
|
||||
var (
|
||||
// Binding request message type.
|
||||
BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
|
||||
// Binding success response message type
|
||||
// Binding success response message type.
|
||||
BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
|
||||
// Binding error response message type.
|
||||
BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
|
||||
@@ -501,6 +526,7 @@ func (m Method) String() string {
|
||||
// Falling back to hex representation.
|
||||
s = fmt.Sprintf("0x%x", uint16(m))
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -513,6 +539,7 @@ type MessageType struct {
|
||||
// AddTo sets m type to t.
|
||||
func (t MessageType) AddTo(m *Message) error {
|
||||
m.SetType(t)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -554,13 +581,13 @@ func (t MessageType) Value() uint16 {
|
||||
|
||||
// Warning: Abandon all hope ye who enter here.
|
||||
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
|
||||
m := uint16(t.Method)
|
||||
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
|
||||
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
|
||||
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
|
||||
msg := uint16(t.Method)
|
||||
a := msg & methodABits // A = M * 0b0000000000001111 (right 4 bits)
|
||||
b := msg & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
|
||||
d := msg & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
|
||||
|
||||
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
|
||||
m = a + (b << methodBShift) + (d << methodDShift)
|
||||
msg = a + (b << methodBShift) + (d << methodDShift)
|
||||
|
||||
// C0 is zero bit of C, C1 is first bit.
|
||||
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
|
||||
@@ -573,7 +600,7 @@ func (t MessageType) Value() uint16 {
|
||||
c1 := (c & c1Bit) << classC1Shift
|
||||
class := c0 + c1
|
||||
|
||||
return m + class
|
||||
return msg + class
|
||||
}
|
||||
|
||||
// ReadValue decodes uint16 into MessageType.
|
||||
@@ -604,6 +631,7 @@ func (m *Message) Contains(t AttrType) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -618,5 +646,6 @@ func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
|
||||
func (t transactionIDValueSetter) AddTo(m *Message) error {
|
||||
m.TransactionID = t
|
||||
m.WriteTransactionID()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
222
message_test.go
222
message_test.go
@@ -27,9 +27,11 @@ type attributeEncoder interface {
|
||||
AddTo(m *Message) error
|
||||
}
|
||||
|
||||
func addAttr(t testing.TB, m *Message, a attributeEncoder) {
|
||||
func addAttr(tb testing.TB, m *Message, a attributeEncoder) {
|
||||
tb.Helper()
|
||||
|
||||
if err := a.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
tb.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,21 +131,21 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMessage_WriteTo(t *testing.T) {
|
||||
m := New()
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.TransactionID = NewTransactionID()
|
||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
m.WriteHeader()
|
||||
msg := New()
|
||||
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
msg.TransactionID = NewTransactionID()
|
||||
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
msg.WriteHeader()
|
||||
buf := new(bytes.Buffer)
|
||||
if _, err := m.WriteTo(buf); err != nil {
|
||||
if _, err := msg.WriteTo(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadFrom(buf); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !mDecoded.Equal(m) {
|
||||
t.Error(mDecoded, "!", m)
|
||||
if !mDecoded.Equal(msg) {
|
||||
t.Error(mDecoded, "!", msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,23 +277,23 @@ func BenchmarkMessage_WriteTo(b *testing.B) {
|
||||
|
||||
func BenchmarkMessage_ReadFrom(b *testing.B) {
|
||||
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m := &Message{
|
||||
msg := &Message{
|
||||
Type: mType,
|
||||
Length: 0,
|
||||
TransactionID: [TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
},
|
||||
}
|
||||
m.WriteHeader()
|
||||
msg.WriteHeader()
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
reader := m.reader()
|
||||
b.SetBytes(int64(len(msg.Raw)))
|
||||
reader := msg.reader()
|
||||
mRec := New()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := mRec.ReadFrom(reader); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
reader.Reset(m.Raw)
|
||||
reader.Reset(msg.Raw)
|
||||
mRec.Reset()
|
||||
}
|
||||
}
|
||||
@@ -342,7 +344,7 @@ func TestMessageClass_String(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAttrType_String(t *testing.T) {
|
||||
v := [...]AttrType{
|
||||
attrType := [...]AttrType{
|
||||
AttrMappedAddress,
|
||||
AttrUsername,
|
||||
AttrErrorCode,
|
||||
@@ -355,7 +357,7 @@ func TestAttrType_String(t *testing.T) {
|
||||
AttrAlternateServer,
|
||||
AttrFingerprint,
|
||||
}
|
||||
for _, k := range v {
|
||||
for _, k := range attrType {
|
||||
if k.String() == "" {
|
||||
t.Error(k, "bad stringer")
|
||||
}
|
||||
@@ -379,80 +381,80 @@ func TestMethod_String(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAttribute_Equal(t *testing.T) {
|
||||
a := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||
b := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||
if !a.Equal(b) {
|
||||
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||
if !attr1.Equal(attr2) {
|
||||
t.Error("should equal")
|
||||
}
|
||||
if a.Equal(RawAttribute{Type: 0x2}) {
|
||||
if attr1.Equal(RawAttribute{Type: 0x2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if a.Equal(RawAttribute{Length: 0x2}) {
|
||||
if attr1.Equal(RawAttribute{Length: 0x2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if a.Equal(RawAttribute{Length: 0x3}) {
|
||||
if attr1.Equal(RawAttribute{Length: 0x3}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if a.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
|
||||
if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_Equal(t *testing.T) {
|
||||
func TestMessage_Equal(t *testing.T) { //nolint:cyclop
|
||||
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||
attrs := Attributes{attr}
|
||||
a := &Message{Attributes: attrs, Length: 4 + 2}
|
||||
b := &Message{Attributes: attrs, Length: 4 + 2}
|
||||
if !a.Equal(b) {
|
||||
msg1 := &Message{Attributes: attrs, Length: 4 + 2}
|
||||
msg2 := &Message{Attributes: attrs, Length: 4 + 2}
|
||||
if !msg1.Equal(msg2) {
|
||||
t.Error("should equal")
|
||||
}
|
||||
if a.Equal(&Message{Type: MessageType{Class: 128}}) {
|
||||
if msg1.Equal(&Message{Type: MessageType{Class: 128}}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
tID := [TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
}
|
||||
if a.Equal(&Message{TransactionID: tID}) {
|
||||
if msg1.Equal(&Message{TransactionID: tID}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if a.Equal(&Message{Length: 3}) {
|
||||
if msg1.Equal(&Message{Length: 3}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
tAttrs := Attributes{
|
||||
{Length: 1, Value: []byte{0x1}, Type: 0x1},
|
||||
}
|
||||
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
tAttrs = Attributes{
|
||||
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
|
||||
}
|
||||
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if !(*Message)(nil).Equal(nil) {
|
||||
t.Error("nil should be equal to nil")
|
||||
}
|
||||
if a.Equal(nil) {
|
||||
if msg1.Equal(nil) {
|
||||
t.Error("non-nil should not be equal to nil")
|
||||
}
|
||||
t.Run("Nil attributes", func(t *testing.T) {
|
||||
a := &Message{
|
||||
msg1 := &Message{
|
||||
Attributes: nil,
|
||||
Length: 4 + 2,
|
||||
}
|
||||
b := &Message{
|
||||
msg2 := &Message{
|
||||
Attributes: attrs,
|
||||
Length: 4 + 2,
|
||||
}
|
||||
if a.Equal(b) {
|
||||
if msg1.Equal(msg2) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if b.Equal(a) {
|
||||
if msg2.Equal(msg1) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
b.Attributes = nil
|
||||
if !a.Equal(b) {
|
||||
msg2.Attributes = nil
|
||||
if !msg1.Equal(msg2) {
|
||||
t.Error("should equal")
|
||||
}
|
||||
})
|
||||
@@ -547,6 +549,8 @@ func BenchmarkIsMessage(b *testing.B) {
|
||||
}
|
||||
|
||||
func loadData(tb testing.TB, name string) []byte {
|
||||
tb.Helper()
|
||||
|
||||
name = filepath.Join("testdata", name)
|
||||
f, err := os.Open(name) //nolint:gosec
|
||||
if err != nil {
|
||||
@@ -561,6 +565,7 @@ func loadData(tb testing.TB, name string) []byte {
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -582,7 +587,7 @@ func TestMessageFromBrowsers(t *testing.T) {
|
||||
t.Fatal("failed to skip header of csv: ", err)
|
||||
}
|
||||
crcTable := crc64.MakeTable(crc64.ISO)
|
||||
m := New()
|
||||
msg := New()
|
||||
for {
|
||||
line, err := reader.Read()
|
||||
if errors.Is(err, io.EOF) {
|
||||
@@ -602,10 +607,10 @@ func TestMessageFromBrowsers(t *testing.T) {
|
||||
if b != crc64.Checksum(data, crcTable) {
|
||||
t.Error("crc64 check failed for ", line[1])
|
||||
}
|
||||
if _, err = m.Write(data); err != nil {
|
||||
if _, err = msg.Write(data); err != nil {
|
||||
t.Error("failed to decode ", line[1], " as message: ", err)
|
||||
}
|
||||
m.Reset()
|
||||
msg.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -622,42 +627,42 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) {
|
||||
|
||||
func BenchmarkMessageFull(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
msg := new(Message)
|
||||
s := NewSoftware("software")
|
||||
addr := &XORMappedAddress{
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
if err := addr.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := s.AddTo(m); err != nil {
|
||||
if err := s.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
m.WriteAttributes()
|
||||
m.WriteHeader()
|
||||
Fingerprint.AddTo(m) //nolint:errcheck,gosec
|
||||
m.WriteHeader()
|
||||
m.Reset()
|
||||
msg.WriteAttributes()
|
||||
msg.WriteHeader()
|
||||
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
|
||||
msg.WriteHeader()
|
||||
msg.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessageFullHardcore(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
msg := new(Message)
|
||||
s := NewSoftware("software")
|
||||
addr := &XORMappedAddress{
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
if err := addr.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := s.AddTo(m); err != nil {
|
||||
if err := s.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
m.WriteHeader()
|
||||
m.Reset()
|
||||
msg.WriteHeader()
|
||||
msg.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -689,8 +694,8 @@ func TestMessage_Contains(t *testing.T) {
|
||||
|
||||
func ExampleMessage() {
|
||||
buf := new(bytes.Buffer)
|
||||
m := new(Message)
|
||||
m.Build(BindingRequest, //nolint:errcheck,gosec
|
||||
msg := new(Message)
|
||||
msg.Build(BindingRequest, //nolint:errcheck,gosec
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
@@ -707,8 +712,8 @@ func ExampleMessage() {
|
||||
// m.Build(&software) // no allocations
|
||||
// If you pass software as value, there will be 1 allocation.
|
||||
// This rule is correct for all setters.
|
||||
fmt.Println(m, "buff length:", len(m.Raw))
|
||||
n, err := m.WriteTo(buf)
|
||||
fmt.Println(msg, "buff length:", len(msg.Raw))
|
||||
n, err := msg.WriteTo(buf)
|
||||
fmt.Println("wrote", n, "err", err)
|
||||
|
||||
// Decoding from buf new *Message.
|
||||
@@ -743,6 +748,7 @@ func ExampleMessage() {
|
||||
fmt.Println("fingerprint: failed")
|
||||
}
|
||||
|
||||
//nolint:lll
|
||||
// Output:
|
||||
// Binding request l=48 attrs=3 id=AQIDBAUGBwgJAAEA, attr0=SOFTWARE attr1=MESSAGE-INTEGRITY attr2=FINGERPRINT buff length: 68
|
||||
// wrote 68 err <nil>
|
||||
@@ -811,8 +817,8 @@ func TestAllocationsGetters(t *testing.T) {
|
||||
NewShortTermIntegrity("pwd"),
|
||||
Fingerprint,
|
||||
}
|
||||
m := New()
|
||||
if err := m.Build(setters...); err != nil {
|
||||
msg := New()
|
||||
if err := msg.Build(setters...); err != nil {
|
||||
t.Error("failed to build", err)
|
||||
}
|
||||
getters := []Getter{
|
||||
@@ -826,7 +832,7 @@ func TestAllocationsGetters(t *testing.T) {
|
||||
g := g
|
||||
i := i
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
if err := g.GetFrom(m); err != nil {
|
||||
if err := g.GetFrom(msg); err != nil {
|
||||
t.Errorf("[%d] failed to get", i)
|
||||
}
|
||||
})
|
||||
@@ -837,8 +843,8 @@ func TestAllocationsGetters(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMessageFullSize(t *testing.T) {
|
||||
m := new(Message)
|
||||
if err := m.Build(BindingRequest,
|
||||
msg := new(Message)
|
||||
if err := msg.Build(BindingRequest,
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
@@ -848,18 +854,18 @@ func TestMessageFullSize(t *testing.T) {
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.Raw = m.Raw[:len(m.Raw)-10]
|
||||
msg.Raw = msg.Raw[:len(msg.Raw)-10]
|
||||
|
||||
decoder := new(Message)
|
||||
decoder.Raw = m.Raw[:len(m.Raw)-10]
|
||||
decoder.Raw = msg.Raw[:len(msg.Raw)-10]
|
||||
if err := decoder.Decode(); err == nil {
|
||||
t.Error("decode on truncated buffer should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_CloneTo(t *testing.T) {
|
||||
m := new(Message)
|
||||
if err := m.Build(BindingRequest,
|
||||
msg := new(Message)
|
||||
if err := msg.Build(BindingRequest,
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
@@ -869,29 +875,29 @@ func TestMessage_CloneTo(t *testing.T) {
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.Encode()
|
||||
b := new(Message)
|
||||
if err := m.CloneTo(b); err != nil {
|
||||
msg.Encode()
|
||||
msg2 := new(Message)
|
||||
if err := msg.CloneTo(msg2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !b.Equal(m) {
|
||||
if !msg2.Equal(msg) {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
// Corrupting m and checking that b is not corrupted.
|
||||
s, ok := b.Attributes.Get(AttrSoftware)
|
||||
s, ok := msg2.Attributes.Get(AttrSoftware)
|
||||
if !ok {
|
||||
t.Fatal("no software attribute")
|
||||
}
|
||||
s.Value[0] = 'k'
|
||||
if b.Equal(m) {
|
||||
if msg2.Equal(msg) {
|
||||
t.Fatal("should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessage_CloneTo(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
if err := m.Build(BindingRequest,
|
||||
msg := new(Message)
|
||||
if err := msg.Build(BindingRequest,
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
@@ -901,19 +907,19 @@ func BenchmarkMessage_CloneTo(b *testing.B) {
|
||||
); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
b.SetBytes(int64(len(msg.Raw)))
|
||||
a := new(Message)
|
||||
m.CloneTo(a) //nolint:errcheck,gosec
|
||||
msg.CloneTo(a) //nolint:errcheck,gosec
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := m.CloneTo(a); err != nil {
|
||||
if err := msg.CloneTo(a); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_AddTo(t *testing.T) {
|
||||
m := new(Message)
|
||||
if err := m.Build(BindingRequest,
|
||||
msg := new(Message)
|
||||
if err := msg.Build(BindingRequest,
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
@@ -921,19 +927,19 @@ func TestMessage_AddTo(t *testing.T) {
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.Encode()
|
||||
msg.Encode()
|
||||
b := new(Message)
|
||||
if err := m.CloneTo(b); err != nil {
|
||||
if err := msg.CloneTo(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.TransactionID = [TransactionIDSize]byte{
|
||||
msg.TransactionID = [TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
|
||||
}
|
||||
if b.Equal(m) {
|
||||
if b.Equal(msg) {
|
||||
t.Fatal("should not be equal")
|
||||
}
|
||||
m.AddTo(b) //nolint:errcheck,gosec
|
||||
if !b.Equal(m) {
|
||||
msg.AddTo(b) //nolint:errcheck,gosec
|
||||
if !b.Equal(msg) {
|
||||
t.Fatal("should be equal")
|
||||
}
|
||||
}
|
||||
@@ -964,22 +970,22 @@ func TestDecode(t *testing.T) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
m := New()
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.TransactionID = NewTransactionID()
|
||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
m.WriteHeader()
|
||||
msg := New()
|
||||
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
msg.TransactionID = NewTransactionID()
|
||||
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
msg.WriteHeader()
|
||||
mDecoded := New()
|
||||
if err := Decode(m.Raw, mDecoded); err != nil {
|
||||
if err := Decode(msg.Raw, mDecoded); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !mDecoded.Equal(m) {
|
||||
if !mDecoded.Equal(msg) {
|
||||
t.Error("decoded result is not equal to encoded message")
|
||||
}
|
||||
t.Run("ZeroAlloc", func(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
mDecoded.Reset()
|
||||
if err := Decode(m.Raw, mDecoded); err != nil {
|
||||
if err := Decode(msg.Raw, mDecoded); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
@@ -1008,22 +1014,22 @@ func BenchmarkDecode(b *testing.B) {
|
||||
}
|
||||
|
||||
func TestMessage_MarshalBinary(t *testing.T) {
|
||||
m := MustBuild(
|
||||
msg := MustBuild(
|
||||
NewSoftware("software"),
|
||||
&XORMappedAddress{
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
},
|
||||
)
|
||||
data, err := m.MarshalBinary()
|
||||
data, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reset m.Raw to check retention.
|
||||
for i := range m.Raw {
|
||||
m.Raw[i] = 0
|
||||
for i := range msg.Raw {
|
||||
msg.Raw[i] = 0
|
||||
}
|
||||
if err := m.UnmarshalBinary(data); err != nil {
|
||||
if err := msg.UnmarshalBinary(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1031,28 +1037,28 @@ func TestMessage_MarshalBinary(t *testing.T) {
|
||||
for i := range data {
|
||||
data[i] = 0
|
||||
}
|
||||
if err := m.Decode(); err != nil {
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_GobDecode(t *testing.T) {
|
||||
m := MustBuild(
|
||||
msg := MustBuild(
|
||||
NewSoftware("software"),
|
||||
&XORMappedAddress{
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
},
|
||||
)
|
||||
data, err := m.GobEncode()
|
||||
data, err := msg.GobEncode()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reset m.Raw to check retention.
|
||||
for i := range m.Raw {
|
||||
m.Raw[i] = 0
|
||||
for i := range msg.Raw {
|
||||
msg.Raw[i] = 0
|
||||
}
|
||||
if err := m.GobDecode(data); err != nil {
|
||||
if err := msg.GobDecode(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1060,7 +1066,7 @@ func TestMessage_GobDecode(t *testing.T) {
|
||||
for i := range data {
|
||||
data[i] = 0
|
||||
}
|
||||
if err := m.Decode(); err != nil {
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRFC5769(t *testing.T) {
|
||||
func TestRFC5769(t *testing.T) { //nolint:cyclop
|
||||
// Test Vectors for Session Traversal Utilities for NAT (STUN)
|
||||
// see https://tools.ietf.org/html/rfc5769
|
||||
t.Run("Request", func(t *testing.T) {
|
||||
@@ -46,7 +46,7 @@ func TestRFC5769(t *testing.T) {
|
||||
t.Error("check failed: ", err)
|
||||
}
|
||||
t.Run("Long-Term credentials", func(t *testing.T) {
|
||||
m := &Message{
|
||||
msg := &Message{
|
||||
Raw: []byte("\x00\x01\x00\x60" +
|
||||
"\x21\x12\xa4\x42" +
|
||||
"\x78\xad\x34\x33\xc6\xad\x72\xc0\x29\xda\x41\x2e" +
|
||||
@@ -64,11 +64,11 @@ func TestRFC5769(t *testing.T) {
|
||||
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
|
||||
),
|
||||
}
|
||||
if err := m.Decode(); err != nil {
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
u := new(Username)
|
||||
if err := u.GetFrom(m); err != nil {
|
||||
if err := u.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
|
||||
@@ -76,14 +76,14 @@ func TestRFC5769(t *testing.T) {
|
||||
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
|
||||
}
|
||||
n := new(Nonce)
|
||||
if err := n.GetFrom(m); err != nil {
|
||||
if err := n.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
|
||||
t.Error("bad nonce")
|
||||
}
|
||||
r := new(Realm)
|
||||
if err := r.GetFrom(m); err != nil {
|
||||
if err := r.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if r.String() != "example.org" { //nolint:goconst
|
||||
@@ -95,14 +95,14 @@ func TestRFC5769(t *testing.T) {
|
||||
"example.org",
|
||||
"TheMatrIX",
|
||||
)
|
||||
if err := i.Check(m); err != nil {
|
||||
if err := i.Check(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("Response", func(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
m := &Message{
|
||||
msg := &Message{
|
||||
Raw: []byte("\x01\x01\x00\x3c" +
|
||||
"\x21\x12\xa4\x42" +
|
||||
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
|
||||
@@ -117,21 +117,21 @@ func TestRFC5769(t *testing.T) {
|
||||
"\xc0\x7d\x4c\x96",
|
||||
),
|
||||
}
|
||||
if err := m.Decode(); err != nil {
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
software := new(Software)
|
||||
if err := software.GetFrom(m); err != nil {
|
||||
if err := software.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if software.String() != "test vector" {
|
||||
t.Error("bad software: ", software)
|
||||
}
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("Check failed: ", err)
|
||||
}
|
||||
addr := new(XORMappedAddress)
|
||||
if err := addr.GetFrom(m); err != nil {
|
||||
if err := addr.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) {
|
||||
@@ -140,12 +140,12 @@ func TestRFC5769(t *testing.T) {
|
||||
if addr.Port != 32853 {
|
||||
t.Error("bad Port")
|
||||
}
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("check failed: ", err)
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
m := &Message{
|
||||
msg := &Message{
|
||||
Raw: []byte("\x01\x01\x00\x48" +
|
||||
"\x21\x12\xa4\x42" +
|
||||
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
|
||||
@@ -162,21 +162,21 @@ func TestRFC5769(t *testing.T) {
|
||||
"\xc8\xfb\x0b\x4c",
|
||||
),
|
||||
}
|
||||
if err := m.Decode(); err != nil {
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
software := new(Software)
|
||||
if err := software.GetFrom(m); err != nil {
|
||||
if err := software.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if software.String() != "test vector" {
|
||||
t.Error("bad software: ", software)
|
||||
}
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("Check failed: ", err)
|
||||
}
|
||||
addr := new(XORMappedAddress)
|
||||
if err := addr.GetFrom(m); err != nil {
|
||||
if err := addr.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) {
|
||||
@@ -185,7 +185,7 @@ func TestRFC5769(t *testing.T) {
|
||||
if addr.Port != 32853 {
|
||||
t.Error("bad Port")
|
||||
}
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("check failed: ", err)
|
||||
}
|
||||
})
|
||||
|
2
stun.go
2
stun.go
@@ -27,6 +27,7 @@ func readFullOrPanic(r io.Reader, v []byte) int {
|
||||
if err != nil {
|
||||
panic(err) //nolint
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -35,6 +36,7 @@ func writeOrPanic(w io.Writer, v []byte) int {
|
||||
if err != nil {
|
||||
panic(err) //nolint
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
|
@@ -13,14 +13,19 @@ import (
|
||||
|
||||
var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
|
||||
|
||||
// NewUDPServer creates an udp server for testing. The supplied handler function will be called with the request
|
||||
// NewUDPServer creates an udp server for testing.
|
||||
// The supplied handler function will be called with the request
|
||||
// and should be used to emulate the server behavior.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func NewUDPServer(
|
||||
t *testing.T,
|
||||
network string,
|
||||
maxMessageSize int,
|
||||
handler func(req []byte) ([]byte, error),
|
||||
) (net.Addr, func(t *testing.T), error) {
|
||||
t.Helper()
|
||||
|
||||
var ip string
|
||||
switch network {
|
||||
case "udp4":
|
||||
@@ -50,28 +55,34 @@ func NewUDPServer(
|
||||
n, addr, err := udpConn.ReadFrom(bs)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := handler(bs[:n])
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
_, err = udpConn.WriteTo(resp, addr)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return serverAddr, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatal(err) //nolint
|
||||
t.Fatal(err)
|
||||
|
||||
return
|
||||
}
|
||||
default:
|
||||
@@ -79,7 +90,7 @@ func NewUDPServer(
|
||||
|
||||
err := udpConn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err) //nolint
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
<-errCh
|
||||
|
10
textattrs.go
10
textattrs.go
@@ -10,7 +10,7 @@ func NewUsername(username string) Username {
|
||||
|
||||
// Username represents USERNAME attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.3
|
||||
// RFC 5389 Section 15.3.
|
||||
type Username []byte
|
||||
|
||||
func (u Username) String() string {
|
||||
@@ -37,7 +37,7 @@ func NewRealm(realm string) Realm {
|
||||
|
||||
// Realm represents REALM attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.7
|
||||
// RFC 5389 Section 15.7.
|
||||
type Realm []byte
|
||||
|
||||
func (n Realm) String() string {
|
||||
@@ -60,7 +60,7 @@ const softwareRawMaxB = 763
|
||||
|
||||
// Software is SOFTWARE attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.10
|
||||
// RFC 5389 Section 15.10.
|
||||
type Software []byte
|
||||
|
||||
func (s Software) String() string {
|
||||
@@ -84,7 +84,7 @@ func (s *Software) GetFrom(m *Message) error {
|
||||
|
||||
// Nonce represents NONCE attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.8
|
||||
// RFC 5389 Section 15.8.
|
||||
type Nonce []byte
|
||||
|
||||
// NewNonce returns new Nonce from string.
|
||||
@@ -118,6 +118,7 @@ func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error {
|
||||
return err
|
||||
}
|
||||
m.Add(t, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -128,5 +129,6 @@ func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error {
|
||||
return err
|
||||
}
|
||||
*v = a
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -13,26 +13,26 @@ import (
|
||||
)
|
||||
|
||||
func TestSoftware_GetFrom(t *testing.T) {
|
||||
m := New()
|
||||
v := "Client v0.0.1"
|
||||
m.Add(AttrSoftware, []byte(v))
|
||||
m.WriteHeader()
|
||||
msg := New()
|
||||
val := "Client v0.0.1"
|
||||
msg.Add(AttrSoftware, []byte(val))
|
||||
msg.WriteHeader()
|
||||
|
||||
m2 := &Message{
|
||||
Raw: make([]byte, 0, 256),
|
||||
}
|
||||
software := new(Software)
|
||||
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
||||
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := software.GetFrom(m); err != nil {
|
||||
if err := software.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if software.String() != v {
|
||||
t.Errorf("Expected %q, got %q.", v, software)
|
||||
if software.String() != val {
|
||||
t.Errorf("Expected %q, got %q.", val, software)
|
||||
}
|
||||
|
||||
sAttr, ok := m.Attributes.Get(AttrSoftware)
|
||||
sAttr, ok := msg.Attributes.Get(AttrSoftware)
|
||||
if !ok {
|
||||
t.Error("software attribute should be found")
|
||||
}
|
||||
@@ -90,22 +90,22 @@ func BenchmarkUsername_GetFrom(b *testing.B) {
|
||||
|
||||
func TestUsername(t *testing.T) {
|
||||
username := "username"
|
||||
u := NewUsername(username)
|
||||
m := new(Message)
|
||||
m.WriteHeader()
|
||||
uName := NewUsername(username)
|
||||
msg := new(Message)
|
||||
msg.WriteHeader()
|
||||
t.Run("Bad length", func(t *testing.T) {
|
||||
badU := make(Username, 600)
|
||||
if err := badU.AddTo(m); !IsAttrSizeOverflow(err) {
|
||||
if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) {
|
||||
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := u.AddTo(m); err != nil {
|
||||
if err := uName.AddTo(msg); err != nil {
|
||||
t.Error("errored:", err)
|
||||
}
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(Username)
|
||||
if err := got.GetFrom(m); err != nil {
|
||||
if err := got.GetFrom(msg); err != nil {
|
||||
t.Error("errored:", err)
|
||||
}
|
||||
if got.String() != username {
|
||||
@@ -136,10 +136,10 @@ func TestUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRealm_GetFrom(t *testing.T) {
|
||||
m := New()
|
||||
v := "realm"
|
||||
m.Add(AttrRealm, []byte(v))
|
||||
m.WriteHeader()
|
||||
msg := New()
|
||||
val := "realm"
|
||||
msg.Add(AttrRealm, []byte(val))
|
||||
msg.WriteHeader()
|
||||
|
||||
m2 := &Message{
|
||||
Raw: make([]byte, 0, 256),
|
||||
@@ -148,17 +148,17 @@ func TestRealm_GetFrom(t *testing.T) {
|
||||
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
|
||||
}
|
||||
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
||||
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := r.GetFrom(m); err != nil {
|
||||
if err := r.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.String() != v {
|
||||
t.Errorf("Expected %q, got %q.", v, r)
|
||||
if r.String() != val {
|
||||
t.Errorf("Expected %q, got %q.", val, r)
|
||||
}
|
||||
|
||||
rAttr, ok := m.Attributes.Get(AttrRealm)
|
||||
rAttr, ok := msg.Attributes.Get(AttrRealm)
|
||||
if !ok {
|
||||
t.Error("realm attribute should be found")
|
||||
}
|
||||
@@ -180,26 +180,26 @@ func TestRealm_AddTo_Invalid(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNonce_GetFrom(t *testing.T) {
|
||||
m := New()
|
||||
v := "example.org"
|
||||
m.Add(AttrNonce, []byte(v))
|
||||
m.WriteHeader()
|
||||
msg := New()
|
||||
val := "example.org"
|
||||
msg.Add(AttrNonce, []byte(val))
|
||||
msg.WriteHeader()
|
||||
|
||||
m2 := &Message{
|
||||
Raw: make([]byte, 0, 256),
|
||||
}
|
||||
var nonce Nonce
|
||||
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
||||
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := nonce.GetFrom(m); err != nil {
|
||||
if err := nonce.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nonce.String() != v {
|
||||
t.Errorf("Expected %q, got %q.", v, nonce)
|
||||
if nonce.String() != val {
|
||||
t.Errorf("Expected %q, got %q.", val, nonce)
|
||||
}
|
||||
|
||||
nAttr, ok := m.Attributes.Get(AttrNonce)
|
||||
nAttr, ok := msg.Attributes.Get(AttrNonce)
|
||||
if !ok {
|
||||
t.Error("nonce attribute should be found")
|
||||
}
|
||||
|
@@ -7,7 +7,7 @@ import "errors"
|
||||
|
||||
// UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.9
|
||||
// RFC 5389 Section 15.9.
|
||||
type UnknownAttributes []AttrType
|
||||
|
||||
func (a UnknownAttributes) String() string {
|
||||
@@ -22,6 +22,7 @@ func (a UnknownAttributes) String() string {
|
||||
s += ", "
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -39,6 +40,7 @@ func (a UnknownAttributes) AddTo(m *Message) error {
|
||||
bin.PutUint16(v[first:last], t.Value())
|
||||
}
|
||||
m.Add(AttrUnknownAttributes, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -62,5 +64,6 @@ func (a *UnknownAttributes) GetFrom(m *Message) error {
|
||||
*a = append(*a, AttrType(bin.Uint16(v[first:last])))
|
||||
first = last
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -8,26 +8,26 @@ import (
|
||||
)
|
||||
|
||||
func TestUnknownAttributes(t *testing.T) {
|
||||
m := new(Message)
|
||||
a := &UnknownAttributes{
|
||||
msg := new(Message)
|
||||
attr := &UnknownAttributes{
|
||||
AttrDontFragment,
|
||||
AttrChannelNumber,
|
||||
}
|
||||
if a.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
|
||||
t.Error("bad String:", a)
|
||||
if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
|
||||
t.Error("bad String:", attr)
|
||||
}
|
||||
if (UnknownAttributes{}).String() != "<nil>" {
|
||||
t.Error("bad blank string")
|
||||
}
|
||||
if err := a.AddTo(m); err != nil {
|
||||
if err := attr.AddTo(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
attrs := make(UnknownAttributes, 10)
|
||||
if err := attrs.GetFrom(m); err != nil {
|
||||
if err := attrs.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
for i, at := range *a {
|
||||
for i, at := range *attr {
|
||||
if at != attrs[i] {
|
||||
t.Error("expected", at, "!=", attrs[i])
|
||||
}
|
||||
@@ -44,8 +44,8 @@ func TestUnknownAttributes(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkUnknownAttributes(b *testing.B) {
|
||||
m := new(Message)
|
||||
a := UnknownAttributes{
|
||||
msg := new(Message)
|
||||
attr := UnknownAttributes{
|
||||
AttrDontFragment,
|
||||
AttrChannelNumber,
|
||||
AttrRealm,
|
||||
@@ -54,20 +54,20 @@ func BenchmarkUnknownAttributes(b *testing.B) {
|
||||
b.Run("AddTo", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := a.AddTo(m); err != nil {
|
||||
if err := attr.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
m.Reset()
|
||||
msg.Reset()
|
||||
}
|
||||
})
|
||||
b.Run("GetFrom", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
if err := a.AddTo(m); err != nil {
|
||||
if err := attr.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
attrs := make(UnknownAttributes, 0, 10)
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := attrs.GetFrom(m); err != nil {
|
||||
if err := attrs.GetFrom(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
attrs = attrs[:0]
|
||||
|
47
uri.go
47
uri.go
@@ -124,7 +124,7 @@ func (t ProtoType) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// URI represents a STUN (rfc7064) or TURN (rfc7065) URI
|
||||
// URI represents a STUN (rfc7064) or TURN (rfc7065) URI.
|
||||
type URI struct {
|
||||
Scheme SchemeType
|
||||
Host string
|
||||
@@ -137,73 +137,76 @@ type URI struct {
|
||||
// ParseURI parses a STUN or TURN urls following the ABNF syntax described in
|
||||
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
|
||||
// respectively.
|
||||
func ParseURI(raw string) (*URI, error) { //nolint:gocognit
|
||||
func ParseURI(raw string) (*URI, error) { //nolint:gocognit,cyclop
|
||||
rawParts, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var u URI
|
||||
u.Scheme = NewSchemeType(rawParts.Scheme)
|
||||
if u.Scheme == SchemeTypeUnknown {
|
||||
var uri URI
|
||||
uri.Scheme = NewSchemeType(rawParts.Scheme)
|
||||
if uri.Scheme == SchemeTypeUnknown {
|
||||
return nil, ErrSchemeType
|
||||
}
|
||||
|
||||
var rawPort string
|
||||
if u.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil {
|
||||
if uri.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { //nolint:nestif
|
||||
var e *net.AddrError
|
||||
if errors.As(err, &e) {
|
||||
if e.Err == "missing port in address" {
|
||||
nextRawURL := u.Scheme.String() + ":" + rawParts.Opaque
|
||||
nextRawURL := uri.Scheme.String() + ":" + rawParts.Opaque
|
||||
switch {
|
||||
case u.Scheme == SchemeTypeSTUN || u.Scheme == SchemeTypeTURN:
|
||||
case uri.Scheme == SchemeTypeSTUN || uri.Scheme == SchemeTypeTURN:
|
||||
nextRawURL += ":3478"
|
||||
if rawParts.RawQuery != "" {
|
||||
nextRawURL += "?" + rawParts.RawQuery
|
||||
}
|
||||
|
||||
return ParseURI(nextRawURL)
|
||||
case u.Scheme == SchemeTypeSTUNS || u.Scheme == SchemeTypeTURNS:
|
||||
case uri.Scheme == SchemeTypeSTUNS || uri.Scheme == SchemeTypeTURNS:
|
||||
nextRawURL += ":5349"
|
||||
if rawParts.RawQuery != "" {
|
||||
nextRawURL += "?" + rawParts.RawQuery
|
||||
}
|
||||
|
||||
return ParseURI(nextRawURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
if uri.Host == "" {
|
||||
return nil, ErrHost
|
||||
}
|
||||
|
||||
if u.Port, err = strconv.Atoi(rawPort); err != nil {
|
||||
if uri.Port, err = strconv.Atoi(rawPort); err != nil {
|
||||
return nil, ErrPort
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
switch uri.Scheme {
|
||||
case SchemeTypeSTUN:
|
||||
qArgs, err := url.ParseQuery(rawParts.RawQuery)
|
||||
if err != nil || len(qArgs) > 0 {
|
||||
return nil, ErrSTUNQuery
|
||||
}
|
||||
u.Proto = ProtoTypeUDP
|
||||
uri.Proto = ProtoTypeUDP
|
||||
case SchemeTypeSTUNS:
|
||||
qArgs, err := url.ParseQuery(rawParts.RawQuery)
|
||||
if err != nil || len(qArgs) > 0 {
|
||||
return nil, ErrSTUNQuery
|
||||
}
|
||||
u.Proto = ProtoTypeTCP
|
||||
uri.Proto = ProtoTypeTCP
|
||||
case SchemeTypeTURN:
|
||||
proto, err := parseProto(rawParts.RawQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.Proto = proto
|
||||
if u.Proto == ProtoTypeUnknown {
|
||||
u.Proto = ProtoTypeUDP
|
||||
uri.Proto = proto
|
||||
if uri.Proto == ProtoTypeUnknown {
|
||||
uri.Proto = ProtoTypeUDP
|
||||
}
|
||||
case SchemeTypeTURNS:
|
||||
proto, err := parseProto(rawParts.RawQuery)
|
||||
@@ -211,15 +214,15 @@ func ParseURI(raw string) (*URI, error) { //nolint:gocognit
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.Proto = proto
|
||||
if u.Proto == ProtoTypeUnknown {
|
||||
u.Proto = ProtoTypeTCP
|
||||
uri.Proto = proto
|
||||
if uri.Proto == ProtoTypeUnknown {
|
||||
uri.Proto = ProtoTypeTCP
|
||||
}
|
||||
|
||||
case SchemeTypeUnknown:
|
||||
}
|
||||
|
||||
return &u, nil
|
||||
return &uri, nil
|
||||
}
|
||||
|
||||
func parseProto(raw string) (ProtoType, error) {
|
||||
@@ -233,6 +236,7 @@ func parseProto(raw string) (ProtoType, error) {
|
||||
if proto = NewProtoType(rawProto); proto == ProtoType(0) {
|
||||
return ProtoTypeUnknown, ErrProtoType
|
||||
}
|
||||
|
||||
return proto, nil
|
||||
}
|
||||
|
||||
@@ -248,6 +252,7 @@ func (u URI) String() string {
|
||||
if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS {
|
||||
rawURL += "?transport=" + u.Proto.String()
|
||||
}
|
||||
|
||||
return rawURL
|
||||
}
|
||||
|
||||
|
12
uri_test.go
12
uri_test.go
@@ -35,8 +35,16 @@ func TestParseURL(t *testing.T) {
|
||||
{"stun:[::1]:123", "stun:[::1]:123", SchemeTypeSTUN, false, "::1", 123, ProtoTypeUDP},
|
||||
{"turn:google.de", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
|
||||
{"turns:google.de", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
|
||||
{"turn:google.de?transport=udp", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
|
||||
{"turns:google.de?transport=tcp", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
|
||||
{
|
||||
"turn:google.de?transport=udp",
|
||||
"turn:google.de:3478?transport=udp",
|
||||
SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP,
|
||||
},
|
||||
{
|
||||
"turns:google.de?transport=tcp",
|
||||
"turns:google.de:5349?transport=tcp",
|
||||
SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
|
33
xoraddr.go
33
xoraddr.go
@@ -20,7 +20,7 @@ const (
|
||||
|
||||
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
|
||||
//
|
||||
// RFC 5389 Section 15.2
|
||||
// RFC 5389 Section 15.2.
|
||||
type XORMappedAddress struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
@@ -43,14 +43,15 @@ func isZeros(p net.IP) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
|
||||
var ErrBadIPLength = errors.New("invalid length of IP value")
|
||||
|
||||
// AddToAs adds XOR-MAPPED-ADDRESS value to m as t attribute.
|
||||
func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error {
|
||||
// AddToAs adds XOR-MAPPED-ADDRESS value to msg as attr attribute.
|
||||
func (a XORMappedAddress) AddToAs(msg *Message, attr AttrType) error {
|
||||
var (
|
||||
family = familyIPv4
|
||||
ip = a.IP
|
||||
@@ -67,12 +68,13 @@ func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error {
|
||||
value := make([]byte, 32+128)
|
||||
value[0] = 0 // first 8 bits are zeroes
|
||||
xorValue := make([]byte, net.IPv6len)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
copy(xorValue[4:], msg.TransactionID[:])
|
||||
bin.PutUint32(xorValue[0:4], magicCookie)
|
||||
bin.PutUint16(value[0:2], family)
|
||||
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
|
||||
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) //nolint:gosec // G115, false positive, port
|
||||
xor.XorBytes(value[4:4+len(ip)], ip, xorValue)
|
||||
m.Add(t, value[:4+len(ip)])
|
||||
msg.Add(attr, value[:4+len(ip)])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -83,13 +85,13 @@ func (a XORMappedAddress) AddTo(m *Message) error {
|
||||
}
|
||||
|
||||
// GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message
|
||||
// getting it as for t type.
|
||||
func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
|
||||
v, err := m.Get(t)
|
||||
// getting it as for attr type.
|
||||
func (a *XORMappedAddress) GetFromAs(msg *Message, attr AttrType) error {
|
||||
value, err := msg.Get(attr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
family := bin.Uint16(v[0:2])
|
||||
family := bin.Uint16(value[0:2])
|
||||
if family != familyIPv6 && family != familyIPv4 {
|
||||
return newDecodeErr("xor-mapped address", "family",
|
||||
fmt.Sprintf("bad value %d", family),
|
||||
@@ -110,17 +112,18 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
|
||||
for i := range a.IP {
|
||||
a.IP[i] = 0
|
||||
}
|
||||
if len(v) <= 4 {
|
||||
if len(value) <= 4 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
if err := CheckOverflow(t, len(v[4:]), len(a.IP)); err != nil {
|
||||
if err := CheckOverflow(attr, len(value[4:]), len(a.IP)); err != nil {
|
||||
return err
|
||||
}
|
||||
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
|
||||
a.Port = int(bin.Uint16(value[2:4])) ^ (magicCookie >> 16)
|
||||
xorValue := make([]byte, 4+TransactionIDSize)
|
||||
bin.PutUint32(xorValue[0:4], magicCookie)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
xor.XorBytes(a.IP, v[4:], xorValue)
|
||||
copy(xorValue[4:], msg.TransactionID[:])
|
||||
xor.XorBytes(a.IP, value[4:], xorValue)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -29,21 +29,21 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
|
||||
m := New()
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
m.Add(AttrXORMappedAddress, addrValue)
|
||||
msg.Add(AttrXORMappedAddress, addrValue)
|
||||
addr := new(XORMappedAddress)
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.GetFrom(m); err != nil {
|
||||
if err := addr.GetFrom(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -94,54 +94,54 @@ func TestXORMappedAddress_GetFrom(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
|
||||
m := New()
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("213.141.156.236")
|
||||
expectedPort := 21254
|
||||
addr := new(XORMappedAddress)
|
||||
|
||||
if err = addr.GetFrom(m); err == nil {
|
||||
if err = addr.GetFrom(msg); err == nil {
|
||||
t.Fatal(err, "should be nil")
|
||||
}
|
||||
|
||||
addr.IP = expectedIP
|
||||
addr.Port = expectedPort
|
||||
addr.AddTo(m) //nolint:errcheck,gosec
|
||||
m.WriteHeader()
|
||||
addr.AddTo(msg) //nolint:errcheck,gosec
|
||||
msg.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
binary.BigEndian.PutUint16(m.Raw[20+4:20+4+2], 0x21)
|
||||
if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
|
||||
binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21)
|
||||
if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = addr.GetFrom(m); err == nil {
|
||||
if err = addr.GetFrom(msg); err == nil {
|
||||
t.Fatal(err, "should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_AddTo(t *testing.T) {
|
||||
m := New()
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("213.141.156.236")
|
||||
expectedPort := 21254
|
||||
addr := &XORMappedAddress{
|
||||
IP: net.ParseIP("213.141.156.236"),
|
||||
Port: expectedPort,
|
||||
}
|
||||
if err = addr.AddTo(m); err != nil {
|
||||
if err = addr.AddTo(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.WriteHeader()
|
||||
msg.WriteHeader()
|
||||
mRes := New()
|
||||
if _, err = mRes.Write(m.Raw); err != nil {
|
||||
if _, err = mRes.Write(msg.Raw); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = addr.GetFrom(mRes); err != nil {
|
||||
@@ -156,27 +156,27 @@ func TestXORMappedAddress_AddTo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
|
||||
m := New()
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
|
||||
expectedPort := 21254
|
||||
addr := &XORMappedAddress{
|
||||
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
|
||||
Port: 21254,
|
||||
}
|
||||
addr.AddTo(m) //nolint:errcheck,gosec
|
||||
m.WriteHeader()
|
||||
addr.AddTo(msg) //nolint:errcheck,gosec
|
||||
msg.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||
if _, err = mRes.ReadFrom(msg.reader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotAddr := new(XORMappedAddress)
|
||||
if err = gotAddr.GetFrom(m); err != nil {
|
||||
if err = gotAddr.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !gotAddr.IP.Equal(expectedIP) {
|
||||
|
Reference in New Issue
Block a user