可以把协议栈读到的数据发给用户层应用了! 下一步把用户层应用的数据写给客户端

This commit is contained in:
impact-eintr
2022-12-08 18:39:11 +08:00
parent c025408041
commit 9a46ec9db5
8 changed files with 363 additions and 29 deletions

View File

@@ -142,14 +142,21 @@ func main() {
go func() { // echo server
listener := tcpListen(s, proto, addr, localPort)
for {
conn, err := listener.Accept()
if err != nil {
log.Println(err)
continue
}
conn.Read(nil)
buf := make([]byte, 1024)
if _, err := conn.Read(buf); err != nil {
log.Fatal(err)
}
fmt.Println(string(buf))
if string(buf) != "" {
conn.Write([]byte("Server echo"))
}
os.Exit(1)
select {}
}()
c := make(chan os.Signal)
@@ -165,24 +172,72 @@ type TcpConn struct {
notifyCh chan struct{}
}
// Accept 封装tcp的accept操作
func (conn *TcpConn) Accept() (tcpip.Endpoint, error) {
func (conn *TcpConn) Read(rcv []byte) (int, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
ep, _, err := conn.ep.Accept()
buf, _, err := conn.ep.Read(&conn.raddr)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return nil, fmt.Errorf("%s", err.String())
return 0, fmt.Errorf("%s", err.String())
}
return ep, nil
n := len(buf)
if n > cap(rcv) {
n = cap(rcv)
}
rcv = append(rcv[:0], buf[:n]...)
return n, nil
}
}
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn {
func (conn *TcpConn) Write(snd []byte) error {
for {
_, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr})
if err != nil {
if err == tcpip.ErrNoLinkAddress {
<-notifyCh
continue
}
return fmt.Errorf("%s", err.String())
}
return nil
}
}
// Listener tcp连接监听器
type Listener struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
// Accept 封装tcp的accept操作
func (l *Listener) Accept() (*TcpConn, error) {
l.wq.EventRegister(l.we, waiter.EventIn)
defer l.wq.EventUnregister(l.we)
for {
ep, wq, err := l.ep.Accept()
if err != nil {
if err == tcpip.ErrWouldBlock {
<-l.notifyCh
continue
}
return nil, fmt.Errorf("%s", err.String())
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &TcpConn{ep: ep,
wq: wq,
we: &waitEntry,
notifyCh: notifyCh}, nil
}
}
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *Listener {
var wq waiter.Queue
// 新建一个tcp端
ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
@@ -202,7 +257,7 @@ func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Add
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &TcpConn{
return &Listener{
ep: ep,
wq: &wq,
we: &waitEntry,

View File

@@ -17,7 +17,9 @@ func main() {
log.Println("连接建立")
conn.Write([]byte("helloworld"))
log.Println("发送了数据")
conn.Close()
buf := make([]byte, 1024)
conn.Read(buf)
//conn.Close()
}()
t := time.NewTimer(1000 * time.Millisecond)

View File

@@ -21,7 +21,8 @@ func (v Value) LessThanEq(w Value) bool {
// InRange v ∈ [a, b)
func (v Value) InRange(a, b Value) bool {
return a <= v && v < b
//return a <= v && v < b
return v-a < b-a
}
// InWindows check v in [first, first+size)

View File

@@ -251,7 +251,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.RLock()
if e.state == stateListen {
e.acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
e.waiterQueue.Notify(waiter.EventIn) // 通知 Accept() 停止阻塞
} else {
n.Close()
}

View File

@@ -215,7 +215,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error {
if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
h.sndWnd <<= uint8(h.sndWndScale)
}
log.Println(h.sndWnd)
//log.Println(h.sndWnd)
switch h.state {
case handshakeSynRcvd:
@@ -311,8 +311,7 @@ func (h *handshake) execute() *tcpip.Error {
}
rt.Reset(timeOut)
// 重新发送syn|ack报文
//sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
log.Println("超时重发了 xdm")
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
case wakerForNotification:
case wakerForNewSegment:
@@ -469,7 +468,7 @@ func (e *endpoint) handleClose() *tcpip.Error {
// handleSegments 从队列中取出 tcp 段数据,然后处理它们。
func (e *endpoint) handleSegments() *tcpip.Error {
log.Println("年轻人的第一条数据")
//log.Println("年轻人的第一条数据")
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue()

View File

@@ -66,6 +66,8 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
hardError *tcpip.Error
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -139,8 +141,49 @@ func (e *endpoint) Close() {
log.Println("TODO 在写了 在写了")
}
// Read 从tcp的接收队列中读取数据
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
return nil, tcpip.ControlMessages{}, nil
e.mu.RLock()
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.hardError
e.mu.RUnlock()
if s == stateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
v, err := e.readLocked()
e.rcvListMu.Unlock()
e.mu.RUnlock()
return v, tcpip.ControlMessages{}, err
}
// 从tcp的接收队列中读取数据并从接收队列中删除已读数据
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
if e.rcvClosed || e.state != stateConnected {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
}
s := e.rcvList.Front()
views := s.data.Views()
v := views[s.viewToDeliver]
s.viewToDeliver++
if s.viewToDeliver >= len(views) {
e.rcvList.Remove(s)
s.decRef()
}
log.Println("读到了数据", views, v)
// TODO 流量检测
return v, nil
}
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
@@ -175,9 +218,118 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
return netProto, nil
}
// Connect 这是客户端用的吧
func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
return e.connect(address, true, true)
}
// connect将端点连接到其对等端。在正常的非S/R情况下新连接应该运行主goroutine并执行握手。
// 在恢复先前连接的端点时,将被动地创建两端(因此不会进行新的握手);对于应用程序尚未接受的堆栈接受连接,
// 它们将在不运行主goroutine的情况下进行恢复。
func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
defer func() {
if err != nil && !err.IgnoreStats() {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
}
}()
connectingAddr := addr.Addr
// 检查ipv4是否映射到ipv6
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
nicid := addr.NIC
// 判断连接的状态
switch e.state {
case stateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
if e.boundNICID == 0 {
break
}
if nicid != 0 && nicid != e.boundNICID {
return tcpip.ErrNoRoute
}
nicid = e.boundNICID
case stateInitial:
// Nothing to do. We'll eventually fill-in the gaps in the ID
// (if any) when we find a route.
case stateConnecting:
// A connection request has already been issued but hasn't
// completed yet.
return tcpip.ErrAlreadyConnecting
case stateConnected:
// The endpoint is already connected. If caller hasn't been notified yet, return success.
if !e.isConnectNotified {
e.isConnectNotified = true
return nil
}
// Otherwise return that it's already connected.
return tcpip.ErrAlreadyConnected
case stateError:
return e.hardError
default:
return tcpip.ErrInvalidEndpointState
}
// Find a route to the desired destination.
// 根据目标ip查找路由信息
r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
if err != nil {
return err
}
defer r.Release()
origID := e.id
netProtos := []tcpip.NetworkProtocolNumber{netProto}
e.id.LocalAddress = r.LocalAddress
e.id.RemoteAddress = r.RemoteAddress
e.id.RemotePort = addr.Port
if e.id.LocalPort != 0 {
// 记录和检查原端口是否已被使用
// The endpoint is bound to a port, attempt to register it.
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
if err != nil {
return err
}
} else {
// TODO 需要添加
}
// Remove the port reservation. This can happen when Bind is called
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
e.isPortReserved = false
}
// 记录该端点的参数
e.isRegistered = true
e.state = stateConnecting
e.route = r.Clone()
e.boundNICID = nicid
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
// TODO 需要添加
return tcpip.ErrConnectStarted
}
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return nil
@@ -238,7 +390,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
var n *endpoint
select {
case n = <-e.acceptedChan:
case n = <-e.acceptedChan: // 外部再次调用后尝试取出ep
log.Println("监听者进行一个新连接的分发", n.id)
default:
return nil, nil, tcpip.ErrWouldBlock
@@ -343,7 +495,48 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
return waiter.EventErr
result := waiter.EventMask(0)
e.mu.RLock()
defer e.mu.RUnlock()
switch e.state {
case stateInitial, stateBound, stateConnecting:
// Ready for nothing.
case stateClosed, stateError:
// Ready for anything.
result = mask
case stateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
if len(e.acceptedChan) > 0 {
result |= waiter.EventIn
}
}
case stateConnected:
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
if e.sndClosed || e.sndBufUsed < e.sndBufSize {
result |= waiter.EventOut
}
e.sndBufMu.Unlock()
}
// Determine if the endpoint is readable if requested.
if (mask & waiter.EventIn) != 0 {
e.rcvListMu.Lock()
if e.rcvBufUsed > 0 || e.rcvClosed {
result |= waiter.EventIn
}
e.rcvListMu.Unlock()
}
}
return result
}
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
@@ -385,6 +578,20 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
}
func (e *endpoint) readyToRead(s *segment) {
e.rcvListMu.Lock()
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
e.waiterQueue.Notify(waiter.EventIn)
}
// receiveBufferAvailable calculates how many bytes are still available in the
// receive buffer.
// tcp流量控制计算未被占用的接收缓存大小

