This commit is contained in:
impact-eintr
2022-12-07 19:43:27 +08:00
parent d4d5c61a83
commit de9a9295b5
15 changed files with 2709 additions and 2709 deletions

View File

@@ -1,280 +1,280 @@
package main
import (
"flag"
"fmt"
"log"
"net"
"netstack/tcpip"
"netstack/tcpip/link/fdbased"
"netstack/tcpip/link/tuntap"
"netstack/tcpip/network/arp"
"netstack/tcpip/network/ipv4"
"netstack/tcpip/network/ipv6"
"netstack/tcpip/stack"
"netstack/tcpip/transport/tcp"
"netstack/tcpip/transport/udp"
"netstack/waiter"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
)
var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
func main() {
flag.Parse()
if len(flag.Args()) != 4 {
log.Fatal("Usage: ", os.Args[0], " <tap-device> <local-address/mask> <ip-address> <local-port>")
}
log.SetFlags(log.Lshortfile | log.LstdFlags)
tapName := flag.Arg(0)
cidrName := flag.Arg(1)
addrName := flag.Arg(2)
portName := flag.Arg(3)
log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName)
maddr, err := net.ParseMAC(*mac)
if err != nil {
log.Fatalf("Bad MAC address: %v", *mac)
}
parsedAddr := net.ParseIP(addrName)
if err != nil {
log.Fatalf("Bad addrress: %v", addrName)
}
// 解析地址ip地址ipv4或者ipv6地址都支持
var addr tcpip.Address
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
addr = tcpip.Address(parsedAddr.To4())
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
addr = tcpip.Address(parsedAddr.To16())
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", parsedAddr)
}
localPort, err := strconv.Atoi(portName)
if err != nil {
log.Fatalf("Unable to convert port %v: %v", portName, err)
}
// 虚拟网卡配置
conf := &tuntap.Config{
Name: tapName,
Mode: tuntap.TAP,
}
var fd int
// 新建虚拟网卡
fd, err = tuntap.NewNetDev(conf)
if err != nil {
log.Fatal(err)
}
// 启动tap网卡
_ = tuntap.SetLinkUp(tapName)
// 设置路由
_ = tuntap.SetRoute(tapName, cidrName)
// 抽象的文件接口
linkID := fdbased.New(&fdbased.Options{
FD: fd,
MTU: 1500,
Address: tcpip.LinkAddress(maddr),
ResolutionRequired: true,
})
// 新建相关协议的协议栈
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName},
[]string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
// 新建抽象的网卡
if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil {
log.Fatal(err)
}
// 在该协议栈上添加和注册相应的网络层
if err := s.AddAddress(1, proto, addr); err != nil {
log.Fatal(err)
}
// 在该协议栈上添加和注册ARP协议
if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
log.Fatal(err)
}
// 添加默认路由
s.SetRouteTable([]tcpip.Route{
{
Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))),
Gateway: "",
NIC: 1,
},
})
//go func() { // echo server
// // 监听udp localPort端口
// conn := udpListen(s, proto, addr, localPort)
// for {
// buf := make([]byte, 1024)
// n, err := conn.Read(buf)
// if err != nil {
// log.Println(err)
// break
// }
// log.Println("接收到数据", string(buf[:n]))
// conn.Write([]byte("server echo"))
// }
// // 关闭监听服务,此时会释放端口
// conn.Close()
//}()
go func() { // echo server
listener := tcpListen(s, proto, addr, localPort)
for {
conn, err := listener.Accept()
if err != nil {
log.Println(err)
continue
}
conn.Read(nil)
}
}()
c := make(chan os.Signal)
signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2)
<-c
}
type TcpConn struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
// Accept 封装tcp的accept操作
func (conn *TcpConn) Accept() (tcpip.Endpoint, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
ep, _, err := conn.ep.Accept()
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return nil, fmt.Errorf("%s", err.String())
}
return ep, nil
}
}
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn {
var wq waiter.Queue
// 新建一个tcp端
ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
if err != nil {
log.Fatal(err)
}
// 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err)
}
// 开始监听
if err := ep.Listen(10); err != nil {
log.Fatal("Listen failed: ", err)
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &TcpConn{
ep: ep,
wq: &wq,
we: &waitEntry,
notifyCh: notifyCh}
}
type UdpConn struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
func (conn *UdpConn) Close() {
conn.ep.Close()
}
func (conn *UdpConn) Read(rcv []byte) (int, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
buf, _, err := conn.ep.Read(&conn.raddr)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return 0, fmt.Errorf("%s", err.String())
}
n := len(buf)
if n > cap(rcv) {
n = cap(rcv)
}
rcv = append(rcv[:0], buf[:n]...)
return n, nil
}
}
func (conn *UdpConn) Write(snd []byte) error {
for {
_, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr})
if err != nil {
if err == tcpip.ErrNoLinkAddress {
<-notifyCh
continue
}
return fmt.Errorf("%s", err.String())
}
return nil
}
}
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn {
var wq waiter.Queue
// 新建一个udp端
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
if err != nil {
log.Fatal(err)
}
// 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现
// 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err)
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &UdpConn{
ep: ep,
wq: &wq,
we: &waitEntry,
notifyCh: notifyCh}
}
package main
import (
"flag"
"fmt"
"log"
"net"
"netstack/tcpip"
"netstack/tcpip/link/fdbased"
"netstack/tcpip/link/tuntap"
"netstack/tcpip/network/arp"
"netstack/tcpip/network/ipv4"
"netstack/tcpip/network/ipv6"
"netstack/tcpip/stack"
"netstack/tcpip/transport/tcp"
"netstack/tcpip/transport/udp"
"netstack/waiter"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
)
var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
func main() {
flag.Parse()
if len(flag.Args()) != 4 {
log.Fatal("Usage: ", os.Args[0], " <tap-device> <local-address/mask> <ip-address> <local-port>")
}
log.SetFlags(log.Lshortfile | log.LstdFlags)
tapName := flag.Arg(0)
cidrName := flag.Arg(1)
addrName := flag.Arg(2)
portName := flag.Arg(3)
log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName)
maddr, err := net.ParseMAC(*mac)
if err != nil {
log.Fatalf("Bad MAC address: %v", *mac)
}
parsedAddr := net.ParseIP(addrName)
if err != nil {
log.Fatalf("Bad addrress: %v", addrName)
}
// 解析地址ip地址ipv4或者ipv6地址都支持
var addr tcpip.Address
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
addr = tcpip.Address(parsedAddr.To4())
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
addr = tcpip.Address(parsedAddr.To16())
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", parsedAddr)
}
localPort, err := strconv.Atoi(portName)
if err != nil {
log.Fatalf("Unable to convert port %v: %v", portName, err)
}
// 虚拟网卡配置
conf := &tuntap.Config{
Name: tapName,
Mode: tuntap.TAP,
}
var fd int
// 新建虚拟网卡
fd, err = tuntap.NewNetDev(conf)
if err != nil {
log.Fatal(err)
}
// 启动tap网卡
_ = tuntap.SetLinkUp(tapName)
// 设置路由
_ = tuntap.SetRoute(tapName, cidrName)
// 抽象的文件接口
linkID := fdbased.New(&fdbased.Options{
FD: fd,
MTU: 1500,
Address: tcpip.LinkAddress(maddr),
ResolutionRequired: true,
})
// 新建相关协议的协议栈
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName},
[]string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
// 新建抽象的网卡
if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil {
log.Fatal(err)
}
// 在该协议栈上添加和注册相应的网络层
if err := s.AddAddress(1, proto, addr); err != nil {
log.Fatal(err)
}
// 在该协议栈上添加和注册ARP协议
if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
log.Fatal(err)
}
// 添加默认路由
s.SetRouteTable([]tcpip.Route{
{
Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))),
Gateway: "",
NIC: 1,
},
})
//go func() { // echo server
// // 监听udp localPort端口
// conn := udpListen(s, proto, addr, localPort)
// for {
// buf := make([]byte, 1024)
// n, err := conn.Read(buf)
// if err != nil {
// log.Println(err)
// break
// }
// log.Println("接收到数据", string(buf[:n]))
// conn.Write([]byte("server echo"))
// }
// // 关闭监听服务,此时会释放端口
// conn.Close()
//}()
go func() { // echo server
listener := tcpListen(s, proto, addr, localPort)
for {
conn, err := listener.Accept()
if err != nil {
log.Println(err)
continue
}
conn.Read(nil)
}
}()
c := make(chan os.Signal)
signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2)
<-c
}
type TcpConn struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
// Accept 封装tcp的accept操作
func (conn *TcpConn) Accept() (tcpip.Endpoint, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
ep, _, err := conn.ep.Accept()
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return nil, fmt.Errorf("%s", err.String())
}
return ep, nil
}
}
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn {
var wq waiter.Queue
// 新建一个tcp端
ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
if err != nil {
log.Fatal(err)
}
// 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err)
}
// 开始监听
if err := ep.Listen(10); err != nil {
log.Fatal("Listen failed: ", err)
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &TcpConn{
ep: ep,
wq: &wq,
we: &waitEntry,
notifyCh: notifyCh}
}
type UdpConn struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
func (conn *UdpConn) Close() {
conn.ep.Close()
}
func (conn *UdpConn) Read(rcv []byte) (int, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
buf, _, err := conn.ep.Read(&conn.raddr)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return 0, fmt.Errorf("%s", err.String())
}
n := len(buf)
if n > cap(rcv) {
n = cap(rcv)
}
rcv = append(rcv[:0], buf[:n]...)
return n, nil
}
}
func (conn *UdpConn) Write(snd []byte) error {
for {
_, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr})
if err != nil {
if err == tcpip.ErrNoLinkAddress {
<-notifyCh
continue
}
return fmt.Errorf("%s", err.String())
}
return nil
}
}
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn {
var wq waiter.Queue
// 新建一个udp端
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
if err != nil {
log.Fatal(err)
}
// 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现
// 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err)
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &UdpConn{
ep: ep,
wq: &wq,
we: &waitEntry,
notifyCh: notifyCh}
}

View File

