Upgrade golangci-lint, more linters

Introduces new linters, upgrade golangci-lint to version (v1.63.4)
This commit is contained in:
Joe Turki
2025-01-18 10:48:23 -06:00
parent be5e65e013
commit 0397f2187b
44 changed files with 883 additions and 656 deletions

View File

@@ -25,17 +25,32 @@ linters-settings:
- ^os.Exit$
- ^panic$
- ^print(ln)?$
varnamelen:
max-distance: 12
min-name-length: 2
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-decls:
- i int
- n int
- w io.Writer
- r io.Reader
- b []byte
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- containedctx # containedctx is a linter that detects struct contained context.Context field
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- decorder # check declaration order and count of types, constants, variables and functions
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- err113 # Golang linter to check the errors handling expressions
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
@@ -46,18 +61,17 @@ linters:
- forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- godox # Tool for detection of FIXME, TODO and other comment keywords
- err113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
@@ -65,9 +79,15 @@ linters:
- grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- noctx # noctx finds sending http request without context.Context
- predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes
@@ -75,28 +95,22 @@ linters:
- stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varnamelen # checks that the length of a variable's name matches its scope
- wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- gochecknoinits # Checks that no init functions are present in Go code
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- interfacebloat # A linter that checks length of interface.
- ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- mnd # An analyzer to detect magic numbers
- nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated
@@ -104,8 +118,7 @@ linters:
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- varnamelen # checks that the length of a variable's name matches its scope
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines!
@@ -123,3 +136,4 @@ issues:
- path: cmd
linters:
- forbidigo

32
addr.go
View File

