mirror of
https://github.com/impact-eintr/netstack.git
synced 2025-10-07 13:50:58 +08:00
可以把协议栈读到的数据发给用户层应用了! 下一步把用户层应用的数据写给客户端
This commit is contained in:
@@ -142,14 +142,21 @@ func main() {
|
||||
|
||||
go func() { // echo server
|
||||
listener := tcpListen(s, proto, addr, localPort)
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
conn.Read(nil)
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println(string(buf))
|
||||
if string(buf) != "" {
|
||||
conn.Write([]byte("Server echo"))
|
||||
}
|
||||
os.Exit(1)
|
||||
|
||||
select {}
|
||||
}()
|
||||
|
||||
c := make(chan os.Signal)
|
||||
@@ -165,24 +172,72 @@ type TcpConn struct {
|
||||
notifyCh chan struct{}
|
||||
}
|
||||
|
||||
// Accept 封装tcp的accept操作
|
||||
func (conn *TcpConn) Accept() (tcpip.Endpoint, error) {
|
||||
func (conn *TcpConn) Read(rcv []byte) (int, error) {
|
||||
conn.wq.EventRegister(conn.we, waiter.EventIn)
|
||||
defer conn.wq.EventUnregister(conn.we)
|
||||
for {
|
||||
ep, _, err := conn.ep.Accept()
|
||||
buf, _, err := conn.ep.Read(&conn.raddr)
|
||||
if err != nil {
|
||||
if err == tcpip.ErrWouldBlock {
|
||||
<-conn.notifyCh
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("%s", err.String())
|
||||
return 0, fmt.Errorf("%s", err.String())
|
||||
}
|
||||
return ep, nil
|
||||
n := len(buf)
|
||||
if n > cap(rcv) {
|
||||
n = cap(rcv)
|
||||
}
|
||||
rcv = append(rcv[:0], buf[:n]...)
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn {
|
||||
func (conn *TcpConn) Write(snd []byte) error {
|
||||
for {
|
||||
_, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr})
|
||||
if err != nil {
|
||||
if err == tcpip.ErrNoLinkAddress {
|
||||
<-notifyCh
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("%s", err.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Listener tcp连接监听器
|
||||
type Listener struct {
|
||||
raddr tcpip.FullAddress
|
||||
ep tcpip.Endpoint
|
||||
wq *waiter.Queue
|
||||
we *waiter.Entry
|
||||
notifyCh chan struct{}
|
||||
}
|
||||
|
||||
// Accept 封装tcp的accept操作
|
||||
func (l *Listener) Accept() (*TcpConn, error) {
|
||||
l.wq.EventRegister(l.we, waiter.EventIn)
|
||||
defer l.wq.EventUnregister(l.we)
|
||||
for {
|
||||
ep, wq, err := l.ep.Accept()
|
||||
if err != nil {
|
||||
if err == tcpip.ErrWouldBlock {
|
||||
<-l.notifyCh
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("%s", err.String())
|
||||
}
|
||||
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
|
||||
return &TcpConn{ep: ep,
|
||||
wq: wq,
|
||||
we: &waitEntry,
|
||||
notifyCh: notifyCh}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *Listener {
|
||||
var wq waiter.Queue
|
||||
// 新建一个tcp端
|
||||
ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
|
||||
@@ -202,7 +257,7 @@ func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Add
|
||||
}
|
||||
|
||||
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
|
||||
return &TcpConn{
|
||||
return &Listener{
|
||||
ep: ep,
|
||||
wq: &wq,
|
||||
we: &waitEntry,
|
||||
|
@@ -17,7 +17,9 @@ func main() {
|
||||
log.Println("连接建立")
|
||||
conn.Write([]byte("helloworld"))
|
||||
log.Println("发送了数据")
|
||||
conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
conn.Read(buf)
|
||||
//conn.Close()
|
||||
}()
|
||||
|
||||
t := time.NewTimer(1000 * time.Millisecond)
|
||||
|
@@ -21,7 +21,8 @@ func (v Value) LessThanEq(w Value) bool {
|
||||
|
||||
// InRange v ∈ [a, b)
|
||||
func (v Value) InRange(a, b Value) bool {
|
||||
return a <= v && v < b
|
||||
//return a <= v && v < b
|
||||
return v-a < b-a
|
||||
}
|
||||
|
||||
// InWindows check v in [first, first+size)
|
||||
|
@@ -251,7 +251,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
|
||||
e.mu.RLock()
|
||||
if e.state == stateListen {
|
||||
e.acceptedChan <- n
|
||||
e.waiterQueue.Notify(waiter.EventIn)
|
||||
e.waiterQueue.Notify(waiter.EventIn) // 通知 Accept() 停止阻塞
|
||||
} else {
|
||||
n.Close()
|
||||
}
|
||||
|
@@ -215,7 +215,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error {
|
||||
if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
|
||||
h.sndWnd <<= uint8(h.sndWndScale)
|
||||
}
|
||||
log.Println(h.sndWnd)
|
||||
//log.Println(h.sndWnd)
|
||||
|
||||
switch h.state {
|
||||
case handshakeSynRcvd:
|
||||
@@ -311,8 +311,7 @@ func (h *handshake) execute() *tcpip.Error {
|
||||
}
|
||||
rt.Reset(timeOut)
|
||||
// 重新发送syn|ack报文
|
||||
//sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
|
||||
log.Println("超时重发了 xdm")
|
||||
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
|
||||
case wakerForNotification:
|
||||
|
||||
case wakerForNewSegment:
|
||||
@@ -469,7 +468,7 @@ func (e *endpoint) handleClose() *tcpip.Error {
|
||||
|
||||
// handleSegments 从队列中取出 tcp 段数据,然后处理它们。
|
||||
func (e *endpoint) handleSegments() *tcpip.Error {
|
||||
log.Println("年轻人的第一条数据")
|
||||
//log.Println("年轻人的第一条数据")
|
||||
checkRequeue := true
|
||||
for i := 0; i < maxSegmentsPerWake; i++ {
|
||||
s := e.segmentQueue.dequeue()
|
||||
|
@@ -66,6 +66,8 @@ type endpoint struct {
|
||||
// address).
|
||||
effectiveNetProtos []tcpip.NetworkProtocolNumber
|
||||
|
||||
hardError *tcpip.Error
|
||||
|
||||
// workerRunning specifies if a worker goroutine is running.
|
||||
workerRunning bool
|
||||
|
||||
@@ -139,8 +141,49 @@ func (e *endpoint) Close() {
|
||||
log.Println("TODO 在写了 在写了")
|
||||
}
|
||||
|
||||
// Read 从tcp的接收队列中读取数据
|
||||
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
||||
return nil, tcpip.ControlMessages{}, nil
|
||||
e.mu.RLock()
|
||||
|
||||
e.rcvListMu.Lock()
|
||||
bufUsed := e.rcvBufUsed
|
||||
if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
|
||||
e.rcvListMu.Unlock()
|
||||
he := e.hardError
|
||||
e.mu.RUnlock()
|
||||
if s == stateError {
|
||||
return buffer.View{}, tcpip.ControlMessages{}, he
|
||||
}
|
||||
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
|
||||
v, err := e.readLocked()
|
||||
e.rcvListMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
return v, tcpip.ControlMessages{}, err
|
||||
}
|
||||
|
||||
// 从tcp的接收队列中读取数据,并从接收队列中删除已读数据
|
||||
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
|
||||
if e.rcvBufUsed == 0 {
|
||||
if e.rcvClosed || e.state != stateConnected {
|
||||
return buffer.View{}, tcpip.ErrClosedForReceive
|
||||
}
|
||||
return buffer.View{}, tcpip.ErrWouldBlock
|
||||
}
|
||||
s := e.rcvList.Front()
|
||||
views := s.data.Views()
|
||||
v := views[s.viewToDeliver]
|
||||
s.viewToDeliver++
|
||||
|
||||
if s.viewToDeliver >= len(views) {
|
||||
e.rcvList.Remove(s)
|
||||
s.decRef()
|
||||
}
|
||||
log.Println("读到了数据", views, v)
|
||||
// TODO 流量检测
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
|
||||
@@ -175,9 +218,118 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
|
||||
return netProto, nil
|
||||
}
|
||||
|
||||
// Connect 这是客户端用的吧
|
||||
func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
|
||||
return e.connect(address, true, true)
|
||||
}
|
||||
|
||||
// connect将端点连接到其对等端。在正常的非S/R情况下,新连接应该运行主goroutine并执行握手。
|
||||
// 在恢复先前连接的端点时,将被动地创建两端(因此不会进行新的握手);对于应用程序尚未接受的堆栈接受连接,
|
||||
// 它们将在不运行主goroutine的情况下进行恢复。
|
||||
func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
defer func() {
|
||||
if err != nil && !err.IgnoreStats() {
|
||||
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
|
||||
}
|
||||
}()
|
||||
|
||||
connectingAddr := addr.Addr
|
||||
|
||||
// 检查ipv4是否映射到ipv6
|
||||
netProto, err := e.checkV4Mapped(&addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nicid := addr.NIC
|
||||
// 判断连接的状态
|
||||
switch e.state {
|
||||
case stateBound:
|
||||
// If we're already bound to a NIC but the caller is requesting
|
||||
// that we use a different one now, we cannot proceed.
|
||||
if e.boundNICID == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if nicid != 0 && nicid != e.boundNICID {
|
||||
return tcpip.ErrNoRoute
|
||||
}
|
||||
|
||||
nicid = e.boundNICID
|
||||
|
||||
case stateInitial:
|
||||
// Nothing to do. We'll eventually fill-in the gaps in the ID
|
||||
// (if any) when we find a route.
|
||||
|
||||
case stateConnecting:
|
||||
// A connection request has already been issued but hasn't
|
||||
// completed yet.
|
||||
return tcpip.ErrAlreadyConnecting
|
||||
|
||||
case stateConnected:
|
||||
// The endpoint is already connected. If caller hasn't been notified yet, return success.
|
||||
if !e.isConnectNotified {
|
||||
e.isConnectNotified = true
|
||||
return nil
|
||||
}
|
||||
// Otherwise return that it's already connected.
|
||||
return tcpip.ErrAlreadyConnected
|
||||
|
||||
case stateError:
|
||||
return e.hardError
|
||||
|
||||
default:
|
||||
return tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
|
||||
// Find a route to the desired destination.
|
||||
// 根据目标ip查找路由信息
|
||||
r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Release()
|
||||
|
||||
origID := e.id
|
||||
|
||||
netProtos := []tcpip.NetworkProtocolNumber{netProto}
|
||||
e.id.LocalAddress = r.LocalAddress
|
||||
e.id.RemoteAddress = r.RemoteAddress
|
||||
e.id.RemotePort = addr.Port
|
||||
|
||||
if e.id.LocalPort != 0 {
|
||||
// 记录和检查原端口是否已被使用
|
||||
// The endpoint is bound to a port, attempt to register it.
|
||||
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// TODO 需要添加
|
||||
}
|
||||
|
||||
// Remove the port reservation. This can happen when Bind is called
|
||||
// before Connect: in such a case we don't want to hold on to
|
||||
// reservations anymore.
|
||||
if e.isPortReserved {
|
||||
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
|
||||
e.isPortReserved = false
|
||||
}
|
||||
|
||||
// 记录该端点的参数
|
||||
e.isRegistered = true
|
||||
e.state = stateConnecting
|
||||
e.route = r.Clone()
|
||||
e.boundNICID = nicid
|
||||
e.effectiveNetProtos = netProtos
|
||||
e.connectingAddress = connectingAddr
|
||||
|
||||
// TODO 需要添加
|
||||
|
||||
return tcpip.ErrConnectStarted
|
||||
}
|
||||
|
||||
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
|
||||
return nil
|
||||
@@ -238,7 +390,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
|
||||
|
||||
var n *endpoint
|
||||
select {
|
||||
case n = <-e.acceptedChan:
|
||||
case n = <-e.acceptedChan: // 外部再次调用后尝试取出ep
|
||||
log.Println("监听者进行一个新连接的分发", n.id)
|
||||
default:
|
||||
return nil, nil, tcpip.ErrWouldBlock
|
||||
@@ -343,7 +495,48 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
|
||||
}
|
||||
|
||||
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
|
||||
return waiter.EventErr
|
||||
result := waiter.EventMask(0)
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
switch e.state {
|
||||
case stateInitial, stateBound, stateConnecting:
|
||||
// Ready for nothing.
|
||||
|
||||
case stateClosed, stateError:
|
||||
// Ready for anything.
|
||||
result = mask
|
||||
|
||||
case stateListen:
|
||||
// Check if there's anything in the accepted channel.
|
||||
if (mask & waiter.EventIn) != 0 {
|
||||
if len(e.acceptedChan) > 0 {
|
||||
result |= waiter.EventIn
|
||||
}
|
||||
}
|
||||
|
||||
case stateConnected:
|
||||
// Determine if the endpoint is writable if requested.
|
||||
if (mask & waiter.EventOut) != 0 {
|
||||
e.sndBufMu.Lock()
|
||||
if e.sndClosed || e.sndBufUsed < e.sndBufSize {
|
||||
result |= waiter.EventOut
|
||||
}
|
||||
e.sndBufMu.Unlock()
|
||||
}
|
||||
|
||||
// Determine if the endpoint is readable if requested.
|
||||
if (mask & waiter.EventIn) != 0 {
|
||||
e.rcvListMu.Lock()
|
||||
if e.rcvBufUsed > 0 || e.rcvClosed {
|
||||
result |= waiter.EventIn
|
||||
}
|
||||
e.rcvListMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
|
||||
@@ -385,6 +578,20 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
|
||||
|
||||
}
|
||||
|
||||
func (e *endpoint) readyToRead(s *segment) {
|
||||
e.rcvListMu.Lock()
|
||||
if s != nil {
|
||||
s.incRef()
|
||||
e.rcvBufUsed += s.data.Size()
|
||||
e.rcvList.PushBack(s)
|
||||
} else {
|
||||
e.rcvClosed = true
|
||||
}
|
||||
e.rcvListMu.Unlock()
|
||||
|
||||
e.waiterQueue.Notify(waiter.EventIn)
|
||||
}
|
||||
|
||||
// receiveBufferAvailable calculates how many bytes are still available in the
|
||||
// receive buffer.
|
||||
// tcp流量控制:计算未被占用的接收缓存大小
|
||||
|
@@ -5,18 +5,81 @@ import (
|
||||
"netstack/tcpip/seqnum"
|
||||
)
|
||||
|
||||
type receiver struct{}
|
||||
type receiver struct {
|
||||
ep *endpoint
|
||||
rcvNxt seqnum.Value // 准备接收的下一个报文序列号
|
||||
closed bool
|
||||
}
|
||||
|
||||
// 新建并初始化接收器
|
||||
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
|
||||
r := &receiver{}
|
||||
r := &receiver{
|
||||
ep: ep,
|
||||
rcvNxt: irs + 1,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// tcp流量控制:判断 segSeq 在窗口內
|
||||
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
|
||||
// TODO 流量控制
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
|
||||
if segLen > 0 {
|
||||
// 我们期望接收到的序列号范围应该是 seqStart <= rcvNxt < seqEnd,
|
||||
// 如果不在这个范围内说明我们少了数据段,返回false,表示不能立马消费
|
||||
if !r.rcvNxt.InWindows(segSeq, segLen) {
|
||||
return false
|
||||
}
|
||||
// 尝试去除已经确认过的数据
|
||||
if segSeq.LessThan(r.rcvNxt) {
|
||||
log.Println("收到重复数据")
|
||||
diff := segSeq.Size(r.rcvNxt)
|
||||
segLen -= diff
|
||||
segSeq.UpdateForward(diff)
|
||||
s.sequenceNumber.UpdateForward(diff)
|
||||
s.data.TrimFront(int(diff))
|
||||
}
|
||||
// 将tcp段插入接收链表,并通知应用层用数据来了
|
||||
r.ep.readyToRead(s)
|
||||
} else if segSeq != r.rcvNxt { // 空数据 还是非顺序到达 丢弃
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果收到 fin 报文
|
||||
if s.flagIsSet(flagFin) {
|
||||
// TODO 处理fin报文
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleRcvdSegment handles TCP segments directed at the connection managed by
|
||||
// r as they arrive. It is called by the protocol main loop.
|
||||
// 从 handleSegments 接收到tcp段,然后进行处理消费,所谓的消费就是将负载内容插入到接收队列中
|
||||
func (r *receiver) handleRcvdSegment(s *segment) {
|
||||
log.Println(s.data)
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
segLen := seqnum.Size(s.data.Size())
|
||||
segSeq := s.sequenceNumber
|
||||
|
||||
// TODO tcp流量控制
|
||||
// tcp流量控制:判断该数据段的序列号是否在接收窗口内,如果不在,立即返回ack给对端。
|
||||
if !r.acceptable(segSeq, segLen) {
|
||||
r.ep.snd.sendAck()
|
||||
return
|
||||
}
|
||||
|
||||
log.Println(s.data, segLen, segSeq)
|
||||
|
||||
// Defer segment processing if it can't be consumed now.
|
||||
// tcp可靠性:r.consumeSegment 返回值是个bool类型,如果是true,表示已经消费该数据段,
|
||||
// 如果不是,那么进行下面的处理,插入到 pendingRcvdSegments,且进行堆排序
|
||||
if !r.consumeSegment(s, segSeq, segLen) {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -1,6 +1,9 @@
|
||||
package tcp
|
||||
|
||||
import "netstack/tcpip/seqnum"
|
||||
import (
|
||||
"log"
|
||||
"netstack/tcpip/seqnum"
|
||||
)
|
||||
|
||||
type sender struct {
|
||||
}
|
||||
@@ -10,3 +13,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
|
||||
s := &sender{}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *sender) sendAck() {
|
||||
log.Fatal("TODO 需要发送一个ack")
|
||||
}
|
||||
|
Reference in New Issue
Block a user