@@ -1,101 +1,101 @@
package main
import (
"encoding/binary"
"fmt"
"io"
"log"
"net"
"netstack/tcpip"
"netstack/tcpip/header"
"netstack/tcpip/stack"
"netstack/tcpip/transport/udp"
"netstack/waiter"
"runtime"
"strings"
)
type TCPHandler interface {
Handle(net.Conn)
}
func TCPServer(listener net.Listener, handler TCPHandler) error {
log.Printf("netstack 网络解析地址: %s", listener.Addr())
for {
clientConn, err := listener.Accept()
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
log.Printf("temporary Accept() failure - %s", err)
runtime.Gosched()
continue
}
// theres no direct way to detect this error because it is not exposed
if !strings.Contains(err.Error(), "use of closed network connection") {
return fmt.Errorf("listener.Accept() error - %s", err)
}
break
}
go handler.Handle(clientConn)
}
log.Printf("TCP: closing %s", listener.Addr())
return nil
}
var transportPool = make(map[uint64]tcpip.Endpoint)
type RCV struct {
*stack.Stack
ep tcpip.Endpoint
addr tcpip.FullAddress
rcvBuf []byte
}
func (r *RCV) Handle(conn net.Conn) {
var err error
r.rcvBuf, err = io.ReadAll(conn)
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
panic(err)
}
switch string(r.rcvBuf[:3]) {
case "udp":
var wq waiter.Queue
// 新建一个udp端
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
if err != nil {
log.Fatal(err)
}
r.ep = ep
r.Bind()
r.Connect()
r.Close()
case "tcp":
default:
return
}
}
func (r *RCV) Bind() {
if len(r.rcvBuf) < 9 { // udp ip port
log.Println("Error: too few arg")
return
}
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
r.addr = tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(r.rcvBuf[3:7]),
Port: port,
}
r.ep.Bind(r.addr, nil)
}
func (r *RCV) Connect() {
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
}
func (r *RCV) Close() {
r.ep.Close()
}
package main
import (
"encoding/binary"
"fmt"
"io"
"log"
"net"
"netstack/tcpip"
"netstack/tcpip/header"
"netstack/tcpip/stack"
"netstack/tcpip/transport/udp"
"netstack/waiter"
"runtime"
"strings"
)
type TCPHandler interface {
Handle(net.Conn)
}
func TCPServer(listener net.Listener, handler TCPHandler) error {
log.Printf("netstack 网络解析地址: %s", listener.Addr())
for {
clientConn, err := listener.Accept()
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
log.Printf("temporary Accept() failure - %s", err)
runtime.Gosched()
continue
}
// theres no direct way to detect this error because it is not exposed
if !strings.Contains(err.Error(), "use of closed network connection") {
return fmt.Errorf("listener.Accept() error - %s", err)
}
break
}
go handler.Handle(clientConn)
}
log.Printf("TCP: closing %s", listener.Addr())
return nil
}
var transportPool = make(map[uint64]tcpip.Endpoint)
type RCV struct {
*stack.Stack
ep tcpip.Endpoint
addr tcpip.FullAddress
rcvBuf []byte
}
func (r *RCV) Handle(conn net.Conn) {
var err error
r.rcvBuf, err = io.ReadAll(conn)
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
panic(err)
}
switch string(r.rcvBuf[:3]) {
case "udp":
var wq waiter.Queue
// 新建一个udp端
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
if err != nil {
log.Fatal(err)
}
r.ep = ep
r.Bind()
r.Connect()
r.Close()
case "tcp":
default:
return
}
}
func (r *RCV) Bind() {
if len(r.rcvBuf) < 9 { // udp ip port
log.Println("Error: too few arg")
return
}
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
r.addr = tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(r.rcvBuf[3:7]),
Port: port,
}
r.ep.Bind(r.addr, nil)
}
func (r *RCV) Connect() {
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
}
func (r *RCV) Close() {
r.ep.Close()
}

View File