@@ -15,7 +15,7 @@ import (
// This attribute is used only by servers for achieving backwards
// compatibility with RFC 3489 clients.
//
// RFC 5389 Section 15.1
// RFC 5389 Section 15.1.
type MappedAddress struct {
IP net.IP
Port int
@@ -23,7 +23,7 @@ type MappedAddress struct {
// AlternateServer represents ALTERNATE-SERVER attribute.
//
// RFC 5389 Section 15.11
// RFC 5389 Section 15.11.
type AlternateServer struct {
IP net.IP
Port int
@@ -31,7 +31,7 @@ type AlternateServer struct {
// ResponseOrigin represents RESPONSE-ORIGIN attribute.
//
// RFC 5780 Section 7.3
// RFC 5780 Section 7.3.
type ResponseOrigin struct {
IP net.IP
Port int
@@ -39,7 +39,7 @@ type ResponseOrigin struct {
// OtherAddress represents OTHER-ADDRESS attribute.
//
// RFC 5780 Section 7.4
// RFC 5780 Section 7.4.
type OtherAddress struct {
IP net.IP
Port int
@@ -48,12 +48,14 @@ type OtherAddress struct {
// AddTo adds ALTERNATE-SERVER attribute to message.
func (s *AlternateServer) AddTo(m *Message) error {
a := (*MappedAddress)(s)
return a.AddToAs(m, AttrAlternateServer)
}
// GetFrom decodes ALTERNATE-SERVER from message.
func (s *AlternateServer) GetFrom(m *Message) error {
a := (*MappedAddress)(s)
return a.GetFromAs(m, AttrAlternateServer)
}
@@ -63,14 +65,14 @@ func (a MappedAddress) String() string {
// GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t.
func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
v, err := m.Get(t)
value, err := m.Get(t)
if err != nil {
return err
}
if len(v) <= 4 {
if len(value) <= 4 {
return io.ErrUnexpectedEOF
}
family := bin.Uint16(v[0:2])
family := bin.Uint16(value[0:2])
if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
@@ -91,13 +93,14 @@ func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error {
for i := range a.IP {
a.IP[i] = 0
}
a.Port = int(bin.Uint16(v[2:4]))
copy(a.IP, v[4:])
a.Port = int(bin.Uint16(value[2:4]))
copy(a.IP, value[4:])
return nil
}
// AddToAs adds MAPPED-ADDRESS value to m as t attribute.
func (a *MappedAddress) AddToAs(m *Message, t AttrType) error {
func (a *MappedAddress) AddToAs(msg *Message, attrType AttrType) error {
var (
family = familyIPv4
ip = a.IP
@@ -114,9 +117,10 @@ func (a *MappedAddress) AddToAs(m *Message, t AttrType) error {
value := make([]byte, 128)
value[0] = 0 // first 8 bits are zeroes
bin.PutUint16(value[0:2], family)
bin.PutUint16(value[2:4], uint16(a.Port))
bin.PutUint16(value[2:4], uint16(a.Port)) //nolint:gosec //G115
copy(value[4:], ip)
m.Add(t, value[:4+len(ip)])
msg.Add(attrType, value[:4+len(ip)])
return nil
}
@@ -133,12 +137,14 @@ func (a *MappedAddress) GetFrom(m *Message) error {
// AddTo adds OTHER-ADDRESS attribute to message.
func (o *OtherAddress) AddTo(m *Message) error {
a := (*MappedAddress)(o)
return a.AddToAs(m, AttrOtherAddress)
}
// GetFrom decodes OTHER-ADDRESS from message.
func (o *OtherAddress) GetFrom(m *Message) error {
a := (*MappedAddress)(o)
return a.GetFromAs(m, AttrOtherAddress)
}
@@ -149,12 +155,14 @@ func (o OtherAddress) String() string {
// AddTo adds RESPONSE-ORIGIN attribute to message.
func (o *ResponseOrigin) AddTo(m *Message) error {
a := (*MappedAddress)(o)
return a.AddToAs(m, AttrResponseOrigin)
}
// GetFrom decodes RESPONSE-ORIGIN from message.
func (o *ResponseOrigin) GetFrom(m *Message) error {
a := (*MappedAddress)(o)
return a.GetFromAs(m, AttrResponseOrigin)
}

View File

@@ -11,7 +11,7 @@ import (
)
func TestMappedAddress(t *testing.T) {
m := new(Message)
msg := new(Message)
addr := &MappedAddress{
IP: net.ParseIP("122.12.34.5"),
Port: 5412,
@@ -23,17 +23,17 @@ func TestMappedAddress(t *testing.T) {
badAddr := &MappedAddress{
IP: net.IP{1, 2, 3},
}
if err := badAddr.AddTo(m); err == nil {
if err := badAddr.AddTo(msg); err == nil {
t.Error("should error")
}
})
t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil {
if err := addr.AddTo(msg); err != nil {
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) {
got := new(MappedAddress)
if err := got.GetFrom(m); err != nil {
if err := got.GetFrom(msg); err != nil {
t.Error(err)
}
if !got.IP.Equal(addr.IP) {
@@ -46,9 +46,9 @@ func TestMappedAddress(t *testing.T) {
}
})
t.Run("Bad family", func(t *testing.T) {
v, _ := m.Attributes.Get(AttrMappedAddress)
v, _ := msg.Attributes.Get(AttrMappedAddress)
v.Value[0] = 32
if err := got.GetFrom(m); err == nil {
if err := got.GetFrom(msg); err == nil {
t.Error("should error")
}
})

View File

@@ -24,6 +24,7 @@ func NewAgent(h Handler) *Agent {
transactions: make(map[transactionID]agentTransaction),
handler: h,
}
return a
}
@@ -80,6 +81,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
a.mux.Lock()
if a.closed {
a.mux.Unlock()
return ErrAgentClosed
}
t, exists := a.transactions[id]
@@ -93,6 +95,7 @@ func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error {
TransactionID: t.id,
Error: err,
})
return nil
}
@@ -124,6 +127,7 @@ func (a *Agent) Start(id [TransactionIDSize]byte, deadline time.Time) error {
id: id,
deadline: deadline,
}
return nil
}
@@ -147,6 +151,7 @@ func (a *Agent) Collect(gcTime time.Time) error {
// All transactions should be already closed
// during Close() call.
a.mux.Unlock()
return ErrAgentClosed
}
// Adding all transactions with deadline before gcTime
@@ -175,24 +180,27 @@ func (a *Agent) Collect(gcTime time.Time) error {
event.TransactionID = id
h(event)
}
return nil
}
// Process incoming message, synchronously passing it to handler.
func (a *Agent) Process(m *Message) error {
e := Event{
event := Event{
TransactionID: m.TransactionID,
Message: m,
}
a.mux.Lock()
if a.closed {
a.mux.Unlock()
return ErrAgentClosed
}
h := a.handler
delete(a.transactions, m.TransactionID)
a.mux.Unlock()
h(e)
h(event)
return nil
}
@@ -201,10 +209,12 @@ func (a *Agent) SetHandler(h Handler) error {
a.mux.Lock()
if a.closed {
a.mux.Unlock()
return ErrAgentClosed
}
a.handler = h
a.mux.Unlock()
return nil
}
@@ -217,6 +227,7 @@ func (a *Agent) Close() error {
a.mux.Lock()
if a.closed {
a.mux.Unlock()
return ErrAgentClosed
}
for _, t := range a.transactions {
@@ -227,6 +238,7 @@ func (a *Agent) Close() error {
a.closed = true
a.handler = nil
a.mux.Unlock()
return nil
}

View File

@@ -10,49 +10,49 @@ import (
)
func TestAgent_ProcessInTransaction(t *testing.T) {
m := New()
a := NewAgent(func(e Event) {
msg := New()
agent := NewAgent(func(e Event) {
if e.Error != nil {
t.Errorf("got error: %s", e.Error)
}
if !e.Message.Equal(m) {
t.Errorf("%s (got) != %s (expected)", e.Message, m)
if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
}
})
if err := m.NewTransactionID(); err != nil {
if err := msg.NewTransactionID(); err != nil {
t.Fatal(err)
}
if err := a.Start(m.TransactionID, time.Time{}); err != nil {
if err := agent.Start(msg.TransactionID, time.Time{}); err != nil {
t.Fatal(err)
}
if err := a.Process(m); err != nil {
if err := agent.Process(msg); err != nil {
t.Error(err)
}
if err := a.Close(); err != nil {
if err := agent.Close(); err != nil {
t.Error(err)
}
}
func TestAgent_Process(t *testing.T) {
m := New()
a := NewAgent(func(e Event) {
msg := New()
agent := NewAgent(func(e Event) {
if e.Error != nil {
t.Errorf("got error: %s", e.Error)
}
if !e.Message.Equal(m) {
t.Errorf("%s (got) != %s (expected)", e.Message, m)
if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
}
})
if err := m.NewTransactionID(); err != nil {
if err := msg.NewTransactionID(); err != nil {
t.Fatal(err)
}
if err := a.Process(m); err != nil {
if err := agent.Process(msg); err != nil {
t.Error(err)
}
if err := a.Close(); err != nil {
if err := agent.Close(); err != nil {
t.Error(err)
}
if err := a.Process(m); !errors.Is(err, ErrAgentClosed) {
if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) {
t.Errorf("closed agent should return <%s>, but got <%s>",
ErrAgentClosed, err,
)
@@ -60,27 +60,27 @@ func TestAgent_Process(t *testing.T) {
}
func TestAgent_Start(t *testing.T) {
a := NewAgent(nil)
agent := NewAgent(nil)
id := NewTransactionID()
deadline := time.Now().AddDate(0, 0, 1)
if err := a.Start(id, deadline); err != nil {
if err := agent.Start(id, deadline); err != nil {
t.Errorf("failed to statt transaction: %s", err)
}
if err := a.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
t.Errorf("duplicate start should return <%s>, got <%s>",
ErrTransactionExists, err,
)
}
if err := a.Close(); err != nil {
if err := agent.Close(); err != nil {
t.Error(err)
}
id = NewTransactionID()
if err := a.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
t.Errorf("start on closed agent should return <%s>, got <%s>",
ErrAgentClosed, err,
)
}
if err := a.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
ErrAgentClosed, err,
)
@@ -89,18 +89,18 @@ func TestAgent_Start(t *testing.T) {
func TestAgent_Stop(t *testing.T) {
called := make(chan Event, 1)
a := NewAgent(func(e Event) {
agent := NewAgent(func(e Event) {
called <- e
})
if err := a.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
}
id := NewTransactionID()
timeout := time.Millisecond * 200
if err := a.Start(id, time.Now().Add(timeout)); err != nil {
if err := agent.Start(id, time.Now().Add(timeout)); err != nil {
t.Fatal(err)
}
if err := a.Stop(id); err != nil {
if err := agent.Stop(id); err != nil {
t.Fatal(err)
}
select {
@@ -113,19 +113,19 @@ func TestAgent_Stop(t *testing.T) {
case <-time.After(timeout * 2):
t.Fatal("timed out")
}
if err := a.Close(); err != nil {
if err := agent.Close(); err != nil {
t.Fatal(err)
}
if err := a.Close(); !errors.Is(err, ErrAgentClosed) {
if err := agent.Close(); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
}
if err := a.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
}
}
func TestAgent_GC(t *testing.T) {
a := NewAgent(nil)
func TestAgent_GC(t *testing.T) { //nolint:cyclop
agent := NewAgent(nil)
shouldTimeOutID := make(map[transactionID]bool)
deadline := time.Date(2027, time.November, 21,
23, 0, 0, 0,
@@ -133,7 +133,7 @@ func TestAgent_GC(t *testing.T) {
)
gcDeadline := deadline.Add(-time.Second)
deadlineNotGC := gcDeadline.AddDate(0, 0, -1)
a.SetHandler(func(e Event) { //nolint:errcheck,gosec
agent.SetHandler(func(e Event) { //nolint:errcheck,gosec
id := e.TransactionID
shouldTimeOut, found := shouldTimeOutID[id]
if !found {
@@ -149,67 +149,67 @@ func TestAgent_GC(t *testing.T) {
for i := 0; i < 5; i++ {
id := NewTransactionID()
shouldTimeOutID[id] = false
if err := a.Start(id, deadline); err != nil {
if err := agent.Start(id, deadline); err != nil {
t.Fatal(err)
}
}
for i := 0; i < 5; i++ {
id := NewTransactionID()
shouldTimeOutID[id] = true
if err := a.Start(id, deadlineNotGC); err != nil {
if err := agent.Start(id, deadlineNotGC); err != nil {
t.Fatal(err)
}
}
if err := a.Collect(gcDeadline); err != nil {
if err := agent.Collect(gcDeadline); err != nil {
t.Fatal(err)
}
if err := a.Close(); err != nil {
if err := agent.Close(); err != nil {
t.Error(err)
}
if err := a.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
}
}
func BenchmarkAgent_GC(b *testing.B) {
a := NewAgent(nil)
agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < agentCollectCap; i++ {
if err := a.Start(NewTransactionID(), deadline); err != nil {
if err := agent.Start(NewTransactionID(), deadline); err != nil {
b.Fatal(err)
}
}
defer func() {
if err := a.Close(); err != nil {
if err := agent.Close(); err != nil {
b.Error(err)
}
}()
b.ReportAllocs()
gcDeadline := deadline.Add(-time.Second)
for i := 0; i < b.N; i++ {
if err := a.Collect(gcDeadline); err != nil {
if err := agent.Collect(gcDeadline); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAgent_Process(b *testing.B) {
a := NewAgent(nil)
agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < 1000; i++ {
if err := a.Start(NewTransactionID(), deadline); err != nil {
if err := agent.Start(NewTransactionID(), deadline); err != nil {
b.Fatal(err)
}
}
defer func() {
if err := a.Close(); err != nil {
if err := agent.Close(); err != nil {
b.Error(err)
}
}()
b.ReportAllocs()
m := MustBuild(TransactionID)
for i := 0; i < b.N; i++ {
if err := a.Process(m); err != nil {
if err := agent.Process(m); err != nil {
b.Fatal(err)
}
}

View File

@@ -21,6 +21,7 @@ func (a Attributes) Get(t AttrType) (RawAttribute, bool) {
return candidate, true
}
}
return RawAttribute{}, false
}
@@ -77,7 +78,7 @@ const (
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
)
// Attributes from RFC 5780 NAT Behavior Discovery
// Attributes from RFC 5780 NAT Behavior Discovery.
const (
AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST
AttrPadding AttrType = 0x0026 // PADDING
@@ -166,6 +167,7 @@ func (t AttrType) String() string {
// Just return hex representation of unknown attribute type.
return fmt.Sprintf("0x%x", uint16(t))
}
return s
}
@@ -186,6 +188,7 @@ type RawAttribute struct {
// the Length field.
func (a RawAttribute) AddTo(m *Message) error {
m.Add(a.Type, a.Value)
return nil
}
@@ -205,6 +208,7 @@ func (a RawAttribute) Equal(b RawAttribute) bool {
return false
}
}
return true
}
@@ -224,6 +228,7 @@ func (m *Message) Get(t AttrType) ([]byte, error) {
if !ok {
return nil, ErrAttributeNotFound
}
return v.Value, nil
}
@@ -240,6 +245,7 @@ func nearestPaddedValueLength(l int) int {
if n < l {
n += padding
}
return n
}
@@ -250,5 +256,6 @@ func compatAttrType(val uint16) AttrType {
if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN
return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on)
}
return AttrType(val)
}

View File

@@ -44,13 +44,13 @@ func TestRawAttribute_AddTo(t *testing.T) {
}
func TestMessage_GetNoAllocs(t *testing.T) {
m := New()
NewSoftware("c").AddTo(m) //nolint:errcheck,gosec
m.WriteHeader()
msg := New()
NewSoftware("c").AddTo(msg) //nolint:errcheck,gosec
msg.WriteHeader()
t.Run("Default", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() {
m.Get(AttrSoftware) //nolint:errcheck,gosec
msg.Get(AttrSoftware) //nolint:errcheck,gosec
})
if allocs > 0 {
t.Error("allocated memory, but should not")
@@ -58,7 +58,7 @@ func TestMessage_GetNoAllocs(t *testing.T) {
})
t.Run("Not found", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() {
m.Get(AttrOrigin) //nolint:errcheck,gosec
msg.Get(AttrOrigin) //nolint:errcheck,gosec
})
if allocs > 0 {
t.Error("allocated memory, but should not")

View File

@@ -17,6 +17,7 @@ func CheckSize(_ AttrType, got, expected int) error {
if got == expected {
return nil
}
return ErrAttributeSizeInvalid
}
@@ -24,6 +25,7 @@ func checkHMAC(got, expected []byte) error {
if hmac.Equal(got, expected) {
return nil
}
return ErrIntegrityMismatch
}
@@ -31,6 +33,7 @@ func checkFingerprint(got, expected uint32) error {
if got == expected {
return nil
}
return ErrFingerprintMismatch
}
@@ -44,6 +47,7 @@ func CheckOverflow(_ AttrType, got, maxVal int) error {
if got <= maxVal {
return nil
}
return ErrAttributeSizeOverflow
}

126
client.go
View File

@@ -21,7 +21,7 @@ import (
"github.com/pion/transport/v3/stdnet"
)
// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI
// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI.
var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")
// Dial connects to the address on the named network and then
@@ -31,10 +31,11 @@ func Dial(network, address string) (*Client, error) {
if err != nil {
return nil, err
}
return NewClient(conn)
}
// DialConfig is used to pass configuration to DialURI()
// DialConfig is used to pass configuration to DialURI().
type DialConfig struct {
DTLSConfig dtls.Config
TLSConfig tls.Config
@@ -44,7 +45,7 @@ type DialConfig struct {
// DialURI connect to the STUN/TURN URI and then
// initializes Client on that connection, returning error if any.
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) {
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop
var conn Connection
var err error
@@ -203,7 +204,7 @@ const (
// provide any API for it, so if you need to read application data, wrap the
// connection with your (de-)multiplexer and pass the wrapper as conn.
func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
c := &Client{
client := &Client{
close: make(chan struct{}),
c: conn,
clock: systemClock(),
@@ -214,32 +215,33 @@ func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
closeConn: true,
}
for _, o := range options {
o(c)
o(client)
}
if c.c == nil {
if client.c == nil {
return nil, ErrNoConnection
}
if c.a == nil {
c.a = NewAgent(nil)
if client.a == nil {
client.a = NewAgent(nil)
}
if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
if err := client.a.SetHandler(client.handleAgentCallback); err != nil {
return nil, err
}
if c.collector == nil {
c.collector = &tickerCollector{
if client.collector == nil {
client.collector = &tickerCollector{
close: make(chan struct{}),
clock: c.clock,
clock: client.clock,
}
}
if err := c.collector.Start(c.rtoRate, func(t time.Time) {
closedOrPanic(c.a.Collect(t))
if err := client.collector.Start(client.rtoRate, func(t time.Time) {
closedOrPanic(client.a.Collect(t))
}); err != nil {
return nil, err
}
c.wg.Add(1)
go c.readUntilClosed()
runtime.SetFinalizer(c, clientFinalizer)
return c, nil
client.wg.Add(1)
go client.readUntilClosed()
runtime.SetFinalizer(client, clientFinalizer)
return client, nil
}
func clientFinalizer(c *Client) {
@@ -252,6 +254,7 @@ func clientFinalizer(c *Client) {
}
if err == nil {
log.Println("client: called finalizer on non-closed client") // nolint
return
}
log.Println("client: called finalizer on non-closed client:", err) // nolint
@@ -353,6 +356,7 @@ func (c *Client) start(t *clientTransaction) error {
return ErrTransactionExists
}
c.t[t.id] = t
return nil
}
@@ -399,6 +403,7 @@ func sprintErr(err error) string {
if err == nil {
return "<nil>" //nolint:goconst
}
return err.Error()
}
@@ -455,18 +460,21 @@ func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error
select {
case <-a.close:
t.Stop()
return
case <-t.C:
f(a.clock.Now())
}
}
}()
return nil
}
func (a *tickerCollector) Close() error {
close(a.close)
a.wg.Wait()
return nil
}
@@ -481,6 +489,7 @@ func (c *Client) Close() error {
c.mux.Lock()
if c.closed {
c.mux.Unlock()
return ErrClientClosed
}
c.closed = true
@@ -498,6 +507,7 @@ func (c *Client) Close() error {
if agentErr == nil && connErr == nil {
return nil
}
return CloseErr{
AgentErr: agentErr,
ConnectionErr: connErr,
@@ -566,6 +576,7 @@ func (c *Client) checkInit() error {
if c == nil || c.c == nil || c.a == nil || c.close == nil {
return ErrClientNotInitialized
}
return nil
}
@@ -590,6 +601,7 @@ func (c *Client) Do(m *Message, f func(Event)) error {
return err
}
h.wait()
return nil
}
@@ -611,80 +623,85 @@ var bufferPool = &sync.Pool{ //nolint:gochecknoglobals
},
}
func (c *Client) handleAgentCallback(e Event) {
func (c *Client) handleAgentCallback(event Event) { //nolint:cyclop
c.mux.Lock()
if c.closed {
c.mux.Unlock()
return
}
t, found := c.t[e.TransactionID]
transaction, found := c.t[event.TransactionID]
if found {
delete(c.t, t.id)
delete(c.t, transaction.id)
}
c.mux.Unlock()
if !found {
if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
c.handler(e)
if c.handler != nil && !errors.Is(event.Error, ErrTransactionStopped) {
c.handler(event)
}
// Ignoring.
return
}
if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
if atomic.LoadInt32(&c.maxAttempts) <= transaction.attempt || event.Error == nil {
// Transaction completed.
t.handle(e)
putClientTransaction(t)
transaction.handle(event)
putClientTransaction(transaction)
return
}
// Doing re-transmission.
t.attempt++
b := bufferPool.Get().(*buffer) //nolint:forcetypeassert
b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
defer bufferPool.Put(b)
transaction.attempt++
buff := bufferPool.Get().(*buffer) //nolint:forcetypeassert
buff.buf = buff.buf[:copy(buff.buf[:cap(buff.buf)], transaction.raw)]
defer bufferPool.Put(buff)
var (
now = c.clock.Now()
timeOut = t.nextTimeout(now)
id = t.id
timeOut = transaction.nextTimeout(now)
id = transaction.id
)
// Starting client transaction.
if startErr := c.start(t); startErr != nil {
if startErr := c.start(transaction); startErr != nil {
c.delete(id)
e.Error = startErr
t.handle(e)
putClientTransaction(t)
event.Error = startErr
transaction.handle(event)
putClientTransaction(transaction)
return
}
// Starting agent transaction.
if startErr := c.a.Start(id, timeOut); startErr != nil {
c.delete(id)
e.Error = startErr
t.handle(e)
putClientTransaction(t)
event.Error = startErr
transaction.handle(event)
putClientTransaction(transaction)
return
}
// Writing message to connection again.
_, writeErr := c.c.Write(b.buf)
_, writeErr := c.c.Write(buff.buf)
if writeErr != nil {
c.delete(id)
e.Error = writeErr
event.Error = writeErr
// Stopping agent transaction instead of waiting until it's deadline.
// This will call handleAgentCallback with "ErrTransactionStopped" error
// which will be ignored.
if stopErr := c.a.Stop(id); stopErr != nil {
// Failed to stop agent transaction. Wrapping the error in StopError.
e.Error = StopErr{
event.Error = StopErr{
Err: stopErr,
Cause: writeErr,
}
}
t.handle(e)
putClientTransaction(t)
transaction.handle(event)
putClientTransaction(transaction)
return
}
}
// Start starts transaction (if h set) and writes message to server, handler
// is called asynchronously.
func (c *Client) Start(m *Message, h Handler) error {
func (c *Client) Start(msg *Message, handler Handler) error {
if err := c.checkInit(); err != nil {
return err
}
@@ -694,34 +711,35 @@ func (c *Client) Start(m *Message, h Handler) error {
if closed {
return ErrClientClosed
}
if h != nil {
if handler != nil {
// Starting transaction only if h is set. Useful for indications.
t := acquireClientTransaction()
t.id = m.TransactionID
t.id = msg.TransactionID
t.start = c.clock.Now()
t.h = h
t.h = handler
t.rto = time.Duration(atomic.LoadInt64(&c.rto))
t.attempt = 0
t.raw = append(t.raw[:0], m.Raw...)
t.raw = append(t.raw[:0], msg.Raw...)
t.calls = 0
d := t.nextTimeout(t.start)
if err := c.start(t); err != nil {
return err
}
if err := c.a.Start(m.TransactionID, d); err != nil {
if err := c.a.Start(msg.TransactionID, d); err != nil {
return err
}
}
_, err := m.WriteTo(c.c)
if err != nil && h != nil {
c.delete(m.TransactionID)
_, err := msg.WriteTo(c.c)
if err != nil && handler != nil {
c.delete(msg.TransactionID)
// Stopping transaction instead of waiting until deadline.
if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
if stopErr := c.a.Stop(msg.TransactionID); stopErr != nil {
return StopErr{
Err: stopErr,
Cause: err,
}
}
}
return err
}

View File

@@ -38,11 +38,13 @@ type TestAgent struct {
func (n *TestAgent) SetHandler(h Handler) error {
n.h = h
return nil
}
func (n *TestAgent) Close() error {
close(n.e)
return nil
}
@@ -54,6 +56,7 @@ func (n *TestAgent) Start(id [TransactionIDSize]byte, _ time.Time) error {
n.e <- Event{
TransactionID: id,
}
return nil
}
@@ -69,6 +72,7 @@ func (noopConnection) Write(b []byte) (int, error) {
func (noopConnection) Read([]byte) (int, error) {
time.Sleep(time.Millisecond)
return 0, io.EOF
}
@@ -136,6 +140,7 @@ func (t *testConnection) Close() error {
return errClientAlreadyStopped
}
t.stopped = true
return nil
}
@@ -148,6 +153,7 @@ func (t *testConnection) Read(b []byte) (int, error) {
if t.read != nil {
return t.read(b)
}
return copy(b, t.b), nil
}
@@ -165,7 +171,7 @@ func TestClosedOrPanic(t *testing.T) {
}()
}
func TestClient_Start(t *testing.T) {
func TestClient_Start(t *testing.T) { //nolint:cyclop
response := MustBuild(TransactionID, BindingSuccess)
response.Encode()
write := make(chan struct{}, 1)
@@ -178,6 +184,7 @@ func TestClient_Start(t *testing.T) {
case <-read:
t.Log("reading")
copy(i, response.Raw)
return len(response.Raw), nil
case <-time.After(time.Millisecond * 10):
return 0, errClientReadTimedOut
@@ -188,33 +195,34 @@ func TestClient_Start(t *testing.T) {
select {
case <-write:
t.Log("writing")
return len(bytes), nil
case <-time.After(time.Millisecond * 10):
return 0, errClientWriteTimedOut
}
},
}
c, err := NewClient(conn)
client, err := NewClient(conn)
if err != nil {
log.Fatal(err)
}
defer func() {
if err := c.Close(); err != nil {
if err := client.Close(); err != nil {
t.Error(err)
}
if err := c.Close(); err == nil {
if err := client.Close(); err == nil {
t.Error("second close should fail")
}
if err := c.Do(MustBuild(TransactionID), nil); err == nil {
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail")
}
}()
m := MustBuild(response, BindingRequest)
msg := MustBuild(response, BindingRequest)
t.Log("init")
got := make(chan struct{})
write <- struct{}{}
t.Log("starting the first transaction")
if err := c.Start(m, func(event Event) {
if err := client.Start(msg, func(event Event) {
t.Log("got first transaction callback")
if event.Error != nil {
t.Error(event.Error)
@@ -224,7 +232,7 @@ func TestClient_Start(t *testing.T) {
t.Error(err)
}
t.Log("starting the second transaction")
if err := c.Start(m, func(Event) {
if err := client.Start(msg, func(Event) {
t.Error("should not be called")
}); !errors.Is(err, ErrTransactionExists) {
t.Errorf("unexpected error %v", err)
@@ -247,25 +255,25 @@ func TestClient_Do(t *testing.T) {
return len(bytes), nil
},
}
c, err := NewClient(conn)
client, err := NewClient(conn)
if err != nil {
log.Fatal(err)
}
defer func() {
if err := c.Close(); err != nil {
if err := client.Close(); err != nil {
t.Error(err)
}
if err := c.Close(); err == nil {
if err := client.Close(); err == nil {
t.Error("second close should fail")
}
if err := c.Do(MustBuild(TransactionID), nil); err == nil {
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail")
}
}()
m := MustBuild(
NewTransactionIDSetter(response.TransactionID),
)
if err := c.Do(m, func(event Event) {
if err := client.Do(m, func(event Event) {
if event.Error != nil {
t.Error(event.Error)
}
@@ -273,13 +281,13 @@ func TestClient_Do(t *testing.T) {
t.Error(err)
}
m = MustBuild(TransactionID)
if err := c.Do(m, nil); err != nil {
if err := client.Do(m, nil); err != nil {
t.Error(err)
}
}
func TestCloseErr_Error(t *testing.T) {
for id, c := range []struct {
for id, testCase := range []struct {
Err CloseErr
Out string
}{
@@ -291,16 +299,16 @@ func TestCloseErr_Error(t *testing.T) {
ConnectionErr: io.ErrUnexpectedEOF,
}, "failed to close: unexpected EOF (connection), <nil> (agent)"},
} {
if out := c.Err.Error(); out != c.Out {
if out := testCase.Err.Error(); out != testCase.Out {
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, c.Err, out, c.Out,
id, testCase.Err, out, testCase.Out,
)
}
}
}
func TestStopErr_Error(t *testing.T) {
for id, c := range []struct {
for id, testcase := range []struct {
Err StopErr
Out string
}{
@@ -312,9 +320,9 @@ func TestStopErr_Error(t *testing.T) {
Cause: io.ErrUnexpectedEOF,
}, "error while stopping due to unexpected EOF: <nil>"},
} {
if out := c.Err.Error(); out != c.Out {
if out := testcase.Err.Error(); out != testcase.Out {
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, c.Err, out, c.Out,
id, testcase.Err, out, testcase.Out,
)
}
}
@@ -352,7 +360,7 @@ func TestClientAgentError(t *testing.T) {
return len(bytes), nil
},
}
c, err := NewClient(conn,
client, err := NewClient(conn,
WithAgent(errorAgent{
startErr: io.ErrUnexpectedEOF,
}),
@@ -361,15 +369,15 @@ func TestClientAgentError(t *testing.T) {
log.Fatal(err)
}
defer func() {
if err := c.Close(); err != nil {
if err := client.Close(); err != nil {
t.Error(err)
}
}()
m := MustBuild(NewTransactionIDSetter(response.TransactionID))
if err := c.Do(m, nil); err != nil {
if err := client.Do(m, nil); err != nil {
t.Error(err)
}
if err := c.Do(m, func(event Event) {
if err := client.Do(m, func(event Event) {
if event.Error == nil {
t.Error("error expected")
}
@@ -384,20 +392,20 @@ func TestClientConnErr(t *testing.T) {
return 0, io.ErrClosedPipe
},
}
c, err := NewClient(conn)
client, err := NewClient(conn)
if err != nil {
log.Fatal(err)
}
defer func() {
if err := c.Close(); err != nil {
if err := client.Close(); err != nil {
t.Error(err)
}
}()
m := MustBuild(TransactionID)
if err := c.Do(m, nil); err == nil {
if err := client.Do(m, nil); err == nil {
t.Error("error expected")
}
if err := c.Do(m, NoopHandler()); err == nil {
if err := client.Do(m, NoopHandler()); err == nil {
t.Error("error expected")
}
}
@@ -408,7 +416,7 @@ func TestClientConnErrStopErr(t *testing.T) {
return 0, io.ErrClosedPipe
},
}
c, err := NewClient(conn,
client, err := NewClient(conn,
WithAgent(errorAgent{
stopErr: io.ErrUnexpectedEOF,
}),
@@ -417,12 +425,12 @@ func TestClientConnErrStopErr(t *testing.T) {
log.Fatal(err)
}
defer func() {
if err := c.Close(); err != nil {
if err := client.Close(); err != nil {
t.Error(err)
}
}()
m := MustBuild(TransactionID)
if err := c.Do(m, NoopHandler()); err == nil {
if err := client.Do(m, NoopHandler()); err == nil {
t.Error("error expected")
}
}
@@ -556,11 +564,13 @@ func (a *gcWaitAgent) Stop([TransactionIDSize]byte) error {
func (a *gcWaitAgent) Close() error {
close(a.gc)
return nil
}
func (a *gcWaitAgent) Collect(time.Time) error {
a.gc <- struct{}{}
return nil
}
@@ -617,6 +627,7 @@ func captureLog() (*bytes.Buffer, func()) {
log.SetOutput(&buf)
f := log.Flags()
log.SetFlags(0)
return &buf, func() {
log.SetFlags(f)
log.SetOutput(os.Stderr)
@@ -633,12 +644,12 @@ func TestClientFinalizer(t *testing.T) {
return 0, io.ErrClosedPipe
},
}
c, err := NewClient(conn)
client, err := NewClient(conn)
if err != nil {
log.Panic(err)
}
clientFinalizer(c)
clientFinalizer(c)
clientFinalizer(client)
clientFinalizer(client)
response := MustBuild(TransactionID, BindingSuccess)
response.Encode()
conn = &testConnection{
@@ -647,7 +658,7 @@ func TestClientFinalizer(t *testing.T) {
return len(bytes), nil
},
}
c, err = NewClient(conn,
client, err = NewClient(conn,
WithAgent(errorAgent{
closeErr: io.ErrUnexpectedEOF,
}),
@@ -655,7 +666,7 @@ func TestClientFinalizer(t *testing.T) {
if err != nil {
log.Panic(err)
}
clientFinalizer(c)
clientFinalizer(client)
reader := bufio.NewScanner(buf)
var lines int
expectedLines := []string{
@@ -700,6 +711,7 @@ func (m *manualCollector) Collect(t time.Time) {
func (m *manualCollector) Start(_ time.Duration, f func(t time.Time)) error {
m.f = f
return nil
}
@@ -717,12 +729,14 @@ func (m *manualClock) Add(d time.Duration) time.Time {
v := m.current.Add(d)
m.current = v
m.mux.Unlock()
return v
}
func (m *manualClock) Now() time.Time {
m.mux.Lock()
defer m.mux.Unlock()
return m.current
}
@@ -735,6 +749,7 @@ type manualAgent struct {
func (n *manualAgent) SetHandler(h Handler) error {
n.h = h
return nil
}
@@ -748,6 +763,7 @@ func (n *manualAgent) Process(m *Message) error {
if n.process != nil {
return n.process(m)
}
return nil
}
@@ -759,6 +775,7 @@ func (n *manualAgent) Stop(id [TransactionIDSize]byte) error {
if n.stop != nil {
return n.stop(id)
}
return nil
}
@@ -788,9 +805,10 @@ func TestClientRetransmission(t *testing.T) {
Message: response,
})
}
return nil
}
c, err := NewClient(connR,
client, err := NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -799,7 +817,7 @@ func TestClientRetransmission(t *testing.T) {
if err != nil {
t.Fatal(err)
}
c.SetRTO(time.Second)
client.SetRTO(time.Second)
gotReads := make(chan struct{})
go func() {
buf := make([]byte, 1500)
@@ -819,7 +837,7 @@ func TestClientRetransmission(t *testing.T) {
}
gotReads <- struct{}{}
}()
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if event.Error != nil {
t.Error("failed")
}
@@ -829,7 +847,9 @@ func TestClientRetransmission(t *testing.T) {
<-gotReads
}
func testClientDoConcurrent(t *testing.T, concurrency int) {
func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
t.Helper()
response := MustBuild(TransactionID, BindingSuccess)
response.Encode()
connL, connR := net.Pipe()
@@ -846,9 +866,10 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
TransactionID: id,
Message: response,
})
return nil
}
c, err := NewClient(connR,
client, err := NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -856,7 +877,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
if err != nil {
t.Fatal(err)
}
c.SetRTO(time.Second)
client.SetRTO(time.Second)
conns := new(sync.WaitGroup)
wg := new(sync.WaitGroup)
for i := 0; i < concurrency; i++ {
@@ -880,7 +901,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) {
wg.Add(1)
go func() {
defer wg.Done()
if doErr := c.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
if event.Error != nil {
t.Error("failed")
}
@@ -962,14 +983,14 @@ func TestClient_Close(t *testing.T) {
}
func TestClientDefaultHandler(t *testing.T) {
a := &TestAgent{
agent := &TestAgent{
e: make(chan Event),
}
id := NewTransactionID()
handlerCalled := make(chan struct{})
called := false
c, createErr := NewClient(noopConnection{},
WithAgent(a),
client, createErr := NewClient(noopConnection{},
WithAgent(agent),
WithHandler(func(e Event) {
if called {
t.Error("should not be called twice")
@@ -985,7 +1006,7 @@ func TestClientDefaultHandler(t *testing.T) {
t.Fatal(createErr)
}
go func() {
a.h(Event{
agent.h(Event{
TransactionID: id,
})
}()
@@ -995,11 +1016,11 @@ func TestClientDefaultHandler(t *testing.T) {
case <-time.After(time.Millisecond * 100):
t.Fatal("timed out")
}
if closeErr := c.Close(); closeErr != nil {
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
}
// Handler call should be ignored.
a.h(Event{})
agent.h(Event{})
}
func TestClientClosedStart(t *testing.T) {
@@ -1047,9 +1068,10 @@ func TestWithNoRetransmit(t *testing.T) {
Error: ErrTransactionTimeOut,
})
}
return nil
}
c, err := NewClient(connR,
client, err := NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -1071,7 +1093,7 @@ func TestWithNoRetransmit(t *testing.T) {
}
gotReads <- struct{}{}
}()
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrTransactionTimeOut) {
t.Error("unexpected error")
}
@@ -1087,7 +1109,7 @@ func (c callbackClock) Now() time.Time {
return c()
}
func TestClientRTOStartErr(t *testing.T) {
func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
response := MustBuild(TransactionID, BindingSuccess)
response.Encode()
connL, connR := net.Pipe()
@@ -1116,13 +1138,14 @@ func TestClientRTOStartErr(t *testing.T) {
} else {
t.Log("clock returned")
}
return time.Now()
})
agent := &manualAgent{}
attempt := 0
gotReads := make(chan struct{})
var (
c *Client
client *Client
startClientErr error
)
agent.start = func(id [TransactionIDSize]byte, _ time.Time) error {
@@ -1146,7 +1169,7 @@ func TestClientRTOStartErr(t *testing.T) {
t.Log("clock locked")
<-clockLocked
t.Log("closing client")
if closeErr := c.Close(); closeErr != nil {
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
}
t.Log("client closed, unlocking clock")
@@ -1154,9 +1177,10 @@ func TestClientRTOStartErr(t *testing.T) {
t.Log("clock unlocked")
}()
}
return nil
}
c, startClientErr = NewClient(connR,
client, startClientErr = NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -1186,7 +1210,7 @@ func TestClientRTOStartErr(t *testing.T) {
t.Log("starting")
done := make(chan struct{})
go func() {
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrClientClosed) {
t.Error(event.Error)
}
@@ -1203,7 +1227,7 @@ func TestClientRTOStartErr(t *testing.T) {
}
}
func TestClientRTOWriteErr(t *testing.T) {
func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
response := MustBuild(TransactionID, BindingSuccess)
response.Encode()
connL, connR := net.Pipe()
@@ -1232,13 +1256,14 @@ func TestClientRTOWriteErr(t *testing.T) {
} else {
t.Log("clock returned")
}
return time.Now()
})
agent := &manualAgent{}
attempt := 0
gotReads := make(chan struct{})
var (
c *Client
client *Client
startClientErr error
)
agentStopErr := errClientAgentCantStop
@@ -1274,9 +1299,10 @@ func TestClientRTOWriteErr(t *testing.T) {
t.Log("clock unlocked")
}()
}
return nil
}
c, startClientErr = NewClient(connR,
client, startClientErr = NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -1306,7 +1332,7 @@ func TestClientRTOWriteErr(t *testing.T) {
t.Log("starting")
done := make(chan struct{})
go func() {
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
var e StopErr
if !errors.As(event.Error, &e) {
t.Error(event.Error)
@@ -1346,7 +1372,7 @@ func TestClientRTOAgentErr(t *testing.T) {
attempt := 0
gotReads := make(chan struct{})
var (
c *Client
client *Client
startClientErr error
)
agentStartErr := errClientStartRefused
@@ -1361,9 +1387,10 @@ func TestClientRTOAgentErr(t *testing.T) {
} else {
return agentStartErr
}
return nil
}
c, startClientErr = NewClient(connR,
client, startClientErr = NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -1384,7 +1411,7 @@ func TestClientRTOAgentErr(t *testing.T) {
gotReads <- struct{}{}
}()
t.Log("starting")
if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, agentStartErr) {
t.Error(event.Error)
}
@@ -1415,9 +1442,10 @@ func TestClient_HandleProcessError(t *testing.T) {
processCalled := make(chan struct{}, 1)
agent.process = func(*Message) error {
processCalled <- struct{}{}
return ErrAgentClosed
}
c, startClientErr := NewClient(connR,
client, startClientErr := NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -1440,7 +1468,7 @@ func TestClient_HandleProcessError(t *testing.T) {
case <-time.After(time.Second * 5):
t.Error("reads timeout")
}
if closeErr := c.Close(); closeErr != nil {
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
}
}
@@ -1475,9 +1503,10 @@ func TestClientImmediateTimeout(t *testing.T) {
Error: ErrTransactionTimeOut,
})
}
return nil
}
c, err := NewClient(connR,
client, err := NewClient(connR,
WithAgent(agent),
WithClock(clock),
WithCollector(collector),
@@ -1498,7 +1527,7 @@ func TestClientImmediateTimeout(t *testing.T) {
}
gotReads <- struct{}{}
}()
c.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
if errors.Is(e.Error, ErrTransactionTimeOut) {
t.Error("unexpected error")
}

View File

@@ -30,7 +30,7 @@ var (
realRand = flag.Bool("crypt", false, "use crypto/rand as random source") //nolint:gochecknoglobals
)
func main() { //nolint:gocognit
func main() { //nolint:gocognit,cyclop
flag.Parse()
uri, err := stun.ParseURI(*uriStr)
if err != nil {
@@ -88,7 +88,7 @@ func main() { //nolint:gocognit
log.Print("Using crypto/rand as random source for transaction id")
}
for i := 0; i < *workers; i++ {
c, clientErr := stun.DialURI(uri, &stun.DialConfig{})
client, clientErr := stun.DialURI(uri, &stun.DialConfig{})
if clientErr != nil {
log.Panicf("Failed to create client: %s", clientErr)
}
@@ -105,12 +105,13 @@ func main() { //nolint:gocognit
req.Type = stun.BindingRequest
req.WriteHeader()
atomic.AddInt64(&request, 1)
if doErr := c.Do(req, func(event stun.Event) {
if doErr := client.Do(req, func(event stun.Event) {
if event.Error != nil {
if !errors.Is(event.Error, stun.ErrTransactionTimeOut) {
log.Printf("Failed STUN transaction: %s", event.Error)
}
atomic.AddInt64(&requestErr, 1)
return
}
atomic.AddInt64(&requestOK, 1)

View File

@@ -31,11 +31,11 @@ func main() {
}
// we only try the first address, so restrict ourselves to IPv4
c, err := stun.DialURI(uri, &stun.DialConfig{})
client, err := stun.DialURI(uri, &stun.DialConfig{})
if err != nil {
log.Fatalf("Failed to dial: %s", err)
}
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
if res.Error != nil {
log.Fatalf("Failed STUN transaction: %s", res.Error)
}
@@ -49,7 +49,7 @@ func main() {
}); err != nil {
log.Fatal("Do:", err)
}
if err := c.Close(); err != nil {
if err := client.Close(); err != nil {
log.Fatalf("Failed to close connection: %s", err)
}
}

View File

@@ -84,7 +84,7 @@ func multiplex(conn *net.UDPConn, stunAddr net.Addr, stunConn io.Reader) {
var stunServer = flag.String("stun", "stun.l.google.com:19302", "STUN Server to use") //nolint:gochecknoglobals
func main() {
func main() { //nolint:cyclop
flag.Parse()
isServer := flag.Arg(0) == ""
@@ -112,7 +112,7 @@ func main() {
stunL, stunR := net.Pipe()
c, err := stun.NewClient(stunR)
client, err := stun.NewClient(stunR)
if err != nil {
log.Panicf("Failed to create client: %s", err)
}
@@ -134,7 +134,7 @@ func main() {
// This can fail if your NAT Server is strict and will use separate ports
// for application data and STUN
var gotAddr stun.XORMappedAddress
if err = c.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
if err = client.Do(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
if res.Error != nil {
log.Panicf("Failed STUN transaction: %s", res.Error)
}
@@ -153,7 +153,7 @@ func main() {
// Any ping-pong will work, but we are just making binding requests.
// Note that STUN Server is not mandatory for keep alive, application
// data will keep alive that binding too.
go keepAlive(c)
go keepAlive(client)
notify := make(chan os.Signal, 1)
signal.Notify(notify, os.Interrupt, syscall.SIGTERM)
@@ -168,6 +168,7 @@ func main() {
}
case <-notify:
log.Println("Stopping")
return
}
}
@@ -203,10 +204,12 @@ func main() {
case m := <-messages:
log.Printf("Got response from %s: %s", m.addr, m.text)
return
case <-notify:
log.Print("Stopping")
return
}
}

View File

@@ -30,10 +30,14 @@ func (c *stunServerConn) Close() error {
}
var (
addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address") //nolint:gochecknoglobals
timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response") //nolint:gochecknoglobals
verbose = flag.Int("verbose", 1, "the verbosity level") //nolint:gochecknoglobals
log logging.LeveledLogger //nolint:gochecknoglobals
//nolint:gochecknoglobals
addrStrPtr = flag.String("server", "stun.voipgate.com:3478", "STUN server address")
//nolint:gochecknoglobals
timeoutPtr = flag.Int("timeout", 3, "the number of seconds to wait for STUN server's response")
//nolint:gochecknoglobals
verbose = flag.Int("verbose", 1, "the verbosity level")
//nolint:gochecknoglobals
log logging.LeveledLogger
)
const (
@@ -70,11 +74,12 @@ func main() {
}
}
// RFC5780: 4.3. Determining NAT Mapping Behavior
func mappingTests(addrStr string) error {
// RFC5780: 4.3. Determining NAT Mapping Behavior.
func mappingTests(addrStr string) error { //nolint:cyclop
mapTestConn, err := connect(addrStr)
if err != nil {
log.Warnf("Error creating STUN connection: %s", err)
return err
}
@@ -91,11 +96,13 @@ func mappingTests(addrStr string) error {
resps1 := parse(resp)
if resps1.xorAddr == nil || resps1.otherAddr == nil {
log.Info("Error: NAT discovery feature not supported by this server")
return errNoOtherAddress
}
addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String())
if err != nil {
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps1.otherAddr)
return err
}
mapTestConn.OtherAddr = addr
@@ -104,6 +111,7 @@ func mappingTests(addrStr string) error {
// Assert mapping behavior
if resps1.xorAddr.String() == mapTestConn.LocalAddr.String() {
log.Warn("=> NAT mapping behavior: endpoint independent (no NAT)")
return nil
}
@@ -121,6 +129,7 @@ func mappingTests(addrStr string) error {
log.Infof("Received XOR-MAPPED-ADDRESS: %v", resps2.xorAddr)
if resps2.xorAddr.String() == resps1.xorAddr.String() {
log.Warn("=> NAT mapping behavior: endpoint independent")
return nil
}
@@ -143,11 +152,12 @@ func mappingTests(addrStr string) error {
return mapTestConn.Close()
}
// RFC5780: 4.4. Determining NAT Filtering Behavior
func filteringTests(addrStr string) error {
// RFC5780: 4.4. Determining NAT Filtering Behavior.
func filteringTests(addrStr string) error { //nolint:cyclop
mapTestConn, err := connect(addrStr)
if err != nil {
log.Warnf("Error creating STUN connection: %s", err)
return err
}
@@ -162,11 +172,13 @@ func filteringTests(addrStr string) error {
resps := parse(resp)
if resps.xorAddr == nil || resps.otherAddr == nil {
log.Warn("Error: NAT discovery feature not supported by this server")
return errNoOtherAddress
}
addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String())
if err != nil {
log.Infof("Failed resolving OTHER-ADDRESS: %v", resps.otherAddr)
return err
}
mapTestConn.OtherAddr = addr
@@ -180,6 +192,7 @@ func filteringTests(addrStr string) error {
if err == nil {
parse(resp) // just to print out the resp
log.Warn("=> NAT filtering behavior: endpoint independent")
return nil
} else if !errors.Is(err, errTimedOut) {
return err // something else went wrong
@@ -201,7 +214,7 @@ func filteringTests(addrStr string) error {
return mapTestConn.Close()
}
// Parse a STUN message
// Parse a STUN message.
func parse(msg *stun.Message) (ret struct {
xorAddr *stun.XORMappedAddress
otherAddr *stun.OtherAddress
@@ -249,15 +262,17 @@ func parse(msg *stun.Message) (ret struct {
log.Debugf("\t%v (l=%v)", attr, attr.Length)
}
}
return ret
}
// Given an address string, returns a StunServerConn
// Given an address string, returns a StunServerConn.
func connect(addrStr string) (*stunServerConn, error) {
log.Infof("Connecting to STUN server: %s", addrStr)
addr, err := net.ResolveUDPAddr("udp4", addrStr)
if err != nil {
log.Warnf("Error resolving address: %s", err)
return nil, err
}
@@ -278,7 +293,7 @@ func connect(addrStr string) (*stunServerConn, error) {
}, nil
}
// Send request and wait for response or timeout
// Send request and wait for response or timeout.
func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Message, error) {
_ = msg.NewTransactionID()
log.Infof("Sending to %v: (%v bytes)", addr, msg.Length+messageHeaderSize)
@@ -289,6 +304,7 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
_, err := c.conn.WriteTo(msg.Raw, addr)
if err != nil {
log.Warnf("Error sending request to %v", addr)
return nil, err
}
@@ -298,9 +314,11 @@ func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr) (*stun.Mess
if !ok {
return nil, errResponseMessage
}
return m, nil
case <-time.After(time.Duration(*timeoutPtr) * time.Second):
log.Infof("Timed out waiting for response from server %v", addr)
return nil, errTimedOut
}
}
@@ -315,6 +333,7 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
n, addr, err := conn.ReadFromUDP(buf)
if err != nil {
close(messages)
return
}
log.Infof("Response from %v: (%v bytes)", addr, n)
@@ -326,11 +345,13 @@ func listen(conn *net.UDPConn) (messages chan *stun.Message) {
if err != nil {
log.Infof("Error decoding message: %v", err)
close(messages)
return
}
messages <- m
}
}()
return
}

View File

@@ -26,7 +26,7 @@ const (
timeoutMillis = 500
)
func main() { //nolint:gocognit
func main() { //nolint:gocognit,cyclop
flag.Parse()
srvAddr, err := net.ResolveUDPAddr(udp, *server)
@@ -87,11 +87,13 @@ func main() { //nolint:gocognit
decErr := m.Decode()
if decErr != nil {
log.Println("decode:", decErr)
break
}
var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil {
log.Println("getFrom:", getErr)
break
}
@@ -160,6 +162,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
n, _, err := conn.ReadFromUDP(buf)
if err != nil {
close(messages)
return
}
buf = buf[:n]
@@ -167,6 +170,7 @@ func listen(conn *net.UDPConn) <-chan []byte {
messages <- buf
}
}()
return messages
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/pion/stun/v3"
)
func test(network string) {
func test(network string) { //nolint:cyclop
addr := resolve(network)
fmt.Println("START", strings.ToUpper(addr.Network())) //nolint
var (

View File

@@ -11,7 +11,7 @@ import (
// ErrorCodeAttribute represents ERROR-CODE attribute.
//
// RFC 5389 Section 15.6
// RFC 5389 Section 15.6.
type ErrorCodeAttribute struct {
Code ErrorCode
Reason []byte
@@ -31,7 +31,7 @@ const (
)
// AddTo adds ERROR-CODE to m.
func (c ErrorCodeAttribute) AddTo(m *Message) error {
func (c ErrorCodeAttribute) AddTo(msg *Message) error {
value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB)
if err := CheckOverflow(AttrErrorCode,
len(c.Reason)+errorCodeReasonStart,
@@ -45,26 +45,28 @@ func (c ErrorCodeAttribute) AddTo(m *Message) error {
value[errorCodeClassByte] = class
value[errorCodeNumberByte] = number
copy(value[errorCodeReasonStart:], c.Reason)
m.Add(AttrErrorCode, value)
msg.Add(AttrErrorCode, value)
return nil
}
// GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid.
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
v, err := m.Get(AttrErrorCode)
value, err := m.Get(AttrErrorCode)
if err != nil {
return err
}
if len(v) < errorCodeReasonStart {
if len(value) < errorCodeReasonStart {
return io.ErrUnexpectedEOF
}
var (
class = uint16(v[errorCodeClassByte])
number = uint16(v[errorCodeNumberByte])
class = uint16(value[errorCodeClassByte])
number = uint16(value[errorCodeNumberByte])
code = int(class*errorCodeModulo + number)
)
c.Code = ErrorCode(code)
c.Reason = v[errorCodeReasonStart:]
c.Reason = value[errorCodeReasonStart:]
return nil
}
@@ -86,6 +88,7 @@ func (c ErrorCode) AddTo(m *Message) error {
Code: c,
Reason: reason,
}
return a.AddTo(m)
}
@@ -108,7 +111,7 @@ const (
// Error codes from RFC 5766.
//
// RFC 5766 Section 15
// RFC 5766 Section 15.
const (
CodeForbidden ErrorCode = 403 // Forbidden
CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch
@@ -120,7 +123,7 @@ const (
// Error codes from RFC 6062.
//
// RFC 6062 Section 6.3
// RFC 6062 Section 6.3.
const (
CodeConnAlreadyExists ErrorCode = 446
CodeConnTimeoutOrFailure ErrorCode = 447
@@ -128,7 +131,7 @@ const (
// Error codes from RFC 6156.
//
// RFC 6156 Section 10.2
// RFC 6156 Section 10.2.
const (
CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported
CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch

View File

@@ -92,23 +92,23 @@ func TestMessage_AddErrorCode(t *testing.T) {
}
func TestErrorCode(t *testing.T) {
a := &ErrorCodeAttribute{
attr := &ErrorCodeAttribute{
Code: 404,
Reason: []byte("not found!"),
}
if a.String() != "404: not found!" {
t.Error("bad string", a)
if attr.String() != "404: not found!" {
t.Error("bad string", attr)
}
m := New()
cod := ErrorCode(666)
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) {
t.Error("should be ErrNoDefaultReason", err)
}
if err := a.GetFrom(m); err == nil {
if err := attr.GetFrom(m); err == nil {
t.Error("attr should not be in message")
}
a.Reason = make([]byte, 2048)
if err := a.AddTo(m); err == nil {
attr.Reason = make([]byte, 2048)
if err := attr.AddTo(m); err == nil {
t.Error("should error")
}
}

View File

@@ -10,7 +10,7 @@ import (
// FingerprintAttr represents FINGERPRINT attribute.
//
// RFC 5389 Section 15.5
// RFC 5389 Section 15.5.
type FingerprintAttr struct{}
// ErrFingerprintMismatch means that computed fingerprint differs from expected.
@@ -50,6 +50,7 @@ func (FingerprintAttr) AddTo(m *Message) error {
bin.PutUint32(b, val)
m.Length = l
m.Add(AttrFingerprint, b)
return nil
}
@@ -66,5 +67,6 @@ func (FingerprintAttr) Check(m *Message) error {
val := bin.Uint32(b)
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
expected := FingerprintValue(m.Raw[:attrStart])
return checkFingerprint(val, expected)
}

View File

@@ -13,20 +13,20 @@ import (
func BenchmarkFingerprint_AddTo(b *testing.B) {
b.ReportAllocs()
m := new(Message)
msg := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5),
}
addAttr(b, m, addr)
addAttr(b, m, s)
b.SetBytes(int64(len(m.Raw)))
addAttr(b, msg, addr)
addAttr(b, msg, s)
b.SetBytes(int64(len(msg.Raw)))
for i := 0; i < b.N; i++ {
Fingerprint.AddTo(m) //nolint:errcheck,gosec
m.WriteLength()
m.Length -= attributeHeaderSize + fingerprintSize
m.Raw = m.Raw[:m.Length+messageHeaderSize]
m.Attributes = m.Attributes[:len(m.Attributes)-1]
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
msg.WriteLength()
msg.Length -= attributeHeaderSize + fingerprintSize
msg.Raw = msg.Raw[:msg.Length+messageHeaderSize]
msg.Attributes = msg.Attributes[:len(msg.Attributes)-1]
}
}

View File

@@ -109,6 +109,7 @@ func FuzzSetters(f *testing.F) {
if !IsAttrSizeOverflow(err) {
t.Fatal(err)
}
return
}
@@ -150,5 +151,6 @@ func (a attributes) pick(v byte) struct {
t AttrType
} {
idx := int(v) % len(a)
return a[idx]
}

View File

@@ -44,6 +44,7 @@ func (m *Message) Build(setters ...Setter) error {
return err
}
}
return nil
}
@@ -54,6 +55,7 @@ func (m *Message) Check(checkers ...Checker) error {
return err
}
}
return nil
}
@@ -64,6 +66,7 @@ func (m *Message) Parse(getters ...Getter) error {
return err
}
}
return nil
}
@@ -73,6 +76,7 @@ func MustBuild(setters ...Setter) *Message {
if err != nil {
panic(err) //nolint
}
return m
}
@@ -82,6 +86,7 @@ func Build(setters ...Setter) (*Message, error) {
if err := m.Build(setters...); err != nil {
return nil, err
}
return m, nil
}
@@ -105,5 +110,6 @@ func (m *Message) ForEach(t AttrType, f func(m *Message) error) error {
return err
}
}
return nil
}

View File

@@ -13,7 +13,7 @@ import (
func BenchmarkBuildOverhead(b *testing.B) {
var (
t = BindingRequest
msgType = BindingRequest
username = NewUsername("username")
nonce = NewNonce("nonce")
realm = NewRealm("example.org")
@@ -22,14 +22,14 @@ func BenchmarkBuildOverhead(b *testing.B) {
b.ReportAllocs()
m := new(Message)
for i := 0; i < b.N; i++ {
m.Build(&t, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec
m.Build(&msgType, &username, &nonce, &realm, &Fingerprint) //nolint:errcheck,gosec
}
})
b.Run("BuildNonPointer", func(b *testing.B) {
b.ReportAllocs()
m := new(Message)
for i := 0; i < b.N; i++ {
m.Build(t, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec
m.Build(msgType, username, nonce, realm, Fingerprint) //nolint:errcheck,gosec //nolint:errcheck,gosec
}
})
b.Run("Raw", func(b *testing.B) {
@@ -38,7 +38,7 @@ func BenchmarkBuildOverhead(b *testing.B) {
for i := 0; i < b.N; i++ {
m.Reset()
m.WriteHeader()
m.SetType(t)
m.SetType(msgType)
username.AddTo(m) //nolint:errcheck,gosec
nonce.AddTo(m) //nolint:errcheck,gosec
realm.AddTo(m) //nolint:errcheck,gosec
@@ -52,7 +52,7 @@ func TestMessage_Apply(t *testing.T) {
integrity = NewShortTermIntegrity("password")
decoded = new(Message)
)
m, err := Build(BindingRequest, TransactionID,
msg, err := Build(BindingRequest, TransactionID,
NewUsername("username"),
NewNonce("nonce"),
NewRealm("example.org"),
@@ -62,13 +62,13 @@ func TestMessage_Apply(t *testing.T) {
if err != nil {
t.Fatal("failed to build:", err)
}
if err = m.Check(Fingerprint, integrity); err != nil {
if err = msg.Check(Fingerprint, integrity); err != nil {
t.Fatal(err)
}
if _, err := decoded.Write(m.Raw); err != nil {
if _, err := decoded.Write(msg.Raw); err != nil {
t.Fatal(err)
}
if !decoded.Equal(m) {
if !decoded.Equal(msg) {
t.Error("not equal")
}
if err := integrity.Check(decoded); err != nil {
@@ -96,32 +96,32 @@ func (e errReturner) GetFrom(*Message) error {
func TestHelpersErrorHandling(t *testing.T) {
m := New()
e := errReturner{Err: errTError}
if err := m.Build(e); !errors.Is(err, e.Err) {
t.Error(err, "!=", e.Err)
errReturn := errReturner{Err: errTError}
if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
if err := m.Check(e); !errors.Is(err, e.Err) {
t.Error(err, "!=", e.Err)
if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
if err := m.Parse(e); !errors.Is(err, e.Err) {
t.Error(err, "!=", e.Err)
if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
t.Run("MustBuild", func(t *testing.T) {
t.Run("Positive", func(*testing.T) {
MustBuild(NewTransactionIDSetter(transactionID{}))
})
defer func() {
if p, ok := recover().(error); !ok || !errors.Is(p, e.Err) {
if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) {
t.Errorf("%s != %s",
p, e.Err,
p, errReturn.Err,
)
}
}()
MustBuild(e)
MustBuild(errReturn)
})
}
func TestMessage_ForEach(t *testing.T) {
func TestMessage_ForEach(t *testing.T) { //nolint:cyclop
initial := New()
if err := initial.Build(
NewRealm("realm1"), NewRealm("realm2"),
@@ -135,6 +135,7 @@ func TestMessage_ForEach(t *testing.T) {
); err != nil {
t.Fatal(err)
}
return m
}
t.Run("NoResults", func(t *testing.T) {
@@ -144,6 +145,7 @@ func TestMessage_ForEach(t *testing.T) {
}
if err := m.ForEach(AttrUsername, func(*Message) error {
t.Error("should not be called")
return nil
}); err != nil {
t.Fatal(err)
@@ -160,6 +162,7 @@ func TestMessage_ForEach(t *testing.T) {
t.Error("called multiple times")
}
calls++
return ErrAttributeNotFound
}); !errors.Is(err, ErrAttributeNotFound) {
t.Fatal(err)
@@ -169,14 +172,15 @@ func TestMessage_ForEach(t *testing.T) {
}
})
t.Run("Positive", func(t *testing.T) {
m := newMessage()
msg := newMessage()
var realms []string
if err := m.ForEach(AttrRealm, func(m *Message) error {
if err := msg.ForEach(AttrRealm, func(m *Message) error {
var realm Realm
if err := realm.GetFrom(m); err != nil {
return err
}
realms = append(realms, realm.String())
return nil
}); err != nil {
t.Fatal(err)
@@ -190,18 +194,18 @@ func TestMessage_ForEach(t *testing.T) {
if realms[1] != "realm2" {
t.Error("bad value for 2 realm")
}
if !m.Equal(initial) {
if !msg.Equal(initial) {
t.Error("m should be equal to initial")
}
t.Run("ZeroAlloc", func(t *testing.T) {
m = newMessage()
msg = newMessage()
var realm Realm
testutil.ShouldNotAllocate(t, func() {
if err := m.ForEach(AttrRealm, realm.GetFrom); err != nil {
if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil {
t.Fatal(err)
}
})
if !m.Equal(initial) {
if !msg.Equal(initial) {
t.Error("m should be equal to initial")
}
})
@@ -216,6 +220,7 @@ func ExampleMessage_ForEach() {
return err
}
fmt.Println(r)
return nil
}); err != nil {
fmt.Println("error:", err)

View File

@@ -14,22 +14,24 @@ import (
"testing"
)
func loadCSV(t testing.TB, name string) [][]string {
t.Helper()
data := loadData(t, name)
func loadCSV(tb testing.TB, name string) [][]string {
tb.Helper()
data := loadData(tb, name)
r := csv.NewReader(bytes.NewReader(data))
r.Comment = '#'
records, err := r.ReadAll()
if err != nil {
t.Fatal(err)
tb.Fatal(err)
}
return records
}
func TestIANA(t *testing.T) {
func TestIANA(t *testing.T) { //nolint:cyclop
t.Run("Methods", func(t *testing.T) {
records := loadCSV(t, "stun-parameters-2.csv")
m := make(map[string]Method)
methods := make(map[string]Method)
for _, r := range records[1:] {
var (
v = r[0]
@@ -43,10 +45,10 @@ func TestIANA(t *testing.T) {
t.Fatal(parseErr)
}
t.Logf("value: 0x%x, name: %s", val, name)
m[name] = Method(val)
methods[name] = Method(val) //nolint:gosec // G115
}
for val, name := range methodName() {
mapped, ok := m[name]
mapped, ok := methods[name]
if !ok {
t.Errorf("failed to find method %s in IANA", name)
}
@@ -57,7 +59,7 @@ func TestIANA(t *testing.T) {
})
t.Run("Attributes", func(t *testing.T) {
records := loadCSV(t, "stun-parameters-4.csv")
m := map[string]AttrType{}
attrTypes := map[string]AttrType{}
for _, r := range records[1:] {
var (
v = r[0]
@@ -71,16 +73,16 @@ func TestIANA(t *testing.T) {
t.Fatal(parseErr)
}
t.Logf("value: 0x%x, name: %s", val, name)
m[name] = AttrType(val)
attrTypes[name] = AttrType(val) //nolint:gosec // G115
}
// Not registered in IANA.
for k, v := range map[string]AttrType{
"ORIGIN": 0x802F,
} {
m[k] = v
attrTypes[k] = v
}
for val, name := range attrNames() {
mapped, ok := m[name]
mapped, ok := attrTypes[name]
if !ok {
t.Errorf("failed to find attribute %s in IANA", name)
}
@@ -91,7 +93,7 @@ func TestIANA(t *testing.T) {
})
t.Run("ErrorCodes", func(t *testing.T) {
records := loadCSV(t, "stun-parameters-6.csv")
m := map[string]ErrorCode{}
errorCodes := map[string]ErrorCode{}
for _, r := range records[1:] {
var (
v = r[0]
@@ -105,11 +107,11 @@ func TestIANA(t *testing.T) {
t.Fatal(parseErr)
}
t.Logf("value: 0x%x, name: %s", val, name)
m[name] = ErrorCode(val)
errorCodes[name] = ErrorCode(val)
}
for val, nameB := range errorReasons {
name := string(nameB)
mapped, ok := m[name]
mapped, ok := errorCodes[name]
if !ok {
t.Errorf("failed to find error code %s in IANA", name)
}

View File

@@ -22,6 +22,7 @@ func NewLongTermIntegrity(username, realm, password string) MessageIntegrity {
k := strings.Join([]string{username, realm, password}, credentialsSep)
h := md5.New() //nolint:gosec
fmt.Fprint(h, k) //nolint:errcheck
return MessageIntegrity(h.Sum(nil))
}
@@ -36,13 +37,14 @@ func NewShortTermIntegrity(password string) MessageIntegrity {
// AddTo and Check methods are using zero-allocation version of hmac, see
// newHMAC function and internal/hmac/pool.go.
//
// RFC 5389 Section 15.4
// RFC 5389 Section 15.4.
type MessageIntegrity []byte
func newHMAC(key, message, buf []byte) []byte {
mac := hmac.AcquireSHA1(key)
writeOrPanic(mac, message)
defer hmac.PutSHA1(mac)
return mac.Sum(buf)
}
@@ -59,8 +61,8 @@ var ErrFingerprintBeforeIntegrity = errors.New("FINGERPRINT before MESSAGE-INTEG
// AddTo adds MESSAGE-INTEGRITY attribute to message.
//
// CPU costly, see BenchmarkMessageIntegrity_AddTo.
func (i MessageIntegrity) AddTo(m *Message) error {
for _, a := range m.Attributes {
func (i MessageIntegrity) AddTo(msg *Message) error {
for _, a := range msg.Attributes {
// Message should not contain FINGERPRINT attribute
// before MESSAGE-INTEGRITY.
if a.Type == AttrFingerprint {
@@ -70,19 +72,20 @@ func (i MessageIntegrity) AddTo(m *Message) error {
// The text used as input to HMAC is the STUN message,
// including the header, up to and including the attribute preceding the
// MESSAGE-INTEGRITY attribute.
length := m.Length
length := msg.Length
// Adjusting m.Length to contain MESSAGE-INTEGRITY TLV.
m.Length += messageIntegritySize + attributeHeaderSize
m.WriteLength() // writing length to m.Raw
v := newHMAC(i, m.Raw, m.Raw[len(m.Raw):]) // calculating HMAC for adjusted m.Raw
m.Length = length // changing m.Length back
msg.Length += messageIntegritySize + attributeHeaderSize
msg.WriteLength() // writing length to m.Raw
v := newHMAC(i, msg.Raw, msg.Raw[len(msg.Raw):]) // calculating HMAC for adjusted m.Raw
msg.Length = length // changing m.Length back
// Copy hmac value to temporary variable to protect it from resetting
// while processing m.Add call.
vBuf := make([]byte, sha1.Size)
copy(vBuf, v)
m.Add(AttrMessageIntegrity, vBuf)
msg.Add(AttrMessageIntegrity, vBuf)
return nil
}
@@ -92,8 +95,8 @@ var ErrIntegrityMismatch = errors.New("integrity check failed")
// Check checks MESSAGE-INTEGRITY attribute.
//
// CPU costly, see BenchmarkMessageIntegrity_Check.
func (i MessageIntegrity) Check(m *Message) error {
v, err := m.Get(AttrMessageIntegrity)
func (i MessageIntegrity) Check(msg *Message) error {
val, err := msg.Get(AttrMessageIntegrity)
if err != nil {
return err
}
@@ -101,11 +104,11 @@ func (i MessageIntegrity) Check(m *Message) error {
// Adjusting length in header to match m.Raw that was
// used when computing HMAC.
var (
length = m.Length
length = msg.Length
afterIntegrity = false
sizeReduced int
)
for _, a := range m.Attributes {
for _, a := range msg.Attributes {
if afterIntegrity {
sizeReduced += nearestPaddedValueLength(int(a.Length))
sizeReduced += attributeHeaderSize
@@ -114,13 +117,14 @@ func (i MessageIntegrity) Check(m *Message) error {
afterIntegrity = true
}
}
m.Length -= uint32(sizeReduced)
m.WriteLength()
msg.Length -= uint32(sizeReduced) //nolint:gosec // G115
msg.WriteLength()
// startOfHMAC should be first byte of integrity attribute.
startOfHMAC := messageHeaderSize + m.Length - (attributeHeaderSize + messageIntegritySize)
b := m.Raw[:startOfHMAC] // data before integrity attribute
expected := newHMAC(i, b, m.Raw[len(m.Raw):])
m.Length = length
m.WriteLength() // writing length back
return checkHMAC(v, expected)
startOfHMAC := messageHeaderSize + msg.Length - (attributeHeaderSize + messageIntegritySize)
b := msg.Raw[:startOfHMAC] // data before integrity attribute
expected := newHMAC(i, b, msg.Raw[len(msg.Raw):])
msg.Length = length
msg.WriteLength() // writing length back
return checkHMAC(val, expected)
}

View File

@@ -10,18 +10,18 @@ import (
)
func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
i := NewLongTermIntegrity("user", "realm", "pass")
integrity := NewLongTermIntegrity("user", "realm", "pass")
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, i) {
if !bytes.Equal(expected, integrity) {
t.Error(ErrIntegrityMismatch)
}
t.Run("Check", func(t *testing.T) {
m := new(Message)
m.WriteHeader()
if err := i.AddTo(m); err != nil {
if err := integrity.AddTo(m); err != nil {
t.Error(err)
}
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
@@ -31,39 +31,39 @@ func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
if err := dM.Decode(); err != nil {
t.Error(err)
}
if err := i.Check(dM); err != nil {
if err := integrity.Check(dM); err != nil {
t.Error(err)
}
dM.Raw[24] += 12 // HMAC now invalid
if i.Check(dM) == nil {
if integrity.Check(dM) == nil {
t.Error("should be invalid")
}
})
}
func TestMessageIntegrityWithFingerprint(t *testing.T) {
m := new(Message)
m.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
m.WriteHeader()
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
i := NewShortTermIntegrity("pwd")
if i.String() != "KEY: 0x707764" {
t.Error("bad string", i)
msg := new(Message)
msg.TransactionID = [TransactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
msg.WriteHeader()
NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec
integrity := NewShortTermIntegrity("pwd")
if integrity.String() != "KEY: 0x707764" {
t.Error("bad string", integrity)
}
if err := i.Check(m); err == nil {
if err := integrity.Check(msg); err == nil {
t.Error("should error")
}
if err := i.AddTo(m); err != nil {
if err := integrity.AddTo(msg); err != nil {
t.Fatal(err)
}
if err := Fingerprint.AddTo(m); err != nil {
if err := Fingerprint.AddTo(msg); err != nil {
t.Fatal(err)
}
if err := i.Check(m); err != nil {
if err := integrity.Check(msg); err != nil {
t.Fatal(err)
}
m.Raw[24] = 33
if err := i.Check(m); err == nil {
msg.Raw[24] = 33
if err := integrity.Check(msg); err == nil {
t.Fatal("mismatch expected")
}
}

View File

@@ -64,6 +64,7 @@ func (h *hmac) Sum(in []byte) []byte {
h.outer.Write(h.opad) //nolint:errcheck,gosec
}
h.outer.Write(in[origLen:]) //nolint:errcheck,gosec
return h.outer.Sum(in[:origLen])
}
@@ -79,6 +80,7 @@ func (h *hmac) Reset() {
if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert
panic(err) //nolint
}
return
}

View File

@@ -22,7 +22,7 @@ type hmacTest struct {
blocksize int
}
func hmacTests() []hmacTest {
func hmacTests() []hmacTest { //nolint:maintidx
return []hmacTest{
// Tests from US FIPS 198
// https://csrc.nist.gov/publications/fips/fips198/fips-198a.pdf
@@ -523,41 +523,42 @@ func hmacTests() []hmacTest {
func TestHMAC(t *testing.T) {
for i, tt := range hmacTests() {
h := New(tt.hash, tt.key)
if s := h.Size(); s != tt.size {
hsh := New(tt.hash, tt.key)
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := h.BlockSize(); b != tt.blocksize {
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 4; j++ {
n, err := h.Write(tt.in)
for j := 0; j < 4; j++ { //nolint:varnamelen
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil))
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
}
// Second iteration: make sure reset works.
h.Reset()
hsh.Reset()
// Third and fourth iteration: make sure hmac works on
// hashes without MarshalBinary/UnmarshalBinary
if j == 1 {
h = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key)
hsh = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key)
}
}
}
}
// justHash implements just the hash.Hash methods and nothing else
// justHash implements just the hash.Hash methods and nothing else.
type justHash struct {
hash.Hash
}

View File

@@ -40,6 +40,7 @@ func (h *hmac) resetTo(key []byte) {
var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals
New: func() interface{} {
h := New(sha1.New, make([]byte, sha1.BlockSize))
return h
},
}
@@ -49,6 +50,7 @@ func AcquireSHA1(key []byte) hash.Hash {
h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert
assertHMACSize(h, sha1.Size, sha1.BlockSize)
h.resetTo(key)
return h
}
@@ -62,6 +64,7 @@ func PutSHA1(h hash.Hash) {
var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals
New: func() interface{} {
h := New(sha256.New, make([]byte, sha256.BlockSize))
return h
},
}
@@ -71,6 +74,7 @@ func AcquireSHA256(key []byte) hash.Hash {
h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert
assertHMACSize(h, sha256.Size, sha256.BlockSize)
h.resetTo(key)
return h
}

View File

@@ -42,100 +42,103 @@ func BenchmarkHMACSHA1_512_Pool(b *testing.B) {
func TestHMACReset(t *testing.T) {
for i, tt := range hmacTests() {
h := New(tt.hash, tt.key)
h.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
if s := h.Size(); s != tt.size {
hsh := New(tt.hash, tt.key)
hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := h.BlockSize(); b != tt.blocksize {
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 2; j++ {
n, err := h.Write(tt.in)
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil))
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
}
// Second iteration: make sure reset works.
h.Reset()
hsh.Reset()
}
}
}
func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl
func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop
for i, tt := range hmacTests() {
if tt.blocksize != sha1.BlockSize || tt.size != sha1.Size {
continue
}
h := AcquireSHA1(tt.key)
if s := h.Size(); s != tt.size {
hsh := AcquireSHA1(tt.key)
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := h.BlockSize(); b != tt.blocksize {
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 2; j++ {
n, err := h.Write(tt.in)
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil))
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
}
// Second iteration: make sure reset works.
h.Reset()
hsh.Reset()
}
PutSHA1(h)
PutSHA1(hsh)
}
}
func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl
func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop
for i, tt := range hmacTests() {
if tt.blocksize != sha256.BlockSize || tt.size != sha256.Size {
continue
}
h := AcquireSHA256(tt.key)
if s := h.Size(); s != tt.size {
hsh := AcquireSHA256(tt.key)
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := h.BlockSize(); b != tt.blocksize {
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 2; j++ {
n, err := h.Write(tt.in)
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum(nil))
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
}
// Second iteration: make sure reset works.
h.Reset()
hsh.Reset()
}
PutSHA256(h)
PutSHA256(hsh)
}
}

View File

@@ -10,8 +10,11 @@ import (
// ShouldNotAllocate fails if f allocates.
func ShouldNotAllocate(t *testing.T, f func()) {
t.Helper()
if Race {
t.Skip("skip while running with -race")
return
}
if a := testing.AllocsPerRun(10, f); a > 0 {

View File

@@ -32,6 +32,7 @@ const (
// as source.
func NewTransactionID() (b [TransactionIDSize]byte) {
readFullOrPanic(rand.Reader, b[:])
return b
}
@@ -45,6 +46,7 @@ func IsMessage(b []byte) bool {
// New returns *Message with pre-allocated Raw.
func New() *Message {
const defaultRawCapacity = 120
return &Message{
Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
}
@@ -59,6 +61,7 @@ func Decode(data []byte, m *Message) error {
return ErrDecodeToNil
}
m.Raw = append(m.Raw[:0], data...)
return m.Decode()
}
@@ -82,6 +85,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
// contract induced by other implementations.
b := make([]byte, len(m.Raw))
copy(b, m.Raw)
return b, nil
}
@@ -89,6 +93,7 @@ func (m Message) MarshalBinary() (data []byte, err error) {
func (m *Message) UnmarshalBinary(data []byte) error {
// We can't retain data, copy is expected by interface contract.
m.Raw = append(m.Raw[:0], data...)
return m.Decode()
}
@@ -108,6 +113,7 @@ func (m *Message) GobDecode(data []byte) error {
func (m *Message) AddTo(b *Message) error {
b.TransactionID = m.TransactionID
b.WriteTransactionID()
return nil
}
@@ -118,6 +124,7 @@ func (m *Message) NewTransactionID() error {
if err == nil {
m.WriteTransactionID()
}
return err
}
@@ -127,6 +134,7 @@ func (m *Message) String() string {
for k, a := range m.Attributes {
aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type)
}
return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo)
}
@@ -144,6 +152,7 @@ func (m *Message) grow(n int) {
}
if cap(m.Raw) >= n {
m.Raw = m.Raw[:n]
return
}
m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
@@ -153,7 +162,7 @@ func (m *Message) grow(n int) {
//
// Value of attribute is copied to internal buffer so
// it is safe to reuse v.
func (m *Message) Add(t AttrType, v []byte) {
func (m *Message) Add(attrType AttrType, val []byte) {
// Allocating buffer for TLV (type-length-value).
// T = t, L = len(v), V = v.
// m.Raw will look like:
@@ -163,31 +172,33 @@ func (m *Message) Add(t AttrType, v []byte) {
// [first:last] <- same as previous
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
// T L V
allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V)
allocSize := attributeHeaderSize + len(val) // ~ len(TLV) = len(TL) + len(V)
first := messageHeaderSize + int(m.Length) // first byte number
last := first + allocSize // last byte number
m.grow(last) // growing cap(Raw) to fit TLV
m.Raw = m.Raw[:last] // now len(Raw) = last
//nolint:gosec // G115
m.Length += uint32(allocSize) // rendering length change
// Sub-slicing internal buffer to simplify encoding.
buf := m.Raw[first:last] // slice for TLV
value := buf[attributeHeaderSize:] // slice for V
attr := RawAttribute{
Type: t, // T
Length: uint16(len(v)), // L
Type: attrType, // T
//nolint:gosec // G115
Length: uint16(len(val)), // L
Value: value, // V
}
// Encoding attribute TLV to allocated buffer.
bin.PutUint16(buf[0:2], attr.Type.Value()) // T
bin.PutUint16(buf[2:4], attr.Length) // L
copy(value, v) // V
copy(value, val) // V
// Checking that attribute value needs padding.
if attr.Length%padding != 0 {
// Performing padding.
bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
bytesToAdd := nearestPaddedValueLength(len(val)) - len(val)
last += bytesToAdd
m.grow(last)
// setting all padding bytes to zero
@@ -198,6 +209,7 @@ func (m *Message) Add(t AttrType, v []byte) {
buf[i] = 0
}
m.Raw = m.Raw[:last] // increasing buffer length
//nolint:gosec // G115
m.Length += uint32(bytesToAdd) // rendering length change
}
m.Attributes = append(m.Attributes, attr)
@@ -213,6 +225,7 @@ func attrSliceEqual(a, b Attributes) bool {
}
if attrB.Equal(attr) {
found = true
break
}
}
@@ -220,56 +233,59 @@ func attrSliceEqual(a, b Attributes) bool {
return false
}
}
return true
}
func attrEqual(a, b Attributes) bool {
if a == nil && b == nil {
func attrEqual(attrA, attrB Attributes) bool {
if attrA == nil && attrB == nil {
return true
}
if a == nil || b == nil {
if attrA == nil || attrB == nil {
return false
}
if len(a) != len(b) {
if len(attrA) != len(attrB) {
return false
}
if !attrSliceEqual(a, b) {
if !attrSliceEqual(attrA, attrB) {
return false
}
if !attrSliceEqual(b, a) {
if !attrSliceEqual(attrB, attrA) {
return false
}
return true
}
// Equal returns true if Message b equals to m.
// Equal returns true if Message msg equals to m.
// Ignores m.Raw.
func (m *Message) Equal(b *Message) bool {
if m == nil && b == nil {
func (m *Message) Equal(msg *Message) bool {
if m == nil && msg == nil {
return true
}
if m == nil || b == nil {
if m == nil || msg == nil {
return false
}
if m.Type != b.Type {
if m.Type != msg.Type {
return false
}
if m.TransactionID != b.TransactionID {
if m.TransactionID != msg.TransactionID {
return false
}
if m.Length != b.Length {
if m.Length != msg.Length {
return false
}
if !attrEqual(m.Attributes, b.Attributes) {
if !attrEqual(m.Attributes, msg.Attributes) {
return false
}
return true
}
// WriteLength writes m.Length to m.Raw.
func (m *Message) WriteLength() {
m.grow(4)
bin.PutUint16(m.Raw[2:4], uint16(m.Length))
bin.PutUint16(m.Raw[2:4], uint16(m.Length)) //nolint:gosec // G115
}
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
@@ -322,6 +338,7 @@ func (m *Message) Encode() {
// call result.
func (m *Message) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.Raw)
return int64(n), err
}
@@ -340,6 +357,7 @@ func (m *Message) ReadFrom(r io.Reader) (int64, error) {
return int64(n), err
}
m.Raw = tBuf[:n]
return int64(n), m.Decode()
}
@@ -355,22 +373,24 @@ func (m *Message) Decode() error {
return ErrUnexpectedHeaderEOF
}
var (
t = bin.Uint16(buf[0:2]) // first 2 bytes
msgType = bin.Uint16(buf[0:2]) // first 2 bytes
size = int(bin.Uint16(buf[2:4])) // second 2 bytes
cookie = bin.Uint32(buf[4:8]) // last 4 bytes
fullSize = messageHeaderSize + size // len(m.Raw)
)
if cookie != magicCookie {
msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
return newDecodeErr("message", "cookie", msg)
}
if len(buf) < fullSize {
msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
return newAttrDecodeErr("message", msg)
}
// saving header data
m.Type.ReadValue(t)
m.Length = uint32(size)
m.Type.ReadValue(msgType)
m.Length = uint32(size) //nolint:gosec // G115
copy(m.TransactionID[:], buf[8:messageHeaderSize])
m.Attributes = m.Attributes[:0]
@@ -382,28 +402,31 @@ func (m *Message) Decode() error {
// checking that we have enough bytes to read header
if len(b) < attributeHeaderSize {
msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
return newAttrDecodeErr("header", msg)
}
var (
a = RawAttribute{
attr = RawAttribute{
Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
Length: bin.Uint16(b[2:4]), // second 2 bytes
}
aL = int(a.Length) // attribute length
aL = int(attr.Length) // attribute length
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
)
b = b[attributeHeaderSize:] // slicing again to simplify value read
offset += attributeHeaderSize
if len(b) < aBuffL { // checking size
msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type)
msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, attr.Type)
return newAttrDecodeErr("value", msg)
}
a.Value = b[:aL]
attr.Value = b[:aL]
offset += aBuffL
b = b[aBuffL:]
m.Attributes = append(m.Attributes, a)
m.Attributes = append(m.Attributes, attr)
}
return nil
}
@@ -412,12 +435,14 @@ func (m *Message) Decode() error {
// Any error is unrecoverable, but message could be partially decoded.
func (m *Message) Write(tBuf []byte) (int, error) {
m.Raw = append(m.Raw[:0], tBuf...)
return len(tBuf), m.Decode()
}
// CloneTo clones m to b securing any further m mutations.
func (m *Message) CloneTo(b *Message) error {
b.Raw = append(b.Raw[:0], m.Raw...)
return b.Decode()
}
@@ -436,7 +461,7 @@ const (
var (
// Binding request message type.
BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
// Binding success response message type
// Binding success response message type.
BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
// Binding error response message type.
BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
@@ -501,6 +526,7 @@ func (m Method) String() string {
// Falling back to hex representation.
s = fmt.Sprintf("0x%x", uint16(m))
}
return s
}
@@ -513,6 +539,7 @@ type MessageType struct {
// AddTo sets m type to t.
func (t MessageType) AddTo(m *Message) error {
m.SetType(t)
return nil
}
@@ -554,13 +581,13 @@ func (t MessageType) Value() uint16 {
// Warning: Abandon all hope ye who enter here.
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
m := uint16(t.Method)
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
msg := uint16(t.Method)
a := msg & methodABits // A = M * 0b0000000000001111 (right 4 bits)
b := msg & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
d := msg & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
m = a + (b << methodBShift) + (d << methodDShift)
msg = a + (b << methodBShift) + (d << methodDShift)
// C0 is zero bit of C, C1 is first bit.
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
@@ -573,7 +600,7 @@ func (t MessageType) Value() uint16 {
c1 := (c & c1Bit) << classC1Shift
class := c0 + c1
return m + class
return msg + class
}
// ReadValue decodes uint16 into MessageType.
@@ -604,6 +631,7 @@ func (m *Message) Contains(t AttrType) bool {
return true
}
}
return false
}
@@ -618,5 +646,6 @@ func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
func (t transactionIDValueSetter) AddTo(m *Message) error {
m.TransactionID = t
m.WriteTransactionID()
return nil
}

View File

@@ -27,9 +27,11 @@ type attributeEncoder interface {
AddTo(m *Message) error
}
func addAttr(t testing.TB, m *Message, a attributeEncoder) {
func addAttr(tb testing.TB, m *Message, a attributeEncoder) {
tb.Helper()
if err := a.AddTo(m); err != nil {
t.Error(err)
tb.Error(err)
}
}
@@ -129,21 +131,21 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
}
func TestMessage_WriteTo(t *testing.T) {
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
msg := New()
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
msg.TransactionID = NewTransactionID()
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
msg.WriteHeader()
buf := new(bytes.Buffer)
if _, err := m.WriteTo(buf); err != nil {
if _, err := msg.WriteTo(buf); err != nil {
t.Fatal(err)
}
mDecoded := New()
if _, err := mDecoded.ReadFrom(buf); err != nil {
t.Error(err)
}
if !mDecoded.Equal(m) {
t.Error(mDecoded, "!", m)
if !mDecoded.Equal(msg) {
t.Error(mDecoded, "!", msg)
}
}
@@ -275,23 +277,23 @@ func BenchmarkMessage_WriteTo(b *testing.B) {
func BenchmarkMessage_ReadFrom(b *testing.B) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
m := &Message{
msg := &Message{
Type: mType,
Length: 0,
TransactionID: [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
},
}
m.WriteHeader()
msg.WriteHeader()
b.ReportAllocs()
b.SetBytes(int64(len(m.Raw)))
reader := m.reader()
b.SetBytes(int64(len(msg.Raw)))
reader := msg.reader()
mRec := New()
for i := 0; i < b.N; i++ {
if _, err := mRec.ReadFrom(reader); err != nil {
b.Fatal(err)
}
reader.Reset(m.Raw)
reader.Reset(msg.Raw)
mRec.Reset()
}
}
@@ -342,7 +344,7 @@ func TestMessageClass_String(t *testing.T) {
}
func TestAttrType_String(t *testing.T) {
v := [...]AttrType{
attrType := [...]AttrType{
AttrMappedAddress,
AttrUsername,
AttrErrorCode,
@@ -355,7 +357,7 @@ func TestAttrType_String(t *testing.T) {
AttrAlternateServer,
AttrFingerprint,
}
for _, k := range v {
for _, k := range attrType {
if k.String() == "" {
t.Error(k, "bad stringer")
}
@@ -379,80 +381,80 @@ func TestMethod_String(t *testing.T) {
}
func TestAttribute_Equal(t *testing.T) {
a := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
b := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
if !a.Equal(b) {
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
if !attr1.Equal(attr2) {
t.Error("should equal")
}
if a.Equal(RawAttribute{Type: 0x2}) {
if attr1.Equal(RawAttribute{Type: 0x2}) {
t.Error("should not equal")
}
if a.Equal(RawAttribute{Length: 0x2}) {
if attr1.Equal(RawAttribute{Length: 0x2}) {
t.Error("should not equal")
}
if a.Equal(RawAttribute{Length: 0x3}) {
if attr1.Equal(RawAttribute{Length: 0x3}) {
t.Error("should not equal")
}
if a.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
t.Error("should not equal")
}
}
func TestMessage_Equal(t *testing.T) {
func TestMessage_Equal(t *testing.T) { //nolint:cyclop
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attrs := Attributes{attr}
a := &Message{Attributes: attrs, Length: 4 + 2}
b := &Message{Attributes: attrs, Length: 4 + 2}
if !a.Equal(b) {
msg1 := &Message{Attributes: attrs, Length: 4 + 2}
msg2 := &Message{Attributes: attrs, Length: 4 + 2}
if !msg1.Equal(msg2) {
t.Error("should equal")
}
if a.Equal(&Message{Type: MessageType{Class: 128}}) {
if msg1.Equal(&Message{Type: MessageType{Class: 128}}) {
t.Error("should not equal")
}
tID := [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
}
if a.Equal(&Message{TransactionID: tID}) {
if msg1.Equal(&Message{TransactionID: tID}) {
t.Error("should not equal")
}
if a.Equal(&Message{Length: 3}) {
if msg1.Equal(&Message{Length: 3}) {
t.Error("should not equal")
}
tAttrs := Attributes{
{Length: 1, Value: []byte{0x1}, Type: 0x1},
}
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal")
}
tAttrs = Attributes{
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
}
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal")
}
if !(*Message)(nil).Equal(nil) {
t.Error("nil should be equal to nil")
}
if a.Equal(nil) {
if msg1.Equal(nil) {
t.Error("non-nil should not be equal to nil")
}
t.Run("Nil attributes", func(t *testing.T) {
a := &Message{
msg1 := &Message{
Attributes: nil,
Length: 4 + 2,
}
b := &Message{
msg2 := &Message{
Attributes: attrs,
Length: 4 + 2,
}
if a.Equal(b) {
if msg1.Equal(msg2) {
t.Error("should not equal")
}
if b.Equal(a) {
if msg2.Equal(msg1) {
t.Error("should not equal")
}
b.Attributes = nil
if !a.Equal(b) {
msg2.Attributes = nil
if !msg1.Equal(msg2) {
t.Error("should equal")
}
})
@@ -547,6 +549,8 @@ func BenchmarkIsMessage(b *testing.B) {
}
func loadData(tb testing.TB, name string) []byte {
tb.Helper()
name = filepath.Join("testdata", name)
f, err := os.Open(name) //nolint:gosec
if err != nil {
@@ -561,6 +565,7 @@ func loadData(tb testing.TB, name string) []byte {
if err != nil {
tb.Fatal(err)
}
return v
}
@@ -582,7 +587,7 @@ func TestMessageFromBrowsers(t *testing.T) {
t.Fatal("failed to skip header of csv: ", err)
}
crcTable := crc64.MakeTable(crc64.ISO)
m := New()
msg := New()
for {
line, err := reader.Read()
if errors.Is(err, io.EOF) {
@@ -602,10 +607,10 @@ func TestMessageFromBrowsers(t *testing.T) {
if b != crc64.Checksum(data, crcTable) {
t.Error("crc64 check failed for ", line[1])
}
if _, err = m.Write(data); err != nil {
if _, err = msg.Write(data); err != nil {
t.Error("failed to decode ", line[1], " as message: ", err)
}
m.Reset()
msg.Reset()
}
}
@@ -622,42 +627,42 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) {
func BenchmarkMessageFull(b *testing.B) {
b.ReportAllocs()
m := new(Message)
msg := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5),
}
for i := 0; i < b.N; i++ {
if err := addr.AddTo(m); err != nil {
if err := addr.AddTo(msg); err != nil {
b.Fatal(err)
}
if err := s.AddTo(m); err != nil {
if err := s.AddTo(msg); err != nil {
b.Fatal(err)
}
m.WriteAttributes()
m.WriteHeader()
Fingerprint.AddTo(m) //nolint:errcheck,gosec
m.WriteHeader()
m.Reset()
msg.WriteAttributes()
msg.WriteHeader()
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
msg.WriteHeader()
msg.Reset()
}
}
func BenchmarkMessageFullHardcore(b *testing.B) {
b.ReportAllocs()
m := new(Message)
msg := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5),
}
for i := 0; i < b.N; i++ {
if err := addr.AddTo(m); err != nil {
if err := addr.AddTo(msg); err != nil {
b.Fatal(err)
}
if err := s.AddTo(m); err != nil {
if err := s.AddTo(msg); err != nil {
b.Fatal(err)
}
m.WriteHeader()
m.Reset()
msg.WriteHeader()
msg.Reset()
}
}
@@ -689,8 +694,8 @@ func TestMessage_Contains(t *testing.T) {
func ExampleMessage() {
buf := new(bytes.Buffer)
m := new(Message)
m.Build(BindingRequest, //nolint:errcheck,gosec
msg := new(Message)
msg.Build(BindingRequest, //nolint:errcheck,gosec
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
@@ -707,8 +712,8 @@ func ExampleMessage() {
// m.Build(&software) // no allocations
// If you pass software as value, there will be 1 allocation.
// This rule is correct for all setters.
fmt.Println(m, "buff length:", len(m.Raw))
n, err := m.WriteTo(buf)
fmt.Println(msg, "buff length:", len(msg.Raw))
n, err := msg.WriteTo(buf)
fmt.Println("wrote", n, "err", err)
// Decoding from buf new *Message.
@@ -743,6 +748,7 @@ func ExampleMessage() {
fmt.Println("fingerprint: failed")
}
//nolint:lll
// Output:
// Binding request l=48 attrs=3 id=AQIDBAUGBwgJAAEA, attr0=SOFTWARE attr1=MESSAGE-INTEGRITY attr2=FINGERPRINT buff length: 68
// wrote 68 err <nil>
@@ -811,8 +817,8 @@ func TestAllocationsGetters(t *testing.T) {
NewShortTermIntegrity("pwd"),
Fingerprint,
}
m := New()
if err := m.Build(setters...); err != nil {
msg := New()
if err := msg.Build(setters...); err != nil {
t.Error("failed to build", err)
}
getters := []Getter{
@@ -826,7 +832,7 @@ func TestAllocationsGetters(t *testing.T) {
g := g
i := i
allocs := testing.AllocsPerRun(10, func() {
if err := g.GetFrom(m); err != nil {
if err := g.GetFrom(msg); err != nil {
t.Errorf("[%d] failed to get", i)
}
})
@@ -837,8 +843,8 @@ func TestAllocationsGetters(t *testing.T) {
}
func TestMessageFullSize(t *testing.T) {
m := new(Message)
if err := m.Build(BindingRequest,
msg := new(Message)
if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
@@ -848,18 +854,18 @@ func TestMessageFullSize(t *testing.T) {
); err != nil {
t.Fatal(err)
}
m.Raw = m.Raw[:len(m.Raw)-10]
msg.Raw = msg.Raw[:len(msg.Raw)-10]
decoder := new(Message)
decoder.Raw = m.Raw[:len(m.Raw)-10]
decoder.Raw = msg.Raw[:len(msg.Raw)-10]
if err := decoder.Decode(); err == nil {
t.Error("decode on truncated buffer should error")
}
}
func TestMessage_CloneTo(t *testing.T) {
m := new(Message)
if err := m.Build(BindingRequest,
msg := new(Message)
if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
@@ -869,29 +875,29 @@ func TestMessage_CloneTo(t *testing.T) {
); err != nil {
t.Fatal(err)
}
m.Encode()
b := new(Message)
if err := m.CloneTo(b); err != nil {
msg.Encode()
msg2 := new(Message)
if err := msg.CloneTo(msg2); err != nil {
t.Fatal(err)
}
if !b.Equal(m) {
if !msg2.Equal(msg) {
t.Fatal("not equal")
}
// Corrupting m and checking that b is not corrupted.
s, ok := b.Attributes.Get(AttrSoftware)
s, ok := msg2.Attributes.Get(AttrSoftware)
if !ok {
t.Fatal("no software attribute")
}
s.Value[0] = 'k'
if b.Equal(m) {
if msg2.Equal(msg) {
t.Fatal("should not be equal")
}
}
func BenchmarkMessage_CloneTo(b *testing.B) {
b.ReportAllocs()
m := new(Message)
if err := m.Build(BindingRequest,
msg := new(Message)
if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
@@ -901,19 +907,19 @@ func BenchmarkMessage_CloneTo(b *testing.B) {
); err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(m.Raw)))
b.SetBytes(int64(len(msg.Raw)))
a := new(Message)
m.CloneTo(a) //nolint:errcheck,gosec
msg.CloneTo(a) //nolint:errcheck,gosec
for i := 0; i < b.N; i++ {
if err := m.CloneTo(a); err != nil {
if err := msg.CloneTo(a); err != nil {
b.Fatal(err)
}
}
}
func TestMessage_AddTo(t *testing.T) {
m := new(Message)
if err := m.Build(BindingRequest,
msg := new(Message)
if err := msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
@@ -921,19 +927,19 @@ func TestMessage_AddTo(t *testing.T) {
); err != nil {
t.Fatal(err)
}
m.Encode()
msg.Encode()
b := new(Message)
if err := m.CloneTo(b); err != nil {
if err := msg.CloneTo(b); err != nil {
t.Fatal(err)
}
m.TransactionID = [TransactionIDSize]byte{
msg.TransactionID = [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
}
if b.Equal(m) {
if b.Equal(msg) {
t.Fatal("should not be equal")
}
m.AddTo(b) //nolint:errcheck,gosec
if !b.Equal(m) {
msg.AddTo(b) //nolint:errcheck,gosec
if !b.Equal(msg) {
t.Fatal("should be equal")
}
}
@@ -964,22 +970,22 @@ func TestDecode(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
})
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
msg := New()
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
msg.TransactionID = NewTransactionID()
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
msg.WriteHeader()
mDecoded := New()
if err := Decode(m.Raw, mDecoded); err != nil {
if err := Decode(msg.Raw, mDecoded); err != nil {
t.Errorf("unexpected error: %v", err)
}
if !mDecoded.Equal(m) {
if !mDecoded.Equal(msg) {
t.Error("decoded result is not equal to encoded message")
}
t.Run("ZeroAlloc", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() {
mDecoded.Reset()
if err := Decode(m.Raw, mDecoded); err != nil {
if err := Decode(msg.Raw, mDecoded); err != nil {
t.Error(err)
}
})
@@ -1008,22 +1014,22 @@ func BenchmarkDecode(b *testing.B) {
}
func TestMessage_MarshalBinary(t *testing.T) {
m := MustBuild(
msg := MustBuild(
NewSoftware("software"),
&XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5),
},
)
data, err := m.MarshalBinary()
data, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
// Reset m.Raw to check retention.
for i := range m.Raw {
m.Raw[i] = 0
for i := range msg.Raw {
msg.Raw[i] = 0
}
if err := m.UnmarshalBinary(data); err != nil {
if err := msg.UnmarshalBinary(data); err != nil {
t.Fatal(err)
}
@@ -1031,28 +1037,28 @@ func TestMessage_MarshalBinary(t *testing.T) {
for i := range data {
data[i] = 0
}
if err := m.Decode(); err != nil {
if err := msg.Decode(); err != nil {
t.Fatal(err)
}
}
func TestMessage_GobDecode(t *testing.T) {
m := MustBuild(
msg := MustBuild(
NewSoftware("software"),
&XORMappedAddress{
IP: net.IPv4(213, 1, 223, 5),
},
)
data, err := m.GobEncode()
data, err := msg.GobEncode()
if err != nil {
t.Fatal(err)
}
// Reset m.Raw to check retention.
for i := range m.Raw {
m.Raw[i] = 0
for i := range msg.Raw {
msg.Raw[i] = 0
}
if err := m.GobDecode(data); err != nil {
if err := msg.GobDecode(data); err != nil {
t.Fatal(err)
}
@@ -1060,7 +1066,7 @@ func TestMessage_GobDecode(t *testing.T) {
for i := range data {
data[i] = 0
}
if err := m.Decode(); err != nil {
if err := msg.Decode(); err != nil {
t.Fatal(err)
}
}

View File

@@ -8,7 +8,7 @@ import (
"testing"
)
func TestRFC5769(t *testing.T) {
func TestRFC5769(t *testing.T) { //nolint:cyclop
// Test Vectors for Session Traversal Utilities for NAT (STUN)
// see https://tools.ietf.org/html/rfc5769
t.Run("Request", func(t *testing.T) {
@@ -46,7 +46,7 @@ func TestRFC5769(t *testing.T) {
t.Error("check failed: ", err)
}
t.Run("Long-Term credentials", func(t *testing.T) {
m := &Message{
msg := &Message{
Raw: []byte("\x00\x01\x00\x60" +
"\x21\x12\xa4\x42" +
"\x78\xad\x34\x33\xc6\xad\x72\xc0\x29\xda\x41\x2e" +
@@ -64,11 +64,11 @@ func TestRFC5769(t *testing.T) {
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
),
}
if err := m.Decode(); err != nil {
if err := msg.Decode(); err != nil {
t.Error(err)
}
u := new(Username)
if err := u.GetFrom(m); err != nil {
if err := u.GetFrom(msg); err != nil {
t.Error(err)
}
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
@@ -76,14 +76,14 @@ func TestRFC5769(t *testing.T) {
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
}
n := new(Nonce)
if err := n.GetFrom(m); err != nil {
if err := n.GetFrom(msg); err != nil {
t.Error(err)
}
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
t.Error("bad nonce")
}
r := new(Realm)
if err := r.GetFrom(m); err != nil {
if err := r.GetFrom(msg); err != nil {
t.Error(err)
}
if r.String() != "example.org" { //nolint:goconst
@@ -95,14 +95,14 @@ func TestRFC5769(t *testing.T) {
"example.org",
"TheMatrIX",
)
if err := i.Check(m); err != nil {
if err := i.Check(msg); err != nil {
t.Error(err)
}
})
})
t.Run("Response", func(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
m := &Message{
msg := &Message{
Raw: []byte("\x01\x01\x00\x3c" +
"\x21\x12\xa4\x42" +
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
@@ -117,21 +117,21 @@ func TestRFC5769(t *testing.T) {
"\xc0\x7d\x4c\x96",
),
}
if err := m.Decode(); err != nil {
if err := msg.Decode(); err != nil {
t.Error(err)
}
software := new(Software)
if err := software.GetFrom(m); err != nil {
if err := software.GetFrom(msg); err != nil {
t.Error(err)
}
if software.String() != "test vector" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(m); err != nil {
if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err)
}
addr := new(XORMappedAddress)
if err := addr.GetFrom(m); err != nil {
if err := addr.GetFrom(msg); err != nil {
t.Error(err)
}
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) {
@@ -140,12 +140,12 @@ func TestRFC5769(t *testing.T) {
if addr.Port != 32853 {
t.Error("bad Port")
}
if err := Fingerprint.Check(m); err != nil {
if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err)
}
})
t.Run("IPv6", func(t *testing.T) {
m := &Message{
msg := &Message{
Raw: []byte("\x01\x01\x00\x48" +
"\x21\x12\xa4\x42" +
"\xb7\xe7\xa7\x01\xbc\x34\xd6\x86\xfa\x87\xdf\xae" +
@@ -162,21 +162,21 @@ func TestRFC5769(t *testing.T) {
"\xc8\xfb\x0b\x4c",
),
}
if err := m.Decode(); err != nil {
if err := msg.Decode(); err != nil {
t.Error(err)
}
software := new(Software)
if err := software.GetFrom(m); err != nil {
if err := software.GetFrom(msg); err != nil {
t.Error(err)
}
if software.String() != "test vector" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(m); err != nil {
if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err)
}
addr := new(XORMappedAddress)
if err := addr.GetFrom(m); err != nil {
if err := addr.GetFrom(msg); err != nil {
t.Error(err)
}
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) {
@@ -185,7 +185,7 @@ func TestRFC5769(t *testing.T) {
if addr.Port != 32853 {
t.Error("bad Port")
}
if err := Fingerprint.Check(m); err != nil {
if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err)
}
})

View File

@@ -27,6 +27,7 @@ func readFullOrPanic(r io.Reader, v []byte) int {
if err != nil {
panic(err) //nolint
}
return n
}
@@ -35,6 +36,7 @@ func writeOrPanic(w io.Writer, v []byte) int {
if err != nil {
panic(err) //nolint
}
return n
}

View File

@@ -13,14 +13,19 @@ import (
var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
// NewUDPServer creates an udp server for testing. The supplied handler function will be called with the request
// NewUDPServer creates an udp server for testing.
// The supplied handler function will be called with the request
// and should be used to emulate the server behavior.
//
//nolint:cyclop
func NewUDPServer(
t *testing.T,
network string,
maxMessageSize int,
handler func(req []byte) ([]byte, error),
) (net.Addr, func(t *testing.T), error) {
t.Helper()
var ip string
switch network {
case "udp4":
@@ -50,28 +55,34 @@ func NewUDPServer(
n, addr, err := udpConn.ReadFrom(bs)
if err != nil {
errCh <- err
return
}
resp, err := handler(bs[:n])
if err != nil {
errCh <- err
return
}
_, err = udpConn.WriteTo(resp, addr)
if err != nil {
errCh <- err
return
}
}
}()
return serverAddr, func(t *testing.T) {
t.Helper()
select {
case err := <-errCh:
if err != nil {
t.Fatal(err) //nolint
t.Fatal(err)
return
}
default:
@@ -79,7 +90,7 @@ func NewUDPServer(
err := udpConn.Close()
if err != nil {
t.Fatal(err) //nolint
t.Fatal(err)
}
<-errCh

View File

@@ -10,7 +10,7 @@ func NewUsername(username string) Username {
// Username represents USERNAME attribute.
//
// RFC 5389 Section 15.3
// RFC 5389 Section 15.3.
type Username []byte
func (u Username) String() string {
@@ -37,7 +37,7 @@ func NewRealm(realm string) Realm {
// Realm represents REALM attribute.
//
// RFC 5389 Section 15.7
// RFC 5389 Section 15.7.
type Realm []byte
func (n Realm) String() string {
@@ -60,7 +60,7 @@ const softwareRawMaxB = 763
// Software is SOFTWARE attribute.
//
// RFC 5389 Section 15.10
// RFC 5389 Section 15.10.
type Software []byte
func (s Software) String() string {
@@ -84,7 +84,7 @@ func (s *Software) GetFrom(m *Message) error {
// Nonce represents NONCE attribute.
//
// RFC 5389 Section 15.8
// RFC 5389 Section 15.8.
type Nonce []byte
// NewNonce returns new Nonce from string.
@@ -118,6 +118,7 @@ func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error {
return err
}
m.Add(t, v)
return nil
}
@@ -128,5 +129,6 @@ func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error {
return err
}
*v = a
return nil
}

View File

@@ -13,26 +13,26 @@ import (
)
func TestSoftware_GetFrom(t *testing.T) {
m := New()
v := "Client v0.0.1"
m.Add(AttrSoftware, []byte(v))
m.WriteHeader()
msg := New()
val := "Client v0.0.1"
msg.Add(AttrSoftware, []byte(val))
msg.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
software := new(Software)
if _, err := m2.ReadFrom(m.reader()); err != nil {
if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err)
}
if err := software.GetFrom(m); err != nil {
if err := software.GetFrom(msg); err != nil {
t.Fatal(err)
}
if software.String() != v {
t.Errorf("Expected %q, got %q.", v, software)
if software.String() != val {
t.Errorf("Expected %q, got %q.", val, software)
}
sAttr, ok := m.Attributes.Get(AttrSoftware)
sAttr, ok := msg.Attributes.Get(AttrSoftware)
if !ok {
t.Error("software attribute should be found")
}
@@ -90,22 +90,22 @@ func BenchmarkUsername_GetFrom(b *testing.B) {
func TestUsername(t *testing.T) {
username := "username"
u := NewUsername(username)
m := new(Message)
m.WriteHeader()
uName := NewUsername(username)
msg := new(Message)
msg.WriteHeader()
t.Run("Bad length", func(t *testing.T) {
badU := make(Username, 600)
if err := badU.AddTo(m); !IsAttrSizeOverflow(err) {
if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
})
t.Run("AddTo", func(t *testing.T) {
if err := u.AddTo(m); err != nil {
if err := uName.AddTo(msg); err != nil {
t.Error("errored:", err)
}
t.Run("GetFrom", func(t *testing.T) {
got := new(Username)
if err := got.GetFrom(m); err != nil {
if err := got.GetFrom(msg); err != nil {
t.Error("errored:", err)
}
if got.String() != username {
@@ -136,10 +136,10 @@ func TestUsername(t *testing.T) {
}
func TestRealm_GetFrom(t *testing.T) {
m := New()
v := "realm"
m.Add(AttrRealm, []byte(v))
m.WriteHeader()
msg := New()
val := "realm"
msg.Add(AttrRealm, []byte(val))
msg.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
@@ -148,17 +148,17 @@ func TestRealm_GetFrom(t *testing.T) {
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
if _, err := m2.ReadFrom(m.reader()); err != nil {
if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err)
}
if err := r.GetFrom(m); err != nil {
if err := r.GetFrom(msg); err != nil {
t.Fatal(err)
}
if r.String() != v {
t.Errorf("Expected %q, got %q.", v, r)
if r.String() != val {
t.Errorf("Expected %q, got %q.", val, r)
}
rAttr, ok := m.Attributes.Get(AttrRealm)
rAttr, ok := msg.Attributes.Get(AttrRealm)
if !ok {
t.Error("realm attribute should be found")
}
@@ -180,26 +180,26 @@ func TestRealm_AddTo_Invalid(t *testing.T) {
}
func TestNonce_GetFrom(t *testing.T) {
m := New()
v := "example.org"
m.Add(AttrNonce, []byte(v))
m.WriteHeader()
msg := New()
val := "example.org"
msg.Add(AttrNonce, []byte(val))
msg.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
var nonce Nonce
if _, err := m2.ReadFrom(m.reader()); err != nil {
if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err)
}
if err := nonce.GetFrom(m); err != nil {
if err := nonce.GetFrom(msg); err != nil {
t.Fatal(err)
}
if nonce.String() != v {
t.Errorf("Expected %q, got %q.", v, nonce)
if nonce.String() != val {
t.Errorf("Expected %q, got %q.", val, nonce)
}
nAttr, ok := m.Attributes.Get(AttrNonce)
nAttr, ok := msg.Attributes.Get(AttrNonce)
if !ok {
t.Error("nonce attribute should be found")
}

View File

@@ -7,7 +7,7 @@ import "errors"
// UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute.
//
// RFC 5389 Section 15.9
// RFC 5389 Section 15.9.
type UnknownAttributes []AttrType
func (a UnknownAttributes) String() string {
@@ -22,6 +22,7 @@ func (a UnknownAttributes) String() string {
s += ", "
}
}
return s
}
@@ -39,6 +40,7 @@ func (a UnknownAttributes) AddTo(m *Message) error {
bin.PutUint16(v[first:last], t.Value())
}
m.Add(AttrUnknownAttributes, v)
return nil
}
@@ -62,5 +64,6 @@ func (a *UnknownAttributes) GetFrom(m *Message) error {
*a = append(*a, AttrType(bin.Uint16(v[first:last])))
first = last
}
return nil
}

View File

@@ -8,26 +8,26 @@ import (
)
func TestUnknownAttributes(t *testing.T) {
m := new(Message)
a := &UnknownAttributes{
msg := new(Message)
attr := &UnknownAttributes{
AttrDontFragment,
AttrChannelNumber,
}
if a.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
t.Error("bad String:", a)
if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
t.Error("bad String:", attr)
}
if (UnknownAttributes{}).String() != "<nil>" {
t.Error("bad blank string")
}
if err := a.AddTo(m); err != nil {
if err := attr.AddTo(msg); err != nil {
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) {
attrs := make(UnknownAttributes, 10)
if err := attrs.GetFrom(m); err != nil {
if err := attrs.GetFrom(msg); err != nil {
t.Error(err)
}
for i, at := range *a {
for i, at := range *attr {
if at != attrs[i] {
t.Error("expected", at, "!=", attrs[i])
}
@@ -44,8 +44,8 @@ func TestUnknownAttributes(t *testing.T) {
}
func BenchmarkUnknownAttributes(b *testing.B) {
m := new(Message)
a := UnknownAttributes{
msg := new(Message)
attr := UnknownAttributes{
AttrDontFragment,
AttrChannelNumber,
AttrRealm,
@@ -54,20 +54,20 @@ func BenchmarkUnknownAttributes(b *testing.B) {
b.Run("AddTo", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if err := a.AddTo(m); err != nil {
if err := attr.AddTo(msg); err != nil {
b.Fatal(err)
}
m.Reset()
msg.Reset()
}
})
b.Run("GetFrom", func(b *testing.B) {
b.ReportAllocs()
if err := a.AddTo(m); err != nil {
if err := attr.AddTo(msg); err != nil {
b.Fatal(err)
}
attrs := make(UnknownAttributes, 0, 10)
for i := 0; i < b.N; i++ {
if err := attrs.GetFrom(m); err != nil {
if err := attrs.GetFrom(msg); err != nil {
b.Fatal(err)
}
attrs = attrs[:0]

47
uri.go
View File

@@ -124,7 +124,7 @@ func (t ProtoType) String() string {
}
}
// URI represents a STUN (rfc7064) or TURN (rfc7065) URI
// URI represents a STUN (rfc7064) or TURN (rfc7065) URI.
type URI struct {
Scheme SchemeType
Host string
@@ -137,73 +137,76 @@ type URI struct {
// ParseURI parses a STUN or TURN urls following the ABNF syntax described in
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
// respectively.
func ParseURI(raw string) (*URI, error) { //nolint:gocognit
func ParseURI(raw string) (*URI, error) { //nolint:gocognit,cyclop
rawParts, err := url.Parse(raw)
if err != nil {
return nil, err
}
var u URI
u.Scheme = NewSchemeType(rawParts.Scheme)
if u.Scheme == SchemeTypeUnknown {
var uri URI
uri.Scheme = NewSchemeType(rawParts.Scheme)
if uri.Scheme == SchemeTypeUnknown {
return nil, ErrSchemeType
}
var rawPort string
if u.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil {
if uri.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { //nolint:nestif
var e *net.AddrError
if errors.As(err, &e) {
if e.Err == "missing port in address" {
nextRawURL := u.Scheme.String() + ":" + rawParts.Opaque
nextRawURL := uri.Scheme.String() + ":" + rawParts.Opaque
switch {
case u.Scheme == SchemeTypeSTUN || u.Scheme == SchemeTypeTURN:
case uri.Scheme == SchemeTypeSTUN || uri.Scheme == SchemeTypeTURN:
nextRawURL += ":3478"
if rawParts.RawQuery != "" {
nextRawURL += "?" + rawParts.RawQuery
}
return ParseURI(nextRawURL)
case u.Scheme == SchemeTypeSTUNS || u.Scheme == SchemeTypeTURNS:
case uri.Scheme == SchemeTypeSTUNS || uri.Scheme == SchemeTypeTURNS:
nextRawURL += ":5349"
if rawParts.RawQuery != "" {
nextRawURL += "?" + rawParts.RawQuery
}
return ParseURI(nextRawURL)
}
}
}
return nil, err
}
if u.Host == "" {
if uri.Host == "" {
return nil, ErrHost
}
if u.Port, err = strconv.Atoi(rawPort); err != nil {
if uri.Port, err = strconv.Atoi(rawPort); err != nil {
return nil, ErrPort
}
switch u.Scheme {
switch uri.Scheme {
case SchemeTypeSTUN:
qArgs, err := url.ParseQuery(rawParts.RawQuery)
if err != nil || len(qArgs) > 0 {
return nil, ErrSTUNQuery
}
u.Proto = ProtoTypeUDP
uri.Proto = ProtoTypeUDP
case SchemeTypeSTUNS:
qArgs, err := url.ParseQuery(rawParts.RawQuery)
if err != nil || len(qArgs) > 0 {
return nil, ErrSTUNQuery
}
u.Proto = ProtoTypeTCP
uri.Proto = ProtoTypeTCP
case SchemeTypeTURN:
proto, err := parseProto(rawParts.RawQuery)
if err != nil {
return nil, err
}
u.Proto = proto
if u.Proto == ProtoTypeUnknown {
u.Proto = ProtoTypeUDP
uri.Proto = proto
if uri.Proto == ProtoTypeUnknown {
uri.Proto = ProtoTypeUDP
}
case SchemeTypeTURNS:
proto, err := parseProto(rawParts.RawQuery)
@@ -211,15 +214,15 @@ func ParseURI(raw string) (*URI, error) { //nolint:gocognit
return nil, err
}
u.Proto = proto
if u.Proto == ProtoTypeUnknown {
u.Proto = ProtoTypeTCP
uri.Proto = proto
if uri.Proto == ProtoTypeUnknown {
uri.Proto = ProtoTypeTCP
}
case SchemeTypeUnknown:
}
return &u, nil
return &uri, nil
}
func parseProto(raw string) (ProtoType, error) {
@@ -233,6 +236,7 @@ func parseProto(raw string) (ProtoType, error) {
if proto = NewProtoType(rawProto); proto == ProtoType(0) {
return ProtoTypeUnknown, ErrProtoType
}
return proto, nil
}
@@ -248,6 +252,7 @@ func (u URI) String() string {
if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS {
rawURL += "?transport=" + u.Proto.String()
}
return rawURL
}

View File

@@ -35,8 +35,16 @@ func TestParseURL(t *testing.T) {
{"stun:[::1]:123", "stun:[::1]:123", SchemeTypeSTUN, false, "::1", 123, ProtoTypeUDP},
{"turn:google.de", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
{"turns:google.de", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
{"turn:google.de?transport=udp", "turn:google.de:3478?transport=udp", SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP},
{"turns:google.de?transport=tcp", "turns:google.de:5349?transport=tcp", SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP},
{
"turn:google.de?transport=udp",
"turn:google.de:3478?transport=udp",
SchemeTypeTURN, false, "google.de", 3478, ProtoTypeUDP,
},
{
"turns:google.de?transport=tcp",
"turns:google.de:5349?transport=tcp",
SchemeTypeTURNS, true, "google.de", 5349, ProtoTypeTCP,
},
}
for i, testCase := range testCases {

View File

@@ -20,7 +20,7 @@ const (
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
//
// RFC 5389 Section 15.2
// RFC 5389 Section 15.2.
type XORMappedAddress struct {
IP net.IP
Port int
@@ -43,14 +43,15 @@ func isZeros(p net.IP) bool {
return false
}
}
return true
}
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
var ErrBadIPLength = errors.New("invalid length of IP value")
// AddToAs adds XOR-MAPPED-ADDRESS value to m as t attribute.
func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error {
// AddToAs adds XOR-MAPPED-ADDRESS value to msg as attr attribute.
func (a XORMappedAddress) AddToAs(msg *Message, attr AttrType) error {
var (
family = familyIPv4
ip = a.IP
@@ -67,12 +68,13 @@ func (a XORMappedAddress) AddToAs(m *Message, t AttrType) error {
value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:])
copy(xorValue[4:], msg.TransactionID[:])
bin.PutUint32(xorValue[0:4], magicCookie)
bin.PutUint16(value[0:2], family)
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) //nolint:gosec // G115, false positive, port
xor.XorBytes(value[4:4+len(ip)], ip, xorValue)
m.Add(t, value[:4+len(ip)])
msg.Add(attr, value[:4+len(ip)])
return nil
}
@@ -83,13 +85,13 @@ func (a XORMappedAddress) AddTo(m *Message) error {
}
// GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message
// getting it as for t type.
func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
v, err := m.Get(t)
// getting it as for attr type.
func (a *XORMappedAddress) GetFromAs(msg *Message, attr AttrType) error {
value, err := msg.Get(attr)
if err != nil {
return err
}
family := bin.Uint16(v[0:2])
family := bin.Uint16(value[0:2])
if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
@@ -110,17 +112,18 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
for i := range a.IP {
a.IP[i] = 0
}
if len(v) <= 4 {
if len(value) <= 4 {
return io.ErrUnexpectedEOF
}
if err := CheckOverflow(t, len(v[4:]), len(a.IP)); err != nil {
if err := CheckOverflow(attr, len(value[4:]), len(a.IP)); err != nil {
return err
}
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
a.Port = int(bin.Uint16(value[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 4+TransactionIDSize)
bin.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:])
xor.XorBytes(a.IP, v[4:], xorValue)
copy(xorValue[4:], msg.TransactionID[:])
xor.XorBytes(a.IP, value[4:], xorValue)
return nil
}

View File

@@ -29,21 +29,21 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
}
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
m := New()
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
b.Error(err)
}
copy(m.TransactionID[:], transactionID)
copy(msg.TransactionID[:], transactionID)
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
if err != nil {
b.Error(err)
}
m.Add(AttrXORMappedAddress, addrValue)
msg.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if err := addr.GetFrom(m); err != nil {
if err := addr.GetFrom(msg); err != nil {
b.Fatal(err)
}
}
@@ -94,54 +94,54 @@ func TestXORMappedAddress_GetFrom(t *testing.T) {
}
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
m := New()
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
copy(m.TransactionID[:], transactionID)
copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254
addr := new(XORMappedAddress)
if err = addr.GetFrom(m); err == nil {
if err = addr.GetFrom(msg); err == nil {
t.Fatal(err, "should be nil")
}
addr.IP = expectedIP
addr.Port = expectedPort
addr.AddTo(m) //nolint:errcheck,gosec
m.WriteHeader()
addr.AddTo(msg) //nolint:errcheck,gosec
msg.WriteHeader()
mRes := New()
binary.BigEndian.PutUint16(m.Raw[20+4:20+4+2], 0x21)
if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21)
if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil {
t.Fatal(err)
}
if err = addr.GetFrom(m); err == nil {
if err = addr.GetFrom(msg); err == nil {
t.Fatal(err, "should not be nil")
}
}
func TestXORMappedAddress_AddTo(t *testing.T) {
m := New()
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
copy(m.TransactionID[:], transactionID)
copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254
addr := &XORMappedAddress{
IP: net.ParseIP("213.141.156.236"),
Port: expectedPort,
}
if err = addr.AddTo(m); err != nil {
if err = addr.AddTo(msg); err != nil {
t.Fatal(err)
}
m.WriteHeader()
msg.WriteHeader()
mRes := New()
if _, err = mRes.Write(m.Raw); err != nil {
if _, err = mRes.Write(msg.Raw); err != nil {
t.Fatal(err)
}
if err = addr.GetFrom(mRes); err != nil {
@@ -156,27 +156,27 @@ func TestXORMappedAddress_AddTo(t *testing.T) {
}
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
m := New()
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
copy(m.TransactionID[:], transactionID)
copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
expectedPort := 21254
addr := &XORMappedAddress{
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
Port: 21254,
}
addr.AddTo(m) //nolint:errcheck,gosec
m.WriteHeader()
addr.AddTo(msg) //nolint:errcheck,gosec
msg.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil {
if _, err = mRes.ReadFrom(msg.reader()); err != nil {
t.Fatal(err)
}
gotAddr := new(XORMappedAddress)
if err = gotAddr.GetFrom(m); err != nil {
if err = gotAddr.GetFrom(msg); err != nil {
t.Fatal(err)
}
if !gotAddr.IP.Equal(expectedIP) {