diff --git a/.golangci.yml b/.golangci.yml index a3235be..e618944 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,17 +25,32 @@ linters-settings: - ^os.Exit$ - ^panic$ - ^print(ln)?$ + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. @@ -46,18 +61,17 @@ linters: - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - - err113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code @@ -65,9 +79,15 @@ linters: - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes @@ -75,28 +95,22 @@ linters: - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - funlen # Tool for detection of long functions - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - gomnd # An analyzer to detect magic numbers. + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - - lll # Reports long lines - - maintidx # maintidx measures the maintainability index of each function. - - makezero # Finds slice declarations with non-zero initial length - - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements - - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated @@ -104,8 +118,7 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - - varnamelen # checks that the length of a variable's name matches its scope + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! @@ -123,3 +136,4 @@ issues: - path: cmd linters: - forbidigo + diff --git a/addr.go b/addr.go index d15e2bb..b5813ad 100644 --- a/addr.go +++ b/addr.go @@ -15,7 +15,7 @@ import ( // This attribute is used only by servers for achieving backwards // compatibility with RFC 3489 clients. // -// RFC 5389 Section 15.1 +// RFC 5389 Section 15.1. type MappedAddress struct { IP net.IP Port int @@ -23,7 +23,7 @@ type MappedAddress struct { // AlternateServer represents ALTERNATE-SERVER attribute. // -// RFC 5389 Section 15.11 +// RFC 5389 Section 15.11. type AlternateServer struct { IP net.IP Port int @@ -31,7 +31,7 @@ type AlternateServer struct { // ResponseOrigin represents RESPONSE-ORIGIN attribute. // -// RFC 5780 Section 7.3 +// RFC 5780 Section 7.3. type ResponseOrigin struct { IP net.IP Port int @@ -39,7 +39,7 @@ type ResponseOrigin struct { // OtherAddress represents OTHER-ADDRESS attribute. // -// RFC 5780 Section 7.4 +// RFC 5780 Section 7.4. type OtherAddress struct { IP net.IP Port int @@ -48,12 +48,14 @@ type OtherAddress struct { // AddTo adds ALTERNATE-SERVER attribute to message. func (s *AlternateServer) AddTo(m *Message) error { a := (*MappedAddress)(s) + return a.AddToAs(m, AttrAlternateServer) } // GetFrom decodes ALTERNATE-SERVER from message. func (s *AlternateServer) GetFrom(m *Message) error { a := (*MappedAddress)(s) + return a.GetFromAs(m, AttrAlternateServer) } @@ -63,14 +65,14 @@ func (a MappedAddress) String() string { // GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t. func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error { - v, err := m.Get(t) + value, err := m.Get(t) if err != nil { return err } - if len(v) <= 4 { + if len(value) <= 4 { return io.ErrUnexpectedEOF } - family := bin.Uint16(v[0:2]) + family := bin.Uint16(value[0:2]) if family != familyIPv6 && family != familyIPv4 { return newDecodeErr("xor-mapped address", "family", fmt.Sprintf("bad value %d", family), @@ -91,13 +93,14 @@ func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error { for i := range a.IP { a.IP[i] = 0 } - a.Port = int(bin.Uint16(v[2:4])) - copy(a.IP, v[4:]) + a.Port = int(bin.Uint16(value[2:4])) + copy(a.IP, value[4:]) + return nil } // AddToAs adds MAPPED-ADDRESS value to m as t attribute. -func (a *MappedAddress) AddToAs(m *Message, t AttrType) error { +func (a *MappedAddress) AddToAs(msg *Message, attrType AttrType) error { var ( family = familyIPv4 ip = a.IP @@ -114,9 +117,10 @@ func (a *MappedAddress) AddToAs(m *Message, t AttrType) error { value := make([]byte, 128) value[0] = 0 // first 8 bits are zeroes bin.PutUint16(value[0:2], family) - bin.PutUint16(value[2:4], uint16(a.Port)) + bin.PutUint16(value[2:4], uint16(a.Port)) //nolint:gosec //G115 copy(value[4:], ip) - m.Add(t, value[:4+len(ip)]) + msg.Add(attrType, value[:4+len(ip)]) + return nil } @@ -133,12 +137,14 @@ func (a *MappedAddress) GetFrom(m *Message) error { // AddTo adds OTHER-ADDRESS attribute to message. func (o *OtherAddress) AddTo(m *Message) error { a := (*MappedAddress)(o) + return a.AddToAs(m, AttrOtherAddress) } // GetFrom decodes OTHER-ADDRESS from message. func (o *OtherAddress) GetFrom(m *Message) error { a := (*MappedAddress)(o) + return a.GetFromAs(m, AttrOtherAddress) } @@ -149,12 +155,14 @@ func (o OtherAddress) String() string { // AddTo adds RESPONSE-ORIGIN attribute to message. func (o *ResponseOrigin) AddTo(m *Message) error { a := (*MappedAddress)(o) + return a.AddToAs(m, AttrResponseOrigin) } // GetFrom decodes RESPONSE-ORIGIN from message. func (o *ResponseOrigin) GetFrom(m *Message) error { a := (*MappedAddress)(o) + return a.GetFromAs(m, AttrResponseOrigin) } diff --git a/addr_test.go b/addr_test.go index 3984b5d..945426f 100644 --- a/addr_test.go +++ b/addr_test.go @@ -11,7 +11,7 @@ import ( ) func TestMappedAddress(t *testing.T) { - m := new(Message) + msg := new(Message) addr := &MappedAddress{ IP: net.ParseIP("122.12.34.5"), Port: 5412, @@ -23,17 +23,17 @@ func TestMappedAddress(t *testing.T) { badAddr := &MappedAddress{ IP: net.IP{1, 2, 3}, } - if err := badAddr.AddTo(m); err == nil { + if err := badAddr.AddTo(msg); err == nil { t.Error("should error") } }) t.Run("AddTo", func(t *testing.T) { - if err := addr.AddTo(m); err != nil { + if err := addr.AddTo(msg); err != nil { t.Error(err) } t.Run("GetFrom", func(t *testing.T) { got := new(MappedAddress) - if err := got.GetFrom(m); err != nil { + if err := got.GetFrom(msg); err != nil { t.Error(err) } if !got.IP.Equal(addr.IP) { @@ -46,9 +46,9 @@ func TestMappedAddress(t *testing.T) { } }) t.Run("Bad family", func(t *testing.T) { - v, _ := m.Attributes.Get(AttrMappedAddress) + v, _ := msg.Attributes.Get(AttrMappedAddress) v.Value[0] = 32 - if err := got.GetFrom(m); err == nil { + if err := got.GetFrom(msg); err == nil { t.Error("should error") } }) diff --git a/agent.go b/agent.go index bc160c8..4f07dce 100644 --- a/agent.go +++ b/agent.go @@ -24,6 +24,7 @@ func NewAgent(h Handler) *Agent { transactions: make(map[transactionID]agentTransaction), handler: h, } + return a } @@ -80,6 +81,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error { a.mux.Lock() if a.closed { a.mux.Unlock() + return ErrAgentClosed } t, exists := a.transactions[id] @@ -93,6 +95,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error { TransactionID: t.id, Error: err, }) + return nil } @@ -124,6 +127,7 @@ func (a *Agent) Start(id [TransactionIDSize]byte, deadline time.Time) error { id: id, deadline: deadline, } + return nil } @@ -147,6 +151,7 @@ func (a *Agent) Collect(gcTime time.Time) error { // All transactions should be already closed // during Close() call. a.mux.Unlock() + return ErrAgentClosed } // Adding all transactions with deadline before gcTime @@ -175,24 +180,27 @@ func (a *Agent) Collect(gcTime time.Time) error { event.TransactionID = id h(event) } + return nil } // Process incoming message, synchronously passing it to handler. func (a *Agent) Process(m *Message) error { - e := Event{ + event := Event{ TransactionID: m.TransactionID, Message: m, } a.mux.Lock() if a.closed { a.mux.Unlock() + return ErrAgentClosed } h := a.handler delete(a.transactions, m.TransactionID) a.mux.Unlock() - h(e) + h(event) + return nil } @@ -201,10 +209,12 @@ func (a *Agent) SetHandler(h Handler) error { a.mux.Lock() if a.closed { a.mux.Unlock() + return ErrAgentClosed } a.handler = h a.mux.Unlock() + return nil } @@ -217,6 +227,7 @@ func (a *Agent) Close() error { a.mux.Lock() if a.closed { a.mux.Unlock() + return ErrAgentClosed } for _, t := range a.transactions { @@ -227,6 +238,7 @@ func (a *Agent) Close() error { a.closed = true a.handler = nil a.mux.Unlock() + return nil } diff --git a/agent_test.go b/agent_test.go index b473945..9777083 100644 --- a/agent_test.go +++ b/agent_test.go @@ -10,49 +10,49 @@ import ( ) func TestAgent_ProcessInTransaction(t *testing.T) { - m := New() - a := NewAgent(func(e Event) { + msg := New() + agent := NewAgent(func(e Event) { if e.Error != nil { t.Errorf("got error: %s", e.Error) } - if !e.Message.Equal(m) { - t.Errorf("%s (got) != %s (expected)", e.Message, m) + if !e.Message.Equal(msg) { + t.Errorf("%s (got) != %s (expected)", e.Message, msg) } }) - if err := m.NewTransactionID(); err != nil { + if err := msg.NewTransactionID(); err != nil { t.Fatal(err) } - if err := a.Start(m.TransactionID, time.Time{}); err != nil { + if err := agent.Start(msg.TransactionID, time.Time{}); err != nil { t.Fatal(err) } - if err := a.Process(m); err != nil { + if err := agent.Process(msg); err != nil { t.Error(err) } - if err := a.Close(); err != nil { + if err := agent.Close(); err != nil { t.Error(err) } } func TestAgent_Process(t *testing.T) { - m := New() - a := NewAgent(func(e Event) { + msg := New() + agent := NewAgent(func(e Event) { if e.Error != nil { t.Errorf("got error: %s", e.Error) } - if !e.Message.Equal(m) { - t.Errorf("%s (got) != %s (expected)", e.Message, m) + if !e.Message.Equal(msg) { + t.Errorf("%s (got) != %s (expected)", e.Message, msg) } }) - if err := m.NewTransactionID(); err != nil { + if err := msg.NewTransactionID(); err != nil { t.Fatal(err) } - if err := a.Process(m); err != nil { + if err := agent.Process(msg); err != nil { t.Error(err) } - if err := a.Close(); err != nil { + if err := agent.Close(); err != nil { t.Error(err) } - if err := a.Process(m); !errors.Is(err, ErrAgentClosed) { + if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) { t.Errorf("closed agent should return <%s>, but got <%s>", ErrAgentClosed, err, ) @@ -60,27 +60,27 @@ func TestAgent_Process(t *testing.T) { } func TestAgent_Start(t *testing.T) { - a := NewAgent(nil) + agent := NewAgent(nil) id := NewTransactionID() deadline := time.Now().AddDate(0, 0, 1) - if err := a.Start(id, deadline); err != nil { + if err := agent.Start(id, deadline); err != nil { t.Errorf("failed to statt transaction: %s", err) } - if err := a.Start(id, deadline); !errors.Is(err, ErrTransactionExists) { + if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) { t.Errorf("duplicate start should return <%s>, got <%s>", ErrTransactionExists, err, ) } - if err := a.Close(); err != nil { + if err := agent.Close(); err != nil { t.Error(err) } id = NewTransactionID() - if err := a.Start(id, deadline); !errors.Is(err, ErrAgentClosed) { + if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) { t.Errorf("start on closed agent should return <%s>, got <%s>", ErrAgentClosed, err, ) } - if err := a.SetHandler(nil); !errors.Is(err, ErrAgentClosed) { + if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) { t.Errorf("SetHandler on closed agent should return <%s>, got <%s>", ErrAgentClosed, err, ) @@ -89,18 +89,18 @@ func TestAgent_Start(t *testing.T) { func TestAgent_Stop(t *testing.T) { called := make(chan Event, 1) - a := NewAgent(func(e Event) { + agent := NewAgent(func(e Event) { called <- e }) - if err := a.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) { + if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) { t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists) } id := NewTransactionID() timeout := time.Millisecond * 200 - if err := a.Start(id, time.Now().Add(timeout)); err != nil { + if err := agent.Start(id, time.Now().Add(timeout)); err != nil { t.Fatal(err) } - if err := a.Stop(id); err != nil { + if err := agent.Stop(id); err != nil { t.Fatal(err) } select { @@ -113,19 +113,19 @@ func TestAgent_Stop(t *testing.T) { case <-time.After(timeout * 2): t.Fatal("timed out") } - if err := a.Close(); err != nil { + if err := agent.Close(); err != nil { t.Fatal(err) } - if err := a.Close(); !errors.Is(err, ErrAgentClosed) { + if err := agent.Close(); !errors.Is(err, ErrAgentClosed) { t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed) } - if err := a.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) { + if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) { t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed) } } -func TestAgent_GC(t *testing.T) { - a := NewAgent(nil) +func TestAgent_GC(t *testing.T) { //nolint:cyclop + agent := NewAgent(nil) shouldTimeOutID := make(map[transactionID]bool) deadline := time.Date(2027, time.November, 21, 23, 0, 0, 0, @@ -133,7 +133,7 @@ func TestAgent_GC(t *testing.T) { ) gcDeadline := deadline.Add(-time.Second) deadlineNotGC := gcDeadline.AddDate(0, 0, -1) - a.SetHandler(func(e Event) { //nolint:errcheck,gosec + agent.SetHandler(func(e Event) { //nolint:errcheck,gosec id := e.TransactionID shouldTimeOut, found := shouldTimeOutID[id] if !found { @@ -149,67 +149,67 @@ func TestAgent_GC(t *testing.T) { for i := 0; i < 5; i++ { id := NewTransactionID() shouldTimeOutID[id] = false - if err := a.Start(id, deadline); err != nil { + if err := agent.Start(id, deadline); err != nil { t.Fatal(err) } } for i := 0; i < 5; i++ { id := NewTransactionID() shouldTimeOutID[id] = true - if err := a.Start(id, deadlineNotGC); err != nil { + if err := agent.Start(id, deadlineNotGC); err != nil { t.Fatal(err) } } - if err := a.Collect(gcDeadline); err != nil { + if err := agent.Collect(gcDeadline); err != nil { t.Fatal(err) } - if err := a.Close(); err != nil { + if err := agent.Close(); err != nil { t.Error(err) } - if err := a.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) { + if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) { t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err) } } func BenchmarkAgent_GC(b *testing.B) { - a := NewAgent(nil) + agent := NewAgent(nil) deadline := time.Now().AddDate(0, 0, 1) for i := 0; i < agentCollectCap; i++ { - if err := a.Start(NewTransactionID(), deadline); err != nil { + if err := agent.Start(NewTransactionID(), deadline); err != nil { b.Fatal(err) } } defer func() { - if err := a.Close(); err != nil { + if err := agent.Close(); err != nil { b.Error(err) } }() b.ReportAllocs() gcDeadline := deadline.Add(-time.Second) for i := 0; i < b.N; i++ { - if err := a.Collect(gcDeadline); err != nil { + if err := agent.Collect(gcDeadline); err != nil { b.Fatal(err) } } } func BenchmarkAgent_Process(b *testing.B) { - a := NewAgent(nil) + agent := NewAgent(nil) deadline := time.Now().AddDate(0, 0, 1) for i := 0; i < 1000; i++ { - if err := a.Start(NewTransactionID(), deadline); err != nil { + if err := agent.Start(NewTransactionID(), deadline); err != nil { b.Fatal(err) } } defer func() { - if err := a.Close(); err != nil { + if err := agent.Close(); err != nil { b.Error(err) } }() b.ReportAllocs() m := MustBuild(TransactionID) for i := 0; i < b.N; i++ { - if err := a.Process(m); err != nil { + if err := agent.Process(m); err != nil { b.Fatal(err) } } diff --git a/attributes.go b/attributes.go index 8a1aa21..7db0e23 100644 --- a/attributes.go +++ b/attributes.go @@ -21,6 +21,7 @@ func (a Attributes) Get(t AttrType) (RawAttribute, bool) { return candidate, true } } + return RawAttribute{}, false } @@ -77,7 +78,7 @@ const ( AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN ) -// Attributes from RFC 5780 NAT Behavior Discovery +// Attributes from RFC 5780 NAT Behavior Discovery. const ( AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST AttrPadding AttrType = 0x0026 // PADDING @@ -166,6 +167,7 @@ func (t AttrType) String() string { // Just return hex representation of unknown attribute type. return fmt.Sprintf("0x%x", uint16(t)) } + return s } @@ -186,6 +188,7 @@ type RawAttribute struct { // the Length field. func (a RawAttribute) AddTo(m *Message) error { m.Add(a.Type, a.Value) + return nil } @@ -205,6 +208,7 @@ func (a RawAttribute) Equal(b RawAttribute) bool { return false } } + return true } @@ -224,6 +228,7 @@ func (m *Message) Get(t AttrType) ([]byte, error) { if !ok { return nil, ErrAttributeNotFound } + return v.Value, nil } @@ -240,6 +245,7 @@ func nearestPaddedValueLength(l int) int { if n < l { n += padding } + return n } @@ -250,5 +256,6 @@ func compatAttrType(val uint16) AttrType { if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on) } + return AttrType(val) } diff --git a/attributes_test.go b/attributes_test.go index c4255b2..a76e286 100644 --- a/attributes_test.go +++ b/attributes_test.go @@ -44,13 +44,13 @@ func TestRawAttribute_AddTo(t *testing.T) { } func TestMessage_GetNoAllocs(t *testing.T) { - m := New() - NewSoftware("c").AddTo(m) //nolint:errcheck,gosec - m.WriteHeader() + msg := New() + NewSoftware("c").AddTo(msg) //nolint:errcheck,gosec + msg.WriteHeader() t.Run("Default", func(t *testing.T) { allocs := testing.AllocsPerRun(10, func() { - m.Get(AttrSoftware) //nolint:errcheck,gosec + msg.Get(AttrSoftware) //nolint:errcheck,gosec }) if allocs > 0 { t.Error("allocated memory, but should not") @@ -58,7 +58,7 @@ func TestMessage_GetNoAllocs(t *testing.T) { }) t.Run("Not found", func(t *testing.T) { allocs := testing.AllocsPerRun(10, func() { - m.Get(AttrOrigin) //nolint:errcheck,gosec + msg.Get(AttrOrigin) //nolint:errcheck,gosec }) if allocs > 0 { t.Error("allocated memory, but should not") diff --git a/checks.go b/checks.go index b10dd34..04716a4 100644 --- a/checks.go +++ b/checks.go @@ -17,6 +17,7 @@ func CheckSize(_ AttrType, got, expected int) error { if got == expected { return nil } + return ErrAttributeSizeInvalid } @@ -24,6 +25,7 @@ func checkHMAC(got, expected []byte) error { if hmac.Equal(got, expected) { return nil } + return ErrIntegrityMismatch } @@ -31,6 +33,7 @@ func checkFingerprint(got, expected uint32) error { if got == expected { return nil } + return ErrFingerprintMismatch } @@ -44,6 +47,7 @@ func CheckOverflow(_ AttrType, got, maxVal int) error { if got <= maxVal { return nil } + return ErrAttributeSizeOverflow } diff --git a/client.go b/client.go index 24d3faf..ebd8879 100644 --- a/client.go +++ b/client.go @@ -21,7 +21,7 @@ import ( "github.com/pion/transport/v3/stdnet" ) -// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI +// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI. var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport") // Dial connects to the address on the named network and then @@ -31,10 +31,11 @@ func Dial(network, address string) (*Client, error) { if err != nil { return nil, err } + return NewClient(conn) } -// DialConfig is used to pass configuration to DialURI() +// DialConfig is used to pass configuration to DialURI(). type DialConfig struct { DTLSConfig dtls.Config TLSConfig tls.Config @@ -44,7 +45,7 @@ type DialConfig struct { // DialURI connect to the STUN/TURN URI and then // initializes Client on that connection, returning error if any. -func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { +func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop var conn Connection var err error @@ -203,7 +204,7 @@ const ( // provide any API for it, so if you need to read application data, wrap the // connection with your (de-)multiplexer and pass the wrapper as conn. func NewClient(conn Connection, options ...ClientOption) (*Client, error) { - c := &Client{ + client := &Client{ close: make(chan struct{}), c: conn, clock: systemClock(), @@ -214,32 +215,33 @@ func NewClient(conn Connection, options ...ClientOption) (*Client, error) { closeConn: true, } for _, o := range options { - o(c) + o(client) } - if c.c == nil { + if client.c == nil { return nil, ErrNoConnection } - if c.a == nil { - c.a = NewAgent(nil) + if client.a == nil { + client.a = NewAgent(nil) } - if err := c.a.SetHandler(c.handleAgentCallback); err != nil { + if err := client.a.SetHandler(client.handleAgentCallback); err != nil { return nil, err } - if c.collector == nil { - c.collector = &tickerCollector{ + if client.collector == nil { + client.collector = &tickerCollector{ close: make(chan struct{}), - clock: c.clock, + clock: client.clock, } } - if err := c.collector.Start(c.rtoRate, func(t time.Time) { - closedOrPanic(c.a.Collect(t)) + if err := client.collector.Start(client.rtoRate, func(t time.Time) { + closedOrPanic(client.a.Collect(t)) }); err != nil { return nil, err } - c.wg.Add(1) - go c.readUntilClosed() - runtime.SetFinalizer(c, clientFinalizer) - return c, nil + client.wg.Add(1) + go client.readUntilClosed() + runtime.SetFinalizer(client, clientFinalizer) + + return client, nil } func clientFinalizer(c *Client) { @@ -252,6 +254,7 @@ func clientFinalizer(c *Client) { } if err == nil { log.Println("client: called finalizer on non-closed client") // nolint + return } log.Println("client: called finalizer on non-closed client:", err) // nolint @@ -353,6 +356,7 @@ func (c *Client) start(t *clientTransaction) error { return ErrTransactionExists } c.t[t.id] = t + return nil } @@ -399,6 +403,7 @@ func sprintErr(err error) string { if err == nil { return "" //nolint:goconst } + return err.Error() } @@ -455,18 +460,21 @@ func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error select { case <-a.close: t.Stop() + return case <-t.C: f(a.clock.Now()) } } }() + return nil } func (a *tickerCollector) Close() error { close(a.close) a.wg.Wait() + return nil } @@ -481,6 +489,7 @@ func (c *Client) Close() error { c.mux.Lock() if c.closed { c.mux.Unlock() + return ErrClientClosed } c.closed = true @@ -498,6 +507,7 @@ func (c *Client) Close() error { if agentErr == nil && connErr == nil { return nil } + return CloseErr{ AgentErr: agentErr, ConnectionErr: connErr, @@ -566,6 +576,7 @@ func (c *Client) checkInit() error { if c == nil || c.c == nil || c.a == nil || c.close == nil { return ErrClientNotInitialized } + return nil } @@ -590,6 +601,7 @@ func (c *Client) Do(m *Message, f func(Event)) error { return err } h.wait() + return nil } @@ -611,80 +623,85 @@ var bufferPool = &sync.Pool{ //nolint:gochecknoglobals }, } -func (c *Client) handleAgentCallback(e Event) { +func (c *Client) handleAgentCallback(event Event) { //nolint:cyclop c.mux.Lock() if c.closed { c.mux.Unlock() + return } - t, found := c.t[e.TransactionID] + transaction, found := c.t[event.TransactionID] if found { - delete(c.t, t.id) + delete(c.t, transaction.id) } c.mux.Unlock() if !found { - if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) { - c.handler(e) + if c.handler != nil && !errors.Is(event.Error, ErrTransactionStopped) { + c.handler(event) } // Ignoring. return } - if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil { + if atomic.LoadInt32(&c.maxAttempts) <= transaction.attempt || event.Error == nil { // Transaction completed. - t.handle(e) - putClientTransaction(t) + transaction.handle(event) + putClientTransaction(transaction) + return } // Doing re-transmission. - t.attempt++ - b := bufferPool.Get().(*buffer) //nolint:forcetypeassert - b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)] - defer bufferPool.Put(b) + transaction.attempt++ + buff := bufferPool.Get().(*buffer) //nolint:forcetypeassert + buff.buf = buff.buf[:copy(buff.buf[:cap(buff.buf)], transaction.raw)] + defer bufferPool.Put(buff) var ( now = c.clock.Now() - timeOut = t.nextTimeout(now) - id = t.id + timeOut = transaction.nextTimeout(now) + id = transaction.id ) // Starting client transaction. - if startErr := c.start(t); startErr != nil { + if startErr := c.start(transaction); startErr != nil { c.delete(id) - e.Error = startErr - t.handle(e) - putClientTransaction(t) + event.Error = startErr + transaction.handle(event) + putClientTransaction(transaction) + return } // Starting agent transaction. if startErr := c.a.Start(id, timeOut); startErr != nil { c.delete(id) - e.Error = startErr - t.handle(e) - putClientTransaction(t) + event.Error = startErr + transaction.handle(event) + putClientTransaction(transaction) + return } // Writing message to connection again. - _, writeErr := c.c.Write(b.buf) + _, writeErr := c.c.Write(buff.buf) if writeErr != nil { c.delete(id) - e.Error = writeErr + event.Error = writeErr // Stopping agent transaction instead of waiting until it's deadline. // This will call handleAgentCallback with "ErrTransactionStopped" error // which will be ignored. if stopErr := c.a.Stop(id); stopErr != nil { // Failed to stop agent transaction. Wrapping the error in StopError. - e.Error = StopErr{ + event.Error = StopErr{ Err: stopErr, Cause: writeErr, } } - t.handle(e) - putClientTransaction(t) + transaction.handle(event) + putClientTransaction(transaction) + return } } // Start starts transaction (if h set) and writes message to server, handler // is called asynchronously. -func (c *Client) Start(m *Message, h Handler) error { +func (c *Client) Start(msg *Message, handler Handler) error { if err := c.checkInit(); err != nil { return err } @@ -694,34 +711,35 @@ func (c *Client) Start(m *Message, h Handler) error { if closed { return ErrClientClosed } - if h != nil { + if handler != nil { // Starting transaction only if h is set. Useful for indications. t := acquireClientTransaction() - t.id = m.TransactionID + t.id = msg.TransactionID t.start = c.clock.Now() - t.h = h + t.h = handler t.rto = time.Duration(atomic.LoadInt64(&c.rto)) t.attempt = 0 - t.raw = append(t.raw[:0], m.Raw...) + t.raw = append(t.raw[:0], msg.Raw...) t.calls = 0 d := t.nextTimeout(t.start) if err := c.start(t); err != nil { return err } - if err := c.a.Start(m.TransactionID, d); err != nil { + if err := c.a.Start(msg.TransactionID, d); err != nil { return err } } - _, err := m.WriteTo(c.c) - if err != nil && h != nil { - c.delete(m.TransactionID) + _, err := msg.WriteTo(c.c) + if err != nil && handler != nil { + c.delete(msg.TransactionID) // Stopping transaction instead of waiting until deadline. - if stopErr := c.a.Stop(m.TransactionID); stopErr != nil { + if stopErr := c.a.Stop(msg.TransactionID); stopErr != nil { return StopErr{ Err: stopErr, Cause: err, } } } + return err } diff --git a/client_test.go b/client_test.go index 826cc10..da2cc63 100644 --- a/client_test.go +++ b/client_test.go @@ -38,11 +38,13 @@ type TestAgent struct { func (n *TestAgent) SetHandler(h Handler) error { n.h = h + return nil } func (n *TestAgent) Close() error { close(n.e) + return nil } @@ -54,6 +56,7 @@ func (n *TestAgent) Start(id [TransactionIDSize]byte, _ time.Time) error { n.e <- Event{ TransactionID: id, } + return nil } @@ -69,6 +72,7 @@ func (noopConnection) Write(b []byte) (int, error) { func (noopConnection) Read([]byte) (int, error) { time.Sleep(time.Millisecond) + return 0, io.EOF } @@ -136,6 +140,7 @@ func (t *testConnection) Close() error { return errClientAlreadyStopped } t.stopped = true + return nil } @@ -148,6 +153,7 @@ func (t *testConnection) Read(b []byte) (int, error) { if t.read != nil { return t.read(b) } + return copy(b, t.b), nil } @@ -165,7 +171,7 @@ func TestClosedOrPanic(t *testing.T) { }() } -func TestClient_Start(t *testing.T) { +func TestClient_Start(t *testing.T) { //nolint:cyclop response := MustBuild(TransactionID, BindingSuccess) response.Encode() write := make(chan struct{}, 1) @@ -178,6 +184,7 @@ func TestClient_Start(t *testing.T) { case <-read: t.Log("reading") copy(i, response.Raw) + return len(response.Raw), nil case <-time.After(time.Millisecond * 10): return 0, errClientReadTimedOut @@ -188,33 +195,34 @@ func TestClient_Start(t *testing.T) { select { case <-write: t.Log("writing") + return len(bytes), nil case <-time.After(time.Millisecond * 10): return 0, errClientWriteTimedOut } }, } - c, err := NewClient(conn) + client, err := NewClient(conn) if err != nil { log.Fatal(err) } defer func() { - if err := c.Close(); err != nil { + if err := client.Close(); err != nil { t.Error(err) } - if err := c.Close(); err == nil { + if err := client.Close(); err == nil { t.Error("second close should fail") } - if err := c.Do(MustBuild(TransactionID), nil); err == nil { + if err := client.Do(MustBuild(TransactionID), nil); err == nil { t.Error("Do after Close should fail") } }() - m := MustBuild(response, BindingRequest) + msg := MustBuild(response, BindingRequest) t.Log("init") got := make(chan struct{}) write <- struct{}{} t.Log("starting the first transaction") - if err := c.Start(m, func(event Event) { + if err := client.Start(msg, func(event Event) { t.Log("got first transaction callback") if event.Error != nil { t.Error(event.Error) @@ -224,7 +232,7 @@ func TestClient_Start(t *testing.T) { t.Error(err) } t.Log("starting the second transaction") - if err := c.Start(m, func(Event) { + if err := client.Start(msg, func(Event) { t.Error("should not be called") }); !errors.Is(err, ErrTransactionExists) { t.Errorf("unexpected error %v", err) @@ -247,25 +255,25 @@ func TestClient_Do(t *testing.T) { return len(bytes), nil }, } - c, err := NewClient(conn) + client, err := NewClient(conn) if err != nil { log.Fatal(err) } defer func() { - if err := c.Close(); err != nil { + if err := client.Close(); err != nil { t.Error(err) } - if err := c.Close(); err == nil { + if err := client.Close(); err == nil { t.Error("second close should fail") } - if err := c.Do(MustBuild(TransactionID), nil); err == nil { + if err := client.Do(MustBuild(TransactionID), nil); err == nil { t.Error("Do after Close should fail") } }() m := MustBuild( NewTransactionIDSetter(response.TransactionID), ) - if err := c.Do(m, func(event Event) { + if err := client.Do(m, func(event Event) { if event.Error != nil { t.Error(event.Error) } @@ -273,13 +281,13 @@ func TestClient_Do(t *testing.T) { t.Error(err) } m = MustBuild(TransactionID) - if err := c.Do(m, nil); err != nil { + if err := client.Do(m, nil); err != nil { t.Error(err) } } func TestCloseErr_Error(t *testing.T) { - for id, c := range []struct { + for id, testCase := range []struct { Err CloseErr Out string }{ @@ -291,16 +299,16 @@ func TestCloseErr_Error(t *testing.T) { ConnectionErr: io.ErrUnexpectedEOF, }, "failed to close: unexpected EOF (connection), (agent)"}, } { - if out := c.Err.Error(); out != c.Out { + if out := testCase.Err.Error(); out != testCase.Out { t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)", - id, c.Err, out, c.Out, + id, testCase.Err, out, testCase.Out, ) } } } func TestStopErr_Error(t *testing.T) { - for id, c := range []struct { + for id, testcase := range []struct { Err StopErr Out string }{ @@ -312,9 +320,9 @@ func TestStopErr_Error(t *testing.T) { Cause: io.ErrUnexpectedEOF, }, "error while stopping due to unexpected EOF: "}, } { - if out := c.Err.Error(); out != c.Out { + if out := testcase.Err.Error(); out != testcase.Out { t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)", - id, c.Err, out, c.Out, + id, testcase.Err, out, testcase.Out, ) } } @@ -352,7 +360,7 @@ func TestClientAgentError(t *testing.T) { return len(bytes), nil }, } - c, err := NewClient(conn, + client, err := NewClient(conn, WithAgent(errorAgent{ startErr: io.ErrUnexpectedEOF, }), @@ -361,15 +369,15 @@ func TestClientAgentError(t *testing.T) { log.Fatal(err) } defer func() { - if err := c.Close(); err != nil { + if err := client.Close(); err != nil { t.Error(err) } }() m := MustBuild(NewTransactionIDSetter(response.TransactionID)) - if err := c.Do(m, nil); err != nil { + if err := client.Do(m, nil); err != nil { t.Error(err) } - if err := c.Do(m, func(event Event) { + if err := client.Do(m, func(event Event) { if event.Error == nil { t.Error("error expected") } @@ -384,20 +392,20 @@ func TestClientConnErr(t *testing.T) { return 0, io.ErrClosedPipe }, } - c, err := NewClient(conn) + client, err := NewClient(conn) if err != nil { log.Fatal(err) } defer func() { - if err := c.Close(); err != nil { + if err := client.Close(); err != nil { t.Error(err) } }() m := MustBuild(TransactionID) - if err := c.Do(m, nil); err == nil { + if err := client.Do(m, nil); err == nil { t.Error("error expected") } - if err := c.Do(m, NoopHandler()); err == nil { + if err := client.Do(m, NoopHandler()); err == nil { t.Error("error expected") } } @@ -408,7 +416,7 @@ func TestClientConnErrStopErr(t *testing.T) { return 0, io.ErrClosedPipe }, } - c, err := NewClient(conn, + client, err := NewClient(conn, WithAgent(errorAgent{ stopErr: io.ErrUnexpectedEOF, }), @@ -417,12 +425,12 @@ func TestClientConnErrStopErr(t *testing.T) { log.Fatal(err) } defer func() { - if err := c.Close(); err != nil { + if err := client.Close(); err != nil { t.Error(err) } }() m := MustBuild(TransactionID) - if err := c.Do(m, NoopHandler()); err == nil { + if err := client.Do(m, NoopHandler()); err == nil { t.Error("error expected") } } @@ -556,11 +564,13 @@ func (a *gcWaitAgent) Stop([TransactionIDSize]byte) error { func (a *gcWaitAgent) Close() error { close(a.gc) + return nil } func (a *gcWaitAgent) Collect(time.Time) error { a.gc <- struct{}{} + return nil } @@ -617,6 +627,7 @@ func captureLog() (*bytes.Buffer, func()) { log.SetOutput(&buf) f := log.Flags() log.SetFlags(0) + return &buf, func() { log.SetFlags(f) log.SetOutput(os.Stderr) @@ -633,12 +644,12 @@ func TestClientFinalizer(t *testing.T) { return 0, io.ErrClosedPipe }, } - c, err := NewClient(conn) + client, err := NewClient(conn) if err != nil { log.Panic(err) } - clientFinalizer(c) - clientFinalizer(c) + clientFinalizer(client) + clientFinalizer(client) response := MustBuild(TransactionID, BindingSuccess) response.Encode() conn = &testConnection{ @@ -647,7 +658,7 @@ func TestClientFinalizer(t *testing.T) { return len(bytes), nil }, } - c, err = NewClient(conn, + client, err = NewClient(conn, WithAgent(errorAgent{ closeErr: io.ErrUnexpectedEOF, }), @@ -655,7 +666,7 @@ func TestClientFinalizer(t *testing.T) { if err != nil { log.Panic(err) } - clientFinalizer(c) + clientFinalizer(client) reader := bufio.NewScanner(buf) var lines int expectedLines := []string{ @@ -700,6 +711,7 @@ func (m *manualCollector) Collect(t time.Time) { func (m *manualCollector) Start(_ time.Duration, f func(t time.Time)) error { m.f = f + return nil } @@ -717,12 +729,14 @@ func (m *manualClock) Add(d time.Duration) time.Time { v := m.current.Add(d) m.current = v m.mux.Unlock() + return v } func (m *manualClock) Now() time.Time { m.mux.Lock() defer m.mux.Unlock() + return m.current } @@ -735,6 +749,7 @@ type manualAgent struct { func (n *manualAgent) SetHandler(h Handler) error { n.h = h + return nil } @@ -748,6 +763,7 @@ func (n *manualAgent) Process(m *Message) error { if n.process != nil { return n.process(m) } + return nil } @@ -759,6 +775,7 @@ func (n *manualAgent) Stop(id [TransactionIDSize]byte) error { if n.stop != nil { return n.stop(id) } + return nil } @@ -788,9 +805,10 @@ func TestClientRetransmission(t *testing.T) { Message: response, }) } + return nil } - c, err := NewClient(connR, + client, err := NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -799,7 +817,7 @@ func TestClientRetransmission(t *testing.T) { if err != nil { t.Fatal(err) } - c.SetRTO(time.Second) + client.SetRTO(time.Second) gotReads := make(chan struct{}) go func() { buf := make([]byte, 1500) @@ -819,7 +837,7 @@ func TestClientRetransmission(t *testing.T) { } gotReads <- struct{}{} }() - if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { + if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { if event.Error != nil { t.Error("failed") } @@ -829,7 +847,9 @@ func TestClientRetransmission(t *testing.T) { <-gotReads } -func testClientDoConcurrent(t *testing.T, concurrency int) { +func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop + t.Helper() + response := MustBuild(TransactionID, BindingSuccess) response.Encode() connL, connR := net.Pipe() @@ -846,9 +866,10 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { TransactionID: id, Message: response, }) + return nil } - c, err := NewClient(connR, + client, err := NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -856,7 +877,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { if err != nil { t.Fatal(err) } - c.SetRTO(time.Second) + client.SetRTO(time.Second) conns := new(sync.WaitGroup) wg := new(sync.WaitGroup) for i := 0; i < concurrency; i++ { @@ -880,7 +901,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { wg.Add(1) go func() { defer wg.Done() - if doErr := c.Do(MustBuild(TransactionID, BindingRequest), func(event Event) { + if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) { if event.Error != nil { t.Error("failed") } @@ -962,14 +983,14 @@ func TestClient_Close(t *testing.T) { } func TestClientDefaultHandler(t *testing.T) { - a := &TestAgent{ + agent := &TestAgent{ e: make(chan Event), } id := NewTransactionID() handlerCalled := make(chan struct{}) called := false - c, createErr := NewClient(noopConnection{}, - WithAgent(a), + client, createErr := NewClient(noopConnection{}, + WithAgent(agent), WithHandler(func(e Event) { if called { t.Error("should not be called twice") @@ -985,7 +1006,7 @@ func TestClientDefaultHandler(t *testing.T) { t.Fatal(createErr) } go func() { - a.h(Event{ + agent.h(Event{ TransactionID: id, }) }() @@ -995,11 +1016,11 @@ func TestClientDefaultHandler(t *testing.T) { case <-time.After(time.Millisecond * 100): t.Fatal("timed out") } - if closeErr := c.Close(); closeErr != nil { + if closeErr := client.Close(); closeErr != nil { t.Error(closeErr) } // Handler call should be ignored. - a.h(Event{}) + agent.h(Event{}) } func TestClientClosedStart(t *testing.T) { @@ -1047,9 +1068,10 @@ func TestWithNoRetransmit(t *testing.T) { Error: ErrTransactionTimeOut, }) } + return nil } - c, err := NewClient(connR, + client, err := NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -1071,7 +1093,7 @@ func TestWithNoRetransmit(t *testing.T) { } gotReads <- struct{}{} }() - if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { + if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { if !errors.Is(event.Error, ErrTransactionTimeOut) { t.Error("unexpected error") } @@ -1087,7 +1109,7 @@ func (c callbackClock) Now() time.Time { return c() } -func TestClientRTOStartErr(t *testing.T) { +func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop response := MustBuild(TransactionID, BindingSuccess) response.Encode() connL, connR := net.Pipe() @@ -1116,13 +1138,14 @@ func TestClientRTOStartErr(t *testing.T) { } else { t.Log("clock returned") } + return time.Now() }) agent := &manualAgent{} attempt := 0 gotReads := make(chan struct{}) var ( - c *Client + client *Client startClientErr error ) agent.start = func(id [TransactionIDSize]byte, _ time.Time) error { @@ -1146,7 +1169,7 @@ func TestClientRTOStartErr(t *testing.T) { t.Log("clock locked") <-clockLocked t.Log("closing client") - if closeErr := c.Close(); closeErr != nil { + if closeErr := client.Close(); closeErr != nil { t.Error(closeErr) } t.Log("client closed, unlocking clock") @@ -1154,9 +1177,10 @@ func TestClientRTOStartErr(t *testing.T) { t.Log("clock unlocked") }() } + return nil } - c, startClientErr = NewClient(connR, + client, startClientErr = NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -1186,7 +1210,7 @@ func TestClientRTOStartErr(t *testing.T) { t.Log("starting") done := make(chan struct{}) go func() { - if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { + if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { if !errors.Is(event.Error, ErrClientClosed) { t.Error(event.Error) } @@ -1203,7 +1227,7 @@ func TestClientRTOStartErr(t *testing.T) { } } -func TestClientRTOWriteErr(t *testing.T) { +func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop response := MustBuild(TransactionID, BindingSuccess) response.Encode() connL, connR := net.Pipe() @@ -1232,13 +1256,14 @@ func TestClientRTOWriteErr(t *testing.T) { } else { t.Log("clock returned") } + return time.Now() }) agent := &manualAgent{} attempt := 0 gotReads := make(chan struct{}) var ( - c *Client + client *Client startClientErr error ) agentStopErr := errClientAgentCantStop @@ -1274,9 +1299,10 @@ func TestClientRTOWriteErr(t *testing.T) { t.Log("clock unlocked") }() } + return nil } - c, startClientErr = NewClient(connR, + client, startClientErr = NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -1306,7 +1332,7 @@ func TestClientRTOWriteErr(t *testing.T) { t.Log("starting") done := make(chan struct{}) go func() { - if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { + if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { var e StopErr if !errors.As(event.Error, &e) { t.Error(event.Error) @@ -1346,7 +1372,7 @@ func TestClientRTOAgentErr(t *testing.T) { attempt := 0 gotReads := make(chan struct{}) var ( - c *Client + client *Client startClientErr error ) agentStartErr := errClientStartRefused @@ -1361,9 +1387,10 @@ func TestClientRTOAgentErr(t *testing.T) { } else { return agentStartErr } + return nil } - c, startClientErr = NewClient(connR, + client, startClientErr = NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -1384,7 +1411,7 @@ func TestClientRTOAgentErr(t *testing.T) { gotReads <- struct{}{} }() t.Log("starting") - if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { + if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { if !errors.Is(event.Error, agentStartErr) { t.Error(event.Error) } @@ -1415,9 +1442,10 @@ func TestClient_HandleProcessError(t *testing.T) { processCalled := make(chan struct{}, 1) agent.process = func(*Message) error { processCalled <- struct{}{} + return ErrAgentClosed } - c, startClientErr := NewClient(connR, + client, startClientErr := NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -1440,7 +1468,7 @@ func TestClient_HandleProcessError(t *testing.T) { case <-time.After(time.Second * 5): t.Error("reads timeout") } - if closeErr := c.Close(); closeErr != nil { + if closeErr := client.Close(); closeErr != nil { t.Error(closeErr) } } @@ -1475,9 +1503,10 @@ func TestClientImmediateTimeout(t *testing.T) { Error: ErrTransactionTimeOut, }) } + return nil } - c, err := NewClient(connR, + client, err := NewClient(connR, WithAgent(agent), WithClock(clock), WithCollector(collector), @@ -1498,7 +1527,7 @@ func TestClientImmediateTimeout(t *testing.T) { } gotReads <- struct{}{} }() - c.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec + client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec if errors.Is(e.Error, ErrTransactionTimeOut) { t.Error("unexpected error") } diff --git a/cmd/stun-bench/main.go b/cmd/stun-bench/main.go index e6ffdcc..3eef3bf 100644 --- a/cmd/stun-bench/main.go +++ b/cmd/stun-bench/main.go @@ -30,7 +30,7 @@ var ( realRand = flag.Bool("crypt", false, "use crypto/rand as random source") //nolint:gochecknoglobals ) -func main() { //nolint:gocognit +func main() { //nolint:gocognit,cyclop flag.Parse() uri, err := stun.ParseURI(*uriStr) if err != nil { @@ -88,7 +88,7 @@ func main() { //nolint:gocognit log.Print("Using crypto/rand as random source for transaction id") } for i := 0; i < *workers; i++ { - c, clientErr := stun.DialURI(uri, &stun.DialConfig{}) + client, clientErr := stun.DialURI(uri, &stun.DialConfig{}) if clientErr != nil { log.Panicf("Failed to create client: %s", clientErr) } @@ -105,12 +105,13 @@ func main() { //nolint:gocognit req.Type = stun.BindingRequest req.WriteHeader() atomic.AddInt64(&request, 1) - if doErr := c.Do(req, func(event stun.Event) { + if doErr := client.Do(req, func(event stun.Event) { if event.Error != nil { if !errors.Is(event.Error, stun.ErrTransactionTimeOut) { log.Printf("Failed STUN transaction: %s", event.Error) } atomic.AddInt64(&requestErr, 1) + return } atomic.AddInt64(&requestOK, 1) diff --git a/cmd/stun-client/stun_client.go b/cmd/stun-client/stun_client.go index 0442bd7..c75dbd3 100644 --- a/cmd/stun-client/stun_client.go +++ b/cmd/stun-client/stun_client.go @@ -31,11 +31,11 @@ func main() { } // we only try the first address, so restrict ourselves to IPv4 - c, err := stun.DialURI(uri, &stun.DialConfig{}) + client, err := stun.DialURI(uri, &stun.DialConfig{}) if err != nil { log.Fatalf("Failed to dial: %s", err) } - if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { + if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { if res.Error != nil { log.Fatalf("Failed STUN transaction: %s", res.Error) } @@ -49,7 +49,7 @@ func main() { }); err != nil { log.Fatal("Do:", err) } - if err := c.Close(); err != nil { + if err := client.Close(); err != nil { log.Fatalf("Failed to close connection: %s", err) } } diff --git a/cmd/stun-multiplex/main.go b/cmd/stun-multiplex/main.go index 1212235..c1e3564 100644 --- a/cmd/stun-multiplex/main.go +++ b/cmd/stun-multiplex/main.go @@ -84,7 +84,7 @@ func multiplex(conn *net.UDPConn, stunAddr net.Addr, stunConn io.Reader) { var stunServer = flag.String("stun", "stun.l.google.com:19302", "STUN Server to use") //nolint:gochecknoglobals -func main() { +func main() { //nolint:cyclop flag.Parse() isServer := flag.Arg(0) == "" @@ -112,7 +112,7 @@ func main() { stunL, stunR := net.Pipe() - c, err := stun.NewClient(stunR) + client, err := stun.NewClient(stunR) if err != nil { log.Panicf("Failed to create client: %s", err) } @@ -134,7 +134,7 @@ func main() { // This can fail if your NAT Server is strict and will use separate ports // for application data and STUN var gotAddr stun.XORMappedAddress - if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { + if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) { if res.Error != nil { log.Panicf("Failed STUN transaction: %s", res.Error) } @@ -153,7 +153,7 @@ func main() { // Any ping-pong will work, but we are just making binding requests. // Note that STUN Server is not mandatory for keep alive, application // data will keep alive that binding too. - go keepAlive(c) + go keepAlive(client) notify := make(chan os.Signal, 1) signal.Notify(notify, os.Interrupt, syscall.SIGTERM) @@ -168,6 +168,7 @@ func main() { } case <-notify: log.Println("Stopping") + return } } @@ -203,10 +204,12 @@ func main() { case m := <-messages: log.Printf("Got response from %s: %s", m.addr, m.text) + return case <-notify: log.Print("Stopping") + return } } diff --git a/cmd/stun-nat-behaviour/main.go b/cmd/stun-nat-behaviour/main.go index e582114..5551796 100644 --- a/cmd/stun-nat-behaviour/main.go +++ b/cmd/stun-nat-behaviour/main.go @@ -30,10 +30,14 @@ func (c *stunServerConn) Close() error { } var ( - addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address") //nolint:gochecknoglobals - timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response") //nolint:gochecknoglobals - verbose = flag.Int("verbose", 1, "the verbosity level") //nolint:gochecknoglobals - log logging.LeveledLogger //nolint:gochecknoglobals + //nolint:gochecknoglobals + addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address") + //nolint:gochecknoglobals + timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response") + //nolint:gochecknoglobals + verbose = flag.Int("verbose", 1, "the verbosity level") + //nolint:gochecknoglobals + log logging.LeveledLogger ) const ( @@ -70,11 +74,12 @@ func main() { } } -// RFC5780: 4.3. Determining NAT Mapping Behavior -func mappingTests(addrStr string) error { +// RFC5780: 4.3. Determining NAT Mapping Behavior. +func mappingTests(addrStr string) error { //nolint:cyclop mapTestConn, err := connect(addrStr) if err != nil { log.Warnf("Error creating STUN connection: %s", err) + return err } @@ -91,11 +96,13 @@ func mappingTests(addrStr string) error { resps1 := parse(resp) if resps1.xorAddr == nil || resps1.otherAddr == nil { log.Info("Error: NAT discovery feature not supported by this server") + return errNoOtherAddress } addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String()) if err != nil { log.Infof("Failed resolving OTHER-ADDRESS: %v", resps1.otherAddr) + return err } mapTestConn.OtherAddr = addr @@ -104,6 +111,7 @@ func mappingTests(addrStr string) error { // Assert mapping behavior if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() { log.Warn("=> NAT mapping behavior: endpoint independent (no NAT)") + return nil } @@ -121,6 +129,7 @@ func mappingTests(addrStr string) error { log.Infof("Received XOR-MAPPED-ADDRESS: %v", resps2.xorAddr) if resps2.xorAddr.String() == resps1.xorAddr.String() { log.Warn("=> NAT mapping behavior: endpoint independent") + return nil } @@ -143,11 +152,12 @@ func mappingTests(addrStr string) error { return mapTestConn.Close() } -// RFC5780: 4.4. Determining NAT Filtering Behavior -func filteringTests(addrStr string) error { +// RFC5780: 4.4. Determining NAT Filtering Behavior. +func filteringTests(addrStr string) error { //nolint:cyclop mapTestConn, err := connect(addrStr) if err != nil { log.Warnf("Error creating STUN connection: %s", err) + return err } @@ -162,11 +172,13 @@ func filteringTests(addrStr string) error { resps := parse(resp) if resps.xorAddr == nil || resps.otherAddr == nil { log.Warn("Error: NAT discovery feature not supported by this server") + return errNoOtherAddress } addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String()) if err != nil { log.Infof("Failed resolving OTHER-ADDRESS: %v", resps.otherAddr) + return err } mapTestConn.OtherAddr = addr @@ -180,6 +192,7 @@ func filteringTests(addrStr string) error { if err == nil { parse(resp) // just to print out the resp log.Warn("=> NAT filtering behavior: endpoint independent") + return nil } else if !errors.Is(err, errTimedOut) { return err // something else went wrong @@ -201,7 +214,7 @@ func filteringTests(addrStr string) error { return mapTestConn.Close() } -// Parse a STUN message +// Parse a STUN message. func parse(msg *stun.Message) (ret struct { xorAddr *stun.XORMappedAddress otherAddr *stun.OtherAddress @@ -249,15 +262,17 @@ func parse(msg *stun.Message) (ret struct { log.Debugf("\t%v (l=%v)", attr, attr.Length) } } + return ret } -// Given an address string, returns a StunServerConn +// Given an address string, returns a StunServerConn. func connect(addrStr string) (*stunServerConn, error) { log.Infof("Connecting to STUN server: %s", addrStr) addr, err := net.ResolveUDPAddr("udp4", addrStr) if err != nil { log.Warnf("Error resolving address: %s", err) + return nil, err } @@ -278,7 +293,7 @@ func connect(addrStr string) (*stunServerConn, error) { }, nil } -// Send request and wait for response or timeout +// Send request and wait for response or timeout. func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Message, error) { _ = msg.NewTransactionID() log.Infof("Sending to %v: (%v bytes)", addr, msg.Length+messageHeaderSize) @@ -289,6 +304,7 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess _, err := c.conn.WriteTo(msg.Raw, addr) if err != nil { log.Warnf("Error sending request to %v", addr) + return nil, err } @@ -298,9 +314,11 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess if !ok { return nil, errResponseMessage } + return m, nil case <-time.After(time.Duration(*timeoutPtr) * time.Second): log.Infof("Timed out waiting for response from server %v", addr) + return nil, errTimedOut } } @@ -315,6 +333,7 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) { n, addr, err := conn.ReadFromUDP(buf) if err != nil { close(messages) + return } log.Infof("Response from %v: (%v bytes)", addr, n) @@ -326,11 +345,13 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) { if err != nil { log.Infof("Error decoding message: %v", err) close(messages) + return } messages <- m } }() + return } diff --git a/cmd/stun-traversal/main.go b/cmd/stun-traversal/main.go index ada6a36..383e362 100644 --- a/cmd/stun-traversal/main.go +++ b/cmd/stun-traversal/main.go @@ -26,7 +26,7 @@ const ( timeoutMillis = 500 ) -func main() { //nolint:gocognit +func main() { //nolint:gocognit,cyclop flag.Parse() srvAddr, err := net.ResolveUDPAddr(udp, *server) @@ -87,11 +87,13 @@ func main() { //nolint:gocognit decErr := m.Decode() if decErr != nil { log.Println("decode:", decErr) + break } var xorAddr stun.XORMappedAddress if getErr := xorAddr.GetFrom(m); getErr != nil { log.Println("getFrom:", getErr) + break } @@ -160,6 +162,7 @@ func listen(conn *net.UDPConn) <-chan []byte { n, _, err := conn.ReadFromUDP(buf) if err != nil { close(messages) + return } buf = buf[:n] @@ -167,6 +170,7 @@ func listen(conn *net.UDPConn) <-chan []byte { messages <- buf } }() + return messages } diff --git a/e2e/main.go b/e2e/main.go index 1c3c7b9..1c0aba9 100644 --- a/e2e/main.go +++ b/e2e/main.go @@ -14,7 +14,7 @@ import ( "github.com/pion/stun/v3" ) -func test(network string) { +func test(network string) { //nolint:cyclop addr := resolve(network) fmt.Println("START", strings.ToUpper(addr.Network())) //nolint var ( diff --git a/errorcode.go b/errorcode.go index c852eed..36a9d60 100644 --- a/errorcode.go +++ b/errorcode.go @@ -11,7 +11,7 @@ import ( // ErrorCodeAttribute represents ERROR-CODE attribute. // -// RFC 5389 Section 15.6 +// RFC 5389 Section 15.6. type ErrorCodeAttribute struct { Code ErrorCode Reason []byte @@ -31,7 +31,7 @@ const ( ) // AddTo adds ERROR-CODE to m. -func (c ErrorCodeAttribute) AddTo(m *Message) error { +func (c ErrorCodeAttribute) AddTo(msg *Message) error { value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB) if err := CheckOverflow(AttrErrorCode, len(c.Reason)+errorCodeReasonStart, @@ -45,26 +45,28 @@ func (c ErrorCodeAttribute) AddTo(m *Message) error { value[errorCodeClassByte] = class value[errorCodeNumberByte] = number copy(value[errorCodeReasonStart:], c.Reason) - m.Add(AttrErrorCode, value) + msg.Add(AttrErrorCode, value) + return nil } // GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid. func (c *ErrorCodeAttribute) GetFrom(m *Message) error { - v, err := m.Get(AttrErrorCode) + value, err := m.Get(AttrErrorCode) if err != nil { return err } - if len(v) < errorCodeReasonStart { + if len(value) < errorCodeReasonStart { return io.ErrUnexpectedEOF } var ( - class = uint16(v[errorCodeClassByte]) - number = uint16(v[errorCodeNumberByte]) + class = uint16(value[errorCodeClassByte]) + number = uint16(value[errorCodeNumberByte]) code = int(class*errorCodeModulo + number) ) c.Code = ErrorCode(code) - c.Reason = v[errorCodeReasonStart:] + c.Reason = value[errorCodeReasonStart:] + return nil } @@ -86,6 +88,7 @@ func (c ErrorCode) AddTo(m *Message) error { Code: c, Reason: reason, } + return a.AddTo(m) } @@ -108,7 +111,7 @@ const ( // Error codes from RFC 5766. // -// RFC 5766 Section 15 +// RFC 5766 Section 15. const ( CodeForbidden ErrorCode = 403 // Forbidden CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch @@ -120,7 +123,7 @@ const ( // Error codes from RFC 6062. // -// RFC 6062 Section 6.3 +// RFC 6062 Section 6.3. const ( CodeConnAlreadyExists ErrorCode = 446 CodeConnTimeoutOrFailure ErrorCode = 447 @@ -128,7 +131,7 @@ const ( // Error codes from RFC 6156. // -// RFC 6156 Section 10.2 +// RFC 6156 Section 10.2. const ( CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch diff --git a/errorcode_test.go b/errorcode_test.go index 2d416df..d7a0481 100644 --- a/errorcode_test.go +++ b/errorcode_test.go @@ -92,23 +92,23 @@ func TestMessage_AddErrorCode(t *testing.T) { } func TestErrorCode(t *testing.T) { - a := &ErrorCodeAttribute{ + attr := &ErrorCodeAttribute{ Code: 404, Reason: []byte("not found!"), } - if a.String() != "404: not found!" { - t.Error("bad string", a) + if attr.String() != "404: not found!" { + t.Error("bad string", attr) } m := New() cod := ErrorCode(666) if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) { t.Error("should be ErrNoDefaultReason", err) } - if err := a.GetFrom(m); err == nil { + if err := attr.GetFrom(m); err == nil { t.Error("attr should not be in message") } - a.Reason = make([]byte, 2048) - if err := a.AddTo(m); err == nil { + attr.Reason = make([]byte, 2048) + if err := attr.AddTo(m); err == nil { t.Error("should error") } } diff --git a/fingerprint.go b/fingerprint.go index b4126d2..3c33bdd 100644 --- a/fingerprint.go +++ b/fingerprint.go @@ -10,7 +10,7 @@ import ( // FingerprintAttr represents FINGERPRINT attribute. // -// RFC 5389 Section 15.5 +// RFC 5389 Section 15.5. type FingerprintAttr struct{} // ErrFingerprintMismatch means that computed fingerprint differs from expected. @@ -50,6 +50,7 @@ func (FingerprintAttr) AddTo(m *Message) error { bin.PutUint32(b, val) m.Length = l m.Add(AttrFingerprint, b) + return nil } @@ -66,5 +67,6 @@ func (FingerprintAttr) Check(m *Message) error { val := bin.Uint32(b) attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize) expected := FingerprintValue(m.Raw[:attrStart]) + return checkFingerprint(val, expected) } diff --git a/fingerprint_test.go b/fingerprint_test.go index f9c1831..137e1d1 100644 --- a/fingerprint_test.go +++ b/fingerprint_test.go @@ -13,20 +13,20 @@ import ( func BenchmarkFingerprint_AddTo(b *testing.B) { b.ReportAllocs() - m := new(Message) + msg := new(Message) s := NewSoftware("software") addr := &XORMappedAddress{ IP: net.IPv4(213, 1, 223, 5), } - addAttr(b, m, addr) - addAttr(b, m, s) - b.SetBytes(int64(len(m.Raw))) + addAttr(b, msg, addr) + addAttr(b, msg, s) + b.SetBytes(int64(len(msg.Raw))) for i := 0; i < b.N; i++ { - Fingerprint.AddTo(m) //nolint:errcheck,gosec - m.WriteLength() - m.Length -= attributeHeaderSize + fingerprintSize - m.Raw = m.Raw[:m.Length+messageHeaderSize] - m.Attributes = m.Attributes[:len(m.Attributes)-1] + Fingerprint.AddTo(msg) //nolint:errcheck,gosec + msg.WriteLength() + msg.Length -= attributeHeaderSize + fingerprintSize + msg.Raw = msg.Raw[:msg.Length+messageHeaderSize] + msg.Attributes = msg.Attributes[:len(msg.Attributes)-1] } } diff --git a/fuzz_test.go b/fuzz_test.go index 93266d5..e60353e 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -109,6 +109,7 @@ func FuzzSetters(f *testing.F) { if !IsAttrSizeOverflow(err) { t.Fatal(err) } + return } @@ -150,5 +151,6 @@ func (a attributes) pick(v byte) struct { t AttrType } { idx := int(v) % len(a) + return a[idx] } diff --git a/helpers.go b/helpers.go index d405650..75b2f3f 100644 --- a/helpers.go +++ b/helpers.go @@ -44,6 +44,7 @@ func (m *Message) Build(setters ...Setter) error { return err } } + return nil } @@ -54,6 +55,7 @@ func (m *Message) Check(checkers ...Checker) error { return err } } + return nil } @@ -64,6 +66,7 @@ func (m *Message) Parse(getters ...Getter) error { return err } } + return nil } @@ -73,6 +76,7 @@ func MustBuild(setters ...Setter) *Message { if err != nil { panic(err) //nolint } + return m } @@ -82,6 +86,7 @@ func Build(setters ...Setter) (*Message, error) { if err := m.Build(setters...); err != nil { return nil, err } + return m, nil } @@ -105,5 +110,6 @@ func (m *Message) ForEach(t AttrType, f func(m *Message) error) error { return err } } + return nil } diff --git a/helpers_test.go b/helpers_test.go index 5c135a0..f112ade 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -13,7 +13,7 @@ import ( func BenchmarkBuildOverhead(b *testing.B) { var ( - t = BindingRequest + msgType = BindingRequest username = NewUsername("username") nonce = NewNonce("nonce") realm = NewRealm("example.org") @@ -22,14 +22,14 @@ func BenchmarkBuildOverhead(b *testing.B) { b.ReportAllocs() m := new(Message) for i := 0; i < b.N; i++ { - m.Build(&t, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec + m.Build(&msgType, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec } }) b.Run("BuildNonPointer", func(b *testing.B) { b.ReportAllocs() m := new(Message) for i := 0; i < b.N; i++ { - m.Build(t, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec + m.Build(msgType, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec } }) b.Run("Raw", func(b *testing.B) { @@ -38,7 +38,7 @@ func BenchmarkBuildOverhead(b *testing.B) { for i := 0; i < b.N; i++ { m.Reset() m.WriteHeader() - m.SetType(t) + m.SetType(msgType) username.AddTo(m) //nolint:errcheck,gosec nonce.AddTo(m) //nolint:errcheck,gosec realm.AddTo(m) //nolint:errcheck,gosec @@ -52,7 +52,7 @@ func TestMessage_Apply(t *testing.T) { integrity = NewShortTermIntegrity("password") decoded = new(Message) ) - m, err := Build(BindingRequest, TransactionID, + msg, err := Build(BindingRequest, TransactionID, NewUsername("username"), NewNonce("nonce"), NewRealm("example.org"), @@ -62,13 +62,13 @@ func TestMessage_Apply(t *testing.T) { if err != nil { t.Fatal("failed to build:", err) } - if err = m.Check(Fingerprint, integrity); err != nil { + if err = msg.Check(Fingerprint, integrity); err != nil { t.Fatal(err) } - if _, err := decoded.Write(m.Raw); err != nil { + if _, err := decoded.Write(msg.Raw); err != nil { t.Fatal(err) } - if !decoded.Equal(m) { + if !decoded.Equal(msg) { t.Error("not equal") } if err := integrity.Check(decoded); err != nil { @@ -96,32 +96,32 @@ func (e errReturner) GetFrom(*Message) error { func TestHelpersErrorHandling(t *testing.T) { m := New() - e := errReturner{Err: errTError} - if err := m.Build(e); !errors.Is(err, e.Err) { - t.Error(err, "!=", e.Err) + errReturn := errReturner{Err: errTError} + if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) { + t.Error(err, "!=", errReturn.Err) } - if err := m.Check(e); !errors.Is(err, e.Err) { - t.Error(err, "!=", e.Err) + if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) { + t.Error(err, "!=", errReturn.Err) } - if err := m.Parse(e); !errors.Is(err, e.Err) { - t.Error(err, "!=", e.Err) + if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) { + t.Error(err, "!=", errReturn.Err) } t.Run("MustBuild", func(t *testing.T) { t.Run("Positive", func(*testing.T) { MustBuild(NewTransactionIDSetter(transactionID{})) }) defer func() { - if p, ok := recover().(error); !ok || !errors.Is(p, e.Err) { + if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) { t.Errorf("%s != %s", - p, e.Err, + p, errReturn.Err, ) } }() - MustBuild(e) + MustBuild(errReturn) }) } -func TestMessage_ForEach(t *testing.T) { +func TestMessage_ForEach(t *testing.T) { //nolint:cyclop initial := New() if err := initial.Build( NewRealm("realm1"), NewRealm("realm2"), @@ -135,6 +135,7 @@ func TestMessage_ForEach(t *testing.T) { ); err != nil { t.Fatal(err) } + return m } t.Run("NoResults", func(t *testing.T) { @@ -144,6 +145,7 @@ func TestMessage_ForEach(t *testing.T) { } if err := m.ForEach(AttrUsername, func(*Message) error { t.Error("should not be called") + return nil }); err != nil { t.Fatal(err) @@ -160,6 +162,7 @@ func TestMessage_ForEach(t *testing.T) { t.Error("called multiple times") } calls++ + return ErrAttributeNotFound }); !errors.Is(err, ErrAttributeNotFound) { t.Fatal(err) @@ -169,14 +172,15 @@ func TestMessage_ForEach(t *testing.T) { } }) t.Run("Positive", func(t *testing.T) { - m := newMessage() + msg := newMessage() var realms []string - if err := m.ForEach(AttrRealm, func(m *Message) error { + if err := msg.ForEach(AttrRealm, func(m *Message) error { var realm Realm if err := realm.GetFrom(m); err != nil { return err } realms = append(realms, realm.String()) + return nil }); err != nil { t.Fatal(err) @@ -190,18 +194,18 @@ func TestMessage_ForEach(t *testing.T) { if realms[1] != "realm2" { t.Error("bad value for 2 realm") } - if !m.Equal(initial) { + if !msg.Equal(initial) { t.Error("m should be equal to initial") } t.Run("ZeroAlloc", func(t *testing.T) { - m = newMessage() + msg = newMessage() var realm Realm testutil.ShouldNotAllocate(t, func() { - if err := m.ForEach(AttrRealm, realm.GetFrom); err != nil { + if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil { t.Fatal(err) } }) - if !m.Equal(initial) { + if !msg.Equal(initial) { t.Error("m should be equal to initial") } }) @@ -216,6 +220,7 @@ func ExampleMessage_ForEach() { return err } fmt.Println(r) + return nil }); err != nil { fmt.Println("error:", err) diff --git a/iana_test.go b/iana_test.go index 6b30dd6..c129372 100644 --- a/iana_test.go +++ b/iana_test.go @@ -14,22 +14,24 @@ import ( "testing" ) -func loadCSV(t testing.TB, name string) [][]string { - t.Helper() - data := loadData(t, name) +func loadCSV(tb testing.TB, name string) [][]string { + tb.Helper() + + data := loadData(tb, name) r := csv.NewReader(bytes.NewReader(data)) r.Comment = '#' records, err := r.ReadAll() if err != nil { - t.Fatal(err) + tb.Fatal(err) } + return records } -func TestIANA(t *testing.T) { +func TestIANA(t *testing.T) { //nolint:cyclop t.Run("Methods", func(t *testing.T) { records := loadCSV(t, "stun-parameters-2.csv") - m := make(map[string]Method) + methods := make(map[string]Method) for _, r := range records[1:] { var ( v = r[0] @@ -43,10 +45,10 @@ func TestIANA(t *testing.T) { t.Fatal(parseErr) } t.Logf("value: 0x%x, name: %s", val, name) - m[name] = Method(val) + methods[name] = Method(val) //nolint:gosec // G115 } for val, name := range methodName() { - mapped, ok := m[name] + mapped, ok := methods[name] if !ok { t.Errorf("failed to find method %s in IANA", name) } @@ -57,7 +59,7 @@ func TestIANA(t *testing.T) { }) t.Run("Attributes", func(t *testing.T) { records := loadCSV(t, "stun-parameters-4.csv") - m := map[string]AttrType{} + attrTypes := map[string]AttrType{} for _, r := range records[1:] { var ( v = r[0] @@ -71,16 +73,16 @@ func TestIANA(t *testing.T) { t.Fatal(parseErr) } t.Logf("value: 0x%x, name: %s", val, name) - m[name] = AttrType(val) + attrTypes[name] = AttrType(val) //nolint:gosec // G115 } // Not registered in IANA. for k, v := range map[string]AttrType{ "ORIGIN": 0x802F, } { - m[k] = v + attrTypes[k] = v } for val, name := range attrNames() { - mapped, ok := m[name] + mapped, ok := attrTypes[name] if !ok { t.Errorf("failed to find attribute %s in IANA", name) } @@ -91,7 +93,7 @@ func TestIANA(t *testing.T) { }) t.Run("ErrorCodes", func(t *testing.T) { records := loadCSV(t, "stun-parameters-6.csv") - m := map[string]ErrorCode{} + errorCodes := map[string]ErrorCode{} for _, r := range records[1:] { var ( v = r[0] @@ -105,11 +107,11 @@ func TestIANA(t *testing.T) { t.Fatal(parseErr) } t.Logf("value: 0x%x, name: %s", val, name) - m[name] = ErrorCode(val) + errorCodes[name] = ErrorCode(val) } for val, nameB := range errorReasons { name := string(nameB) - mapped, ok := m[name] + mapped, ok := errorCodes[name] if !ok { t.Errorf("failed to find error code %s in IANA", name) } diff --git a/integrity.go b/integrity.go index 285077f..ba61602 100644 --- a/integrity.go +++ b/integrity.go @@ -22,6 +22,7 @@ func NewLongTermIntegrity(username, realm, password string) MessageIntegrity { k := strings.Join([]string{username, realm, password}, credentialsSep) h := md5.New() //nolint:gosec fmt.Fprint(h, k) //nolint:errcheck + return MessageIntegrity(h.Sum(nil)) } @@ -36,13 +37,14 @@ func NewShortTermIntegrity(password string) MessageIntegrity { // AddTo and Check methods are using zero-allocation version of hmac, see // newHMAC function and internal/hmac/pool.go. // -// RFC 5389 Section 15.4 +// RFC 5389 Section 15.4. type MessageIntegrity []byte func newHMAC(key, message, buf []byte) []byte { mac := hmac.AcquireSHA1(key) writeOrPanic(mac, message) defer hmac.PutSHA1(mac) + return mac.Sum(buf) } @@ -59,8 +61,8 @@ var ErrFingerprintBeforeIntegrity = errors.New("FINGERPRINT before MESSAGE-INTEG // AddTo adds MESSAGE-INTEGRITY attribute to message. // // CPU costly, see BenchmarkMessageIntegrity_AddTo. -func (i MessageIntegrity) AddTo(m *Message) error { - for _, a := range m.Attributes { +func (i MessageIntegrity) AddTo(msg *Message) error { + for _, a := range msg.Attributes { // Message should not contain FINGERPRINT attribute // before MESSAGE-INTEGRITY. if a.Type == AttrFingerprint { @@ -70,19 +72,20 @@ func (i MessageIntegrity) AddTo(m *Message) error { // The text used as input to HMAC is the STUN message, // including the header, up to and including the attribute preceding the // MESSAGE-INTEGRITY attribute. - length := m.Length + length := msg.Length // Adjusting m.Length to contain MESSAGE-INTEGRITY TLV. - m.Length += messageIntegritySize + attributeHeaderSize - m.WriteLength() // writing length to m.Raw - v := newHMAC(i, m.Raw, m.Raw[len(m.Raw):]) // calculating HMAC for adjusted m.Raw - m.Length = length // changing m.Length back + msg.Length += messageIntegritySize + attributeHeaderSize + msg.WriteLength() // writing length to m.Raw + v := newHMAC(i, msg.Raw, msg.Raw[len(msg.Raw):]) // calculating HMAC for adjusted m.Raw + msg.Length = length // changing m.Length back // Copy hmac value to temporary variable to protect it from resetting // while processing m.Add call. vBuf := make([]byte, sha1.Size) copy(vBuf, v) - m.Add(AttrMessageIntegrity, vBuf) + msg.Add(AttrMessageIntegrity, vBuf) + return nil } @@ -92,8 +95,8 @@ var ErrIntegrityMismatch = errors.New("integrity check failed") // Check checks MESSAGE-INTEGRITY attribute. // // CPU costly, see BenchmarkMessageIntegrity_Check. -func (i MessageIntegrity) Check(m *Message) error { - v, err := m.Get(AttrMessageIntegrity) +func (i MessageIntegrity) Check(msg *Message) error { + val, err := msg.Get(AttrMessageIntegrity) if err != nil { return err } @@ -101,11 +104,11 @@ func (i MessageIntegrity) Check(m *Message) error { // Adjusting length in header to match m.Raw that was // used when computing HMAC. var ( - length = m.Length + length = msg.Length afterIntegrity = false sizeReduced int ) - for _, a := range m.Attributes { + for _, a := range msg.Attributes { if afterIntegrity { sizeReduced += nearestPaddedValueLength(int(a.Length)) sizeReduced += attributeHeaderSize @@ -114,13 +117,14 @@ func (i MessageIntegrity) Check(m *Message) error { afterIntegrity = true } } - m.Length -= uint32(sizeReduced) - m.WriteLength() + msg.Length -= uint32(sizeReduced) //nolint:gosec // G115 + msg.WriteLength() // startOfHMAC should be first byte of integrity attribute. - startOfHMAC := messageHeaderSize + m.Length - (attributeHeaderSize + messageIntegritySize) - b := m.Raw[:startOfHMAC] // data before integrity attribute - expected := newHMAC(i, b, m.Raw[len(m.Raw):]) - m.Length = length - m.WriteLength() // writing length back - return checkHMAC(v, expected) + startOfHMAC := messageHeaderSize + msg.Length - (attributeHeaderSize + messageIntegritySize) + b := msg.Raw[:startOfHMAC] // data before integrity attribute + expected := newHMAC(i, b, msg.Raw[len(msg.Raw):]) + msg.Length = length + msg.WriteLength() // writing length back + + return checkHMAC(val, expected) } diff --git a/integrity_test.go b/integrity_test.go index 11e421c..1ae7f9c 100644 --- a/integrity_test.go +++ b/integrity_test.go @@ -10,18 +10,18 @@ import ( ) func TestMessageIntegrity_AddTo_Simple(t *testing.T) { - i := NewLongTermIntegrity("user", "realm", "pass") + integrity := NewLongTermIntegrity("user", "realm", "pass") expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb") if err != nil { t.Fatal(err) } - if !bytes.Equal(expected, i) { + if !bytes.Equal(expected, integrity) { t.Error(ErrIntegrityMismatch) } t.Run("Check", func(t *testing.T) { m := new(Message) m.WriteHeader() - if err := i.AddTo(m); err != nil { + if err := integrity.AddTo(m); err != nil { t.Error(err) } NewSoftware("software").AddTo(m) //nolint:errcheck,gosec @@ -31,39 +31,39 @@ func TestMessageIntegrity_AddTo_Simple(t *testing.T) { if err := dM.Decode(); err != nil { t.Error(err) } - if err := i.Check(dM); err != nil { + if err := integrity.Check(dM); err != nil { t.Error(err) } dM.Raw[24] += 12 // HMAC now invalid - if i.Check(dM) == nil { + if integrity.Check(dM) == nil { t.Error("should be invalid") } }) } func TestMessageIntegrityWithFingerprint(t *testing.T) { - m := new(Message) - m.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} - m.WriteHeader() - NewSoftware("software").AddTo(m) //nolint:errcheck,gosec - i := NewShortTermIntegrity("pwd") - if i.String() != "KEY: 0x707764" { - t.Error("bad string", i) + msg := new(Message) + msg.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} + msg.WriteHeader() + NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec + integrity := NewShortTermIntegrity("pwd") + if integrity.String() != "KEY: 0x707764" { + t.Error("bad string", integrity) } - if err := i.Check(m); err == nil { + if err := integrity.Check(msg); err == nil { t.Error("should error") } - if err := i.AddTo(m); err != nil { + if err := integrity.AddTo(msg); err != nil { t.Fatal(err) } - if err := Fingerprint.AddTo(m); err != nil { + if err := Fingerprint.AddTo(msg); err != nil { t.Fatal(err) } - if err := i.Check(m); err != nil { + if err := integrity.Check(msg); err != nil { t.Fatal(err) } - m.Raw[24] = 33 - if err := i.Check(m); err == nil { + msg.Raw[24] = 33 + if err := integrity.Check(msg); err == nil { t.Fatal("mismatch expected") } } diff --git a/internal/hmac/hmac.go b/internal/hmac/hmac.go index b4db5c9..6caddbe 100644 --- a/internal/hmac/hmac.go +++ b/internal/hmac/hmac.go @@ -64,6 +64,7 @@ func (h *hmac) Sum(in []byte) []byte { h.outer.Write(h.opad) //nolint:errcheck,gosec } h.outer.Write(in[origLen:]) //nolint:errcheck,gosec + return h.outer.Sum(in[:origLen]) } @@ -79,6 +80,7 @@ func (h *hmac) Reset() { if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert panic(err) //nolint } + return } diff --git a/internal/hmac/hmac_test.go b/internal/hmac/hmac_test.go index a867f9a..59217da 100644 --- a/internal/hmac/hmac_test.go +++ b/internal/hmac/hmac_test.go @@ -22,7 +22,7 @@ type hmacTest struct { blocksize int } -func hmacTests() []hmacTest { +func hmacTests() []hmacTest { //nolint:maintidx return []hmacTest{ // Tests from US FIPS 198 // https://csrc.nist.gov/publications/fips/fips198/fips-198a.pdf @@ -523,41 +523,42 @@ func hmacTests() []hmacTest { func TestHMAC(t *testing.T) { for i, tt := range hmacTests() { - h := New(tt.hash, tt.key) - if s := h.Size(); s != tt.size { + hsh := New(tt.hash, tt.key) + if s := hsh.Size(); s != tt.size { t.Errorf("Size: got %v, want %v", s, tt.size) } - if b := h.BlockSize(); b != tt.blocksize { + if b := hsh.BlockSize(); b != tt.blocksize { t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) } - for j := 0; j < 4; j++ { - n, err := h.Write(tt.in) + for j := 0; j < 4; j++ { //nolint:varnamelen + n, err := hsh.Write(tt.in) if n != len(tt.in) || err != nil { t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) + continue } // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { - sum := fmt.Sprintf("%x", h.Sum(nil)) + sum := fmt.Sprintf("%x", hsh.Sum(nil)) if sum != tt.out { t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } } // Second iteration: make sure reset works. - h.Reset() + hsh.Reset() // Third and fourth iteration: make sure hmac works on // hashes without MarshalBinary/UnmarshalBinary if j == 1 { - h = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key) + hsh = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key) } } } } -// justHash implements just the hash.Hash methods and nothing else +// justHash implements just the hash.Hash methods and nothing else. type justHash struct { hash.Hash } diff --git a/internal/hmac/pool.go b/internal/hmac/pool.go index d2ac14a..c632251 100644 --- a/internal/hmac/pool.go +++ b/internal/hmac/pool.go @@ -40,6 +40,7 @@ func (h *hmac) resetTo(key []byte) { var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { h := New(sha1.New, make([]byte, sha1.BlockSize)) + return h }, } @@ -49,6 +50,7 @@ func AcquireSHA1(key []byte) hash.Hash { h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert assertHMACSize(h, sha1.Size, sha1.BlockSize) h.resetTo(key) + return h } @@ -62,6 +64,7 @@ func PutSHA1(h hash.Hash) { var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { h := New(sha256.New, make([]byte, sha256.BlockSize)) + return h }, } @@ -71,6 +74,7 @@ func AcquireSHA256(key []byte) hash.Hash { h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert assertHMACSize(h, sha256.Size, sha256.BlockSize) h.resetTo(key) + return h } diff --git a/internal/hmac/pool_test.go b/internal/hmac/pool_test.go index 7a07d8b..9fbb96f 100644 --- a/internal/hmac/pool_test.go +++ b/internal/hmac/pool_test.go @@ -42,100 +42,103 @@ func BenchmarkHMACSHA1_512_Pool(b *testing.B) { func TestHMACReset(t *testing.T) { for i, tt := range hmacTests() { - h := New(tt.hash, tt.key) - h.(*hmac).resetTo(tt.key) //nolint:forcetypeassert - if s := h.Size(); s != tt.size { + hsh := New(tt.hash, tt.key) + hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert + if s := hsh.Size(); s != tt.size { t.Errorf("Size: got %v, want %v", s, tt.size) } - if b := h.BlockSize(); b != tt.blocksize { + if b := hsh.BlockSize(); b != tt.blocksize { t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) } for j := 0; j < 2; j++ { - n, err := h.Write(tt.in) + n, err := hsh.Write(tt.in) if n != len(tt.in) || err != nil { t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) + continue } // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { - sum := fmt.Sprintf("%x", h.Sum(nil)) + sum := fmt.Sprintf("%x", hsh.Sum(nil)) if sum != tt.out { t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } } // Second iteration: make sure reset works. - h.Reset() + hsh.Reset() } } } -func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl +func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop for i, tt := range hmacTests() { if tt.blocksize != sha1.BlockSize || tt.size != sha1.Size { continue } - h := AcquireSHA1(tt.key) - if s := h.Size(); s != tt.size { + hsh := AcquireSHA1(tt.key) + if s := hsh.Size(); s != tt.size { t.Errorf("Size: got %v, want %v", s, tt.size) } - if b := h.BlockSize(); b != tt.blocksize { + if b := hsh.BlockSize(); b != tt.blocksize { t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) } for j := 0; j < 2; j++ { - n, err := h.Write(tt.in) + n, err := hsh.Write(tt.in) if n != len(tt.in) || err != nil { t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) + continue } // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { - sum := fmt.Sprintf("%x", h.Sum(nil)) + sum := fmt.Sprintf("%x", hsh.Sum(nil)) if sum != tt.out { t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } } // Second iteration: make sure reset works. - h.Reset() + hsh.Reset() } - PutSHA1(h) + PutSHA1(hsh) } } -func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl +func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop for i, tt := range hmacTests() { if tt.blocksize != sha256.BlockSize || tt.size != sha256.Size { continue } - h := AcquireSHA256(tt.key) - if s := h.Size(); s != tt.size { + hsh := AcquireSHA256(tt.key) + if s := hsh.Size(); s != tt.size { t.Errorf("Size: got %v, want %v", s, tt.size) } - if b := h.BlockSize(); b != tt.blocksize { + if b := hsh.BlockSize(); b != tt.blocksize { t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) } for j := 0; j < 2; j++ { - n, err := h.Write(tt.in) + n, err := hsh.Write(tt.in) if n != len(tt.in) || err != nil { t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) + continue } // Repetitive Sum() calls should return the same value for k := 0; k < 2; k++ { - sum := fmt.Sprintf("%x", h.Sum(nil)) + sum := fmt.Sprintf("%x", hsh.Sum(nil)) if sum != tt.out { t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out) } } // Second iteration: make sure reset works. - h.Reset() + hsh.Reset() } - PutSHA256(h) + PutSHA256(hsh) } } diff --git a/internal/testutil/allocs.go b/internal/testutil/allocs.go index c43d8ba..27ac5d5 100644 --- a/internal/testutil/allocs.go +++ b/internal/testutil/allocs.go @@ -10,8 +10,11 @@ import ( // ShouldNotAllocate fails if f allocates. func ShouldNotAllocate(t *testing.T, f func()) { + t.Helper() + if Race { t.Skip("skip while running with -race") + return } if a := testing.AllocsPerRun(10, f); a > 0 { diff --git a/message.go b/message.go index 6a828d6..ed15e5e 100644 --- a/message.go +++ b/message.go @@ -32,6 +32,7 @@ const ( // as source. func NewTransactionID() (b [TransactionIDSize]byte) { readFullOrPanic(rand.Reader, b[:]) + return b } @@ -45,6 +46,7 @@ func IsMessage(b []byte) bool { // New returns *Message with pre-allocated Raw. func New() *Message { const defaultRawCapacity = 120 + return &Message{ Raw: make([]byte, messageHeaderSize, defaultRawCapacity), } @@ -59,6 +61,7 @@ func Decode(data []byte, m *Message) error { return ErrDecodeToNil } m.Raw = append(m.Raw[:0], data...) + return m.Decode() } @@ -82,6 +85,7 @@ func (m Message) MarshalBinary() (data []byte, err error) { // contract induced by other implementations. b := make([]byte, len(m.Raw)) copy(b, m.Raw) + return b, nil } @@ -89,6 +93,7 @@ func (m Message) MarshalBinary() (data []byte, err error) { func (m *Message) UnmarshalBinary(data []byte) error { // We can't retain data, copy is expected by interface contract. m.Raw = append(m.Raw[:0], data...) + return m.Decode() } @@ -108,6 +113,7 @@ func (m *Message) GobDecode(data []byte) error { func (m *Message) AddTo(b *Message) error { b.TransactionID = m.TransactionID b.WriteTransactionID() + return nil } @@ -118,6 +124,7 @@ func (m *Message) NewTransactionID() error { if err == nil { m.WriteTransactionID() } + return err } @@ -127,6 +134,7 @@ func (m *Message) String() string { for k, a := range m.Attributes { aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type) } + return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo) } @@ -144,6 +152,7 @@ func (m *Message) grow(n int) { } if cap(m.Raw) >= n { m.Raw = m.Raw[:n] + return } m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...) @@ -153,7 +162,7 @@ func (m *Message) grow(n int) { // // Value of attribute is copied to internal buffer so // it is safe to reuse v. -func (m *Message) Add(t AttrType, v []byte) { +func (m *Message) Add(attrType AttrType, val []byte) { // Allocating buffer for TLV (type-length-value). // T = t, L = len(v), V = v. // m.Raw will look like: @@ -163,31 +172,33 @@ func (m *Message) Add(t AttrType, v []byte) { // [first:last] <- same as previous // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer // T L V - allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V) - first := messageHeaderSize + int(m.Length) // first byte number - last := first + allocSize // last byte number - m.grow(last) // growing cap(Raw) to fit TLV - m.Raw = m.Raw[:last] // now len(Raw) = last - m.Length += uint32(allocSize) // rendering length change + allocSize := attributeHeaderSize + len(val) // ~ len(TLV) = len(TL) + len(V) + first := messageHeaderSize + int(m.Length) // first byte number + last := first + allocSize // last byte number + m.grow(last) // growing cap(Raw) to fit TLV + m.Raw = m.Raw[:last] // now len(Raw) = last + //nolint:gosec // G115 + m.Length += uint32(allocSize) // rendering length change // Sub-slicing internal buffer to simplify encoding. buf := m.Raw[first:last] // slice for TLV value := buf[attributeHeaderSize:] // slice for V attr := RawAttribute{ - Type: t, // T - Length: uint16(len(v)), // L - Value: value, // V + Type: attrType, // T + //nolint:gosec // G115 + Length: uint16(len(val)), // L + Value: value, // V } // Encoding attribute TLV to allocated buffer. bin.PutUint16(buf[0:2], attr.Type.Value()) // T bin.PutUint16(buf[2:4], attr.Length) // L - copy(value, v) // V + copy(value, val) // V // Checking that attribute value needs padding. if attr.Length%padding != 0 { // Performing padding. - bytesToAdd := nearestPaddedValueLength(len(v)) - len(v) + bytesToAdd := nearestPaddedValueLength(len(val)) - len(val) last += bytesToAdd m.grow(last) // setting all padding bytes to zero @@ -197,7 +208,8 @@ func (m *Message) Add(t AttrType, v []byte) { for i := range buf { buf[i] = 0 } - m.Raw = m.Raw[:last] // increasing buffer length + m.Raw = m.Raw[:last] // increasing buffer length + //nolint:gosec // G115 m.Length += uint32(bytesToAdd) // rendering length change } m.Attributes = append(m.Attributes, attr) @@ -213,6 +225,7 @@ func attrSliceEqual(a, b Attributes) bool { } if attrB.Equal(attr) { found = true + break } } @@ -220,56 +233,59 @@ func attrSliceEqual(a, b Attributes) bool { return false } } + return true } -func attrEqual(a, b Attributes) bool { - if a == nil && b == nil { +func attrEqual(attrA, attrB Attributes) bool { + if attrA == nil && attrB == nil { return true } - if a == nil || b == nil { + if attrA == nil || attrB == nil { return false } - if len(a) != len(b) { + if len(attrA) != len(attrB) { return false } - if !attrSliceEqual(a, b) { + if !attrSliceEqual(attrA, attrB) { return false } - if !attrSliceEqual(b, a) { + if !attrSliceEqual(attrB, attrA) { return false } + return true } -// Equal returns true if Message b equals to m. +// Equal returns true if Message msg equals to m. // Ignores m.Raw. -func (m *Message) Equal(b *Message) bool { - if m == nil && b == nil { +func (m *Message) Equal(msg *Message) bool { + if m == nil && msg == nil { return true } - if m == nil || b == nil { + if m == nil || msg == nil { return false } - if m.Type != b.Type { + if m.Type != msg.Type { return false } - if m.TransactionID != b.TransactionID { + if m.TransactionID != msg.TransactionID { return false } - if m.Length != b.Length { + if m.Length != msg.Length { return false } - if !attrEqual(m.Attributes, b.Attributes) { + if !attrEqual(m.Attributes, msg.Attributes) { return false } + return true } // WriteLength writes m.Length to m.Raw. func (m *Message) WriteLength() { m.grow(4) - bin.PutUint16(m.Raw[2:4], uint16(m.Length)) + bin.PutUint16(m.Raw[2:4], uint16(m.Length)) //nolint:gosec // G115 } // WriteHeader writes header to underlying buffer. Not goroutine-safe. @@ -322,6 +338,7 @@ func (m *Message) Encode() { // call result. func (m *Message) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(m.Raw) + return int64(n), err } @@ -340,6 +357,7 @@ func (m *Message) ReadFrom(r io.Reader) (int64, error) { return int64(n), err } m.Raw = tBuf[:n] + return int64(n), m.Decode() } @@ -355,22 +373,24 @@ func (m *Message) Decode() error { return ErrUnexpectedHeaderEOF } var ( - t = bin.Uint16(buf[0:2]) // first 2 bytes + msgType = bin.Uint16(buf[0:2]) // first 2 bytes size = int(bin.Uint16(buf[2:4])) // second 2 bytes cookie = bin.Uint32(buf[4:8]) // last 4 bytes fullSize = messageHeaderSize + size // len(m.Raw) ) if cookie != magicCookie { msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie) + return newDecodeErr("message", "cookie", msg) } if len(buf) < fullSize { msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize) + return newAttrDecodeErr("message", msg) } // saving header data - m.Type.ReadValue(t) - m.Length = uint32(size) + m.Type.ReadValue(msgType) + m.Length = uint32(size) //nolint:gosec // G115 copy(m.TransactionID[:], buf[8:messageHeaderSize]) m.Attributes = m.Attributes[:0] @@ -382,28 +402,31 @@ func (m *Message) Decode() error { // checking that we have enough bytes to read header if len(b) < attributeHeaderSize { msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize) + return newAttrDecodeErr("header", msg) } var ( - a = RawAttribute{ + attr = RawAttribute{ Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes Length: bin.Uint16(b[2:4]), // second 2 bytes } - aL = int(a.Length) // attribute length + aL = int(attr.Length) // attribute length aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding) ) b = b[attributeHeaderSize:] // slicing again to simplify value read offset += attributeHeaderSize if len(b) < aBuffL { // checking size - msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type) + msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, attr.Type) + return newAttrDecodeErr("value", msg) } - a.Value = b[:aL] + attr.Value = b[:aL] offset += aBuffL b = b[aBuffL:] - m.Attributes = append(m.Attributes, a) + m.Attributes = append(m.Attributes, attr) } + return nil } @@ -412,12 +435,14 @@ func (m *Message) Decode() error { // Any error is unrecoverable, but message could be partially decoded. func (m *Message) Write(tBuf []byte) (int, error) { m.Raw = append(m.Raw[:0], tBuf...) + return len(tBuf), m.Decode() } // CloneTo clones m to b securing any further m mutations. func (m *Message) CloneTo(b *Message) error { b.Raw = append(b.Raw[:0], m.Raw...) + return b.Decode() } @@ -436,7 +461,7 @@ const ( var ( // Binding request message type. BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals - // Binding success response message type + // Binding success response message type. BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals // Binding error response message type. BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals @@ -501,6 +526,7 @@ func (m Method) String() string { // Falling back to hex representation. s = fmt.Sprintf("0x%x", uint16(m)) } + return s } @@ -513,6 +539,7 @@ type MessageType struct { // AddTo sets m type to t. func (t MessageType) AddTo(m *Message) error { m.SetType(t) + return nil } @@ -554,13 +581,13 @@ func (t MessageType) Value() uint16 { // Warning: Abandon all hope ye who enter here. // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11). - m := uint16(t.Method) - a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits) - b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A) - d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B) + msg := uint16(t.Method) + a := msg & methodABits // A = M * 0b0000000000001111 (right 4 bits) + b := msg & methodBBits // B = M * 0b0000000001110000 (3 bits after A) + d := msg & methodDBits // D = M * 0b0000111110000000 (5 bits after B) // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). - m = a + (b << methodBShift) + (d << methodDShift) + msg = a + (b << methodBShift) + (d << methodDShift) // C0 is zero bit of C, C1 is first bit. // C0 = C * 0b01, C1 = (C * 0b10) >> 1 @@ -573,7 +600,7 @@ func (t MessageType) Value() uint16 { c1 := (c & c1Bit) << classC1Shift class := c0 + c1 - return m + class + return msg + class } // ReadValue decodes uint16 into MessageType. @@ -604,6 +631,7 @@ func (m *Message) Contains(t AttrType) bool { return true } } + return false } @@ -618,5 +646,6 @@ func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter { func (t transactionIDValueSetter) AddTo(m *Message) error { m.TransactionID = t m.WriteTransactionID() + return nil } diff --git a/message_test.go b/message_test.go index d52b06a..2ccea4e 100644 --- a/message_test.go +++ b/message_test.go @@ -27,9 +27,11 @@ type attributeEncoder interface { AddTo(m *Message) error } -func addAttr(t testing.TB, m *Message, a attributeEncoder) { +func addAttr(tb testing.TB, m *Message, a attributeEncoder) { + tb.Helper() + if err := a.AddTo(m); err != nil { - t.Error(err) + tb.Error(err) } } @@ -129,21 +131,21 @@ func TestMessageType_ReadWriteValue(t *testing.T) { } func TestMessage_WriteTo(t *testing.T) { - m := New() - m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} - m.TransactionID = NewTransactionID() - m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) - m.WriteHeader() + msg := New() + msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest} + msg.TransactionID = NewTransactionID() + msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) + msg.WriteHeader() buf := new(bytes.Buffer) - if _, err := m.WriteTo(buf); err != nil { + if _, err := msg.WriteTo(buf); err != nil { t.Fatal(err) } mDecoded := New() if _, err := mDecoded.ReadFrom(buf); err != nil { t.Error(err) } - if !mDecoded.Equal(m) { - t.Error(mDecoded, "!", m) + if !mDecoded.Equal(msg) { + t.Error(mDecoded, "!", msg) } } @@ -275,23 +277,23 @@ func BenchmarkMessage_WriteTo(b *testing.B) { func BenchmarkMessage_ReadFrom(b *testing.B) { mType := MessageType{Method: MethodBinding, Class: ClassRequest} - m := &Message{ + msg := &Message{ Type: mType, Length: 0, TransactionID: [TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, }, } - m.WriteHeader() + msg.WriteHeader() b.ReportAllocs() - b.SetBytes(int64(len(m.Raw))) - reader := m.reader() + b.SetBytes(int64(len(msg.Raw))) + reader := msg.reader() mRec := New() for i := 0; i < b.N; i++ { if _, err := mRec.ReadFrom(reader); err != nil { b.Fatal(err) } - reader.Reset(m.Raw) + reader.Reset(msg.Raw) mRec.Reset() } } @@ -342,7 +344,7 @@ func TestMessageClass_String(t *testing.T) { } func TestAttrType_String(t *testing.T) { - v := [...]AttrType{ + attrType := [...]AttrType{ AttrMappedAddress, AttrUsername, AttrErrorCode, @@ -355,7 +357,7 @@ func TestAttrType_String(t *testing.T) { AttrAlternateServer, AttrFingerprint, } - for _, k := range v { + for _, k := range attrType { if k.String() == "" { t.Error(k, "bad stringer") } @@ -379,80 +381,80 @@ func TestMethod_String(t *testing.T) { } func TestAttribute_Equal(t *testing.T) { - a := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} - b := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} - if !a.Equal(b) { + attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} + attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} + if !attr1.Equal(attr2) { t.Error("should equal") } - if a.Equal(RawAttribute{Type: 0x2}) { + if attr1.Equal(RawAttribute{Type: 0x2}) { t.Error("should not equal") } - if a.Equal(RawAttribute{Length: 0x2}) { + if attr1.Equal(RawAttribute{Length: 0x2}) { t.Error("should not equal") } - if a.Equal(RawAttribute{Length: 0x3}) { + if attr1.Equal(RawAttribute{Length: 0x3}) { t.Error("should not equal") } - if a.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) { + if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) { t.Error("should not equal") } } -func TestMessage_Equal(t *testing.T) { +func TestMessage_Equal(t *testing.T) { //nolint:cyclop attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} attrs := Attributes{attr} - a := &Message{Attributes: attrs, Length: 4 + 2} - b := &Message{Attributes: attrs, Length: 4 + 2} - if !a.Equal(b) { + msg1 := &Message{Attributes: attrs, Length: 4 + 2} + msg2 := &Message{Attributes: attrs, Length: 4 + 2} + if !msg1.Equal(msg2) { t.Error("should equal") } - if a.Equal(&Message{Type: MessageType{Class: 128}}) { + if msg1.Equal(&Message{Type: MessageType{Class: 128}}) { t.Error("should not equal") } tID := [TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, } - if a.Equal(&Message{TransactionID: tID}) { + if msg1.Equal(&Message{TransactionID: tID}) { t.Error("should not equal") } - if a.Equal(&Message{Length: 3}) { + if msg1.Equal(&Message{Length: 3}) { t.Error("should not equal") } tAttrs := Attributes{ {Length: 1, Value: []byte{0x1}, Type: 0x1}, } - if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { + if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { t.Error("should not equal") } tAttrs = Attributes{ {Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2}, } - if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { + if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { t.Error("should not equal") } if !(*Message)(nil).Equal(nil) { t.Error("nil should be equal to nil") } - if a.Equal(nil) { + if msg1.Equal(nil) { t.Error("non-nil should not be equal to nil") } t.Run("Nil attributes", func(t *testing.T) { - a := &Message{ + msg1 := &Message{ Attributes: nil, Length: 4 + 2, } - b := &Message{ + msg2 := &Message{ Attributes: attrs, Length: 4 + 2, } - if a.Equal(b) { + if msg1.Equal(msg2) { t.Error("should not equal") } - if b.Equal(a) { + if msg2.Equal(msg1) { t.Error("should not equal") } - b.Attributes = nil - if !a.Equal(b) { + msg2.Attributes = nil + if !msg1.Equal(msg2) { t.Error("should equal") } }) @@ -547,6 +549,8 @@ func BenchmarkIsMessage(b *testing.B) { } func loadData(tb testing.TB, name string) []byte { + tb.Helper() + name = filepath.Join("testdata", name) f, err := os.Open(name) //nolint:gosec if err != nil { @@ -561,6 +565,7 @@ func loadData(tb testing.TB, name string) []byte { if err != nil { tb.Fatal(err) } + return v } @@ -582,7 +587,7 @@ func TestMessageFromBrowsers(t *testing.T) { t.Fatal("failed to skip header of csv: ", err) } crcTable := crc64.MakeTable(crc64.ISO) - m := New() + msg := New() for { line, err := reader.Read() if errors.Is(err, io.EOF) { @@ -602,10 +607,10 @@ func TestMessageFromBrowsers(t *testing.T) { if b != crc64.Checksum(data, crcTable) { t.Error("crc64 check failed for ", line[1]) } - if _, err = m.Write(data); err != nil { + if _, err = msg.Write(data); err != nil { t.Error("failed to decode ", line[1], " as message: ", err) } - m.Reset() + msg.Reset() } } @@ -622,42 +627,42 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) { func BenchmarkMessageFull(b *testing.B) { b.ReportAllocs() - m := new(Message) + msg := new(Message) s := NewSoftware("software") addr := &XORMappedAddress{ IP: net.IPv4(213, 1, 223, 5), } for i := 0; i < b.N; i++ { - if err := addr.AddTo(m); err != nil { + if err := addr.AddTo(msg); err != nil { b.Fatal(err) } - if err := s.AddTo(m); err != nil { + if err := s.AddTo(msg); err != nil { b.Fatal(err) } - m.WriteAttributes() - m.WriteHeader() - Fingerprint.AddTo(m) //nolint:errcheck,gosec - m.WriteHeader() - m.Reset() + msg.WriteAttributes() + msg.WriteHeader() + Fingerprint.AddTo(msg) //nolint:errcheck,gosec + msg.WriteHeader() + msg.Reset() } } func BenchmarkMessageFullHardcore(b *testing.B) { b.ReportAllocs() - m := new(Message) + msg := new(Message) s := NewSoftware("software") addr := &XORMappedAddress{ IP: net.IPv4(213, 1, 223, 5), } for i := 0; i < b.N; i++ { - if err := addr.AddTo(m); err != nil { + if err := addr.AddTo(msg); err != nil { b.Fatal(err) } - if err := s.AddTo(m); err != nil { + if err := s.AddTo(msg); err != nil { b.Fatal(err) } - m.WriteHeader() - m.Reset() + msg.WriteHeader() + msg.Reset() } } @@ -689,8 +694,8 @@ func TestMessage_Contains(t *testing.T) { func ExampleMessage() { buf := new(bytes.Buffer) - m := new(Message) - m.Build(BindingRequest, //nolint:errcheck,gosec + msg := new(Message) + msg.Build(BindingRequest, //nolint:errcheck,gosec NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), @@ -707,8 +712,8 @@ func ExampleMessage() { // m.Build(&software) // no allocations // If you pass software as value, there will be 1 allocation. // This rule is correct for all setters. - fmt.Println(m, "buff length:", len(m.Raw)) - n, err := m.WriteTo(buf) + fmt.Println(msg, "buff length:", len(msg.Raw)) + n, err := msg.WriteTo(buf) fmt.Println("wrote", n, "err", err) // Decoding from buf new *Message. @@ -743,6 +748,7 @@ func ExampleMessage() { fmt.Println("fingerprint: failed") } + //nolint:lll // Output: // Binding request l=48 attrs=3 id=AQIDBAUGBwgJAAEA, attr0=SOFTWARE attr1=MESSAGE-INTEGRITY attr2=FINGERPRINT buff length: 68 // wrote 68 err @@ -811,8 +817,8 @@ func TestAllocationsGetters(t *testing.T) { NewShortTermIntegrity("pwd"), Fingerprint, } - m := New() - if err := m.Build(setters...); err != nil { + msg := New() + if err := msg.Build(setters...); err != nil { t.Error("failed to build", err) } getters := []Getter{ @@ -826,7 +832,7 @@ func TestAllocationsGetters(t *testing.T) { g := g i := i allocs := testing.AllocsPerRun(10, func() { - if err := g.GetFrom(m); err != nil { + if err := g.GetFrom(msg); err != nil { t.Errorf("[%d] failed to get", i) } }) @@ -837,8 +843,8 @@ func TestAllocationsGetters(t *testing.T) { } func TestMessageFullSize(t *testing.T) { - m := new(Message) - if err := m.Build(BindingRequest, + msg := new(Message) + if err := msg.Build(BindingRequest, NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), @@ -848,18 +854,18 @@ func TestMessageFullSize(t *testing.T) { ); err != nil { t.Fatal(err) } - m.Raw = m.Raw[:len(m.Raw)-10] + msg.Raw = msg.Raw[:len(msg.Raw)-10] decoder := new(Message) - decoder.Raw = m.Raw[:len(m.Raw)-10] + decoder.Raw = msg.Raw[:len(msg.Raw)-10] if err := decoder.Decode(); err == nil { t.Error("decode on truncated buffer should error") } } func TestMessage_CloneTo(t *testing.T) { - m := new(Message) - if err := m.Build(BindingRequest, + msg := new(Message) + if err := msg.Build(BindingRequest, NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), @@ -869,29 +875,29 @@ func TestMessage_CloneTo(t *testing.T) { ); err != nil { t.Fatal(err) } - m.Encode() - b := new(Message) - if err := m.CloneTo(b); err != nil { + msg.Encode() + msg2 := new(Message) + if err := msg.CloneTo(msg2); err != nil { t.Fatal(err) } - if !b.Equal(m) { + if !msg2.Equal(msg) { t.Fatal("not equal") } // Corrupting m and checking that b is not corrupted. - s, ok := b.Attributes.Get(AttrSoftware) + s, ok := msg2.Attributes.Get(AttrSoftware) if !ok { t.Fatal("no software attribute") } s.Value[0] = 'k' - if b.Equal(m) { + if msg2.Equal(msg) { t.Fatal("should not be equal") } } func BenchmarkMessage_CloneTo(b *testing.B) { b.ReportAllocs() - m := new(Message) - if err := m.Build(BindingRequest, + msg := new(Message) + if err := msg.Build(BindingRequest, NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), @@ -901,19 +907,19 @@ func BenchmarkMessage_CloneTo(b *testing.B) { ); err != nil { b.Fatal(err) } - b.SetBytes(int64(len(m.Raw))) + b.SetBytes(int64(len(msg.Raw))) a := new(Message) - m.CloneTo(a) //nolint:errcheck,gosec + msg.CloneTo(a) //nolint:errcheck,gosec for i := 0; i < b.N; i++ { - if err := m.CloneTo(a); err != nil { + if err := msg.CloneTo(a); err != nil { b.Fatal(err) } } } func TestMessage_AddTo(t *testing.T) { - m := new(Message) - if err := m.Build(BindingRequest, + msg := new(Message) + if err := msg.Build(BindingRequest, NewTransactionIDSetter([TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, }), @@ -921,19 +927,19 @@ func TestMessage_AddTo(t *testing.T) { ); err != nil { t.Fatal(err) } - m.Encode() + msg.Encode() b := new(Message) - if err := m.CloneTo(b); err != nil { + if err := msg.CloneTo(b); err != nil { t.Fatal(err) } - m.TransactionID = [TransactionIDSize]byte{ + msg.TransactionID = [TransactionIDSize]byte{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, } - if b.Equal(m) { + if b.Equal(msg) { t.Fatal("should not be equal") } - m.AddTo(b) //nolint:errcheck,gosec - if !b.Equal(m) { + msg.AddTo(b) //nolint:errcheck,gosec + if !b.Equal(msg) { t.Fatal("should be equal") } } @@ -964,22 +970,22 @@ func TestDecode(t *testing.T) { t.Errorf("unexpected error: %v", err) } }) - m := New() - m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} - m.TransactionID = NewTransactionID() - m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) - m.WriteHeader() + msg := New() + msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest} + msg.TransactionID = NewTransactionID() + msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) + msg.WriteHeader() mDecoded := New() - if err := Decode(m.Raw, mDecoded); err != nil { + if err := Decode(msg.Raw, mDecoded); err != nil { t.Errorf("unexpected error: %v", err) } - if !mDecoded.Equal(m) { + if !mDecoded.Equal(msg) { t.Error("decoded result is not equal to encoded message") } t.Run("ZeroAlloc", func(t *testing.T) { allocs := testing.AllocsPerRun(10, func() { mDecoded.Reset() - if err := Decode(m.Raw, mDecoded); err != nil { + if err := Decode(msg.Raw, mDecoded); err != nil { t.Error(err) } }) @@ -1008,22 +1014,22 @@ func BenchmarkDecode(b *testing.B) { } func TestMessage_MarshalBinary(t *testing.T) { - m := MustBuild( + msg := MustBuild( NewSoftware("software"), &XORMappedAddress{ IP: net.IPv4(213, 1, 223, 5), }, ) - data, err := m.MarshalBinary() + data, err := msg.MarshalBinary() if err != nil { t.Fatal(err) } // Reset m.Raw to check retention. - for i := range m.Raw { - m.Raw[i] = 0 + for i := range msg.Raw { + msg.Raw[i] = 0 } - if err := m.UnmarshalBinary(data); err != nil { + if err := msg.UnmarshalBinary(data); err != nil { t.Fatal(err) } @@ -1031,28 +1037,28 @@ func TestMessage_MarshalBinary(t *testing.T) { for i := range data { data[i] = 0 } - if err := m.Decode(); err != nil { + if err := msg.Decode(); err != nil { t.Fatal(err) } } func TestMessage_GobDecode(t *testing.T) { - m := MustBuild( + msg := MustBuild( NewSoftware("software"), &XORMappedAddress{ IP: net.IPv4(213, 1, 223, 5), }, ) - data, err := m.GobEncode() + data, err := msg.GobEncode() if err != nil { t.Fatal(err) } // Reset m.Raw to check retention. - for i := range m.Raw { - m.Raw[i] = 0 + for i := range msg.Raw { + msg.Raw[i] = 0 } - if err := m.GobDecode(data); err != nil { + if err := msg.GobDecode(data); err != nil { t.Fatal(err) } @@ -1060,7 +1066,7 @@ func TestMessage_GobDecode(t *testing.T) { for i := range data { data[i] = 0 } - if err := m.Decode(); err != nil { + if err := msg.Decode(); err != nil { t.Fatal(err) } } diff --git a/rfc5769_test.go b/rfc5769_test.go index d6cc1d6..30ac3fb 100644 --- a/rfc5769_test.go +++ b/rfc5769_test.go @@ -8,7 +8,7 @@ import ( "testing" ) -func TestRFC5769(t *testing.T) { +func TestRFC5769(t *testing.T) { //nolint:cyclop // Test Vectors for Session Traversal Utilities for NAT (STUN) // see https://tools.ietf.org/html/rfc5769 t.Run("Request", func(t *testing.T) { @@ -46,7 +46,7 @@ func TestRFC5769(t *testing.T) { t.Error("check failed: ", err) } t.Run("Long-Term credentials", func(t *testing.T) { - m := &Message{ + msg := &Message{ Raw: []byte("\x00\x01\x00\x60" + "\x21\x12\xa4\x42" + "\x78\xad\x34\x33\xc6\xad\x72\xc0\x29\xda\x41\x2e" + @@ -64,11 +64,11 @@ func TestRFC5769(t *testing.T) { "\x2e\x85\xc9\xa2\x8c\xa8\x96\x66", ), } - if err := m.Decode(); err != nil { + if err := msg.Decode(); err != nil { t.Error(err) } u := new(Username) - if err := u.GetFrom(m); err != nil { + if err := u.GetFrom(msg); err != nil { t.Error(err) } expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9" @@ -76,14 +76,14 @@ func TestRFC5769(t *testing.T) { t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername) } n := new(Nonce) - if err := n.GetFrom(m); err != nil { + if err := n.GetFrom(msg); err != nil { t.Error(err) } if n.String() != "f//499k954d6OL34oL9FSTvy64sA" { t.Error("bad nonce") } r := new(Realm) - if err := r.GetFrom(m); err != nil { + if err := r.GetFrom(msg); err != nil { t.Error(err) } if r.String() != "example.org" { //nolint:goconst @@ -95,14 +95,14 @@ func TestRFC5769(t *testing.T) { "example.org", "TheMatrIX", ) - if err := i.Check(m); err != nil { + if err := i.Check(msg); err != nil { t.Error(err) } }) }) t.Run("Response", func(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - m := &Message{ + msg := &Message{ Raw: []byte("\x01\x01\x00\x3c" + "\x21\x12\xa4\x42" + "\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" + @@ -117,21 +117,21 @@ func TestRFC5769(t *testing.T) { "\xc0\x7d\x4c\x96", ), } - if err := m.Decode(); err != nil { + if err := msg.Decode(); err != nil { t.Error(err) } software := new(Software) - if err := software.GetFrom(m); err != nil { + if err := software.GetFrom(msg); err != nil { t.Error(err) } if software.String() != "test vector" { t.Error("bad software: ", software) } - if err := Fingerprint.Check(m); err != nil { + if err := Fingerprint.Check(msg); err != nil { t.Error("Check failed: ", err) } addr := new(XORMappedAddress) - if err := addr.GetFrom(m); err != nil { + if err := addr.GetFrom(msg); err != nil { t.Error(err) } if !addr.IP.Equal(net.ParseIP("192.0.2.1")) { @@ -140,12 +140,12 @@ func TestRFC5769(t *testing.T) { if addr.Port != 32853 { t.Error("bad Port") } - if err := Fingerprint.Check(m); err != nil { + if err := Fingerprint.Check(msg); err != nil { t.Error("check failed: ", err) } }) t.Run("IPv6", func(t *testing.T) { - m := &Message{ + msg := &Message{ Raw: []byte("\x01\x01\x00\x48" + "\x21\x12\xa4\x42" + "\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" + @@ -162,21 +162,21 @@ func TestRFC5769(t *testing.T) { "\xc8\xfb\x0b\x4c", ), } - if err := m.Decode(); err != nil { + if err := msg.Decode(); err != nil { t.Error(err) } software := new(Software) - if err := software.GetFrom(m); err != nil { + if err := software.GetFrom(msg); err != nil { t.Error(err) } if software.String() != "test vector" { t.Error("bad software: ", software) } - if err := Fingerprint.Check(m); err != nil { + if err := Fingerprint.Check(msg); err != nil { t.Error("Check failed: ", err) } addr := new(XORMappedAddress) - if err := addr.GetFrom(m); err != nil { + if err := addr.GetFrom(msg); err != nil { t.Error(err) } if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) { @@ -185,7 +185,7 @@ func TestRFC5769(t *testing.T) { if addr.Port != 32853 { t.Error("bad Port") } - if err := Fingerprint.Check(m); err != nil { + if err := Fingerprint.Check(msg); err != nil { t.Error("check failed: ", err) } }) diff --git a/stun.go b/stun.go index 3a7954f..26c2a53 100644 --- a/stun.go +++ b/stun.go @@ -27,6 +27,7 @@ func readFullOrPanic(r io.Reader, v []byte) int { if err != nil { panic(err) //nolint } + return n } @@ -35,6 +36,7 @@ func writeOrPanic(w io.Writer, v []byte) int { if err != nil { panic(err) //nolint } + return n } diff --git a/stuntest/udp_server.go b/stuntest/udp_server.go index d19c179..191b74d 100644 --- a/stuntest/udp_server.go +++ b/stuntest/udp_server.go @@ -13,14 +13,19 @@ import ( var errUDPServerUnsupportedNetwork = errors.New("unsupported network") -// NewUDPServer creates an udp server for testing. The supplied handler function will be called with the request +// NewUDPServer creates an udp server for testing. +// The supplied handler function will be called with the request // and should be used to emulate the server behavior. +// +//nolint:cyclop func NewUDPServer( t *testing.T, network string, maxMessageSize int, handler func(req []byte) ([]byte, error), ) (net.Addr, func(t *testing.T), error) { + t.Helper() + var ip string switch network { case "udp4": @@ -50,28 +55,34 @@ func NewUDPServer( n, addr, err := udpConn.ReadFrom(bs) if err != nil { errCh <- err + return } resp, err := handler(bs[:n]) if err != nil { errCh <- err + return } _, err = udpConn.WriteTo(resp, addr) if err != nil { errCh <- err + return } } }() return serverAddr, func(t *testing.T) { + t.Helper() + select { case err := <-errCh: if err != nil { - t.Fatal(err) //nolint + t.Fatal(err) + return } default: @@ -79,7 +90,7 @@ func NewUDPServer( err := udpConn.Close() if err != nil { - t.Fatal(err) //nolint + t.Fatal(err) } <-errCh diff --git a/textattrs.go b/textattrs.go index a98915a..b9626f0 100644 --- a/textattrs.go +++ b/textattrs.go @@ -10,7 +10,7 @@ func NewUsername(username string) Username { // Username represents USERNAME attribute. // -// RFC 5389 Section 15.3 +// RFC 5389 Section 15.3. type Username []byte func (u Username) String() string { @@ -37,7 +37,7 @@ func NewRealm(realm string) Realm { // Realm represents REALM attribute. // -// RFC 5389 Section 15.7 +// RFC 5389 Section 15.7. type Realm []byte func (n Realm) String() string { @@ -60,7 +60,7 @@ const softwareRawMaxB = 763 // Software is SOFTWARE attribute. // -// RFC 5389 Section 15.10 +// RFC 5389 Section 15.10. type Software []byte func (s Software) String() string { @@ -84,7 +84,7 @@ func (s *Software) GetFrom(m *Message) error { // Nonce represents NONCE attribute. // -// RFC 5389 Section 15.8 +// RFC 5389 Section 15.8. type Nonce []byte // NewNonce returns new Nonce from string. @@ -118,6 +118,7 @@ func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error { return err } m.Add(t, v) + return nil } @@ -128,5 +129,6 @@ func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error { return err } *v = a + return nil } diff --git a/textattrs_test.go b/textattrs_test.go index 649aba9..3935368 100644 --- a/textattrs_test.go +++ b/textattrs_test.go @@ -13,26 +13,26 @@ import ( ) func TestSoftware_GetFrom(t *testing.T) { - m := New() - v := "Client v0.0.1" - m.Add(AttrSoftware, []byte(v)) - m.WriteHeader() + msg := New() + val := "Client v0.0.1" + msg.Add(AttrSoftware, []byte(val)) + msg.WriteHeader() m2 := &Message{ Raw: make([]byte, 0, 256), } software := new(Software) - if _, err := m2.ReadFrom(m.reader()); err != nil { + if _, err := m2.ReadFrom(msg.reader()); err != nil { t.Error(err) } - if err := software.GetFrom(m); err != nil { + if err := software.GetFrom(msg); err != nil { t.Fatal(err) } - if software.String() != v { - t.Errorf("Expected %q, got %q.", v, software) + if software.String() != val { + t.Errorf("Expected %q, got %q.", val, software) } - sAttr, ok := m.Attributes.Get(AttrSoftware) + sAttr, ok := msg.Attributes.Get(AttrSoftware) if !ok { t.Error("software attribute should be found") } @@ -90,22 +90,22 @@ func BenchmarkUsername_GetFrom(b *testing.B) { func TestUsername(t *testing.T) { username := "username" - u := NewUsername(username) - m := new(Message) - m.WriteHeader() + uName := NewUsername(username) + msg := new(Message) + msg.WriteHeader() t.Run("Bad length", func(t *testing.T) { badU := make(Username, 600) - if err := badU.AddTo(m); !IsAttrSizeOverflow(err) { + if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) { t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) } }) t.Run("AddTo", func(t *testing.T) { - if err := u.AddTo(m); err != nil { + if err := uName.AddTo(msg); err != nil { t.Error("errored:", err) } t.Run("GetFrom", func(t *testing.T) { got := new(Username) - if err := got.GetFrom(m); err != nil { + if err := got.GetFrom(msg); err != nil { t.Error("errored:", err) } if got.String() != username { @@ -136,10 +136,10 @@ func TestUsername(t *testing.T) { } func TestRealm_GetFrom(t *testing.T) { - m := New() - v := "realm" - m.Add(AttrRealm, []byte(v)) - m.WriteHeader() + msg := New() + val := "realm" + msg.Add(AttrRealm, []byte(val)) + msg.WriteHeader() m2 := &Message{ Raw: make([]byte, 0, 256), @@ -148,17 +148,17 @@ func TestRealm_GetFrom(t *testing.T) { if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) { t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) } - if _, err := m2.ReadFrom(m.reader()); err != nil { + if _, err := m2.ReadFrom(msg.reader()); err != nil { t.Error(err) } - if err := r.GetFrom(m); err != nil { + if err := r.GetFrom(msg); err != nil { t.Fatal(err) } - if r.String() != v { - t.Errorf("Expected %q, got %q.", v, r) + if r.String() != val { + t.Errorf("Expected %q, got %q.", val, r) } - rAttr, ok := m.Attributes.Get(AttrRealm) + rAttr, ok := msg.Attributes.Get(AttrRealm) if !ok { t.Error("realm attribute should be found") } @@ -180,26 +180,26 @@ func TestRealm_AddTo_Invalid(t *testing.T) { } func TestNonce_GetFrom(t *testing.T) { - m := New() - v := "example.org" - m.Add(AttrNonce, []byte(v)) - m.WriteHeader() + msg := New() + val := "example.org" + msg.Add(AttrNonce, []byte(val)) + msg.WriteHeader() m2 := &Message{ Raw: make([]byte, 0, 256), } var nonce Nonce - if _, err := m2.ReadFrom(m.reader()); err != nil { + if _, err := m2.ReadFrom(msg.reader()); err != nil { t.Error(err) } - if err := nonce.GetFrom(m); err != nil { + if err := nonce.GetFrom(msg); err != nil { t.Fatal(err) } - if nonce.String() != v { - t.Errorf("Expected %q, got %q.", v, nonce) + if nonce.String() != val { + t.Errorf("Expected %q, got %q.", val, nonce) } - nAttr, ok := m.Attributes.Get(AttrNonce) + nAttr, ok := msg.Attributes.Get(AttrNonce) if !ok { t.Error("nonce attribute should be found") } diff --git a/uattrs.go b/uattrs.go index 8b85d8a..238cd84 100644 --- a/uattrs.go +++ b/uattrs.go @@ -7,7 +7,7 @@ import "errors" // UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute. // -// RFC 5389 Section 15.9 +// RFC 5389 Section 15.9. type UnknownAttributes []AttrType func (a UnknownAttributes) String() string { @@ -22,6 +22,7 @@ func (a UnknownAttributes) String() string { s += ", " } } + return s } @@ -39,6 +40,7 @@ func (a UnknownAttributes) AddTo(m *Message) error { bin.PutUint16(v[first:last], t.Value()) } m.Add(AttrUnknownAttributes, v) + return nil } @@ -62,5 +64,6 @@ func (a *UnknownAttributes) GetFrom(m *Message) error { *a = append(*a, AttrType(bin.Uint16(v[first:last]))) first = last } + return nil } diff --git a/uattrs_test.go b/uattrs_test.go index acf5df0..c70872c 100644 --- a/uattrs_test.go +++ b/uattrs_test.go @@ -8,26 +8,26 @@ import ( ) func TestUnknownAttributes(t *testing.T) { - m := new(Message) - a := &UnknownAttributes{ + msg := new(Message) + attr := &UnknownAttributes{ AttrDontFragment, AttrChannelNumber, } - if a.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" { - t.Error("bad String:", a) + if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" { + t.Error("bad String:", attr) } if (UnknownAttributes{}).String() != "" { t.Error("bad blank string") } - if err := a.AddTo(m); err != nil { + if err := attr.AddTo(msg); err != nil { t.Error(err) } t.Run("GetFrom", func(t *testing.T) { attrs := make(UnknownAttributes, 10) - if err := attrs.GetFrom(m); err != nil { + if err := attrs.GetFrom(msg); err != nil { t.Error(err) } - for i, at := range *a { + for i, at := range *attr { if at != attrs[i] { t.Error("expected", at, "!=", attrs[i]) } @@ -44,8 +44,8 @@ func TestUnknownAttributes(t *testing.T) { } func BenchmarkUnknownAttributes(b *testing.B) { - m := new(Message) - a := UnknownAttributes{ + msg := new(Message) + attr := UnknownAttributes{ AttrDontFragment, AttrChannelNumber, AttrRealm, @@ -54,20 +54,20 @@ func BenchmarkUnknownAttributes(b *testing.B) { b.Run("AddTo", func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - if err := a.AddTo(m); err != nil { + if err := attr.AddTo(msg); err != nil { b.Fatal(err) } - m.Reset() + msg.Reset() } }) b.Run("GetFrom", func(b *testing.B) { b.ReportAllocs() - if err := a.AddTo(m); err != nil { + if err := attr.AddTo(msg); err != nil { b.Fatal(err) } attrs := make(UnknownAttributes, 0, 10) for i := 0; i < b.N; i++ { - if err := attrs.GetFrom(m); err != nil { + if err := attrs.GetFrom(msg); err != nil { b.Fatal(err) } attrs = attrs[:0] diff --git a/uri.go b/uri.go index b9dab69..63a4b47 100644 --- a/uri.go +++ b/uri.go @@ -124,7 +124,7 @@ func (t ProtoType) String() string { } } -// URI represents a STUN (rfc7064) or TURN (rfc7065) URI +// URI represents a STUN (rfc7064) or TURN (rfc7065) URI. type URI struct { Scheme SchemeType Host string @@ -137,73 +137,76 @@ type URI struct { // ParseURI parses a STUN or TURN urls following the ABNF syntax described in // https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065 // respectively. -func ParseURI(raw string) (*URI, error) { //nolint:gocognit +func ParseURI(raw string) (*URI, error) { //nolint:gocognit,cyclop rawParts, err := url.Parse(raw) if err != nil { return nil, err } - var u URI - u.Scheme = NewSchemeType(rawParts.Scheme) - if u.Scheme == SchemeTypeUnknown { + var uri URI + uri.Scheme = NewSchemeType(rawParts.Scheme) + if uri.Scheme == SchemeTypeUnknown { return nil, ErrSchemeType } var rawPort string - if u.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { + if uri.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { //nolint:nestif var e *net.AddrError if errors.As(err, &e) { if e.Err == "missing port in address" { - nextRawURL := u.Scheme.String() + ":" + rawParts.Opaque + nextRawURL := uri.Scheme.String() + ":" + rawParts.Opaque switch { - case u.Scheme == SchemeTypeSTUN || u.Scheme == SchemeTypeTURN: + case uri.Scheme == SchemeTypeSTUN || uri.Scheme == SchemeTypeTURN: nextRawURL += ":3478" if rawParts.RawQuery != "" { nextRawURL += "?" + rawParts.RawQuery } + return ParseURI(nextRawURL) - case u.Scheme == SchemeTypeSTUNS || u.Scheme == SchemeTypeTURNS: + case uri.Scheme == SchemeTypeSTUNS || uri.Scheme == SchemeTypeTURNS: nextRawURL += ":5349" if rawParts.RawQuery != "" { nextRawURL += "?" + rawParts.RawQuery } + return ParseURI(nextRawURL) } } } + return nil, err } - if u.Host == "" { + if uri.Host == "" { return nil, ErrHost } - if u.Port, err = strconv.Atoi(rawPort); err != nil { + if uri.Port, err = strconv.Atoi(rawPort); err != nil { return nil, ErrPort } - switch u.Scheme { + switch uri.Scheme { case SchemeTypeSTUN: qArgs, err := url.ParseQuery(rawParts.RawQuery) if err != nil || len(qArgs) > 0 { return nil, ErrSTUNQuery } - u.Proto = ProtoTypeUDP + uri.Proto = ProtoTypeUDP case SchemeTypeSTUNS: qArgs, err := url.ParseQuery(rawParts.RawQuery) if err != nil || len(qArgs) > 0 { return nil, ErrSTUNQuery } - u.Proto = ProtoTypeTCP + uri.Proto = ProtoTypeTCP case SchemeTypeTURN: proto, err := parseProto(rawParts.RawQuery) if err != nil { return nil, err } - u.Proto = proto - if u.Proto == ProtoTypeUnknown { - u.Proto = ProtoTypeUDP + uri.Proto = proto + if uri.Proto == ProtoTypeUnknown { + uri.Proto = ProtoTypeUDP } case SchemeTypeTURNS: proto, err := parseProto(rawParts.RawQuery) @@ -211,15 +214,15 @@ func ParseURI(raw string) (*URI, error) { //nolint:gocognit return nil, err } - u.Proto = proto - if u.Proto == ProtoTypeUnknown { - u.Proto = ProtoTypeTCP + uri.Proto = proto + if uri.Proto == ProtoTypeUnknown { + uri.Proto = ProtoTypeTCP } case SchemeTypeUnknown: } - return &u, nil + return &uri, nil } func parseProto(raw string) (ProtoType, error) { @@ -233,6 +236,7 @@ func parseProto(raw string) (ProtoType, error) { if proto = NewProtoType(rawProto); proto == ProtoType(0) { return ProtoTypeUnknown, ErrProtoType } + return proto, nil } @@ -248,6 +252,7 @@ func (u URI) String() string { if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS { rawURL += "?transport=" + u.Proto.String() } + return rawURL } diff --git a/uri_test.go b/uri_test.go index 9153c16..e8cc3e9 100644 --- a/uri_test.go +++ b/uri_test.go @@ -35,8 +35,16 @@ func TestParseURL(t *testing.T) { {"stun:[::1]:123", "stun:[::1]:123", SchemeTypeSTUN, false, "::1", 123, ProtoTypeUDP}, {"turn:google.de", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP}, {"turns:google.de", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP}, - {"turn:google.de?transport=udp", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP}, - {"turns:google.de?transport=tcp", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP}, + { + "turn:google.de?transport=udp", + "turn:google.de:3478?transport=udp", + SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP, + }, + { + "turns:google.de?transport=tcp", + "turns:google.de:5349?transport=tcp", + SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP, + }, } for i, testCase := range testCases { diff --git a/xoraddr.go b/xoraddr.go index 2c5a954..273ac50 100644 --- a/xoraddr.go +++ b/xoraddr.go @@ -20,7 +20,7 @@ const ( // XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. // -// RFC 5389 Section 15.2 +// RFC 5389 Section 15.2. type XORMappedAddress struct { IP net.IP Port int @@ -43,14 +43,15 @@ func isZeros(p net.IP) bool { return false } } + return true } // ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}. var ErrBadIPLength = errors.New("invalid length of IP value") -// AddToAs adds XOR-MAPPED-ADDRESS value to m as t attribute. -func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error { +// AddToAs adds XOR-MAPPED-ADDRESS value to msg as attr attribute. +func (a XORMappedAddress) AddToAs(msg *Message, attr AttrType) error { var ( family = familyIPv4 ip = a.IP @@ -67,12 +68,13 @@ func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error { value := make([]byte, 32+128) value[0] = 0 // first 8 bits are zeroes xorValue := make([]byte, net.IPv6len) - copy(xorValue[4:], m.TransactionID[:]) + copy(xorValue[4:], msg.TransactionID[:]) bin.PutUint32(xorValue[0:4], magicCookie) bin.PutUint16(value[0:2], family) - bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) + bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) //nolint:gosec // G115, false positive, port xor.XorBytes(value[4:4+len(ip)], ip, xorValue) - m.Add(t, value[:4+len(ip)]) + msg.Add(attr, value[:4+len(ip)]) + return nil } @@ -83,13 +85,13 @@ func (a XORMappedAddress) AddTo(m *Message) error { } // GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message -// getting it as for t type. -func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error { - v, err := m.Get(t) +// getting it as for attr type. +func (a *XORMappedAddress) GetFromAs(msg *Message, attr AttrType) error { + value, err := msg.Get(attr) if err != nil { return err } - family := bin.Uint16(v[0:2]) + family := bin.Uint16(value[0:2]) if family != familyIPv6 && family != familyIPv4 { return newDecodeErr("xor-mapped address", "family", fmt.Sprintf("bad value %d", family), @@ -110,17 +112,18 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error { for i := range a.IP { a.IP[i] = 0 } - if len(v) <= 4 { + if len(value) <= 4 { return io.ErrUnexpectedEOF } - if err := CheckOverflow(t, len(v[4:]), len(a.IP)); err != nil { + if err := CheckOverflow(attr, len(value[4:]), len(a.IP)); err != nil { return err } - a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16) + a.Port = int(bin.Uint16(value[2:4])) ^ (magicCookie >> 16) xorValue := make([]byte, 4+TransactionIDSize) bin.PutUint32(xorValue[0:4], magicCookie) - copy(xorValue[4:], m.TransactionID[:]) - xor.XorBytes(a.IP, v[4:], xorValue) + copy(xorValue[4:], msg.TransactionID[:]) + xor.XorBytes(a.IP, value[4:], xorValue) + return nil } diff --git a/xoraddr_test.go b/xoraddr_test.go index 61506c4..f1017f8 100644 --- a/xoraddr_test.go +++ b/xoraddr_test.go @@ -29,21 +29,21 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) { } func BenchmarkXORMappedAddress_GetFrom(b *testing.B) { - m := New() + msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") if err != nil { b.Error(err) } - copy(m.TransactionID[:], transactionID) + copy(msg.TransactionID[:], transactionID) addrValue, err := hex.DecodeString("00019cd5f49f38ae") if err != nil { b.Error(err) } - m.Add(AttrXORMappedAddress, addrValue) + msg.Add(AttrXORMappedAddress, addrValue) addr := new(XORMappedAddress) b.ReportAllocs() for i := 0; i < b.N; i++ { - if err := addr.GetFrom(m); err != nil { + if err := addr.GetFrom(msg); err != nil { b.Fatal(err) } } @@ -94,54 +94,54 @@ func TestXORMappedAddress_GetFrom(t *testing.T) { } func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) { - m := New() + msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") if err != nil { t.Error(err) } - copy(m.TransactionID[:], transactionID) + copy(msg.TransactionID[:], transactionID) expectedIP := net.ParseIP("213.141.156.236") expectedPort := 21254 addr := new(XORMappedAddress) - if err = addr.GetFrom(m); err == nil { + if err = addr.GetFrom(msg); err == nil { t.Fatal(err, "should be nil") } addr.IP = expectedIP addr.Port = expectedPort - addr.AddTo(m) //nolint:errcheck,gosec - m.WriteHeader() + addr.AddTo(msg) //nolint:errcheck,gosec + msg.WriteHeader() mRes := New() - binary.BigEndian.PutUint16(m.Raw[20+4:20+4+2], 0x21) - if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil { + binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21) + if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil { t.Fatal(err) } - if err = addr.GetFrom(m); err == nil { + if err = addr.GetFrom(msg); err == nil { t.Fatal(err, "should not be nil") } } func TestXORMappedAddress_AddTo(t *testing.T) { - m := New() + msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") if err != nil { t.Error(err) } - copy(m.TransactionID[:], transactionID) + copy(msg.TransactionID[:], transactionID) expectedIP := net.ParseIP("213.141.156.236") expectedPort := 21254 addr := &XORMappedAddress{ IP: net.ParseIP("213.141.156.236"), Port: expectedPort, } - if err = addr.AddTo(m); err != nil { + if err = addr.AddTo(msg); err != nil { t.Fatal(err) } - m.WriteHeader() + msg.WriteHeader() mRes := New() - if _, err = mRes.Write(m.Raw); err != nil { + if _, err = mRes.Write(msg.Raw); err != nil { t.Fatal(err) } if err = addr.GetFrom(mRes); err != nil { @@ -156,27 +156,27 @@ func TestXORMappedAddress_AddTo(t *testing.T) { } func TestXORMappedAddress_AddTo_IPv6(t *testing.T) { - m := New() + msg := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") if err != nil { t.Error(err) } - copy(m.TransactionID[:], transactionID) + copy(msg.TransactionID[:], transactionID) expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009") expectedPort := 21254 addr := &XORMappedAddress{ IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"), Port: 21254, } - addr.AddTo(m) //nolint:errcheck,gosec - m.WriteHeader() + addr.AddTo(msg) //nolint:errcheck,gosec + msg.WriteHeader() mRes := New() - if _, err = mRes.ReadFrom(m.reader()); err != nil { + if _, err = mRes.ReadFrom(msg.reader()); err != nil { t.Fatal(err) } gotAddr := new(XORMappedAddress) - if err = gotAddr.GetFrom(m); err != nil { + if err = gotAddr.GetFrom(msg); err != nil { t.Fatal(err) } if !gotAddr.IP.Equal(expectedIP) {