@@ -1,41 +1,41 @@
package main
import (
"flag"
"log"
"net"
)
func main() {
var (
addr = flag.String("a", "192.168.1.1:9999", "udp dst address")
)
log.SetFlags(log.Lshortfile | log.LstdFlags)
var err error
udpAddr, err := net.ResolveUDPAddr("udp", *addr)
if err != nil {
panic(err)
}
// 建立UDP连接只是填息了目的IP和端口并未真正的建立连接
conn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
panic(err)
}
//send := []byte("hello world")
send := make([]byte, 1600)
if _, err := conn.Write(send); err != nil {
panic(err)
}
log.Printf("send: %s", string(send))
recv := make([]byte, 32)
rn, _, err := conn.ReadFrom(recv)
if err != nil {
panic(err)
}
log.Printf("recv: %s", string(recv[:rn]))
}
package main
import (
"flag"
"log"
"net"
)
func main() {
var (
addr = flag.String("a", "192.168.1.1:9999", "udp dst address")
)
log.SetFlags(log.Lshortfile | log.LstdFlags)
var err error
udpAddr, err := net.ResolveUDPAddr("udp", *addr)
if err != nil {
panic(err)
}
// 建立UDP连接只是填息了目的IP和端口并未真正的建立连接
conn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
panic(err)
}
//send := []byte("hello world")
send := make([]byte, 1600)
if _, err := conn.Write(send); err != nil {
panic(err)
}
log.Printf("send: %s", string(send))
recv := make([]byte, 32)
rn, _, err := conn.ReadFrom(recv)
if err != nil {
panic(err)
}
log.Printf("recv: %s", string(recv[:rn]))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,50 +1,50 @@
package seqnum
// Value represents the value of a sequence number.
type Value uint32
// Size represents the size (length) of a sequence number window
type Size uint32
// LessThan v < w
func (v Value) LessThan(w Value) bool {
return int32(v-w) < 0
}
// LessThanEq returns true if v==w or v is before i.e., v < w.
func (v Value) LessThanEq(w Value) bool {
if v == w {
return true
}
return v.LessThan(w)
}
// InRange v ∈ [a, b)
func (v Value) InRange(a, b Value) bool {
return a <= v && v < b
}
// InWindows check v in [first, first+size)
func (v Value) InWindows(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
// Add return v + s
func (v Value) Add(s Size) Value {
return v + Value(s)
}
// Size return the size of [v, w)
func (v Value) Size(w Value) Size {
return Size(w - v)
}
// UpdateForward update the value to v+s
func (v *Value) UpdateForward(s Size) {
*v += Value(s)
}
// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
func Overlap(a Value, b Size, x Value, y Size) bool {
return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
}
package seqnum
// Value represents the value of a sequence number.
type Value uint32
// Size represents the size (length) of a sequence number window
type Size uint32
// LessThan v < w
func (v Value) LessThan(w Value) bool {
return int32(v-w) < 0
}
// LessThanEq returns true if v==w or v is before i.e., v < w.
func (v Value) LessThanEq(w Value) bool {
if v == w {
return true
}
return v.LessThan(w)
}
// InRange v ∈ [a, b)
func (v Value) InRange(a, b Value) bool {
return a <= v && v < b
}
// InWindows check v in [first, first+size)
func (v Value) InWindows(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
// Add return v + s
func (v Value) Add(s Size) Value {
return v + Value(s)
}
// Size return the size of [v, w)
func (v Value) Size(w Value) Size {
return Size(w - v)
}
// UpdateForward update the value to v+s
func (v *Value) UpdateForward(s Size) {
*v += Value(s)
}
// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
func Overlap(a Value, b Size, x Value, y Size) bool {
return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
}

View File

@@ -1,77 +1,77 @@
# TCP 协议
## tcp特点
1. tcp 是面向连接的传输协议。
2. tcp 的连接是端到端的。
3. tcp 提供可靠的传输。
4. tcp 的传输以字节流的方式。
5. tcp 提供全双工的通信。
6. tcp 有拥塞控制。
![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555573949562.png)
``` sh
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Acknowledgment Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data | |U|A|P|R|S|F| |
| Offset| Reserved |R|C|S|S|Y|I| Window |
| | |G|K|H|T|N|N| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Checksum | Urgent Pointer |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
```
1. 源端口和目的端口 各占 2 个字节,分别 tcp 连接的源端口和目的端口。关于端口的概念之前已经介绍过了。
2. 序号 占 4 字节,序号范围是[02^32 - 1],共 2^32即 4294967296个序号。序号增加到 2^32-1 后,下一个序号就又回到 0。TCP 是面向字节流的,在一个 TCP 连接中传送的字节流中的每一个字节都按顺序编号。整个要传送的字节流的起始序号ISN必须在连接建立时设置。首部中的序号字段值则是指的是本报文段所发送的数据的第一个字节的序号。例如一报文段的序号是 301而接待的数据共有 100 字节。这就表明:本报文段的数据的第一个字节的序号是 301最后一个字节的序号是 400。显然下一个报文段如果还有的话的数据序号应当从 401 开始,即下一个报文段的序号字段值应为 401。
3. 确认号 占 4 字节是期望收到对方下一个报文段的第一个数据字节的序号。例如B 正确收到了 A 发送过来的一个报文段,其序号字段值是 501而数据长度是 200 字节(序号 501~700这表明 B 正确收到了 A 发送的到序号 700 为止的数据。因此B 期望收到 A 的下一个数据序号是 701于是 B 在发送给 A 的确认报文段中把确认号置为 701。注意现在确认号不是 501也不是 700而是 701。 总之:若确认号为 N则表明到序号 N-1 为止的所有数据都已正确收到。TCP 除了第一个 SYN 报文之外,所有 TCP 报文都需要携带 ACK 状态位。
4. 数据偏移 占 4 位,它指出 TCP 报文段的数据起始处距离 TCP 报文段的起始处有多远。这个字段实际上是指出 TCP 报文段的首部长度。由于首部中还有长度不确定的选项字段,因此数据偏移字段是必要的,但应注意,“数据偏移”的单位是 4 个字节,由于 4 位二进制数能表示的最大十进制数字是 15因此数据偏移的最大值是 60 字节。
5. 保留 占 6 位,保留为今后使用,但目前应置为 0。
6. 控制报文标志
- **紧急URGURGent** 当 URG=1 时,表明紧急指针字段有效。它告诉系统此报文段中有紧急数据,应尽快发送(相当于高优先级的数据),而不要按原来的排队顺序来传送。例如,已经发送了很长的一个程序要在远地的主机上运行。但后来发现了一些问题,需要取消该程序的运行,因此用户从键盘发出中断命令。如果不使用紧急数据,那么这两个字符将存储在接收 TCP 的缓存末尾。只有在所有的数据被处理完毕后这两个字符才被交付接收方的应用进程。这样做就浪费了很多时间。 当 URG 置为 1 时,发送应用进程就告诉发送方的 TCP 有紧急数据要传送。于是发送方 TCP 就把紧急数据插入到本报文段数据的最前面而在紧急数据后面的数据仍然是普通数据。这时要与首部中紧急指针Urgent Pointer字段配合使用。
- **确认ACKACKnowledgment** 仅当 ACK=1 时确认号字段才有效,当 ACK=0 时确认号无效。TCP 规定,在连接建立后所有的传送的报文段都必须把 ACK 置为 1。
- **推送 PSHPuSH** 当两个应用进程进行交互式的通信时有时在一端的应用进程希望在键入一个命令后立即就能收到对方的响应。在这种情况下TCP 就可以使用推送push操作。这时发送方 TCP 把 PSH 置为 1并立即创建一个报文段发送出去。接收方 TCP 收到 PSH=1 的报文段,就尽快地交付接收应用进程。
- **复位RSTReSeT** 当 RST=1 时,表名 TCP 连接中出现了严重错误如由于主机崩溃或其他原因必须释放连接然后再重新建立传输连接。RST 置为 1 用来拒绝一个非法的报文段或拒绝打开一个连接。
- **同步SYNSYNchronization** 在连接建立时用来同步序号。当 SYN=1 而 ACK=0 时,表明这是一个连接请求报文段。对方若同意建立连接,则应在响应的报文段中使 SYN=1 和 ACK=1因此 SYN 置为 1 就表示这是一个连接请求或连接接受报文。
- **终止FINFINis意思是“完”“终”** 用来释放一个连接。当 FIN=1 时,表明此报文段的发送发的数据已发送完毕,并要求释放运输连接。
7. 窗口 占 2 字节,窗口值是[02^16-1]之间的整数。窗口指的是发送本报文段的一方的接受窗口(而不是自己的发送窗口)。窗口值告诉对方:从本报文段首部中的确认号算起,接收方目前允许对方发送的数据量(以字节为单位)。之所以要有这个限制,是因为接收方的数据缓存空间是有限的。总之,窗口值作为接收方让发送方设置其发送窗口的依据,作为流量控制的依据,后面会详细介绍。 总之:窗口字段明确指出了现在允许对方发送的数据量。窗口值经常在动态变化。
8. 检验和 占 2 字节,检验和字段检验的范围包括首部和数据这两部分。和 UDP 用户数据报一样,在计算检验和时,要在 TCP 报文段的前面加上 12 字节的伪首部。伪首部的格式和 UDP 用户数据报的伪首部一样。但应把伪首部第 4 个字段中的 17 改为 6TCP 的协议号是 6把第 5 字段中的 UDP 中的长度改为 TCP 长度。接收方收到此报文段后,仍要加上这个伪首部来计算检验和。若使用 IPv6则相应的伪首部也要改变。
9. 紧急指针 占 2 字节,紧急指针仅在 URG=1 时才有意义,它指出本报文段中的紧急数据的字节数(紧急数据结束后就是普通数据) 。因此在紧急指针指出了紧急数据的末尾在报文段中的位置。当所有紧急数据都处理完时TCP 就告诉应用程序恢复到正常操作。值得注意的是,即使窗口为 0 时也可以发送紧急数据。
10. 选项 选项长度可变,最长可达 40 字节。当没有使用“选项”时TCP 的首部长度是 20 字节。TCP 首部总长度由 TCP 头中的“数据偏移”字段决定,前面说了,最长偏移为 60 字节。那么“tcp 选项”的长度最大为 60-20=40 字节。
## tcp选项
TCP 最初只规定了一种选项,即最大报文段长度 MSSMaximum Segment Szie。后来又增加了几个选项如窗口扩大选项、时间戳选项等下面说明常用的选项。
1. kind=0 是选项表结束选项。
2. kind=1 是空操作nop选项
没有特殊含义,一般用于将 TCP 选项的总长度填充为 4 字节的整数倍,为啥需要 4 字节整数倍?因为前面讲了数据偏移字段的单位是 4 个字节。
3. kind=2 是最大报文段长度选项
TCP 连接初始化时通信双方使用该选项来协商最大报文段长度Max Segment SizeMSS。TCP 模块通常将 MSS 设置为MTU-40字节减掉的这 40 字节包括 20 字节的 TCP 头部和 20 字节的 IP 头部)。这样携带 TCP 报文段的 IP 数据报的长度就不会超过 MTU假设 TCP 头部和 IP 头部都不包含选项字段,并且这也是一般情况),从而避免本机发生 IP 分片。对以太网而言MSS 值是 14601500-40字节。
4. kind=3 是窗口扩大因子选项
TCP 连接初始化时,通信双方使用该选项来协商接收通告窗口的扩大因子。在 TCP 的头部中,接收通告窗口大小是用 16 位表示的,故最大为 65535 字节,但实际上 TCP 模块允许的接收通告窗口大小远不止这个数(为了提高 TCP 通信的吞吐量)。窗口扩大因子解决了这个问题。假设 TCP 头部中的接收通告窗口大小是 N窗口扩大因子移位数是 M那么 TCP 报文段的实际接收通告窗口大小是 N 乘 2M或者说 N 左移 M 位。注意M 的取值范围是 0 14。
和 MSS 选项一样,窗口扩大因子选项只能出现在同步报文段中,否则将被忽略。但同步报文段本身不执行窗口扩大操作,即同步报文段头部的接收通告窗口大小就是该 TCP 报文段的实际接收通告窗口大小。当连接建立好之后,每个数据传输方向的窗口扩大因子就固定不变了。关于窗口扩大因子选项的细节,可参考标准文档 RFC 1323。
5. kind=4 是选择性确认Selective AcknowledgmentSACK选项
TCP 通信时,如果某个 TCP 报文段丢失,则 TCP 模块会重传最后被确认的 TCP 报文段后续的所有报文段,这样原先已经正确传输的 TCP 报文段也可能重复发送,从而降低了 TCP 性能。SACK 技术正是为改善这种情况而产生的,它使 TCP 模块只重新发送丢失的 TCP 报文段,不用发送所有未被确认的 TCP 报文段。选择性确认选项用在连接初始化时,表示是否支持 SACK 技术。
6. kind=5 是 SACK 实际工作的选项
该选项的参数告诉发送方本端已经收到并缓存的不连续的数据块从而让发送端可以据此检查并重发丢失的数据块。每个块边沿edge of block参数包含一个 4 字节的序号。其中块左边沿表示不连续块的第一个数据的序号,而块右边沿则表示不连续块的最后一个数据的序号的下一个序号。这样一对参数(块左边沿和块右边沿)之间的数据是没有收到的。因为一个块信息占用 8 字节,所以 TCP 头部选项中实际上最多可以包含 4 个这样的不连续数据块(考虑选项类型和长度占用的 2 字节)。
7. kind=8 是时间戳选项
该选项提供了较为准确的计算通信双方之间的回路时间Round Trip TimeRTT的方法从而为 TCP 流量控制提供重要信息。
# TCP 协议
## tcp特点
1. tcp 是面向连接的传输协议。
2. tcp 的连接是端到端的。
3. tcp 提供可靠的传输。
4. tcp 的传输以字节流的方式。
5. tcp 提供全双工的通信。
6. tcp 有拥塞控制。
![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555573949562.png)
``` sh
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Acknowledgment Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data | |U|A|P|R|S|F| |
| Offset| Reserved |R|C|S|S|Y|I| Window |
| | |G|K|H|T|N|N| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Checksum | Urgent Pointer |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
```
1. 源端口和目的端口 各占 2 个字节,分别 tcp 连接的源端口和目的端口。关于端口的概念之前已经介绍过了。
2. 序号 占 4 字节,序号范围是[02^32 - 1],共 2^32即 4294967296个序号。序号增加到 2^32-1 后,下一个序号就又回到 0。TCP 是面向字节流的,在一个 TCP 连接中传送的字节流中的每一个字节都按顺序编号。整个要传送的字节流的起始序号ISN必须在连接建立时设置。首部中的序号字段值则是指的是本报文段所发送的数据的第一个字节的序号。例如一报文段的序号是 301而接待的数据共有 100 字节。这就表明:本报文段的数据的第一个字节的序号是 301最后一个字节的序号是 400。显然下一个报文段如果还有的话的数据序号应当从 401 开始,即下一个报文段的序号字段值应为 401。
3. 确认号 占 4 字节是期望收到对方下一个报文段的第一个数据字节的序号。例如B 正确收到了 A 发送过来的一个报文段,其序号字段值是 501而数据长度是 200 字节(序号 501~700这表明 B 正确收到了 A 发送的到序号 700 为止的数据。因此B 期望收到 A 的下一个数据序号是 701于是 B 在发送给 A 的确认报文段中把确认号置为 701。注意现在确认号不是 501也不是 700而是 701。 总之:若确认号为 N则表明到序号 N-1 为止的所有数据都已正确收到。TCP 除了第一个 SYN 报文之外,所有 TCP 报文都需要携带 ACK 状态位。
4. 数据偏移 占 4 位,它指出 TCP 报文段的数据起始处距离 TCP 报文段的起始处有多远。这个字段实际上是指出 TCP 报文段的首部长度。由于首部中还有长度不确定的选项字段,因此数据偏移字段是必要的,但应注意,“数据偏移”的单位是 4 个字节,由于 4 位二进制数能表示的最大十进制数字是 15因此数据偏移的最大值是 60 字节。
5. 保留 占 6 位,保留为今后使用,但目前应置为 0。
6. 控制报文标志
- **紧急URGURGent** 当 URG=1 时,表明紧急指针字段有效。它告诉系统此报文段中有紧急数据,应尽快发送(相当于高优先级的数据),而不要按原来的排队顺序来传送。例如,已经发送了很长的一个程序要在远地的主机上运行。但后来发现了一些问题,需要取消该程序的运行,因此用户从键盘发出中断命令。如果不使用紧急数据,那么这两个字符将存储在接收 TCP 的缓存末尾。只有在所有的数据被处理完毕后这两个字符才被交付接收方的应用进程。这样做就浪费了很多时间。 当 URG 置为 1 时,发送应用进程就告诉发送方的 TCP 有紧急数据要传送。于是发送方 TCP 就把紧急数据插入到本报文段数据的最前面而在紧急数据后面的数据仍然是普通数据。这时要与首部中紧急指针Urgent Pointer字段配合使用。
- **确认ACKACKnowledgment** 仅当 ACK=1 时确认号字段才有效,当 ACK=0 时确认号无效。TCP 规定,在连接建立后所有的传送的报文段都必须把 ACK 置为 1。
- **推送 PSHPuSH** 当两个应用进程进行交互式的通信时有时在一端的应用进程希望在键入一个命令后立即就能收到对方的响应。在这种情况下TCP 就可以使用推送push操作。这时发送方 TCP 把 PSH 置为 1并立即创建一个报文段发送出去。接收方 TCP 收到 PSH=1 的报文段,就尽快地交付接收应用进程。
- **复位RSTReSeT** 当 RST=1 时,表名 TCP 连接中出现了严重错误如由于主机崩溃或其他原因必须释放连接然后再重新建立传输连接。RST 置为 1 用来拒绝一个非法的报文段或拒绝打开一个连接。
- **同步SYNSYNchronization** 在连接建立时用来同步序号。当 SYN=1 而 ACK=0 时,表明这是一个连接请求报文段。对方若同意建立连接,则应在响应的报文段中使 SYN=1 和 ACK=1因此 SYN 置为 1 就表示这是一个连接请求或连接接受报文。
- **终止FINFINis意思是“完”“终”** 用来释放一个连接。当 FIN=1 时,表明此报文段的发送发的数据已发送完毕,并要求释放运输连接。
7. 窗口 占 2 字节,窗口值是[02^16-1]之间的整数。窗口指的是发送本报文段的一方的接受窗口(而不是自己的发送窗口)。窗口值告诉对方:从本报文段首部中的确认号算起,接收方目前允许对方发送的数据量(以字节为单位)。之所以要有这个限制,是因为接收方的数据缓存空间是有限的。总之,窗口值作为接收方让发送方设置其发送窗口的依据,作为流量控制的依据,后面会详细介绍。 总之:窗口字段明确指出了现在允许对方发送的数据量。窗口值经常在动态变化。
8. 检验和 占 2 字节,检验和字段检验的范围包括首部和数据这两部分。和 UDP 用户数据报一样,在计算检验和时,要在 TCP 报文段的前面加上 12 字节的伪首部。伪首部的格式和 UDP 用户数据报的伪首部一样。但应把伪首部第 4 个字段中的 17 改为 6TCP 的协议号是 6把第 5 字段中的 UDP 中的长度改为 TCP 长度。接收方收到此报文段后,仍要加上这个伪首部来计算检验和。若使用 IPv6则相应的伪首部也要改变。
9. 紧急指针 占 2 字节,紧急指针仅在 URG=1 时才有意义,它指出本报文段中的紧急数据的字节数(紧急数据结束后就是普通数据) 。因此在紧急指针指出了紧急数据的末尾在报文段中的位置。当所有紧急数据都处理完时TCP 就告诉应用程序恢复到正常操作。值得注意的是,即使窗口为 0 时也可以发送紧急数据。
10. 选项 选项长度可变,最长可达 40 字节。当没有使用“选项”时TCP 的首部长度是 20 字节。TCP 首部总长度由 TCP 头中的“数据偏移”字段决定,前面说了,最长偏移为 60 字节。那么“tcp 选项”的长度最大为 60-20=40 字节。
## tcp选项
TCP 最初只规定了一种选项,即最大报文段长度 MSSMaximum Segment Szie。后来又增加了几个选项如窗口扩大选项、时间戳选项等下面说明常用的选项。
1. kind=0 是选项表结束选项。
2. kind=1 是空操作nop选项
没有特殊含义,一般用于将 TCP 选项的总长度填充为 4 字节的整数倍,为啥需要 4 字节整数倍?因为前面讲了数据偏移字段的单位是 4 个字节。
3. kind=2 是最大报文段长度选项
TCP 连接初始化时通信双方使用该选项来协商最大报文段长度Max Segment SizeMSS。TCP 模块通常将 MSS 设置为MTU-40字节减掉的这 40 字节包括 20 字节的 TCP 头部和 20 字节的 IP 头部)。这样携带 TCP 报文段的 IP 数据报的长度就不会超过 MTU假设 TCP 头部和 IP 头部都不包含选项字段,并且这也是一般情况),从而避免本机发生 IP 分片。对以太网而言MSS 值是 14601500-40字节。
4. kind=3 是窗口扩大因子选项
TCP 连接初始化时,通信双方使用该选项来协商接收通告窗口的扩大因子。在 TCP 的头部中,接收通告窗口大小是用 16 位表示的,故最大为 65535 字节,但实际上 TCP 模块允许的接收通告窗口大小远不止这个数(为了提高 TCP 通信的吞吐量)。窗口扩大因子解决了这个问题。假设 TCP 头部中的接收通告窗口大小是 N窗口扩大因子移位数是 M那么 TCP 报文段的实际接收通告窗口大小是 N 乘 2M或者说 N 左移 M 位。注意M 的取值范围是 0 14。
和 MSS 选项一样,窗口扩大因子选项只能出现在同步报文段中,否则将被忽略。但同步报文段本身不执行窗口扩大操作,即同步报文段头部的接收通告窗口大小就是该 TCP 报文段的实际接收通告窗口大小。当连接建立好之后,每个数据传输方向的窗口扩大因子就固定不变了。关于窗口扩大因子选项的细节,可参考标准文档 RFC 1323。
5. kind=4 是选择性确认Selective AcknowledgmentSACK选项
TCP 通信时,如果某个 TCP 报文段丢失,则 TCP 模块会重传最后被确认的 TCP 报文段后续的所有报文段,这样原先已经正确传输的 TCP 报文段也可能重复发送,从而降低了 TCP 性能。SACK 技术正是为改善这种情况而产生的,它使 TCP 模块只重新发送丢失的 TCP 报文段,不用发送所有未被确认的 TCP 报文段。选择性确认选项用在连接初始化时,表示是否支持 SACK 技术。
6. kind=5 是 SACK 实际工作的选项
该选项的参数告诉发送方本端已经收到并缓存的不连续的数据块从而让发送端可以据此检查并重发丢失的数据块。每个块边沿edge of block参数包含一个 4 字节的序号。其中块左边沿表示不连续块的第一个数据的序号,而块右边沿则表示不连续块的最后一个数据的序号的下一个序号。这样一对参数(块左边沿和块右边沿)之间的数据是没有收到的。因为一个块信息占用 8 字节,所以 TCP 头部选项中实际上最多可以包含 4 个这样的不连续数据块(考虑选项类型和长度占用的 2 字节)。
7. kind=8 是时间戳选项
该选项提供了较为准确的计算通信双方之间的回路时间Round Trip TimeRTT的方法从而为 TCP 流量控制提供重要信息。

