diff --git a/example/tcp_server.go b/example/tcp_server.go index 759a416..82aedb3 100644 --- a/example/tcp_server.go +++ b/example/tcp_server.go @@ -36,3 +36,11 @@ func TCPServer(listener net.Listener, handler TCPHandler) error { return nil } + +func main() { + _, err := net.Dial("tcp", "192.168.1.1:9999") + if err != nil { + fmt.Println("err : ", err) + return + } +} diff --git a/tcpip/header/checksum.go b/tcpip/header/checksum.go new file mode 100644 index 0000000..59054ee --- /dev/null +++ b/tcpip/header/checksum.go @@ -0,0 +1,37 @@ +package header + +import "netstack/tcpip" + +// 校验和的计算 +func Checksum(buf []byte, initial uint16) uint16 { + v := uint32(initial) + + l := len(buf) + if l&1 != 0 { + l-- + v += uint32(buf[l]) << 8 + } + + for i := 0; i < l; i += 2 { + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + } + + return ChecksumCombine(uint16(v), uint16(v>>16)) +} + +// ChecksumCombine combines the two uint16 to form their checksum. This is done +// by adding them and the carry. +func ChecksumCombine(a, b uint16) uint16 { + v := uint32(a) + uint32(b) + return uint16(v + v>>16) +} + +// PseudoHeaderChecksum calculates the pseudo-header checksum for the +// given destination protocol and network address, ignoring the length +// field. Pseudo-headers are needed by transport layers when calculating +// their own checksum. +func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address) uint16 { + xsum := Checksum([]byte(srcAddr), 0) + xsum = Checksum([]byte(dstAddr), xsum) + return Checksum([]byte{0, uint8(protocol)}, xsum) +} diff --git a/tcpip/header/checksum_test.go b/tcpip/header/checksum_test.go new file mode 100644 index 0000000..637fe94 --- /dev/null +++ b/tcpip/header/checksum_test.go @@ -0,0 +1,19 @@ +package header_test + +import ( + "log" + "math/rand" + "netstack/tcpip/header" + "testing" + "time" +) + +func TestChecksum(t *testing.T) { + buf := make([]byte, 1024) + rand.Seed(time.Now().Unix()) + for i := range buf { + buf[i] = uint8(rand.Intn(255)) + } + sum := header.Checksum(buf, 0) + log.Println(sum) +} diff --git a/tcpip/header/icmpv4.go b/tcpip/header/icmpv4.go new file mode 100644 index 0000000..bddc80d --- /dev/null +++ b/tcpip/header/icmpv4.go @@ -0,0 +1,108 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "netstack/tcpip" +) + +// ICMPv4 represents an ICMPv4 header stored in a byte array. +type ICMPv4 []byte + +const ( + // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. + ICMPv4MinimumSize = 4 + + // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet. + ICMPv4EchoMinimumSize = 6 + + // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP + // destination unreachable packet. + ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4 + + // ICMPv4ProtocolNumber is the ICMP transport protocol number. + ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 +) + +// ICMPv4Type is the ICMP type field described in RFC 792. +type ICMPv4Type byte + +// Typical values of ICMPv4Type defined in RFC 792. +const ( + ICMPv4EchoReply ICMPv4Type = 0 + ICMPv4DstUnreachable ICMPv4Type = 3 + ICMPv4SrcQuench ICMPv4Type = 4 + ICMPv4Redirect ICMPv4Type = 5 + ICMPv4Echo ICMPv4Type = 8 + ICMPv4TimeExceeded ICMPv4Type = 11 + ICMPv4ParamProblem ICMPv4Type = 12 + ICMPv4Timestamp ICMPv4Type = 13 + ICMPv4TimestampReply ICMPv4Type = 14 + ICMPv4InfoRequest ICMPv4Type = 15 + ICMPv4InfoReply ICMPv4Type = 16 +) + +// Values for ICMP code as defined in RFC 792. +const ( + ICMPv4PortUnreachable = 3 + ICMPv4FragmentationNeeded = 4 +) + +// Type is the ICMP type field. +func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } + +// SetType sets the ICMP type field. +func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) } + +// Code is the ICMP code field. Its meaning depends on the value of Type. +func (b ICMPv4) Code() byte { return b[1] } + +// SetCode sets the ICMP code field. +func (b ICMPv4) SetCode(c byte) { b[1] = c } + +// Checksum is the ICMP checksum field. +func (b ICMPv4) Checksum() uint16 { + return binary.BigEndian.Uint16(b[2:]) +} + +// SetChecksum sets the ICMP checksum field. +func (b ICMPv4) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[2:], checksum) +} + +// SourcePort implements Transport.SourcePort. +func (ICMPv4) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (ICMPv4) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (ICMPv4) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (ICMPv4) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (b ICMPv4) Payload() []byte { + return b[ICMPv4MinimumSize:] +} diff --git a/tcpip/header/icmpv6.go b/tcpip/header/icmpv6.go new file mode 100644 index 0000000..2344b88 --- /dev/null +++ b/tcpip/header/icmpv6.go @@ -0,0 +1,121 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "netstack/tcpip" +) + +// ICMPv6 represents an ICMPv6 header stored in a byte array. +type ICMPv6 []byte + +const ( + // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. + ICMPv6MinimumSize = 4 + + // ICMPv6ProtocolNumber is the ICMP transport protocol number. + ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 + + // ICMPv6NeighborSolicitMinimumSize is the minimum size of a + // neighbor solicitation packet. + ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16 + + // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. + ICMPv6NeighborAdvertSize = 32 + + // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. + ICMPv6EchoMinimumSize = 8 + + // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP + // destination unreachable packet. + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4 + + // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP + // packet-too-big packet. + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4 +) + +// ICMPv6Type is the ICMP type field described in RFC 4443 and friends. +type ICMPv6Type byte + +// Typical values of ICMPv6Type defined in RFC 4443. +const ( + ICMPv6DstUnreachable ICMPv6Type = 1 + ICMPv6PacketTooBig ICMPv6Type = 2 + ICMPv6TimeExceeded ICMPv6Type = 3 + ICMPv6ParamProblem ICMPv6Type = 4 + ICMPv6EchoRequest ICMPv6Type = 128 + ICMPv6EchoReply ICMPv6Type = 129 + + // Neighbor Discovery Protocol (NDP) messages, see RFC 4861. + + ICMPv6RouterSolicit ICMPv6Type = 133 + ICMPv6RouterAdvert ICMPv6Type = 134 + ICMPv6NeighborSolicit ICMPv6Type = 135 + ICMPv6NeighborAdvert ICMPv6Type = 136 + ICMPv6RedirectMsg ICMPv6Type = 137 +) + +// Values for ICMP code as defined in RFC 4443. +const ( + ICMPv6PortUnreachable = 4 +) + +// Type is the ICMP type field. +func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } + +// SetType sets the ICMP type field. +func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) } + +// Code is the ICMP code field. Its meaning depends on the value of Type. +func (b ICMPv6) Code() byte { return b[1] } + +// SetCode sets the ICMP code field. +func (b ICMPv6) SetCode(c byte) { b[1] = c } + +// Checksum is the ICMP checksum field. +func (b ICMPv6) Checksum() uint16 { + return binary.BigEndian.Uint16(b[2:]) +} + +// SetChecksum calculates and sets the ICMP checksum field. +func (b ICMPv6) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[2:], checksum) +} + +// SourcePort implements Transport.SourcePort. +func (ICMPv6) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (ICMPv6) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (ICMPv6) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (ICMPv6) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (b ICMPv6) Payload() []byte { + return b[ICMPv6MinimumSize:] +} diff --git a/tcpip/header/ipv4.go b/tcpip/header/ipv4.go index 2495c37..8a3f05c 100644 --- a/tcpip/header/ipv4.go +++ b/tcpip/header/ipv4.go @@ -1,6 +1,83 @@ package header -import "netstack/tcpip" +import ( + "encoding/binary" + "netstack/tcpip" +) + +/* _ +|Version 4b|IHL 4b|Type of Service 8b| Total Length 16b | + ---------------------------------------------------------------- +| fragment ID 16b |R|DF|MF|Fragment Offset 13b| + ---------------------------------------------------------------- +| TTL 8b | Protocol 8b | Header Checksum 16b | 20 bytes + ---------------------------------------------------------------- +| Sorece IP Address 32b | + ---------------------------------------------------------------- +| Destination IP Address 32b | _ + ---------------------------------------------------------------- +| Options | Padding | +*/ + +const ( + versIHL = 0 + tos = 1 + totalLen = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksum = 10 + srcAddr = 12 + dstAddr = 16 +) + +// 表示IPv4头部信息的结构体 +type IPv4Fields struct { + // IHL is the "internet header length" field of an IPv4 packet. + // 头部长度 + IHL uint8 + + // TOS is the "type of service" field of an IPv4 packet. + // 服务区分的表示 + TOS uint8 + + // TotalLength is the "total length" field of an IPv4 packet. + // 数据报文总长 + TotalLength uint16 + + // ID is the "identification" field of an IPv4 packet. + // 标识符 + ID uint16 + + // Flags is the "flags" field of an IPv4 packet. + // 标签 + Flags uint8 + + // FragmentOffset is the "fragment offset" field of an IPv4 packet. + // 分片偏移 + FragmentOffset uint16 + + // TTL is the "time to live" field of an IPv4 packet. + // 存活时间 + TTL uint8 + + // Protocol is the "protocol" field of an IPv4 packet. + // 表示的传输层协议 + Protocol uint8 + + // Checksum is the "checksum" field of an IPv4 packet. + // 首部校验和 + Checksum uint16 + + // SrcAddr is the "source ip address" of an IPv4 packet. + // 源IP地址 + SrcAddr tcpip.Address + + // DstAddr is the "destination ip address" of an IPv4 packet. + // 目的IP地址 + DstAddr tcpip.Address +} type IPv4 []byte @@ -28,3 +105,162 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" ) + +func IPVersion(b []byte) int { + if len(b) < versIHL+1 { + return -1 + } + return int(b[versIHL] >> 4) +} + +// 首部长度说明首部有多少 32 位字(4 字节) 这个函数返回其实际占用的字节数 +func (b IPv4) HeaderLength() uint8 { + return (b[versIHL] & 0xf) * 4 +} + +func (b IPv4) ID() uint16 { + return binary.BigEndian.Uint16(b[id:]) +} + +// Protocol returns the value of the protocol field of the ipv4 header. +func (b IPv4) Protocol() uint8 { + return b[protocol] +} + +// Flags returns the "flags" field of the ipv4 header. +func (b IPv4) Flags() uint8 { + return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) +} + +// TTL returns the "TTL" field of the ipv4 header. +func (b IPv4) TTL() uint8 { + return b[ttl] +} + +// FragmentOffset returns the "fragment offset" field of the ipv4 header. +func (b IPv4) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[flagsFO:]) << 3 +} + +// TotalLength returns the "total length" field of the ipv4 header. +func (b IPv4) TotalLength() uint16 { + return binary.BigEndian.Uint16(b[totalLen:]) +} + +// Checksum returns the checksum field of the ipv4 header. +func (b IPv4) Checksum() uint16 { + return binary.BigEndian.Uint16(b[checksum:]) +} + +// SourceAddress returns the "source address" field of the ipv4 header. +func (b IPv4) SourceAddress() tcpip.Address { + return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize]) +} + +// DestinationAddress returns the "destination address" field of the ipv4 +// header. +func (b IPv4) DestinationAddress() tcpip.Address { + return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.Protocol()) +} + +// Payload implements Network.Payload. +func (b IPv4) Payload() []byte { + return b[b.HeaderLength():][:b.PayloadLength()] +} + +// PayloadLength returns the length of the payload portion of the ipv4 packet. +func (b IPv4) PayloadLength() uint16 { + return b.TotalLength() - uint16(b.HeaderLength()) +} + +// TOS returns the "type of service" field of the ipv4 header. +func (b IPv4) TOS() (uint8, uint32) { + return b[tos], 0 +} + +// SetTOS sets the "type of service" field of the ipv4 header. +func (b IPv4) SetTOS(v uint8, _ uint32) { + b[tos] = v +} + +// SetTotalLength sets the "total length" field of the ipv4 header. +func (b IPv4) SetTotalLength(totalLength uint16) { + binary.BigEndian.PutUint16(b[totalLen:], totalLength) +} + +// SetChecksum sets the checksum field of the ipv4 header. +func (b IPv4) SetChecksum(v uint16) { + binary.BigEndian.PutUint16(b[checksum:], v) +} + +// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the +// ipv4 header. +func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { + v := (uint16(flags) << 13) | (offset >> 3) + binary.BigEndian.PutUint16(b[flagsFO:], v) +} + +// SetSourceAddress sets the "source address" field of the ipv4 header. +func (b IPv4) SetSourceAddress(addr tcpip.Address) { + copy(b[srcAddr:srcAddr+IPv4AddressSize], addr) +} + +// SetDestinationAddress sets the "destination address" field of the ipv4 +// header. +func (b IPv4) SetDestinationAddress(addr tcpip.Address) { + copy(b[dstAddr:dstAddr+IPv4AddressSize], addr) +} + +// Encode encodes all the fields of the ipv4 header. +func (b IPv4) Encode(i *IPv4Fields) { + b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) + b[tos] = i.TOS + b.SetTotalLength(i.TotalLength) + binary.BigEndian.PutUint16(b[id:], i.ID) + b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset) + b[ttl] = i.TTL + b[protocol] = i.Protocol + b.SetChecksum(i.Checksum) + copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr) + copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) +} + +// EncodePartial updates the total length and checksum fields of ipv4 header, +// taking in the partial checksum, which is the checksum of the header without +// the total length and checksum fields. It is useful in cases when similar +// packets are produced. +func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) { + b.SetTotalLength(totalLength) + checksum := Checksum(b[totalLen:totalLen+2], partialChecksum) + b.SetChecksum(^checksum) +} + +// IsValid performs basic validation on the packet. +func (b IPv4) IsValid(pktSize int) bool { + if len(b) < IPv4MinimumSize { + return false + } + + hlen := int(b.HeaderLength()) + tlen := int(b.TotalLength()) + if hlen > tlen || tlen > pktSize { + return false + } + + return true +} + +// IsV4MulticastAddress determines if the provided address is an IPv4 multicast +// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits +// will be 1110 = 0xe0. +func IsV4MulticastAddress(addr tcpip.Address) bool { + if len(addr) != IPv4AddressSize { + return false + } + return (addr[0] & 0xf0) == 0xe0 +} diff --git a/tcpip/header/ipv6.go b/tcpip/header/ipv6.go index 6687cd1..d2fa4b2 100644 --- a/tcpip/header/ipv6.go +++ b/tcpip/header/ipv6.go @@ -1,7 +1,49 @@ package header -import "netstack/tcpip" +import ( + "encoding/binary" + "netstack/tcpip" + "strings" +) +const ( + versTCFL = 0 + payloadLen = 4 + nextHdr = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = 24 +) + +// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the +// fields of a packet that needs to be encoded. +type IPv6Fields struct { + // TrafficClass is the "traffic class" field of an IPv6 packet. + TrafficClass uint8 + + // FlowLabel is the "flow label" field of an IPv6 packet. + FlowLabel uint32 + + // PayloadLength is the "payload length" field of an IPv6 packet. + PayloadLength uint16 + + // NextHeader is the "next header" field of an IPv6 packet. + NextHeader uint8 + + // HopLimit is the "hop limit" field of an IPv6 packet. + HopLimit uint8 + + // SrcAddr is the "source ip address" of an IPv6 packet. + SrcAddr tcpip.Address + + // DstAddr is the "destination ip address" of an IPv6 packet. + DstAddr tcpip.Address +} + +// IPv6 represents an ipv6 header stored in a byte array. +// Most of the methods of IPv6 access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv6 before using other methods. type IPv6 []byte const ( @@ -21,3 +63,156 @@ const ( // section 5. IPv6MinimumMTU = 1280 ) + +// PayloadLength returns the value of the "payload length" field of the ipv6 +// header. +func (b IPv6) PayloadLength() uint16 { + return binary.BigEndian.Uint16(b[payloadLen:]) +} + +// HopLimit returns the value of the "hop limit" field of the ipv6 header. +func (b IPv6) HopLimit() uint8 { + return b[hopLimit] +} + +// NextHeader returns the value of the "next header" field of the ipv6 header. +func (b IPv6) NextHeader() uint8 { + return b[nextHdr] +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.NextHeader()) +} + +// Payload implements Network.Payload. +func (b IPv6) Payload() []byte { + return b[IPv6MinimumSize:][:b.PayloadLength()] +} + +// SourceAddress returns the "source address" field of the ipv6 header. +func (b IPv6) SourceAddress() tcpip.Address { + return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) +} + +// DestinationAddress returns the "destination address" field of the ipv6 +// header. +func (b IPv6) DestinationAddress() tcpip.Address { + return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) +} + +// Checksum implements Network.Checksum. Given that IPv6 doesn't have a +// checksum, it just returns 0. +func (IPv6) Checksum() uint16 { + return 0 +} + +// TOS returns the "traffic class" and "flow label" fields of the ipv6 header. +func (b IPv6) TOS() (uint8, uint32) { + v := binary.BigEndian.Uint32(b[versTCFL:]) + return uint8(v >> 20), v & 0xfffff +} + +// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header. +func (b IPv6) SetTOS(t uint8, l uint32) { + vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff) + binary.BigEndian.PutUint32(b[versTCFL:], vtf) +} + +// SetPayloadLength sets the "payload length" field of the ipv6 header. +func (b IPv6) SetPayloadLength(payloadLength uint16) { + binary.BigEndian.PutUint16(b[payloadLen:], payloadLength) +} + +// SetSourceAddress sets the "source address" field of the ipv6 header. +func (b IPv6) SetSourceAddress(addr tcpip.Address) { + copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) +} + +// SetDestinationAddress sets the "destination address" field of the ipv6 +// header. +func (b IPv6) SetDestinationAddress(addr tcpip.Address) { + copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) +} + +// SetNextHeader sets the value of the "next header" field of the ipv6 header. +func (b IPv6) SetNextHeader(v uint8) { + b[nextHdr] = v +} + +// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a +// checksum, it is empty. +func (IPv6) SetChecksum(uint16) { +} + +// Encode encodes all the fields of the ipv6 header. +func (b IPv6) Encode(i *IPv6Fields) { + b.SetTOS(i.TrafficClass, i.FlowLabel) + b.SetPayloadLength(i.PayloadLength) + b[nextHdr] = i.NextHeader + b[hopLimit] = i.HopLimit + copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) + copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) +} + +// IsValid performs basic validation on the packet. +func (b IPv6) IsValid(pktSize int) bool { + if len(b) < IPv6MinimumSize { + return false + } + + dlen := int(b.PayloadLength()) + + return dlen <= pktSize-IPv6MinimumSize +} + +// IsV4MappedAddress determines if the provided address is an IPv4 mapped +// address by checking if its prefix is 0:0:0:0:0:ffff::/96. +func IsV4MappedAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + + return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff") +} + +// IsV6MulticastAddress determines if the provided address is an IPv6 +// multicast address (anything starting with FF). +func IsV6MulticastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + return addr[0] == 0xff +} + +// SolicitedNodeAddr computes the solicited-node multicast address. This is +// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 +// address. +func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { + const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + return solicitedNodeMulticastPrefix + addr[len(addr)-3:] +} + +// LinkLocalAddr computes the default IPv6 link-local address from a link-layer +// (MAC) address. +func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { + // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local + // header, FE80::. + // + // The conversion is very nearly: + // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff + // Note the capital A. The conversion aa->Aa involves a bit flip. + lladdrb := [16]byte{ + 0: 0xFE, + 1: 0x80, + 8: linkAddr[0] ^ 2, + 9: linkAddr[1], + 10: linkAddr[2], + 11: 0xFF, + 12: 0xFE, + 13: linkAddr[3], + 14: linkAddr[4], + 15: linkAddr[5], + } + return tcpip.Address(lladdrb[:]) +} diff --git a/tcpip/header/ipv6_fragment.go b/tcpip/header/ipv6_fragment.go new file mode 100644 index 0000000..1b4362c --- /dev/null +++ b/tcpip/header/ipv6_fragment.go @@ -0,0 +1,146 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "netstack/tcpip" +) + +const ( + nextHdrFrag = 0 + fragOff = 2 + more = 3 + idV6 = 4 +) + +// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the +// fields of a packet that needs to be encoded. +type IPv6FragmentFields struct { + // NextHeader is the "next header" field of an IPv6 fragment. + NextHeader uint8 + + // FragmentOffset is the "fragment offset" field of an IPv6 fragment. + FragmentOffset uint16 + + // M is the "more" field of an IPv6 fragment. + M bool + + // Identification is the "identification" field of an IPv6 fragment. + Identification uint32 +} + +// IPv6Fragment represents an ipv6 fragment header stored in a byte array. +// Most of the methods of IPv6Fragment access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv6Fragment before using other methods. +type IPv6Fragment []byte + +const ( + // IPv6FragmentHeader header is the number used to specify that the next + // header is a fragment header, per RFC 2460. + IPv6FragmentHeader = 44 + + // IPv6FragmentHeaderSize is the size of the fragment header. + IPv6FragmentHeaderSize = 8 +) + +// Encode encodes all the fields of the ipv6 fragment. +func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { + b[nextHdrFrag] = i.NextHeader + binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) + if i.M { + b[more] |= 1 + } + binary.BigEndian.PutUint32(b[idV6:], i.Identification) +} + +// IsValid performs basic validation on the fragment header. +func (b IPv6Fragment) IsValid() bool { + return len(b) >= IPv6FragmentHeaderSize +} + +// NextHeader returns the value of the "next header" field of the ipv6 fragment. +func (b IPv6Fragment) NextHeader() uint8 { + return b[nextHdrFrag] +} + +// FragmentOffset returns the "fragment offset" field of the ipv6 fragment. +func (b IPv6Fragment) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[fragOff:]) >> 3 +} + +// More returns the "more" field of the ipv6 fragment. +func (b IPv6Fragment) More() bool { + return b[more]&1 > 0 +} + +// Payload implements Network.Payload. +func (b IPv6Fragment) Payload() []byte { + return b[IPv6FragmentHeaderSize:] +} + +// ID returns the value of the identifier field of the ipv6 fragment. +func (b IPv6Fragment) ID() uint32 { + return binary.BigEndian.Uint32(b[idV6:]) +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.NextHeader()) +} + +// The functions below have been added only to satisfy the Network interface. + +// Checksum is not supported by IPv6Fragment. +func (b IPv6Fragment) Checksum() uint16 { + panic("not supported") +} + +// SourceAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SourceAddress() tcpip.Address { + panic("not supported") +} + +// DestinationAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) DestinationAddress() tcpip.Address { + panic("not supported") +} + +// SetSourceAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SetSourceAddress(tcpip.Address) { + panic("not supported") +} + +// SetDestinationAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) { + panic("not supported") +} + +// SetChecksum is not supported by IPv6Fragment. +func (b IPv6Fragment) SetChecksum(uint16) { + panic("not supported") +} + +// TOS is not supported by IPv6Fragment. +func (b IPv6Fragment) TOS() (uint8, uint32) { + panic("not supported") +} + +// SetTOS is not supported by IPv6Fragment. +func (b IPv6Fragment) SetTOS(t uint8, l uint32) { + panic("not supported") +} diff --git a/tcpip/network/READMD.md b/tcpip/network/READMD.md index 6c5a90a..483adb7 100644 --- a/tcpip/network/READMD.md +++ b/tcpip/network/READMD.md @@ -1,5 +1,117 @@ -# CIDR地址的计算方法 +# 网络层的基本实现 +本章介绍网络层的实现,网络层又称网际层、ip 层,它是 tcpip 架构中核心的实现,全球计算机的互联很大部分归功于网络层,核心网络(路由器)都跑在网络层,为网络提供路由交换的功能,将数据包分发到相应的主机。虽然网络层在路由器上的实现比较复杂,因为要实现各种路由协议,但主机协议栈中的网络层并不复杂,因为它没有实现各种路由协议,路由表也很简单。下面介绍网络层提供的服务和实现网络层的 ip 协议-ipv4。 + +## 网络层提供的服务 +在计算机网络领域,曾经为网络层应该提供怎样的服务(面向连接还是无连接)引起了长时间的争论。最终因特网采用的设计思路是:网络层向上提供简单灵活的、无连接的、尽最大努力交付的数据报服务。所谓的数据报服务具有以下几个特点: + +1. 无需建立连接 +2. 不保证可靠性 +3. 每个分组都有终点的完整地址 +4. 每个分组独立选择路由进行转发 +5. 可靠通信应该有上层负责 网络层的目的是实现两个主机之间的数据透明传送,具体功能包括寻址和路由选择等。它提供的服务使传输层不需要了解网络中的数据传输和交换技术。对网络层而言使用一种逻辑地址来唯一标识互联网上的设备,网络层依靠逻辑地址进行相互通信(类似于数据链路层的 MAC 地址),逻辑地址编址方案现主要有两种,IPv4 和 IPv6,我们主要讲协议栈对 IPv4 协议的处理。一般我们说 IP 地址,指的是 ipv4 地址。 + +## 网络层和链路层的功能区别 +之前讲过链路层也可以实现主机到主机的数据透明传输,那为何还需要网络层实现主机到主机的数据传输? + +因为链路层的数据交换是在同个局域网实现的,链路层的交换也就是二层交换,它依赖二层广播 ARP 报文,来学习 MAC 地址和端口的对应关系。当交换机从某个端口收到一个数据包,它会先读取包中的源 MAC 地址,再去读取包中的目的 MAC 地址,并在地址表中查找对应的端口,如表中有和目的 MAC 地址对应的端口,就把数据包直接复制到这个端口上。链路层其最基本的服务是将源自网络层来的数据可靠地传输到相邻节点的目标机网络层。 + +而网络层的数据交换是不限于局域网的,网络层连接着因特网中各局域网、广域网的设备,是互联网络的枢纽。网络层的数据交换(路由交换)是根据目的 IP,查找路由表找到下一跳的 IP 地址,再根据这个下一跳 IP 地址,查找转发表,将数据包转发给相应的端口。简单的说链路层的寻址关心 MAC 地址而不管数据包中的 IP 地址,而网络层的寻址关心 IP 地址,而不关心 MAC 地址,链路层和网络层的结合实现了世界上两台主机的数据互相传输。 + +## ipv4简介 + +IPv4,是互联网协议(Internet Protocol,IP)的第四版,也是第一个被广泛使用,构成现今互联网技术的基础的协议。IPv4 是一种无连接的协议,操作在使用分组交换的链路层(如以太网)上。此协议会尽最大努力交付数据包,意即它不保证任何数据包均能送达目的地,也不保证所有数据包均按照正确的顺序无重复地到达。这些方面是由上层的传输协议(如传输控制协议)处理的。 + +## ip报文 + +- IPv4,是互联网协议(Internet Protocol,IP)的第四版,也是第一个被广泛使用,构成现今互联网技术的基础的协议。IPv4 是一种无连接的协议,操作在使用分组交换的链路层(如以太网)上。此协议会尽最大努力交付数据包,意即它不保证任何数据包均能送达目的地,也不保证所有数据包均按照正确的顺序无重复地到达。这些方面是由上层的传输协议(如传输控制协议)处理的。 + +- 版本(Version) 版本字段占 4bit,通信双方使用的版本必须一致。对于 IPv4,字段的值是 4。 + +- 首部长度(Internet Header Length, IHL) 占 4bit,首部长度说明首部有多少 32 位字(4 字节)。由于 IPv4 首部可能包含数目不定的选项,这个字段也用来确定数据的偏移量。这个字段的最小值是 5(二进制 0101),相当于 5*4=20 字节(RFC 791),最大十进制值是 15。 + +- 区分服务(Differentiated Services,DS) 占 8bit,最初被定义为服务类型字段,实际上并未使用,但 1998 年被 IETF 重定义为区分服务 RFC 2474。只有在使用区分服务时,这个字段才起作用,在一般的情况 下都不使用这个字段。例如需要实时数据流的技术会应用这个字段,一个例子是 VoIP。 + +- 显式拥塞通告( Explicit Congestion Notification,ECN) 在 RFC 3168 中定义,允许在不丢弃报文的同时通知对方网络拥塞的发生。ECN 是一种可选的功能,仅当两端都支持并希望使用,且底层网络支持时才被使用。 + +- 全长(Total Length) 这个 16 位字段定义了报文总长,包含首部和数据,单位为字节。这个字段的最小值是 20(20 字节首部+0 字节数据),最大值是 2^16-1=65,535。IP 规定所有主机都必须支持最小 576 字节的报文,这是假定上层数据长度 512 字节,加上最长 IP 首部 60 字节,加上 4 字节富裕量,得出 576 字节,但大多数现代主机支持更大的报文。**当下层的数据链路协议的最大传输单元(MTU)字段的值小于 IP 报文长度时,报文就必须被分片,详细见下个标题。** + +- 标识符(Identification) 占 16 位,这个字段主要被用来唯一地标识一个报文的所有分片,因为分片不一定按序到达,所以在重组时需要知道分片所属的报文。每产生一个数据报,计数器加 1,并赋值给此字段。一些实验性的工作建议将此字段用于其它目的,例如增加报文跟踪信息以协助探测伪造的源地址。 + +- 标志 (Flags) 这个 3 位字段用于控制和识别分片,它们是: 位 0:保留,必须为 0; 位 1:禁止分片(Don’t Fragment,DF),当 DF=0 时才允许分片; 位 2:更多分片(More Fragment,MF),MF=1 代表后面还有分片,MF=0 代表已经是最后一个分片。 如果 DF 标志被设置为 1,但路由要求必须分片报文,此报文会被丢弃。这个标志可被用于发往没有能力组装分片的主机。当一个报文被分片,除了最后一片外的所有分片都设置 MF 为 1。最后一个片段具有非零片段偏移字段,将其与未分片数据包区分开,未分片的偏移字段为 0。 + +- 分片偏移 (Fragment Offset) 这个 13 位字段指明了每个分片相对于原始报文开头的偏移量,以 8 字节作单位。 + +- 存活时间(Time To Live,TTL) 这个 8 位字段避免报文在互联网中永远存在(例如陷入路由环路)。存活时间以秒为单位,但小于一秒的时间均向上取整到一秒。在现实中,这实际上成了一个跳数计数器:报文经过的每个路由器都将此字段减 1,当此字段等于 0 时,报文不再向下一跳传送并被丢弃,最大值是 255。常规地,一份 ICMP 报文被发回报文发送端说明其发送的报文已被丢弃。这也是 traceroute 的核心原理。 + +- 协议 (Protocol) 占 8bit,这个字段定义了该报文数据区使用的协议。IANA 维护着一份协议列表(最初由 RFC 790 定义),详细参见 IP 协议号列表。 + +- 首部检验和 (Header Checksum) 这个 16 位检验和字段只对首部查错,不包括数据部分。在每一跳,路由器都要重新计算出首部检验和并与此字段进行比对,如果不一致,此报文将会被丢弃。重新计算的必要性是因为每一跳的一些首部字段(如 TTL、Flag、Offset 等)都有可能发生变化,不检查数据部分是为了减少工作量。数据区的错误留待上层协议处理——用户数据报协议(UDP)和传输控制协议(TCP)都有检验和字段。此处的检验计算方法不使用 CRC。 + +- 源地址 一个 IPv4 地址由四个字节共 32 位构成,此字段的值是将每个字节转为二进制并拼在一起所得到的 32 位值。例如,10.9.8.7 是 00001010000010010000100000000111。但请注意,因为 NAT 的存在,这个地址并不总是报文的真实发送端,因此发往此地址的报文会被送往 NAT 设备,并由它被翻译为真实的地址。 + +- 目的地址 与源地址格式相同,但指出报文的接收端。 + +- 选项 附加的首部字段可能跟在目的地址之后,但这并不被经常使用,从 1 到 40 个字节不等。请注意首部长度字段必须包括足够的 32 位字来放下所有的选项(包括任何必须的填充以使首部长度能够被 32 位整除)。当选项列表的结尾不是首部的结尾时,EOL(选项列表结束,0x00)选项被插入列表末尾。下表列出了可能。 + +|字段|长度(位)|描述| +|----|---------|----| +|备份| 1 |当此选项需要被备份到所有分片中时,设为 1。| +| 类 | 2 |常规的选项类别,0 为“控制”,2 为“查错和措施”,1 和 3 保留。| +|数字| 5 |指明一个选项。| +|长度| 8 |指明整个选项的长度,对于简单的选项此字段可能不存在。| +|数据| 可变|选项相关数据,对于简单的选项此字段可能不存在。| + +**注:如果首部长度大于 5,那么选项字段必然存在并必须被考虑。 注:备份、类和数字经常被一并称呼为“类型”。** + +- 数据 数据字段不是首部的一部分,因此并不被包含在首部检验和中。数据的格式在协议首部字段中被指明,并可以是任意的传输层协议。一些常见协议的协议字段值被列在下面: + +|协议字段值| 协议名 |缩写| +|---------|------------|----| +|1 |互联网控制消息协议|ICMP| +|2 |互联网组管理协议 |IGMP| +|6 |传输控制协议 |TCP| +|17 |用户数据报协议 |UDP| +|41 |IPv6 封装 |ENCAP| +|89 |开放式最短路径优先 |OSPF| +|132|流控制传输协议 |SCTP| + +## ipv4地址 +IPv4 使用 32 位(4 字节)地址,因此地址空间中只有 4,294,967,296(2^32)个地址。不过,一些地址是为特殊用途所保留的,如专用网络(约 1800 万 个地址)和多播地址(约 2.7 亿个地址),这减少了可在互联网上路由的地址数量。随着地址不断被分配给最终用户,IPv4 地址枯竭问题也在随之产生。基于分类网络、无类别域间路由和网络地址转换的地址结构重构显著地减少了地址枯竭的速度。但在 2011 年 2 月 3 日,在最后 5 个地址块被分配给 5 个区域互联网注册管理机构之后,IANA 的主要地址池已经用尽。 + +IPv4 地址可被写作任何表示一个 32 位整数值的形式,但为了方便人类阅读和分析,它通常被写作点分十进制的形式,即四个字节被分开用十进制写出,中间用点分隔,如 192.168.1.1。ip 地址的编址方法一共经历过三个阶段: + +### 分类的 IP 地址 +- A 类网络地址占有 1 个字节(8 位),定义最高位为 0 来标识此类网络,余下 7 位为真正的网络地址。后面 3 个字节(24)为主机地址。A 类网络地址第一个字节的十进制值为:001~127.通常用于大型网络。 +- B 类网络地址占 2 个字节,使用最高两位为“10”来标识此类地址,其余 14 位为真正的网络地址,主机地址占后面的 2 个字节(16 位)。B 类网络地址第一个字节的十进制值为:128~191.通常用于中型网络。 +- C 类网络地址占 3 个字节,它是最通用的 Internet 地址。使用最高三位为“110”来标识此类地址。其余 21 位为真正的网络地址。主机地址占最后 1 个字节。C 类网络地址第一个字节的十进制值为:192~223。通常用于小型网络。 +- D 类地址是相当新的。它的识别头是 1110,用于组播,例如用于路由器修改。D 类网络地址第一个字节的十进制值为:224~239。 +- E 类地址为实验保留,其识别头是 1111。E 类网络地址第一个字节的十进制值为:240~255。 + +**但要注意得是,上面得这些地址分类已成为了历史,现在用的都是无分类 IP 地址进行路由选择。** + +### 子网的划分 + +由于上面固定分类的 IP 地址有不少的缺陷,比如,IP 地址空间的利用率很低、固定就意味着不够灵活、使路由表太大而影响性能,为了解决上述的问题,在 IP 地址概念中,又增加了一个“子网字段”,这样的话,一个 IP 地址可以用下面的方式表示 + +``` sh +IP地址 = (网络号,子网号,主机号) +``` + +### 无分类编址(CIDR) + +为了提高 ip 地址资源的利用率,提出了变长子网掩码(VLSM),而在 VLSM 的研究基础上又提出了“无分类编址”方法,也叫无分类域间路由选择-CIDR。 CIDR 最主要有两个以下特点: + +- 消除传统的 A,B,C 地址和划分子网的概念,更有效的分配 IPv4 的地址空间,CIDR 使 IP 地址又回到无分类的两级编码。记法:IP 地址::={<<网络前缀>,<<主机号>}。CIDR 还使用“斜线记法”即在 IP 地址后面加上“/”然后写网络前缀所占的位数。 +- CIDR 把网络前缀都相同的连续 IP 地址组成一个“CIDR 地址块”,即强化路由聚合(构成超网)。 其表示方法 + +``` sh +IP地址 = (网络前缀,主机号) +``` + +CIDR 还使用“斜线记法”,在 IP 地址后面加个“/”,紧跟着网络前缀所占的位数。例如:192.168.1.0/24,这种表示方式其实我们在上一章就用了,也是我们最常用的编址方式。 + +#### CIDR地址的计算方法 CIDR无类域间路由,打破了原本的ABC类地址的规划限定,使用地址段分配更加灵活,日常工作中也经常使用,也正是因为其灵活的特点使我们无法一眼辨认出网络号、广播地址、网络中的第一台主机等信息,本文主要针对这些信息的获得介绍一些计算方法。 当给定一个IP地址,比如18.232.133.86/22,需要求一下这个IP所在网络的 网络地址、子网掩码、广播i地址、这个网络的第一台主机的IP地址: @@ -17,3 +129,42 @@ CIDR无类域间路由,打破了原本的ABC类地址的规划限定,使用 5. 将主机位全部置1便是广播地址,18.232.<100001><11>.<11111111>即18.232.135.255 6. 子网掩码可以直接使用22计算即可,即前22位都为1,其余为0,即255.255.252.0 + + +| TYPE | CODE | Description | +| ---- | ---- | ------------| +| 0 | 0 | Echo Reply——回显应答(Ping 应答)   | +| 3 | 0 | Network Unreachable——网络不可达   | +| 3 | 1 | Host Unreachable——主机不可达   | +| 3 | 2 | Protocol Unreachable——协议不可达   | +| 3 | 3 | Port Unreachable——端口不可达   | +| 3 | 4 | Fragmentation needed but no frag. bit set——需要进行分片但设置不分片标志   | +| 3 | 5 | Source routing failed——源站选路失败   | +| 3 | 6 | Destination network unknown——目的网络未知   | +| 3 | 7 | Destination host unknown——目的主机未知   | +| 3 | 8 | Source host isolated (obsolete)——源主机被隔离(作废不用)   | +| 3 | 9 | Destination network administratively prohibited——目的网络被强制禁止   | +| 3 | 10 | Destination host administratively prohibited——目的主机被强制禁止   | +| 3 | 11 | Network unreachable for TOS——由于服务类型 TOS,网络不可达   | +| 3 | 12 | Host unreachable for TOS——由于服务类型 TOS,主机不可达   | +| 3 | 13 | Communication administratively prohibited by filtering——由于过滤,通信被强制禁止   | +| 3 | 14 | Host precedence violation——主机越权   | +| 3 | 15 | Precedence cutoff in effect——优先中止生效   | +| 4 | 0 | Source quench——源端被关闭(基本流控制)     | +| 5 | 0 | Redirect for network——对网络重定向     | +| 5 | 1 | Redirect for host——对主机重定向     | +| 5 | 2 | Redirect for TOS and network——对服务类型和网络重定向     | +| 5 | 3 | Redirect for TOS and host——对服务类型和主机重定向     | +| 8 | 0 | Echo request——回显请求(Ping 请求)   | +| 9 | 0 | Router advertisement——路由器通告     | +| 10 | 0 | Route solicitation——路由器请求     | +| 11 | 0 | TTL equals 0 during transit——传输期间生存时间为 0   | +| 11 | 1 | TTL equals 0 during reassembly——在数据报组装期间生存时间为 0   | +| 12 | 0 | IP header bad (catchall error)——坏的 IP 首部(包括各种差错)   | +| 12 | 1 | Required options missing——缺少必需的选项   | +| 13 | 0 | Timestamp request (obsolete)——时间戳请求(作废不用)   | +| 14 |   | Timestamp reply (obsolete)——时间戳应答(作废不用)   | +| 15 | 0 | Information request (obsolete)——信息请求(作废不用)   | +| 16 | 0 | Information reply (obsolete)——信息应答(作废不用)   | +| 17 | 0 | Address mask request——地址掩码请求   | +| 18 | 0 | Address mask | reply——地址掩码应答 | \ No newline at end of file diff --git a/tcpip/network/ipv6/icmp.go b/tcpip/network/ipv6/icmp.go new file mode 100644 index 0000000..74029ec --- /dev/null +++ b/tcpip/network/ipv6/icmp.go @@ -0,0 +1,231 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "encoding/binary" + + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/stack" +) + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + h := header.IPv6(vv.First()) + + // We don't use IsValid() here because ICMP only requires that up to + // 1280 bytes of the original packet be included. So it's likely that it + // is truncated, which would cause IsValid to return false. + // + // Drop packet if it doesn't have the basic IPv6 header or if the + // original source address doesn't match the endpoint's address. + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + return + } + + // Skip the IP header, then handle the fragmentation header if there + // is one. + vv.TrimFront(header.IPv6MinimumSize) + p := h.TransportProtocol() + if p == header.IPv6FragmentHeader { + f := header.IPv6Fragment(vv.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { + // We can't handle fragments that aren't at offset 0 + // because they don't have the transport headers. + return + } + + // Skip fragmentation header and find out the actual protocol + // number. + vv.TrimFront(header.IPv6FragmentHeaderSize) + p = f.TransportProtocol() + } + + // Deliver the control packet to the transport endpoint. + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) +} + +func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { + v := vv.First() + if len(v) < header.ICMPv6MinimumSize { + return + } + h := header.ICMPv6(v) + + switch h.Type() { + case header.ICMPv6PacketTooBig: + if len(v) < header.ICMPv6PacketTooBigMinimumSize { + return + } + vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:]) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + + case header.ICMPv6DstUnreachable: + if len(v) < header.ICMPv6DstUnreachableMinimumSize { + return + } + vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + switch h.Code() { + case header.ICMPv6PortUnreachable: + e.handleControl(stack.ControlPortUnreachable, 0, vv) + } + + case header.ICMPv6NeighborSolicit: + if len(v) < header.ICMPv6NeighborSolicitMinimumSize { + return + } + targetAddr := tcpip.Address(v[8 : 8+16]) + if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + // We don't have a useful answer; the best we can do is ignore the request. + return + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag + copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr) + pkt[icmpV6OptOffset] = ndpOptDstLinkAddr + pkt[icmpV6LengthOffset] = 1 + copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:]) + + // ICMPv6 Neighbor Solicit messages are always sent to + // specially crafted IPv6 multicast addresses. As a result, the + // route we end up with here has as its LocalAddress such a + // multicast address. It would be nonsense to claim that our + // source address is a multicast address, so we manually set + // the source address to the target address requested in the + // solicit message. Since that requires mutating the route, we + // must first clone it. + r := r.Clone() + defer r.Release() + r.LocalAddress = targetAddr + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + r.WritePacket(hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + + e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + + case header.ICMPv6NeighborAdvert: + if len(v) < header.ICMPv6NeighborAdvertSize { + return + } + targetAddr := tcpip.Address(v[8 : 8+16]) + e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + if targetAddr != r.RemoteAddress { + e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + } + + case header.ICMPv6EchoRequest: + if len(v) < header.ICMPv6EchoMinimumSize { + return + } + vv.TrimFront(header.ICMPv6EchoMinimumSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(pkt, h) + pkt.SetType(header.ICMPv6EchoReply) + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) + r.WritePacket(hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + + case header.ICMPv6EchoReply: + if len(v) < header.ICMPv6EchoMinimumSize { + return + } + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv) + + } +} + +const ( + ndpSolicitedFlag = 1 << 6 + ndpOverrideFlag = 1 << 5 + + ndpOptSrcLinkAddr = 1 + ndpOptDstLinkAddr = 2 + + icmpV6FlagOffset = 4 + icmpV6OptOffset = 24 + icmpV6LengthOffset = 25 +) + +var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + +var _ stack.LinkAddressResolver = (*protocol)(nil) + +// LinkAddressProtocol implements stack.LinkAddressResolver. +func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements stack.LinkAddressResolver. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { + snaddr := header.SolicitedNodeAddr(addr) + r := &stack.Route{ + LocalAddress: localAddr, + RemoteAddress: snaddr, + RemoteLinkAddress: broadcastMAC, + } + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + copy(pkt[icmpV6OptOffset-len(addr):], addr) + pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr + pkt[icmpV6LengthOffset] = 1 + copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(hdr.UsedLength()) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: length, + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: defaultIPv6HopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + + return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) +} + +// ResolveStaticAddress implements stack.LinkAddressResolver. +func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + return "", false +} + +func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := header.Checksum([]byte(src), 0) + xsum = header.Checksum([]byte(dst), xsum) + var upperLayerLength [4]byte + binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) + xsum = header.Checksum(upperLayerLength[:], xsum) + xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum) + for _, v := range vv.Views() { + xsum = header.Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^header.Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum +} diff --git a/tcpip/network/ipv6/ipv6.go b/tcpip/network/ipv6/ipv6.go new file mode 100644 index 0000000..ac458a0 --- /dev/null +++ b/tcpip/network/ipv6/ipv6.go @@ -0,0 +1,187 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ipv6 contains the implementation of the ipv6 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the +// network protocols when calling stack.New(). Then endpoints can be created +// by passing ipv6.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv6 + +import ( + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ipv6 protocol name. + ProtocolName = "ipv6" + + // ProtocolNumber is the ipv6 protocol number. + ProtocolNumber = header.IPv6ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // PayloadLength field of the ipv6 header. + maxPayloadSize = 0xffff + + // defaultIPv6HopLimit is the default hop limit for IPv6 Packets + // egressed by Netstack. + defaultIPv6HopLimit = 255 +) + +type endpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + linkEP stack.LinkEndpoint + linkAddrCache stack.LinkAddressCache + dispatcher stack.TransportDispatcher +} + +// DefaultTTL is the default hop limit for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return 255 +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +// ID returns the ipv6 endpoint ID. +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength returns the maximum length needed by ipv6 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + length := uint16(hdr.UsedLength() + payload.Size()) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: length, + NextHeader: uint8(protocol), + HopLimit: ttl, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + r.Stats().IP.PacketsSent.Increment() + + return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) +} + +// HandlePacket is called by the link layer when new ipv6 packets arrive for +// this endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { + h := header.IPv6(vv.First()) + if !h.IsValid(vv.Size()) { + return + } + + vv.TrimFront(header.IPv6MinimumSize) + vv.CapLength(int(h.PayloadLength())) + + p := h.TransportProtocol() + if p == header.ICMPv6ProtocolNumber { + e.handleICMP(r, vv) + return + } + + r.Stats().IP.PacketsDelivered.Increment() + e.dispatcher.DeliverTransportPacket(r, p, vv) +} + +// Close cleans up resources associated with the endpoint. +func (*endpoint) Close() {} + +type protocol struct{} + +// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported +// only for tests that short-circuit the stack. Regular use of the protocol is +// done via the stack, which gets a protocol descriptor from the init() function +// below. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} + +// Number returns the ipv6 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv6 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint creates a new ipv6 endpoint. +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return &endpoint{ + nicid: nicid, + id: stack.NetworkEndpointID{LocalAddress: addr}, + linkEP: linkEP, + linkAddrCache: linkAddrCache, + dispatcher: dispatcher, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + mtu -= header.IPv6MinimumSize + if mtu <= maxPayloadSize { + return mtu + } + return maxPayloadSize +} + +func init() { + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go index 81c133e..eb855c9 100644 --- a/tcpip/stack/nic.go +++ b/tcpip/stack/nic.go @@ -379,6 +379,19 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLin n.stack.stats.IP.InvalidAddressesReceived.Increment() } +// DeliverTransportPacket delivers packets to the appropriate +// transport protocol endpoint. +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { + +} + +// DeliverTransportControlPacket delivers control packets to the +// appropriate transport protocol endpoint. +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, + trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { + +} + func (n *NIC) ID() tcpip.NICID { return n.id } diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go index dff37e8..e9d6176 100644 --- a/tcpip/stack/registration.go +++ b/tcpip/stack/registration.go @@ -185,18 +185,32 @@ type TransportEndpointID struct { // ControlType 是网络层控制消息的类型 type ControlType int +const ( + ControlPacketTooBig ControlType = iota + ControlPortUnreachable + ControlUnknown +) + // TODO 需要解读 type TransportEndpoint interface { HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) HandleControlPacker(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) } -// TODO 需要解读 +// 传输层协议 TCP OR UDP type TransportProtocol interface { } -// TODO 需要解读 +// 传输层调度器 type TransportDispatcher interface { + // DeliverTransportPacket delivers packets to the appropriate + // transport protocol endpoint. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) + + // DeliverTransportControlPacket delivers control packets to the + // appropriate transport protocol endpoint. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, + trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) } // 注册一个新的网络协议工厂 diff --git a/waiter/waiter.go b/waiter/waiter.go new file mode 100644 index 0000000..7998d4e --- /dev/null +++ b/waiter/waiter.go @@ -0,0 +1,240 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package waiter provides the implementation of a wait queue, where waiters can +// be enqueued to be notified when an event of interest happens. +// +// Becoming readable and/or writable are examples of events. Waiters are +// expected to use a pattern similar to this to make a blocking function out of +// a non-blocking one: +// +// func (o *object) blockingRead(...) error { +// err := o.nonBlockingRead(...) +// if err != ErrAgain { +// // Completed with no need to wait! +// return err +// } +// +// e := createOrGetWaiterEntry(...) +// o.EventRegister(&e, waiter.EventIn) +// defer o.EventUnregister(&e) +// +// // We need to try to read again after registration because the +// // object may have become readable between the last attempt to +// // read and read registration. +// err = o.nonBlockingRead(...) +// for err == ErrAgain { +// wait() +// err = o.nonBlockingRead(...) +// } +// +// return err +// } +// +// Another goroutine needs to notify waiters when events happen. For example: +// +// func (o *object) Write(...) ... { +// // Do write work. +// [...] +// +// if oldDataAvailableSize == 0 && dataAvailableSize > 0 { +// // If no data was available and now some data is +// // available, the object became readable, so notify +// // potential waiters about this. +// o.Notify(waiter.EventIn) +// } +// } +package waiter + +import ( + "sync" + + "netstack/ilist" +) + +// EventMask represents io events as used in the poll() syscall. +type EventMask uint16 + +// Events that waiters can wait on. The meaning is the same as those in the +// poll() syscall. +const ( + EventIn EventMask = 0x01 // syscall.EPOLLIN + EventPri EventMask = 0x02 // syscall.EPOLLPRI + EventOut EventMask = 0x04 // syscall.EPOLLOUT + EventErr EventMask = 0x08 // syscall.EPOLLERR + EventHUp EventMask = 0x10 // syscall.EPOLLHUP + EventNVal EventMask = 0x20 // Not defined in syscall. +) + +// Waitable contains the methods that need to be implemented by waitable +// objects. +type Waitable interface { + // Readiness returns what the object is currently ready for. If it's + // not ready for a desired purpose, the caller may use EventRegister and + // EventUnregister to get notifications once the object becomes ready. + // + // Implementations should allow for events like EventHUp and EventErr + // to be returned regardless of whether they are in the input EventMask. + Readiness(mask EventMask) EventMask + + // EventRegister registers the given waiter entry to receive + // notifications when an event occurs that makes the object ready for + // at least one of the events in mask. + EventRegister(e *Entry, mask EventMask) + + // EventUnregister unregisters a waiter entry previously registered with + // EventRegister(). + EventUnregister(e *Entry) +} + +// EntryCallback provides a notify callback. +type EntryCallback interface { + // Callback is the function to be called when the waiter entry is + // notified. It is responsible for doing whatever is needed to wake up + // the waiter. + // + // The callback is supposed to perform minimal work, and cannot call + // any method on the queue itself because it will be locked while the + // callback is running. + Callback(e *Entry) +} + +// Entry represents a waiter that can be add to the a wait queue. It can +// only be in one queue at a time, and is added "intrusively" to the queue with +// no extra memory allocations. +// +// +stateify savable +type Entry struct { + // Context stores any state the waiter may wish to store in the entry + // itself, which may be used at wake up time. + // + // Note that use of this field is optional and state may alternatively be + // stored in the callback itself. + Context interface{} + + Callback EntryCallback + + // The following fields are protected by the queue lock. + mask EventMask + ilist.Entry +} + +type channelCallback struct{} + +// Callback implements EntryCallback.Callback. +func (*channelCallback) Callback(e *Entry) { + ch := e.Context.(chan struct{}) + select { + case ch <- struct{}{}: + default: + } +} + +// NewChannelEntry initializes a new Entry that does a non-blocking write to a +// struct{} channel when the callback is called. It returns the new Entry +// instance and the channel being used. +// +// If a channel isn't specified (i.e., if "c" is nil), then NewChannelEntry +// allocates a new channel. +func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { + if c == nil { + c = make(chan struct{}, 1) + } + + return Entry{Context: c, Callback: &channelCallback{}}, c +} + +// Queue represents the wait queue where waiters can be added and +// notifiers can notify them when events happen. +// +// The zero value for waiter.Queue is an empty queue ready for use. +// +// +stateify savable +type Queue struct { + list ilist.List + mu sync.RWMutex +} + +// EventRegister adds a waiter to the wait queue; the waiter will be notified +// when at least one of the events specified in mask happens. +func (q *Queue) EventRegister(e *Entry, mask EventMask) { + q.mu.Lock() + e.mask = mask + q.list.PushBack(e) + q.mu.Unlock() +} + +// EventUnregister removes the given waiter entry from the wait queue. +func (q *Queue) EventUnregister(e *Entry) { + q.mu.Lock() + q.list.Remove(e) + q.mu.Unlock() +} + +// Notify notifies all waiters in the queue whose masks have at least one bit +// in common with the notification mask. +func (q *Queue) Notify(mask EventMask) { + q.mu.RLock() + for it := q.list.Front(); it != nil; it = it.Next() { + e := it.(*Entry) + if mask&e.mask != 0 { + e.Callback.Callback(e) + } + } + q.mu.RUnlock() +} + +// Events returns the set of events being waited on. It is the union of the +// masks of all registered entries. +func (q *Queue) Events() EventMask { + ret := EventMask(0) + + q.mu.RLock() + for it := q.list.Front(); it != nil; it = it.Next() { + e := it.(*Entry) + ret |= e.mask + } + q.mu.RUnlock() + + return ret +} + +// IsEmpty returns if the wait queue is empty or not. +func (q *Queue) IsEmpty() bool { + q.mu.Lock() + defer q.mu.Unlock() + + return q.list.Front() == nil +} + +// AlwaysReady implements the Waitable interface but is always ready. Embedding +// this struct into another struct makes it implement the boilerplate empty +// functions automatically. +type AlwaysReady struct { +} + +// Readiness always returns the input mask because this object is always ready. +func (*AlwaysReady) Readiness(mask EventMask) EventMask { + return mask +} + +// EventRegister doesn't do anything because this object doesn't need to issue +// notifications because its readiness never changes. +func (*AlwaysReady) EventRegister(*Entry, EventMask) { +} + +// EventUnregister doesn't do anything because this object doesn't need to issue +// notifications because its readiness never changes. +func (*AlwaysReady) EventUnregister(e *Entry) { +} diff --git a/waiter/waiter_test.go b/waiter/waiter_test.go new file mode 100644 index 0000000..60853f9 --- /dev/null +++ b/waiter/waiter_test.go @@ -0,0 +1,192 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package waiter + +import ( + "sync/atomic" + "testing" +) + +type callbackStub struct { + f func(e *Entry) +} + +// Callback implements EntryCallback.Callback. +func (c *callbackStub) Callback(e *Entry) { + c.f(e) +} + +func TestEmptyQueue(t *testing.T) { + var q Queue + + // Notify the zero-value of a queue. + q.Notify(EventIn) + + // Register then unregister a waiter, then notify the queue. + cnt := 0 + e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + q.EventRegister(&e, EventIn) + q.EventUnregister(&e) + q.Notify(EventIn) + if cnt != 0 { + t.Errorf("Callback was called when it shouldn't have been") + } +} + +func TestMask(t *testing.T) { + // Register a waiter. + var q Queue + var cnt int + e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + q.EventRegister(&e, EventIn|EventErr) + + // Notify with an overlapping mask. + cnt = 0 + q.Notify(EventIn | EventOut) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a subset mask. + cnt = 0 + q.Notify(EventIn) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a superset mask. + cnt = 0 + q.Notify(EventIn | EventErr | EventOut) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with the exact same mask. + cnt = 0 + q.Notify(EventIn | EventErr) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a disjoint mask. + cnt = 0 + q.Notify(EventOut | EventHUp) + if cnt != 0 { + t.Errorf("Callback was called when it shouldn't have been") + } +} + +func TestConcurrentRegistration(t *testing.T) { + var q Queue + var cnt int + const concurrency = 1000 + + ch1 := make(chan struct{}) + ch2 := make(chan struct{}) + ch3 := make(chan struct{}) + + // Create goroutines that will all register/unregister concurrently. + for i := 0; i < concurrency; i++ { + go func() { + var e Entry + e.Callback = &callbackStub{func(entry *Entry) { + cnt++ + if entry != &e { + t.Errorf("entry = %p, want %p", entry, &e) + } + }} + + // Wait for notification, then register. + <-ch1 + q.EventRegister(&e, EventIn|EventErr) + + // Tell main goroutine that we're done registering. + ch2 <- struct{}{} + + // Wait for notification, then unregister. + <-ch3 + q.EventUnregister(&e) + + // Tell main goroutine that we're done unregistering. + ch2 <- struct{}{} + }() + } + + // Let the goroutines register. + close(ch1) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Issue a notification. + q.Notify(EventIn) + if cnt != concurrency { + t.Errorf("cnt = %d, want %d", cnt, concurrency) + } + + // Let the goroutine unregister. + close(ch3) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Issue a notification. + q.Notify(EventIn) + if cnt != concurrency { + t.Errorf("cnt = %d, want %d", cnt, concurrency) + } +} + +func TestConcurrentNotification(t *testing.T) { + var q Queue + var cnt int32 + const concurrency = 1000 + const waiterCount = 1000 + + // Register waiters. + for i := 0; i < waiterCount; i++ { + var e Entry + e.Callback = &callbackStub{func(entry *Entry) { + atomic.AddInt32(&cnt, 1) + if entry != &e { + t.Errorf("entry = %p, want %p", entry, &e) + } + }} + + q.EventRegister(&e, EventIn|EventErr) + } + + // Launch notifiers. + ch1 := make(chan struct{}) + ch2 := make(chan struct{}) + for i := 0; i < concurrency; i++ { + go func() { + <-ch1 + q.Notify(EventIn) + ch2 <- struct{}{} + }() + } + + // Let notifiers go. + close(ch1) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Check the count. + if cnt != concurrency*waiterCount { + t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) + } +}