View File

@@ -5,18 +5,81 @@ import (
"netstack/tcpip/seqnum"
)
type receiver struct{}
type receiver struct {
ep *endpoint
rcvNxt seqnum.Value // 准备接收的下一个报文序列号
closed bool
}
// 新建并初始化接收器
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
r := &receiver{}
r := &receiver{
ep: ep,
rcvNxt: irs + 1,
}
return r
}
// tcp流量控制判断 segSeq 在窗口內
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// TODO 流量控制
return true
}
func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
if segLen > 0 {
// 我们期望接收到的序列号范围应该是 seqStart <= rcvNxt < seqEnd
// 如果不在这个范围内说明我们少了数据段返回false表示不能立马消费
if !r.rcvNxt.InWindows(segSeq, segLen) {
return false
}
// 尝试去除已经确认过的数据
if segSeq.LessThan(r.rcvNxt) {
log.Println("收到重复数据")
diff := segSeq.Size(r.rcvNxt)
segLen -= diff
segSeq.UpdateForward(diff)
s.sequenceNumber.UpdateForward(diff)
s.data.TrimFront(int(diff))
}
// 将tcp段插入接收链表并通知应用层用数据来了
r.ep.readyToRead(s)
} else if segSeq != r.rcvNxt { // 空数据 还是非顺序到达 丢弃
return false
}
// 如果收到 fin 报文
if s.flagIsSet(flagFin) {
// TODO 处理fin报文
}
return true
}
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
// 从 handleSegments 接收到tcp段然后进行处理消费所谓的消费就是将负载内容插入到接收队列中
func (r *receiver) handleRcvdSegment(s *segment) {
log.Println(s.data)
if r.closed {
return
}
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// TODO tcp流量控制
// tcp流量控制判断该数据段的序列号是否在接收窗口内如果不在立即返回ack给对端。
if !r.acceptable(segSeq, segLen) {
r.ep.snd.sendAck()
return
}
log.Println(s.data, segLen, segSeq)
// Defer segment processing if it can't be consumed now.
// tcp可靠性r.consumeSegment 返回值是个bool类型如果是true表示已经消费该数据段
// 如果不是,那么进行下面的处理,插入到 pendingRcvdSegments且进行堆排序
if !r.consumeSegment(s, segSeq, segLen) {
return
}
}

View File

@@ -1,6 +1,9 @@
package tcp
import "netstack/tcpip/seqnum"
import (
"log"
"netstack/tcpip/seqnum"
)
type sender struct {
}
@@ -10,3 +13,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s := &sender{}
return s
}
func (s *sender) sendAck() {
log.Fatal("TODO 需要发送一个ack")
}