View File

@@ -1,330 +1,330 @@
package tcp
import (
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"hash"
"io"
"log"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"sync"
"time"
)
const (
// tsLen is the length, in bits, of the timestamp in the SYN cookie.
tsLen = 8
// tsMask is a mask for timestamp values (i.e., tsLen bits).
tsMask = (1 << tsLen) - 1
// tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
tsOffset = 24
// hashMask is the mask for hash values (i.e., tsOffset bits).
hashMask = (1 << tsOffset) - 1
// maxTSDiff is the maximum allowed difference between a received cookie
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
)
var (
// SynRcvdCountThreshold is the global maximum number of connections
// that are allowed to be in SYN-RCVD state before TCP starts using SYN
// cookies to accept connections.
//
// It is an exported variable only for testing, and should not otherwise
// be used by importers of this package.
SynRcvdCountThreshold uint64 = 1000
// mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460}
)
func encodeMSS(mss uint16) uint32 {
for i := len(mssTable) - 1; i > 0; i-- {
if mss >= mssTable[i] {
return uint32(i)
}
}
return 0
}
// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
// protected by a mutex so that we can increment only when it's guaranteed not
// to go above a threshold.
var synRcvdCount struct {
sync.Mutex
value uint64
pending sync.WaitGroup
}
// listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
stack *stack.Stack
rcvWnd seqnum.Size
nonce [2][sha1.BlockSize]byte // nonce 随机数
hasherMu sync.Mutex
hasher hash.Hash // 散列实现
v6only bool
netProto tcpip.NetworkProtocolNumber
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask // 00 00 00 FF
}
// 增加一个任务 最多1000个
func incSynRcvdCount() bool {
synRcvdCount.Mutex.Lock()
defer synRcvdCount.Unlock()
if synRcvdCount.value >= SynRcvdCountThreshold {
return false
}
synRcvdCount.pending.Add(1)
synRcvdCount.value++
return true
}
// 结束一个任务
func decSynRcvdCount() {
synRcvdCount.Mutex.Lock()
defer synRcvdCount.Unlock()
synRcvdCount.value--
synRcvdCount.pending.Done()
}
// newListenContext creates a new listen context.
func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stack,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6only: v6only,
netProto: netProto,
}
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
return l
}
// cookieHash calculates the cookieHash for the given id, timestamp and nonce
// index. The hash is used to create and validate cookies.
func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
// Initialize block with fixed-size data: local ports and v.
var payload [8]byte
binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
binary.BigEndian.PutUint32(payload[4:], ts)
// Feed everything to the hasher.
l.hasherMu.Lock()
l.hasher.Reset()
l.hasher.Write(payload[:])
l.hasher.Write(l.nonce[nonceIndex][:])
io.WriteString(l.hasher, string(id.LocalAddress))
io.WriteString(l.hasher, string(id.RemoteAddress))
// Finalize the calculation of the hash and return the first 4 bytes.
h := make([]byte, 0, sha1.Size)
h = l.hasher.Sum(h)
l.hasherMu.Unlock()
return binary.BigEndian.Uint32(h[:])
}
// createCookie creates a SYN cookie for the given id and incoming sequence
// number.
func (l *listenContext) createCookie(id stack.TransportEndpointID,
seq seqnum.Value, data uint32) seqnum.Value {
ts := timeStamp()
v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
v += (l.cookieHash(id, ts, 1) + data) & hashMask
return seqnum.Value(v)
}
// isCookieValid checks if the supplied cookie is valid for the given id and
// sequence number. If it is, it also returns the data originally encoded in the
// cookie when createCookie was called.
func (l *listenContext) isCookieValid(id stack.TransportEndpointID,
cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
ts := timeStamp()
v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
cookieTS := v >> tsOffset
if ((ts - cookieTS) & tsMask) > maxTSDiff {
return 0, false
}
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
// 新建一个tcp端 这个tcp端与segment同属一个tcp连接 但属于不同阶段 用于写回远端
func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value,
irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
n := newEndpoint(l.stack, netProto, nil)
n.v6only = l.v6only
n.id = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
// Register new endpoint so that packets are routed to it.
// 在网络协议栈中去注册这个tcp端
if err := n.stack.RegisterTransportEndpoint(n.boundNICID,
n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
n.Close()
return nil, err
}
n.isRegistered = true
n.state = stateConnected
// Create sender and receiver.
// The receiver at least temporarily has a zero receive window scale,
// but the caller may change it (before starting the protocol loop).
n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
return n, nil
}
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// create new endpoint
irs := s.sequenceNumber
cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
log.Println("收到一个远端握手申请", irs, "标记cookie", cookie)
ep, err := l.createConnectedEndpoint(s, cookie, irs, opts)
if err != nil {
return nil, err
}
// 以下执行三次握手
// 构建handshake管理器
h, err := newHandshake(ep, l.rcvWnd)
if err != nil {
ep.Close()
return nil, err
}
// 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack
h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
ep.Close()
return nil, err
}
// 更新接收窗口扩张因子
return ep, nil
}
// 一旦侦听端点收到SYN段handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。
// 在TCP开始使用SYN cookie接受连接之前允许使用有限数量的这些goroutine。
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer decSynRcvdCount()
defer s.decRef()
_, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
return
}
// 到这里,三次握手已经完成,那么分发一个新的连接
//e.deliverAccepted(n)
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
switch s.flags {
case flagSyn: // syn报文处理
// 分析tcp选项
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts)
} else {
log.Println("暂时不处理")
}
// 返回一个syn+ack报文
case flagFin: // fin报文处理
// 三次握手最后一次 ack 报文
}
}
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck))
if synOpts.TS {
s.parsedOptions.TSVal = synOpts.TSVal
s.parsedOptions.TSEcr = synOpts.TSEcr
}
return synOpts
}
// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行负责处理连接请求
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
defer func() {
// TODO 后置处理
}()
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto)
// 初始化事件触发器 并添加事件
s := sleep.Sleeper{}
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
s.AddWaker(&e.notificationWaker, wakerForNotification)
for {
switch index, _ := s.Fetch(true); index { // Fetch(true) 阻塞获取
case wakerForNewSegment:
mayRequeue := true
// 接收和处理tcp报文
for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue()
if s == nil {
mayRequeue = false
break
}
e.handleListenSegment(ctx, s)
s.decRef()
}
// If the queue is not empty, make sure we'll wake up
// in the next iteration.
if mayRequeue && !e.segmentQueue.empty() { // 主协程又添加了新数据
e.newSegmentWaker.Assert() // 重新尝试获取数据
}
case wakerForNotification:
// TODO 触发其他事件
log.Println("其他事件?")
}
}
}
package tcp
import (
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"hash"
"io"
"log"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"sync"
"time"
)
const (
// tsLen is the length, in bits, of the timestamp in the SYN cookie.
tsLen = 8
// tsMask is a mask for timestamp values (i.e., tsLen bits).
tsMask = (1 << tsLen) - 1
// tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
tsOffset = 24
// hashMask is the mask for hash values (i.e., tsOffset bits).
hashMask = (1 << tsOffset) - 1
// maxTSDiff is the maximum allowed difference between a received cookie
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
)
var (
// SynRcvdCountThreshold is the global maximum number of connections
// that are allowed to be in SYN-RCVD state before TCP starts using SYN
// cookies to accept connections.
//
// It is an exported variable only for testing, and should not otherwise
// be used by importers of this package.
SynRcvdCountThreshold uint64 = 1000
// mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460}
)
func encodeMSS(mss uint16) uint32 {
for i := len(mssTable) - 1; i > 0; i-- {
if mss >= mssTable[i] {
return uint32(i)
}
}
return 0
}
// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
// protected by a mutex so that we can increment only when it's guaranteed not
// to go above a threshold.
var synRcvdCount struct {
sync.Mutex
value uint64
pending sync.WaitGroup
}
// listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
stack *stack.Stack
rcvWnd seqnum.Size
nonce [2][sha1.BlockSize]byte // nonce 随机数
hasherMu sync.Mutex
hasher hash.Hash // 散列实现
v6only bool
netProto tcpip.NetworkProtocolNumber
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask // 00 00 00 FF
}
// 增加一个任务 最多1000个
func incSynRcvdCount() bool {
synRcvdCount.Mutex.Lock()
defer synRcvdCount.Unlock()
if synRcvdCount.value >= SynRcvdCountThreshold {
return false
}
synRcvdCount.pending.Add(1)
synRcvdCount.value++
return true
}
// 结束一个任务
func decSynRcvdCount() {
synRcvdCount.Mutex.Lock()
defer synRcvdCount.Unlock()
synRcvdCount.value--
synRcvdCount.pending.Done()
}
// newListenContext creates a new listen context.
func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stack,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6only: v6only,
netProto: netProto,
}
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
return l
}
// cookieHash calculates the cookieHash for the given id, timestamp and nonce
// index. The hash is used to create and validate cookies.
func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
// Initialize block with fixed-size data: local ports and v.
var payload [8]byte
binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
binary.BigEndian.PutUint32(payload[4:], ts)
// Feed everything to the hasher.
l.hasherMu.Lock()
l.hasher.Reset()
l.hasher.Write(payload[:])
l.hasher.Write(l.nonce[nonceIndex][:])
io.WriteString(l.hasher, string(id.LocalAddress))
io.WriteString(l.hasher, string(id.RemoteAddress))
// Finalize the calculation of the hash and return the first 4 bytes.
h := make([]byte, 0, sha1.Size)
h = l.hasher.Sum(h)
l.hasherMu.Unlock()
return binary.BigEndian.Uint32(h[:])
}
// createCookie creates a SYN cookie for the given id and incoming sequence
// number.
func (l *listenContext) createCookie(id stack.TransportEndpointID,
seq seqnum.Value, data uint32) seqnum.Value {
ts := timeStamp()
v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
v += (l.cookieHash(id, ts, 1) + data) & hashMask
return seqnum.Value(v)
}
// isCookieValid checks if the supplied cookie is valid for the given id and
// sequence number. If it is, it also returns the data originally encoded in the
// cookie when createCookie was called.
func (l *listenContext) isCookieValid(id stack.TransportEndpointID,
cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
ts := timeStamp()
v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
cookieTS := v >> tsOffset
if ((ts - cookieTS) & tsMask) > maxTSDiff {
return 0, false
}
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
// 新建一个tcp端 这个tcp端与segment同属一个tcp连接 但属于不同阶段 用于写回远端
func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value,
irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
n := newEndpoint(l.stack, netProto, nil)
n.v6only = l.v6only
n.id = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
// Register new endpoint so that packets are routed to it.
// 在网络协议栈中去注册这个tcp端
if err := n.stack.RegisterTransportEndpoint(n.boundNICID,
n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
n.Close()
return nil, err
}
n.isRegistered = true
n.state = stateConnected
// Create sender and receiver.
// The receiver at least temporarily has a zero receive window scale,
// but the caller may change it (before starting the protocol loop).
n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
return n, nil
}
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// create new endpoint
irs := s.sequenceNumber
cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
log.Println("收到一个远端握手申请", irs, "标记cookie", cookie)
ep, err := l.createConnectedEndpoint(s, cookie, irs, opts)
if err != nil {
return nil, err
}
// 以下执行三次握手
// 构建handshake管理器
h, err := newHandshake(ep, l.rcvWnd)
if err != nil {
ep.Close()
return nil, err
}
// 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack
h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
ep.Close()
return nil, err
}
// 更新接收窗口扩张因子
return ep, nil
}
// 一旦侦听端点收到SYN段handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。
// 在TCP开始使用SYN cookie接受连接之前允许使用有限数量的这些goroutine。
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer decSynRcvdCount()
defer s.decRef()
_, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
return
}
// 到这里,三次握手已经完成,那么分发一个新的连接
//e.deliverAccepted(n)
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
switch s.flags {
case flagSyn: // syn报文处理
// 分析tcp选项
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts)
} else {
log.Println("暂时不处理")
}
// 返回一个syn+ack报文
case flagFin: // fin报文处理
// 三次握手最后一次 ack 报文
}
}
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck))
if synOpts.TS {
s.parsedOptions.TSVal = synOpts.TSVal
s.parsedOptions.TSEcr = synOpts.TSEcr
}
return synOpts
}
// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行负责处理连接请求
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
defer func() {
// TODO 后置处理
}()
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto)
// 初始化事件触发器 并添加事件
s := sleep.Sleeper{}
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
s.AddWaker(&e.notificationWaker, wakerForNotification)
for {
switch index, _ := s.Fetch(true); index { // Fetch(true) 阻塞获取
case wakerForNewSegment:
mayRequeue := true
// 接收和处理tcp报文
for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue()
if s == nil {
mayRequeue = false
break
}
e.handleListenSegment(ctx, s)
s.decRef()
}
// If the queue is not empty, make sure we'll wake up
// in the next iteration.
if mayRequeue && !e.segmentQueue.empty() { // 主协程又添加了新数据
e.newSegmentWaker.Assert() // 重新尝试获取数据
}
case wakerForNotification:
// TODO 触发其他事件
log.Println("其他事件?")
}
}
}

View File

@@ -1,364 +1,364 @@
package tcp
import (
"crypto/rand"
"fmt"
"log"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"sync"
"time"
)
const maxSegmentsPerWake = 100
type handshakeState int
const (
handshakeSynSent handshakeState = iota
handshakeSynRcvd
handshakeCompleted
)
// The following are used to set up sleepers.
const (
wakerForNotification = iota
wakerForNewSegment
wakerForResend
wakerForResolution
)
// handshake holds the state used during a TCP 3-way handshake.
// tcp三次握手时候使用的对象
type handshake struct {
ep *endpoint
// 握手的状态
state handshakeState
active bool
flags uint8
ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
// 初始序列号
iss seqnum.Value
// rcvWnd is the receive window, as defined in RFC 793.
// 接收窗口
rcvWnd seqnum.Size
// sndWnd is the send window, as defined in RFC 793.
// 发送窗口
sndWnd seqnum.Size
// mss is the maximum segment size received from the peer.
// 最大报文段大小
mss uint16
// sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer.
// 发送窗口扩展因子
sndWndScale int
// rcvWndScale is the receive window scale, as defined in RFC 1323.
// 接收窗口扩展因子
rcvWndScale int
}
const (
// Maximum space available for options.
// tcp选项的最大长度
maxOptionSize = 40
)
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) {
h := handshake{
ep: ep,
active: true, // 激活这个管理器
rcvWnd: rcvWnd, // 初始接收窗口
// TODO
}
if err := h.resetState(); err != nil {
return handshake{}, err
}
return h, nil
}
func (h *handshake) resetState() *tcpip.Error {
// 随机一个iss(对方将收到的序号) 防止黑客搞事
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
panic(err)
}
// 初始化状态为 SynSent
h.state = handshakeSynSent
log.Println("收到 syn 同步报文 设置tcp状态为 [sent]")
h.flags = flagSyn
h.ackNum = 0
h.mss = 0
h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
return nil
}
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
h.active = false
h.state = handshakeSynRcvd
log.Println("发送 syn|ack 确认报文 设置tcp状态为 [rcvd]")
h.flags = flagSyn | flagAck
h.iss = iss
h.ackNum = irs + 1 // NOTE ACK = synNum + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
}
func (h *handshake) resolveRoute() *tcpip.Error {
log.Printf("tcp resolveRoute")
// Set up the wakers.
s := sleep.Sleeper{}
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
defer s.Done()
// Initial action is to resolve route.
index := wakerForResolution
for {
log.Println(index)
switch index {
case wakerForResolution:
if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
// Either success (err == nil) or failure.
return err
}
// Resolution not completed. Keep trying...
case wakerForNotification:
// TODO
//n := h.ep.fetchNotifications()
//if n&notifyClose != 0 {
// h.ep.route.RemoveWaker(resolutionWaker)
// return tcpip.ErrAborted
//}
//if n&notifyDrain != 0 {
// close(h.ep.drainDone)
// <-h.ep.undrain
//}
}
// Wait for notification.
index, _ = s.Fetch(true)
}
}
// execute executes the TCP 3-way handshake.
// 执行tcp 3次握手客户端和服务端都是调用该函数来实现三次握手
/*
c flag s
| |
sync_sent|------sync---->|sync_rcvd
| |
| |
established|<--sync|ack----|
| |
| |
|------ack----->|established
*/
func (h *handshake) execute() *tcpip.Error {
// 是否需要拿到下一条地址
if h.ep.route.IsResolutionRequired() {
if err := h.resolveRoute(); err != nil {
return err
}
}
// Initialize the resend timer.
// 初始化重传定时器
resendWaker := sleep.Waker{}
// 设置1s超时
timeOut := time.Duration(time.Second)
rt := time.AfterFunc(timeOut, func() {
resendWaker.Assert()
})
defer rt.Stop()
// Set up the wakers.
s := sleep.Sleeper{}
s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done()
// sync报文的选项参数
synOpts := header.TCPSynOptions{}
// 如果是客户端发送 syn 报文,如果是服务端发送 syn+ack 报文
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
for h.state != handshakeCompleted {
// 获取事件id
switch index, _ := s.Fetch(true); index {
case wakerForResend: // NOTE tcp超时重传机制
// 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传
// 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传
// 超时时间变为上次的2倍
timeOut *= 2
if timeOut > 60*time.Second {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
// 重新发送syn报文
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
case wakerForNotification:
case wakerForNewSegment:
// 处理握手报文
}
}
return nil
}
var optionPool = sync.Pool{
New: func() interface{} {
return make([]byte, maxOptionSize)
},
}
// 减少资源浪费
func getOptions() []byte {
return optionPool.Get().([]byte)
}
func putOptions(options []byte) {
// Reslice to full capacity.
optionPool.Put(options[0:cap(options)])
}
// tcp选项的编码 将一个TCPSyncOptions编码到 []byte 中
func makeSynOptions(opts header.TCPSynOptions) []byte {
// Emulate linux option order. This is as follows:
//
// if md5: NOP NOP MD5SIG 18 md5sig(16)
// if mss: MSS 4 mss(2)
// if ts and sack_advertise:
// SACK 2 TIMESTAMP 2 timestamp(8)
// elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
// elif sack: NOP NOP SACK 2
// if wscale: NOP WINDOW 3 ws(1)
// if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
// [for each block] start_seq(4) end_seq(4)
// if fastopen_cookie:
// if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
// else: FASTOPEN (2 + len(cookie))
// cookie(variable) [padding to four bytes]
//
options := getOptions()
// Always encode the mss.
offset := header.EncodeMSSOption(uint32(opts.MSS), options)
// Special ordering is required here. If both TS and SACK are enabled,
// then the SACK option precedes TS, with no padding. If they are
// enabled individually, then we see padding before the option.
if opts.TS && opts.SACKPermitted {
offset += header.EncodeSACKPermittedOption(options[offset:])
offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
} else if opts.TS {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
} else if opts.SACKPermitted {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeSACKPermittedOption(options[offset:])
}
// Initialize the WS option.
if opts.WS >= 0 {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeWSOption(opts.WS, options[offset:])
}
// Padding to the end; note that this never apply unless we add a
// fastopen option, we always expect the offset to remain the same.
if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
panic("unexpected option encoding")
}
return options[:offset]
}
// 封装 sendTCP ,发送 syn 报文
func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte,
seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
if opts.MSS == 0 {
opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
}
options := makeSynOptions(opts)
err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
return err
}
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
// 发送一个tcp段数据封装 tcp 首部,并写入网路层
func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte,
seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
log.Println("进行一个报文的发送")
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
if rcvWnd > 0xffff {
rcvWnd = 0xffff
}
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{
SrcPort: id.LocalPort,
DstPort: id.RemotePort,
SeqNum: uint32(seq),
AckNum: uint32(ack),
DataOffset: uint8(header.TCPMinimumSize + optLen),
Flags: flags,
WindowSize: uint16(rcvWnd),
})
copy(tcp[header.TCPMinimumSize:], opts)
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
length := uint16(hdr.UsedLength() + data.Size())
// tcp伪首部校验和的计算
xsum := r.PseudoHeaderChecksum(ProtocolNumber)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
}
// tcp的可靠性校验和的计算用于检测损伤的报文段
tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
}
r.Stats().TCP.SegmentsSent.Increment()
if (flags & flagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
log.Printf("send tcp %s segment to %s, seq: %d, ack: %d, rcvWnd: %d",
flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort),
seq, ack, rcvWnd)
return r.WritePacket(hdr, data, ProtocolNumber, ttl)
}
// protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行负责握手、发送段和处理收到的段
func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
for {
log.Println("三次握手机制在这里实现")
select {}
}
}
package tcp
import (
"crypto/rand"
"fmt"
"log"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"sync"
"time"
)
const maxSegmentsPerWake = 100
type handshakeState int
const (
handshakeSynSent handshakeState = iota
handshakeSynRcvd
handshakeCompleted
)
// The following are used to set up sleepers.
const (
wakerForNotification = iota
wakerForNewSegment
wakerForResend
wakerForResolution
)
// handshake holds the state used during a TCP 3-way handshake.
// tcp三次握手时候使用的对象
type handshake struct {
ep *endpoint
// 握手的状态
state handshakeState
active bool
flags uint8
ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
// 初始序列号
iss seqnum.Value
// rcvWnd is the receive window, as defined in RFC 793.
// 接收窗口
rcvWnd seqnum.Size
// sndWnd is the send window, as defined in RFC 793.
// 发送窗口
sndWnd seqnum.Size
// mss is the maximum segment size received from the peer.
// 最大报文段大小
mss uint16
// sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer.
// 发送窗口扩展因子
sndWndScale int
// rcvWndScale is the receive window scale, as defined in RFC 1323.
// 接收窗口扩展因子
rcvWndScale int
}
const (
// Maximum space available for options.
// tcp选项的最大长度
maxOptionSize = 40
)
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) {
h := handshake{
ep: ep,
active: true, // 激活这个管理器
rcvWnd: rcvWnd, // 初始接收窗口
// TODO
}
if err := h.resetState(); err != nil {
return handshake{}, err
}
return h, nil
}
func (h *handshake) resetState() *tcpip.Error {
// 随机一个iss(对方将收到的序号) 防止黑客搞事
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
panic(err)
}
// 初始化状态为 SynSent
h.state = handshakeSynSent
log.Println("收到 syn 同步报文 设置tcp状态为 [sent]")
h.flags = flagSyn
h.ackNum = 0
h.mss = 0
h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
return nil
}
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
h.active = false
h.state = handshakeSynRcvd
log.Println("发送 syn|ack 确认报文 设置tcp状态为 [rcvd]")
h.flags = flagSyn | flagAck
h.iss = iss
h.ackNum = irs + 1 // NOTE ACK = synNum + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
}
func (h *handshake) resolveRoute() *tcpip.Error {
log.Printf("tcp resolveRoute")
// Set up the wakers.
s := sleep.Sleeper{}
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
defer s.Done()
// Initial action is to resolve route.
index := wakerForResolution
for {
log.Println(index)
switch index {
case wakerForResolution:
if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
// Either success (err == nil) or failure.
return err
}
// Resolution not completed. Keep trying...
case wakerForNotification:
// TODO
//n := h.ep.fetchNotifications()
//if n&notifyClose != 0 {
// h.ep.route.RemoveWaker(resolutionWaker)
// return tcpip.ErrAborted
//}
//if n&notifyDrain != 0 {
// close(h.ep.drainDone)
// <-h.ep.undrain
//}
}
// Wait for notification.
index, _ = s.Fetch(true)
}
}
// execute executes the TCP 3-way handshake.
// 执行tcp 3次握手客户端和服务端都是调用该函数来实现三次握手
/*
c flag s
| |
sync_sent|------sync---->|sync_rcvd
| |
| |
established|<--sync|ack----|
| |
| |
|------ack----->|established
*/
func (h *handshake) execute() *tcpip.Error {
// 是否需要拿到下一条地址
if h.ep.route.IsResolutionRequired() {
if err := h.resolveRoute(); err != nil {
return err
}
}
// Initialize the resend timer.
// 初始化重传定时器
resendWaker := sleep.Waker{}
// 设置1s超时
timeOut := time.Duration(time.Second)
rt := time.AfterFunc(timeOut, func() {
resendWaker.Assert()
})
defer rt.Stop()
// Set up the wakers.
s := sleep.Sleeper{}
s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done()
// sync报文的选项参数
synOpts := header.TCPSynOptions{}
// 如果是客户端发送 syn 报文,如果是服务端发送 syn+ack 报文
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
for h.state != handshakeCompleted {
// 获取事件id
switch index, _ := s.Fetch(true); index {
case wakerForResend: // NOTE tcp超时重传机制
// 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传
// 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传
// 超时时间变为上次的2倍
timeOut *= 2
if timeOut > 60*time.Second {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
// 重新发送syn报文
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
case wakerForNotification:
case wakerForNewSegment:
// 处理握手报文
}
}
return nil
}
var optionPool = sync.Pool{
New: func() interface{} {
return make([]byte, maxOptionSize)
},
}
// 减少资源浪费
func getOptions() []byte {
return optionPool.Get().([]byte)
}
func putOptions(options []byte) {
// Reslice to full capacity.
optionPool.Put(options[0:cap(options)])
}
// tcp选项的编码 将一个TCPSyncOptions编码到 []byte 中
func makeSynOptions(opts header.TCPSynOptions) []byte {
// Emulate linux option order. This is as follows:
//
// if md5: NOP NOP MD5SIG 18 md5sig(16)
// if mss: MSS 4 mss(2)
// if ts and sack_advertise:
// SACK 2 TIMESTAMP 2 timestamp(8)
// elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
// elif sack: NOP NOP SACK 2
// if wscale: NOP WINDOW 3 ws(1)
// if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
// [for each block] start_seq(4) end_seq(4)
// if fastopen_cookie:
// if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
// else: FASTOPEN (2 + len(cookie))
// cookie(variable) [padding to four bytes]
//
options := getOptions()
// Always encode the mss.
offset := header.EncodeMSSOption(uint32(opts.MSS), options)
// Special ordering is required here. If both TS and SACK are enabled,
// then the SACK option precedes TS, with no padding. If they are
// enabled individually, then we see padding before the option.
if opts.TS && opts.SACKPermitted {
offset += header.EncodeSACKPermittedOption(options[offset:])
offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
} else if opts.TS {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
} else if opts.SACKPermitted {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeSACKPermittedOption(options[offset:])
}
// Initialize the WS option.
if opts.WS >= 0 {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeWSOption(opts.WS, options[offset:])
}
// Padding to the end; note that this never apply unless we add a
// fastopen option, we always expect the offset to remain the same.
if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
panic("unexpected option encoding")
}
return options[:offset]
}
// 封装 sendTCP ,发送 syn 报文
func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte,
seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
if opts.MSS == 0 {
opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
}
options := makeSynOptions(opts)
err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
return err
}
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
// 发送一个tcp段数据封装 tcp 首部,并写入网路层
func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte,
seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
log.Println("进行一个报文的发送")
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
if rcvWnd > 0xffff {
rcvWnd = 0xffff
}
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{
SrcPort: id.LocalPort,
DstPort: id.RemotePort,
SeqNum: uint32(seq),
AckNum: uint32(ack),
DataOffset: uint8(header.TCPMinimumSize + optLen),
Flags: flags,
WindowSize: uint16(rcvWnd),
})
copy(tcp[header.TCPMinimumSize:], opts)
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
length := uint16(hdr.UsedLength() + data.Size())
// tcp伪首部校验和的计算
xsum := r.PseudoHeaderChecksum(ProtocolNumber)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
}
// tcp的可靠性校验和的计算用于检测损伤的报文段
tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
}
r.Stats().TCP.SegmentsSent.Increment()
if (flags & flagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
log.Printf("send tcp %s segment to %s, seq: %d, ack: %d, rcvWnd: %d",
flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort),
seq, ack, rcvWnd)
return r.WritePacket(hdr, data, ProtocolNumber, ttl)
}
// protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行负责握手、发送段和处理收到的段
func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
for {
log.Println("三次握手机制在这里实现")
select {}
}
}

View File

@@ -1,403 +1,403 @@
package tcp
import (
"fmt"
"log"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"netstack/waiter"
"sync"
)
// tcp状态机的状态
type endpointState int
// tcp 状态机的各种状态
const (
stateInitial endpointState = iota
stateBound
stateListen
stateConnecting
stateConnected
stateClosed
stateError
)
// endpoint 表示TCP端点。该结构用作端点用户和协议实现之间的接口;让并发goroutine调用端点是合法的
// 它们是正确同步的。然而协议实现在单个goroutine中运行。
type endpoint struct {
stack *stack.Stack // 网络协议栈
netProto tcpip.NetworkProtocolNumber // 网络协议号 ipv4 ipv6
waiterQueue *waiter.Queue // 事件驱动机制
// TODO 需要添加
// rcvListMu can be taken after the endpoint mu below.
rcvListMu sync.Mutex
rcvList segmentList
rcvClosed bool
rcvBufSize int
rcvBufUsed int
// The following fields are protected by the mutex.
mu sync.RWMutex
id stack.TransportEndpointID // tcp端在网络协议栈的唯一ID
state endpointState // 目前tcp状态机的状态
isPortReserved bool // 是否已经分配端口
isRegistered bool // 是否已经注册在网络协议栈
boundNICID tcpip.NICID
route stack.Route // tcp端在网络协议栈中的路由地址
v6only bool // 是否仅仅支持ipv6
isConnectNotified bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
// sendTSOk is used to indicate when the TS Option has been negotiated.
// When sendTSOk is true every non-RST segment should carry a TS as per
// RFC7323#section-1.1
sendTSOk bool
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
recentTS uint32
// sackPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
sackPermitted bool
segmentQueue segmentQueue
// When the send side is closed, the protocol goroutine is notified via
// sndCloseWaker, and sndClosed is set to true.
sndBufMu sync.Mutex
sndBufSize int
sndBufUsed int
sndClosed bool
sndBufInQueue seqnum.Size
sndQueue segmentList
sndWaker sleep.Waker
sndCloseWaker sleep.Waker
// notificationWaker is used to indicate to the protocol goroutine that
// it needs to wake up and check for notifications.
notificationWaker sleep.Waker
// newSegmentWaker is used to indicate to the protocol goroutine that
// it needs to wake up and handle new segments queued to it.
// HandlePacket收到segment后通知处理的事件驱动器
newSegmentWaker sleep.Waker
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
acceptedChan chan *endpoint
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
rcv *receiver
snd *sender
// The following are only used to assist the restore run to re-connect.
bindAddress tcpip.Address
connectingAddress tcpip.Address
}
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
rcvBufSize: DefaultBufferSize,
sndBufSize: DefaultBufferSize,
}
// TODO 需要添加
e.segmentQueue.setLimit(2 * e.rcvBufSize)
return e
}
func (e *endpoint) Close() {
log.Println("TODO 在写了 在写了")
}
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
return nil, tcpip.ControlMessages{}, nil
}
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
return 0, nil, nil
}
func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only {
return 0, tcpip.ErrNoRoute
}
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
if addr.Addr == "\x00\x00\x00\x00" {
addr.Addr = ""
}
}
// Fail if we're bound to an address length different from the one we're
// checking.
if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
return netProto, nil
}
func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
return nil
}
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return nil
}
func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
log.Println("监听一个tcp端口")
e.mu.Lock()
defer e.mu.Unlock()
defer func() {
if err != nil && err.IgnoreStats() {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
}
}()
// TODO 需要添加
// 在调用 Listen 之前,必须先 Bind
if e.state != stateBound {
return tcpip.ErrInvalidEndpointState
}
// 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。
if err := e.stack.RegisterTransportEndpoint(e.boundNICID,
e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil {
return err
}
e.isRegistered = true
e.state = stateListen
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
e.workerRunning = true
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
// TODO tcp服务端实现的主循环这个函数很重要用一个goroutine来服务
go e.protocolListenLoop(seqnum.Size(0))
return nil
}
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
e.waiterQueue = waiterQueue
e.workerRunning = true
go e.protocolMainLoop(false)
}
func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
if e.state != stateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
var n *endpoint
select {
case n = <-e.acceptedChan:
default:
return nil, nil, tcpip.ErrWouldBlock
}
wq := &waiter.Queue{}
n.startAcceptedLoop(wq)
return n, wq, nil
}
// Bind binds the endpoint to a specific local port and optionally address.
// 将端点绑定到特定的本地端口和可选的地址。
func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。
if e.state != stateInitial {
return tcpip.ErrAlreadyBound
}
// 确定tcp端的绑定ip
e.bindAddress = addr.Addr
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
// 确定tcp支持的网络层协议
netProtos := []tcpip.NetworkProtocolNumber{netProto}
if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
}
}
// 绑定端口
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
if err != nil {
return err
}
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.id.LocalPort = port
defer func() {
// 如果有错,在退出的时候应该解除端口绑定
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.id.LocalPort = 0
e.id.LocalAddress = ""
e.boundNICID = 0
}
}()
// 如果指定了ip地址 需要检查一下这个ip地址本地是否绑定过
if len(addr.Addr) != 0 {
nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
return tcpip.ErrBadLocalAddress
}
e.boundNICID = nic
e.id.LocalAddress = addr.Addr
}
// Check the commit function.
if commit != nil {
if err := commit(); err != nil {
// The defer takes care of unwind.
return err
}
}
// 标记状态为 stateBound
e.state = stateBound
return nil
}
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
return tcpip.FullAddress{
Addr: e.id.LocalAddress,
Port: e.id.LocalPort,
NIC: e.boundNICID,
}, nil
}
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.state != stateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
return tcpip.FullAddress{
Addr: e.id.RemoteAddress,
Port: e.id.RemotePort,
NIC: e.boundNICID,
}, nil
}
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
return waiter.EventErr
}
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
s := newSegment(r, id, vv)
// 解析tcp段如果解析失败丢弃该报文
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
s.decRef()
return
}
e.stack.Stats().TCP.ValidSegmentsReceived.Increment() // 有效报文喜加一
if (s.flags & flagRst) != 0 { // RST报文需要拒绝
e.stack.Stats().TCP.ResetsReceived.Increment()
}
// Send packet to worker goroutine.
if e.segmentQueue.enqueue(s) {
log.Printf("收到 tcp [%s] 报文片段 from %s, seq: %d, ack: %d",
flagString(s.flags), fmt.Sprintf("%s:%d", s.id.RemoteAddress, s.id.RemotePort),
s.sequenceNumber, s.ackNumber)
e.newSegmentWaker.Assert()
} else {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
s.decRef()
}
}
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
// the SYN options indicate that timestamp option was negotiated. It also
// initializes the recentTS with the value provided in synOpts.TSval.
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
e.sendTSOk = true
e.recentTS = synOpts.TSVal
}
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
var v SACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return
}
if bool(v) && synOpts.SACKPermitted {
e.sackPermitted = true
}
}
package tcp
import (
"fmt"
"log"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"netstack/waiter"
"sync"
)
// tcp状态机的状态
type endpointState int
// tcp 状态机的各种状态
const (
stateInitial endpointState = iota
stateBound
stateListen
stateConnecting
stateConnected
stateClosed
stateError
)
// endpoint 表示TCP端点。该结构用作端点用户和协议实现之间的接口;让并发goroutine调用端点是合法的
// 它们是正确同步的。然而协议实现在单个goroutine中运行。
type endpoint struct {
stack *stack.Stack // 网络协议栈
netProto tcpip.NetworkProtocolNumber // 网络协议号 ipv4 ipv6
waiterQueue *waiter.Queue // 事件驱动机制
// TODO 需要添加
// rcvListMu can be taken after the endpoint mu below.
rcvListMu sync.Mutex
rcvList segmentList
rcvClosed bool
rcvBufSize int
rcvBufUsed int
// The following fields are protected by the mutex.
mu sync.RWMutex
id stack.TransportEndpointID // tcp端在网络协议栈的唯一ID
state endpointState // 目前tcp状态机的状态
isPortReserved bool // 是否已经分配端口
isRegistered bool // 是否已经注册在网络协议栈
boundNICID tcpip.NICID
route stack.Route // tcp端在网络协议栈中的路由地址
v6only bool // 是否仅仅支持ipv6
isConnectNotified bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
// sendTSOk is used to indicate when the TS Option has been negotiated.
// When sendTSOk is true every non-RST segment should carry a TS as per
// RFC7323#section-1.1
sendTSOk bool
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
recentTS uint32
// sackPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
sackPermitted bool
segmentQueue segmentQueue
// When the send side is closed, the protocol goroutine is notified via
// sndCloseWaker, and sndClosed is set to true.
sndBufMu sync.Mutex
sndBufSize int
sndBufUsed int
sndClosed bool
sndBufInQueue seqnum.Size
sndQueue segmentList
sndWaker sleep.Waker
sndCloseWaker sleep.Waker
// notificationWaker is used to indicate to the protocol goroutine that
// it needs to wake up and check for notifications.
notificationWaker sleep.Waker
// newSegmentWaker is used to indicate to the protocol goroutine that
// it needs to wake up and handle new segments queued to it.
// HandlePacket收到segment后通知处理的事件驱动器
newSegmentWaker sleep.Waker
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
acceptedChan chan *endpoint
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
rcv *receiver
snd *sender
// The following are only used to assist the restore run to re-connect.
bindAddress tcpip.Address
connectingAddress tcpip.Address
}
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
rcvBufSize: DefaultBufferSize,
sndBufSize: DefaultBufferSize,
}
// TODO 需要添加
e.segmentQueue.setLimit(2 * e.rcvBufSize)
return e
}
func (e *endpoint) Close() {
log.Println("TODO 在写了 在写了")
}
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
return nil, tcpip.ControlMessages{}, nil
}
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
return 0, nil, nil
}
func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only {
return 0, tcpip.ErrNoRoute
}
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
if addr.Addr == "\x00\x00\x00\x00" {
addr.Addr = ""
}
}
// Fail if we're bound to an address length different from the one we're
// checking.
if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
return netProto, nil
}
func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
return nil
}
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return nil
}
func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
log.Println("监听一个tcp端口")
e.mu.Lock()
defer e.mu.Unlock()
defer func() {
if err != nil && err.IgnoreStats() {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
}
}()
// TODO 需要添加
// 在调用 Listen 之前,必须先 Bind
if e.state != stateBound {
return tcpip.ErrInvalidEndpointState
}
// 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。
if err := e.stack.RegisterTransportEndpoint(e.boundNICID,
e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil {
return err
}
e.isRegistered = true
e.state = stateListen
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
e.workerRunning = true
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
// TODO tcp服务端实现的主循环这个函数很重要用一个goroutine来服务
go e.protocolListenLoop(seqnum.Size(0))
return nil
}
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
e.waiterQueue = waiterQueue
e.workerRunning = true
go e.protocolMainLoop(false)
}
func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
if e.state != stateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
var n *endpoint
select {
case n = <-e.acceptedChan:
default:
return nil, nil, tcpip.ErrWouldBlock
}
wq := &waiter.Queue{}
n.startAcceptedLoop(wq)
return n, wq, nil
}
// Bind binds the endpoint to a specific local port and optionally address.
// 将端点绑定到特定的本地端口和可选的地址。
func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。
if e.state != stateInitial {
return tcpip.ErrAlreadyBound
}
// 确定tcp端的绑定ip
e.bindAddress = addr.Addr
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
// 确定tcp支持的网络层协议
netProtos := []tcpip.NetworkProtocolNumber{netProto}
if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
}
}
// 绑定端口
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
if err != nil {
return err
}
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.id.LocalPort = port
defer func() {
// 如果有错,在退出的时候应该解除端口绑定
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.id.LocalPort = 0
e.id.LocalAddress = ""
e.boundNICID = 0
}
}()
// 如果指定了ip地址 需要检查一下这个ip地址本地是否绑定过
if len(addr.Addr) != 0 {
nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
return tcpip.ErrBadLocalAddress
}
e.boundNICID = nic
e.id.LocalAddress = addr.Addr
}
// Check the commit function.
if commit != nil {
if err := commit(); err != nil {
// The defer takes care of unwind.
return err
}
}
// 标记状态为 stateBound
e.state = stateBound
return nil
}
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
return tcpip.FullAddress{
Addr: e.id.LocalAddress,
Port: e.id.LocalPort,
NIC: e.boundNICID,
}, nil
}
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.state != stateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
return tcpip.FullAddress{
Addr: e.id.RemoteAddress,
Port: e.id.RemotePort,
NIC: e.boundNICID,
}, nil
}
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
return waiter.EventErr
}
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
s := newSegment(r, id, vv)
// 解析tcp段如果解析失败丢弃该报文
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
s.decRef()
return
}
e.stack.Stats().TCP.ValidSegmentsReceived.Increment() // 有效报文喜加一
if (s.flags & flagRst) != 0 { // RST报文需要拒绝
e.stack.Stats().TCP.ResetsReceived.Increment()
}
// Send packet to worker goroutine.
if e.segmentQueue.enqueue(s) {
log.Printf("收到 tcp [%s] 报文片段 from %s, seq: %d, ack: %d",
flagString(s.flags), fmt.Sprintf("%s:%d", s.id.RemoteAddress, s.id.RemotePort),
s.sequenceNumber, s.ackNumber)
e.newSegmentWaker.Assert()
} else {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
s.decRef()
}
}
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
// the SYN options indicate that timestamp option was negotiated. It also
// initializes the recentTS with the value provided in synOpts.TSval.
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
e.sendTSOk = true
e.recentTS = synOpts.TSVal
}
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
var v SACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return
}
if bool(v) && synOpts.SACKPermitted {
e.sackPermitted = true
}
}

View File

@@ -1,73 +1,73 @@
package tcp
import (
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/stack"
"netstack/waiter"
)
const (
// ProtocolName is the string representation of the tcp protocol name.
ProtocolName = "tcp"
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
// MinBufferSize is the smallest size of a receive or send buffer.
minBufferSize = 4 << 10 // 4096 bytes.
// DefaultBufferSize is the default size of the receive and send buffers.
DefaultBufferSize = 1 << 20 // 1MB
// MaxBufferSize is the largest size a receive and send buffer can grow to.
maxBufferSize = 4 << 20 // 4MB
)
// SACKEnabled option can be used to enable SACK support in the TCP
// protocol. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
type protocol struct{}
// Number returns the tcp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
return ProtocolNumber
}
// NewEndpoint creates a new tcp endpoint.
func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
// ParsePorts returns the source and destination ports stored in the given tcp
// packet.
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
h := header.TCP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
// MinimumPacketSize returns the minimum valid tcp packet size.
func (*protocol) MinimumPacketSize() int {
return header.TCPMinimumSize
}
func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
return false
}
// SetOption implements TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return nil
}
// Option implements TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return nil
}
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{}
})
}
package tcp
import (
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/stack"
"netstack/waiter"
)
const (
// ProtocolName is the string representation of the tcp protocol name.
ProtocolName = "tcp"
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
// MinBufferSize is the smallest size of a receive or send buffer.
minBufferSize = 4 << 10 // 4096 bytes.
// DefaultBufferSize is the default size of the receive and send buffers.
DefaultBufferSize = 1 << 20 // 1MB
// MaxBufferSize is the largest size a receive and send buffer can grow to.
maxBufferSize = 4 << 20 // 4MB
)
// SACKEnabled option can be used to enable SACK support in the TCP
// protocol. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
type protocol struct{}
// Number returns the tcp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
return ProtocolNumber
}
// NewEndpoint creates a new tcp endpoint.
func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
// ParsePorts returns the source and destination ports stored in the given tcp
// packet.
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
h := header.TCP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
// MinimumPacketSize returns the minimum valid tcp packet size.
func (*protocol) MinimumPacketSize() int {
return header.TCPMinimumSize
}
func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
return false
}
// SetOption implements TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return nil
}
// Option implements TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return nil
}
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{}
})
}

View File

@@ -1,11 +1,11 @@
package tcp
import "netstack/tcpip/seqnum"
type receiver struct{}
// 新建并初始化接收器
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
r := &receiver{}
return r
}
package tcp
import "netstack/tcpip/seqnum"
type receiver struct{}
// 新建并初始化接收器
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
r := &receiver{}
return r
}

View File

@@ -1,135 +1,135 @@
package tcp
import (
"fmt"
"log"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"strings"
"sync/atomic"
)
// tcp 太复杂了 专门写一个协议解析器 segment 是有种类之分的
// Flags that may be set in a TCP segment.
const (
flagFin = 1 << iota
flagSyn
flagRst
flagPsh
flagAck
flagUrg
)
func flagString(flags uint8) string {
var s []string
if (flags & flagAck) != 0 {
s = append(s, "ack")
}
if (flags & flagFin) != 0 {
s = append(s, "fin")
}
if (flags & flagPsh) != 0 {
s = append(s, "psh")
}
if (flags & flagRst) != 0 {
s = append(s, "rst")
}
if (flags & flagSyn) != 0 {
s = append(s, "syn")
}
if (flags & flagUrg) != 0 {
s = append(s, "urg")
}
return strings.Join(s, "|")
}
// segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中
type segment struct {
segmentEntry
refCnt int32 // 引用计数
id stack.TransportEndpointID
route stack.Route
data buffer.VectorisedView
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
// TODO 需要解析
viewToDeliver int
sequenceNumber seqnum.Value // tcp序号 第一个字节在整个报文的位置
ackNumber seqnum.Value // 确认号 希望继续获取的下一个字节序号
flags uint8
window seqnum.Size
// parsedOptions stores the parsed values from the options in the segment.
parsedOptions header.TCPOptions
options []byte
}
func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
s := &segment{refCnt: 1, id: id, route: r.Clone()}
s.data = vv.Clone(s.views[:])
return s
}
func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1]) // TODO 为什么只复制1?
return s
}
func (s *segment) clone() *segment {
t := &segment{
refCnt: 1,
id: s.id,
sequenceNumber: s.sequenceNumber,
ackNumber: s.ackNumber,
flags: s.flags,
window: s.window,
route: s.route.Clone(),
viewToDeliver: s.viewToDeliver,
}
t.data = s.data.Clone(t.views[:])
return t
}
func (s *segment) flagIsSet(flag uint8) bool {
return (s.flags & flag) != 0
}
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
s.route.Release()
}
}
func (s *segment) incRef() {
atomic.AddInt32(&s.refCnt, 1)
}
func (s *segment) parse() bool {
h := header.TCP(s.data.First())
offset := int(h.DataOffset())
if offset < header.TCPMinimumSize || offset > len(h) {
return false
}
s.options = h.Options()
s.parsedOptions = header.ParseTCPOptions(s.options)
log.Println(h)
fmt.Println(s.parsedOptions)
s.data.TrimFront(offset)
s.sequenceNumber = seqnum.Value(h.SequenceNumber())
s.ackNumber = seqnum.Value(h.AckNumber())
s.flags = h.Flags() // U|A|P|R|S|F
s.window = seqnum.Size(h.WindowSize())
return true
}
package tcp
import (
"fmt"
"log"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/seqnum"
"netstack/tcpip/stack"
"strings"
"sync/atomic"
)
// tcp 太复杂了 专门写一个协议解析器 segment 是有种类之分的
// Flags that may be set in a TCP segment.
const (
flagFin = 1 << iota
flagSyn
flagRst
flagPsh
flagAck
flagUrg
)
func flagString(flags uint8) string {
var s []string
if (flags & flagAck) != 0 {
s = append(s, "ack")
}
if (flags & flagFin) != 0 {
s = append(s, "fin")
}
if (flags & flagPsh) != 0 {
s = append(s, "psh")
}
if (flags & flagRst) != 0 {
s = append(s, "rst")
}
if (flags & flagSyn) != 0 {
s = append(s, "syn")
}
if (flags & flagUrg) != 0 {
s = append(s, "urg")
}
return strings.Join(s, "|")
}
// segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中
type segment struct {
segmentEntry
refCnt int32 // 引用计数
id stack.TransportEndpointID
route stack.Route
data buffer.VectorisedView
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
// TODO 需要解析
viewToDeliver int
sequenceNumber seqnum.Value // tcp序号 第一个字节在整个报文的位置
ackNumber seqnum.Value // 确认号 希望继续获取的下一个字节序号
flags uint8
window seqnum.Size
// parsedOptions stores the parsed values from the options in the segment.
parsedOptions header.TCPOptions
options []byte
}
func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
s := &segment{refCnt: 1, id: id, route: r.Clone()}
s.data = vv.Clone(s.views[:])
return s
}
func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1]) // TODO 为什么只复制1?
return s
}
func (s *segment) clone() *segment {
t := &segment{
refCnt: 1,
id: s.id,
sequenceNumber: s.sequenceNumber,
ackNumber: s.ackNumber,
flags: s.flags,
window: s.window,
route: s.route.Clone(),
viewToDeliver: s.viewToDeliver,
}
t.data = s.data.Clone(t.views[:])
return t
}
func (s *segment) flagIsSet(flag uint8) bool {
return (s.flags & flag) != 0
}
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
s.route.Release()
}
}
func (s *segment) incRef() {
atomic.AddInt32(&s.refCnt, 1)
}
func (s *segment) parse() bool {
h := header.TCP(s.data.First())
offset := int(h.DataOffset())
if offset < header.TCPMinimumSize || offset > len(h) {
return false
}
s.options = h.Options()
s.parsedOptions = header.ParseTCPOptions(s.options)
log.Println(h)
fmt.Println(s.parsedOptions)
s.data.TrimFront(offset)
s.sequenceNumber = seqnum.Value(h.SequenceNumber())
s.ackNumber = seqnum.Value(h.AckNumber())
s.flags = h.Flags() // U|A|P|R|S|F
s.window = seqnum.Size(h.WindowSize())
return true
}

View File

@@ -1,50 +1,50 @@
package tcp
import (
"netstack/tcpip/header"
"sync"
)
type segmentQueue struct {
mu sync.Mutex
list segmentList // 队列实现
limit int // 队列容量
used int // 队列长度
}
func (q *segmentQueue) empty() bool {
q.mu.Lock()
r := q.used == 0
q.mu.Unlock()
return r
}
func (q *segmentQueue) enqueue(s *segment) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
q.list.PushBack(s)
q.used += s.data.Size() + header.TCPMinimumSize
}
q.mu.Unlock()
return r
}
func (q *segmentQueue) dequeue() *segment {
q.mu.Lock()
s := q.list.Front()
if s != nil {
q.list.Remove(s)
q.used -= s.data.Size() + header.TCPMinimumSize
}
q.mu.Unlock()
return s
}
func (q *segmentQueue) setLimit(limit int) {
q.mu.Lock()
q.limit = limit
q.mu.Unlock()
}
package tcp
import (
"netstack/tcpip/header"
"sync"
)
type segmentQueue struct {
mu sync.Mutex
list segmentList // 队列实现
limit int // 队列容量
used int // 队列长度
}
func (q *segmentQueue) empty() bool {
q.mu.Lock()
r := q.used == 0
q.mu.Unlock()
return r
}
func (q *segmentQueue) enqueue(s *segment) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
q.list.PushBack(s)
q.used += s.data.Size() + header.TCPMinimumSize
}
q.mu.Unlock()
return r
}
func (q *segmentQueue) dequeue() *segment {
q.mu.Lock()
s := q.list.Front()
if s != nil {
q.list.Remove(s)
q.used -= s.data.Size() + header.TCPMinimumSize
}
q.mu.Unlock()
return s
}
func (q *segmentQueue) setLimit(limit int) {
q.mu.Lock()
q.limit = limit
q.mu.Unlock()
}

View File

@@ -1,12 +1,12 @@
package tcp
import "netstack/tcpip/seqnum"
type sender struct {
}
// 新建并初始化发送器
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
s := &sender{}
return s
}
package tcp
import "netstack/tcpip/seqnum"
type sender struct {
}
// 新建并初始化发送器
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
s := &sender{}
return s
}

View File

@@ -1,173 +1,173 @@
package tcp
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type segmentElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type segmentList struct {
head *segment
tail *segment
}
// Reset resets list l to the empty state.
func (l *segmentList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
func (l *segmentList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
func (l *segmentList) Front() *segment {
return l.head
}
// Back returns the last element of list l or nil.
func (l *segmentList) Back() *segment {
return l.tail
}
// PushFront inserts the element e at the front of list l.
func (l *segmentList) PushFront(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(l.head)
segmentElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushBack inserts the element e at the back of list l.
func (l *segmentList) PushBack(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(nil)
segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
func (l *segmentList) PushBackList(m *segmentList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *segmentList) InsertAfter(b, e *segment) {
a := segmentElementMapper{}.linkerFor(b).Next()
segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
segmentElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
func (l *segmentList) InsertBefore(a, e *segment) {
b := segmentElementMapper{}.linkerFor(a).Prev()
segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
segmentElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
func (l *segmentList) Remove(e *segment) {
prev := segmentElementMapper{}.linkerFor(e).Prev()
next := segmentElementMapper{}.linkerFor(e).Next()
if prev != nil {
segmentElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
segmentElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type segmentEntry struct {
next *segment
prev *segment
}
// Next returns the entry that follows e in the list.
func (e *segmentEntry) Next() *segment {
return e.next
}
// Prev returns the entry that precedes e in the list.
func (e *segmentEntry) Prev() *segment {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
func (e *segmentEntry) SetNext(elem *segment) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *segmentEntry) SetPrev(elem *segment) {
e.prev = elem
}
package tcp
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type segmentElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type segmentList struct {
head *segment
tail *segment
}
// Reset resets list l to the empty state.
func (l *segmentList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
func (l *segmentList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
func (l *segmentList) Front() *segment {
return l.head
}
// Back returns the last element of list l or nil.
func (l *segmentList) Back() *segment {
return l.tail
}
// PushFront inserts the element e at the front of list l.
func (l *segmentList) PushFront(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(l.head)
segmentElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushBack inserts the element e at the back of list l.
func (l *segmentList) PushBack(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(nil)
segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
func (l *segmentList) PushBackList(m *segmentList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *segmentList) InsertAfter(b, e *segment) {
a := segmentElementMapper{}.linkerFor(b).Next()
segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
segmentElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
func (l *segmentList) InsertBefore(a, e *segment) {
b := segmentElementMapper{}.linkerFor(a).Prev()
segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
segmentElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
func (l *segmentList) Remove(e *segment) {
prev := segmentElementMapper{}.linkerFor(e).Prev()
next := segmentElementMapper{}.linkerFor(e).Next()
if prev != nil {
segmentElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
segmentElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type segmentEntry struct {
next *segment
prev *segment
}
// Next returns the entry that follows e in the list.
func (e *segmentEntry) Next() *segment {
return e.next
}
// Prev returns the entry that precedes e in the list.
func (e *segmentEntry) Prev() *segment {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
func (e *segmentEntry) SetNext(elem *segment) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *segmentEntry) SetPrev(elem *segment) {
e.prev = elem
}