mirror of
https://github.com/Monibuca/plugin-webtransport.git
synced 2025-12-24 11:51:00 +08:00
🐛 FIX:升级quic库,兼容go1.19
This commit is contained in:
70
internal/frames.go
Normal file
70
internal/frames.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package h3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// Frame types
|
||||
const (
|
||||
FRAME_DATA = 0x00
|
||||
FRAME_HEADERS = 0x01
|
||||
FRAME_CANCEL_PUSH = 0x03
|
||||
FRAME_SETTINGS = 0x04
|
||||
FRAME_PUSH_PROMISE = 0x05
|
||||
FRAME_GOAWAY = 0x07
|
||||
FRAME_MAX_PUSH_ID = 0x0D
|
||||
FRAME_WEBTRANSPORT_STREAM = 0x41
|
||||
)
|
||||
|
||||
// HTTP/3 frame
|
||||
type Frame struct {
|
||||
Type uint64
|
||||
SessionID uint64
|
||||
Length uint64
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (f *Frame) Read(r io.Reader) error {
|
||||
qr := quicvarint.NewReader(r)
|
||||
t, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.Type = t
|
||||
|
||||
// For most (but not all) frame types, l is the data length
|
||||
switch t {
|
||||
case FRAME_WEBTRANSPORT_STREAM:
|
||||
f.Length = 0
|
||||
f.SessionID = l
|
||||
f.Data = []byte{}
|
||||
return nil
|
||||
default:
|
||||
f.Length = l
|
||||
f.Data = make([]byte, l)
|
||||
_, err := r.Read(f.Data)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Frame) Write(w io.Writer) (int, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
quicvarint.Write(buf, f.Type)
|
||||
if f.Type == FRAME_WEBTRANSPORT_STREAM {
|
||||
quicvarint.Write(buf, f.SessionID)
|
||||
} else {
|
||||
quicvarint.Write(buf, f.Length)
|
||||
}
|
||||
buf.Write(f.Data)
|
||||
|
||||
return w.Write(buf.Bytes())
|
||||
}
|
||||
97
internal/request_reader.go
Normal file
97
internal/request_reader.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package h3
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/marten-seemann/qpack"
|
||||
)
|
||||
|
||||
func RequestFromHeaders(headers []qpack.HeaderField) (*http.Request, string, error) {
|
||||
var path, authority, method, contentLengthStr, protocol string
|
||||
httpHeaders := http.Header{}
|
||||
|
||||
for _, h := range headers {
|
||||
switch h.Name {
|
||||
case ":path":
|
||||
path = h.Value
|
||||
case ":method":
|
||||
method = h.Value
|
||||
case ":authority":
|
||||
authority = h.Value
|
||||
case ":protocol":
|
||||
protocol = h.Value
|
||||
case "content-length":
|
||||
contentLengthStr = h.Value
|
||||
default:
|
||||
if !h.IsPseudo() {
|
||||
httpHeaders.Add(h.Name, h.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
||||
if len(httpHeaders["Cookie"]) > 0 {
|
||||
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
|
||||
}
|
||||
|
||||
isConnect := method == http.MethodConnect
|
||||
if isConnect {
|
||||
// if path != "" || authority == "" {
|
||||
// return nil, errors.New(":path must be empty and :authority must not be empty")
|
||||
// }
|
||||
} else if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
|
||||
return nil, "", errors.New(":path, :authority and :method must not be empty")
|
||||
}
|
||||
|
||||
var u *url.URL
|
||||
var requestURI string
|
||||
var err error
|
||||
|
||||
if isConnect {
|
||||
u, err = url.ParseRequestURI("https://" + authority + path)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
requestURI = path
|
||||
} else {
|
||||
u, err = url.ParseRequestURI(path)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
requestURI = path
|
||||
}
|
||||
|
||||
var contentLength int64
|
||||
if len(contentLengthStr) > 0 {
|
||||
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Method: method,
|
||||
URL: u,
|
||||
Proto: "HTTP/3",
|
||||
ProtoMajor: 3,
|
||||
ProtoMinor: 0,
|
||||
Header: httpHeaders,
|
||||
Body: nil,
|
||||
ContentLength: contentLength,
|
||||
Host: authority,
|
||||
RequestURI: requestURI,
|
||||
TLS: &tls.ConnectionState{},
|
||||
}, protocol, nil
|
||||
}
|
||||
|
||||
func hostnameFromRequest(req *http.Request) string {
|
||||
if req.URL != nil {
|
||||
return req.URL.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
111
internal/response_writer.go
Normal file
111
internal/response_writer.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package h3
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/marten-seemann/qpack"
|
||||
)
|
||||
|
||||
// DataStreamer lets the caller take over the stream. After a call to DataStream
|
||||
// the HTTP server library will not do anything else with the connection.
|
||||
//
|
||||
// It becomes the caller's responsibility to manage and close the stream.
|
||||
//
|
||||
// After a call to DataStream, the original Request.Body must not be used.
|
||||
type DataStreamer interface {
|
||||
DataStream() quic.Stream
|
||||
}
|
||||
|
||||
type ResponseWriter struct {
|
||||
stream quic.Stream // needed for DataStream()
|
||||
bufferedStream *bufio.Writer
|
||||
|
||||
header http.Header
|
||||
status int // status code passed to WriteHeader
|
||||
headerWritten bool
|
||||
dataStreamUsed bool // set when DataSteam() is called
|
||||
}
|
||||
|
||||
func NewResponseWriter(stream quic.Stream) *ResponseWriter {
|
||||
return &ResponseWriter{
|
||||
header: http.Header{},
|
||||
stream: stream,
|
||||
bufferedStream: bufio.NewWriter(stream),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *ResponseWriter) WriteHeader(status int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
|
||||
if status < 100 || status >= 200 {
|
||||
w.headerWritten = true
|
||||
}
|
||||
w.status = status
|
||||
|
||||
var headers bytes.Buffer
|
||||
enc := qpack.NewEncoder(&headers)
|
||||
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
||||
for k, v := range w.header {
|
||||
for index := range v {
|
||||
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
||||
}
|
||||
}
|
||||
|
||||
headersFrame := Frame{Type: FRAME_HEADERS, Length: uint64(headers.Len()), Data: headers.Bytes()}
|
||||
headersFrame.Write(w.bufferedStream)
|
||||
if !w.headerWritten {
|
||||
w.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ResponseWriter) Write(p []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
if !bodyAllowedForStatus(w.status) {
|
||||
return 0, http.ErrBodyNotAllowed
|
||||
}
|
||||
|
||||
dataFrame := Frame{Type: FRAME_DATA, Length: uint64(len(p)), Data: p}
|
||||
return dataFrame.Write(w.bufferedStream)
|
||||
}
|
||||
|
||||
func (w *ResponseWriter) Flush() {
|
||||
w.bufferedStream.Flush()
|
||||
}
|
||||
|
||||
func (w *ResponseWriter) usedDataStream() bool {
|
||||
return w.dataStreamUsed
|
||||
}
|
||||
|
||||
func (w *ResponseWriter) DataStream() quic.Stream {
|
||||
w.dataStreamUsed = true
|
||||
w.Flush()
|
||||
return w.stream
|
||||
}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == 204:
|
||||
return false
|
||||
case status == 304:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
87
internal/settings.go
Normal file
87
internal/settings.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package h3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// Settings
|
||||
const (
|
||||
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-http-34
|
||||
SETTINGS_MAX_FIELD_SECTION_SIZE = SettingID(0x6)
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-qpack-21
|
||||
SETTINGS_QPACK_MAX_TABLE_CAPACITY = SettingID(0x1)
|
||||
SETTINGS_QPACK_BLOCKED_STREAMS = SettingID(0x7)
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-05#section-9.1
|
||||
H3_DATAGRAM_05 = SettingID(0xffd277)
|
||||
|
||||
// https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-02.html#section-8.2
|
||||
ENABLE_WEBTRANSPORT = SettingID(0x2b603742)
|
||||
)
|
||||
|
||||
type SettingID uint64
|
||||
|
||||
type SettingsMap map[SettingID]uint64
|
||||
|
||||
func (s *SettingsMap) FromFrame(f Frame) error {
|
||||
if f.Length > 8*(1<<10) {
|
||||
return fmt.Errorf("unexpected size for SETTINGS frame: %d", f.Length)
|
||||
}
|
||||
|
||||
b := bytes.NewReader(f.Data)
|
||||
for b.Len() > 0 {
|
||||
id, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return err
|
||||
}
|
||||
val, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := (*s)[SettingID(id)]; ok {
|
||||
return fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
(*s)[SettingID(id)] = val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s SettingsMap) ToFrame() Frame {
|
||||
f := Frame{Type: FRAME_SETTINGS}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
var l uint64
|
||||
for id, val := range s {
|
||||
l += uint64(quicvarint.Len(uint64(id)) + quicvarint.Len(val))
|
||||
}
|
||||
f.Length = l
|
||||
for id, val := range s {
|
||||
quicvarint.Write(b, uint64(id))
|
||||
quicvarint.Write(b, val)
|
||||
}
|
||||
f.Data = b.Bytes()
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (id SettingID) String() string {
|
||||
switch id {
|
||||
case 0x01:
|
||||
return "QPACK_MAX_TABLE_CAPACITY"
|
||||
case 0x06:
|
||||
return "MAX_FIELD_SECTION_SIZE"
|
||||
case 0x07:
|
||||
return "QPACK_BLOCKED_STREAMS"
|
||||
case 0x2b603742:
|
||||
return "ENABLE_WEBTRANSPORT"
|
||||
case 0xffd277:
|
||||
return "H3_DATAGRAM_05"
|
||||
default:
|
||||
return fmt.Sprintf("%#x", uint64(id))
|
||||
}
|
||||
}
|
||||
68
internal/streams.go
Normal file
68
internal/streams.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package h3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// Stream types
|
||||
const (
|
||||
STREAM_CONTROL = 0x00
|
||||
STREAM_PUSH = 0x01
|
||||
STREAM_QPACK_ENCODER = 0x02
|
||||
STREAM_QPACK_DECODER = 0x03
|
||||
STREAM_WEBTRANSPORT_UNI_STREAM = 0x54
|
||||
)
|
||||
|
||||
// HTTP/3 stream header
|
||||
type StreamHeader struct {
|
||||
Type uint64
|
||||
ID uint64
|
||||
}
|
||||
|
||||
func (s *StreamHeader) Read(r io.Reader) error {
|
||||
qr := quicvarint.NewReader(r)
|
||||
t, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Type = t
|
||||
|
||||
switch t {
|
||||
// One-byte streams
|
||||
case STREAM_CONTROL, STREAM_QPACK_ENCODER, STREAM_QPACK_DECODER:
|
||||
return nil
|
||||
// Two-byte streams
|
||||
case STREAM_PUSH, STREAM_WEBTRANSPORT_UNI_STREAM:
|
||||
l, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.ID = l
|
||||
return nil
|
||||
default:
|
||||
// skip over unknown streams
|
||||
return fmt.Errorf("unknown stream type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StreamHeader) Write(w io.Writer) (int, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
quicvarint.Write(buf, s.Type)
|
||||
switch s.Type {
|
||||
// One-byte streams
|
||||
case STREAM_CONTROL, STREAM_QPACK_ENCODER, STREAM_QPACK_DECODER:
|
||||
// Two-byte streams
|
||||
case STREAM_PUSH, STREAM_WEBTRANSPORT_UNI_STREAM:
|
||||
quicvarint.Write(buf, s.ID)
|
||||
default:
|
||||
// skip over unknown streams
|
||||
return 0, fmt.Errorf("unknown stream type")
|
||||
}
|
||||
|
||||
return w.Write(buf.Bytes())
|
||||
}
|
||||
11
main.go
11
main.go
@@ -4,7 +4,6 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
webtransport "github.com/adriancable/webtransport-go"
|
||||
. "m7s.live/engine/v4"
|
||||
)
|
||||
|
||||
@@ -24,7 +23,7 @@ func (c *WebTransportConfig) OnEvent(event any) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/play/", func(w http.ResponseWriter, r *http.Request) {
|
||||
streamPath := r.URL.Path[len("/play/"):]
|
||||
session := r.Body.(*webtransport.Session)
|
||||
session := r.Body.(*Session)
|
||||
session.AcceptSession()
|
||||
defer session.CloseSession()
|
||||
// TODO: 多路
|
||||
@@ -44,7 +43,7 @@ func (c *WebTransportConfig) OnEvent(event any) {
|
||||
})
|
||||
mux.HandleFunc("/push/", func(w http.ResponseWriter, r *http.Request) {
|
||||
streamPath := r.URL.Path[len("/push/"):]
|
||||
session := r.Body.(*webtransport.Session)
|
||||
session := r.Body.(*Session)
|
||||
session.AcceptSession()
|
||||
defer session.CloseSession()
|
||||
// TODO: 多路
|
||||
@@ -64,11 +63,11 @@ func (c *WebTransportConfig) OnEvent(event any) {
|
||||
|
||||
}
|
||||
})
|
||||
server := &webtransport.Server{
|
||||
server := &Server{
|
||||
Handler: mux,
|
||||
ListenAddr: c.ListenAddr,
|
||||
TLSCert: webtransport.CertFile{Path: c.CertFile},
|
||||
TLSKey: webtransport.CertFile{Path: c.KeyFile},
|
||||
TLSCert: CertFile{Path: c.CertFile},
|
||||
TLSKey: CertFile{Path: c.KeyFile},
|
||||
}
|
||||
go server.Run(plugin)
|
||||
}
|
||||
|
||||
429
webtransport.go
Normal file
429
webtransport.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package webtransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/http3"
|
||||
"github.com/lucas-clemente/quic-go/quicvarint"
|
||||
"github.com/marten-seemann/qpack"
|
||||
h3 "m7s.live/plugin/webtransport/v4/internal"
|
||||
)
|
||||
|
||||
type receiveMessageResult struct {
|
||||
msg []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// A CertFile represents a TLS certificate or key, expressed either as a file path or as the certificate/key itself as a []byte.
|
||||
type CertFile struct {
|
||||
Path string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Wrapper for quic.Config
|
||||
type QuicConfig quic.Config
|
||||
|
||||
// A Server defines parameters for running a WebTransport server. Use http.HandleFunc to register HTTP/3 endpoints for handling WebTransport requests.
|
||||
type Server struct {
|
||||
http.Handler
|
||||
// ListenAddr sets an address to bind server to, e.g. ":4433"
|
||||
ListenAddr string
|
||||
// TLSCert defines a path to, or byte array containing, a certificate (CRT file)
|
||||
TLSCert CertFile
|
||||
// TLSKey defines a path to, or byte array containing, the certificate's private key (KEY file)
|
||||
TLSKey CertFile
|
||||
// AllowedOrigins represents list of allowed origins to connect from
|
||||
AllowedOrigins []string
|
||||
// Additional configuration parameters to pass onto QUIC listener
|
||||
QuicConfig *QuicConfig
|
||||
}
|
||||
|
||||
// Starts a WebTransport server and blocks while it's running. Cancel the supplied Context to stop the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if s.Handler == nil {
|
||||
s.Handler = http.DefaultServeMux
|
||||
}
|
||||
if s.QuicConfig == nil {
|
||||
s.QuicConfig = &QuicConfig{}
|
||||
}
|
||||
s.QuicConfig.EnableDatagrams = true
|
||||
|
||||
listener, err := quic.ListenAddr(s.ListenAddr, s.generateTLSConfig(), (*quic.Config)(s.QuicConfig))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
sess, err := listener.Accept(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handleSession(ctx, sess)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleSession(ctx context.Context, sess quic.Connection) {
|
||||
serverControlStream, err := sess.OpenUniStream()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Write server settings
|
||||
streamHeader := h3.StreamHeader{Type: h3.STREAM_CONTROL}
|
||||
streamHeader.Write(serverControlStream)
|
||||
|
||||
settingsFrame := (h3.SettingsMap{h3.H3_DATAGRAM_05: 1, h3.ENABLE_WEBTRANSPORT: 1}).ToFrame()
|
||||
settingsFrame.Write(serverControlStream)
|
||||
|
||||
// Accept control stream - client settings will appear here
|
||||
clientControlStream, err := sess.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
// log.Printf("Read settings from control stream id: %d\n", stream.StreamID())
|
||||
|
||||
clientSettingsReader := quicvarint.NewReader(clientControlStream)
|
||||
quicvarint.Read(clientSettingsReader)
|
||||
|
||||
clientSettingsFrame := h3.Frame{}
|
||||
if clientSettingsFrame.Read(clientControlStream); err != nil || clientSettingsFrame.Type != h3.FRAME_SETTINGS {
|
||||
// log.Println("control stream read error, or not a settings frame")
|
||||
return
|
||||
}
|
||||
|
||||
// Accept request stream
|
||||
requestStream, err := sess.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
// log.Printf("request stream err: %v", err)
|
||||
return
|
||||
}
|
||||
// log.Printf("request stream accepted: %d", requestStream.StreamID())
|
||||
|
||||
ctx, cancelFunction := context.WithCancel(requestStream.Context())
|
||||
ctx = context.WithValue(ctx, http3.ServerContextKey, s)
|
||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr())
|
||||
|
||||
// log.Println(streamType, settingsFrame)
|
||||
|
||||
headersFrame := h3.Frame{}
|
||||
err = headersFrame.Read(requestStream)
|
||||
if err != nil {
|
||||
// log.Printf("request stream ParseNextFrame err: %v", err)
|
||||
cancelFunction()
|
||||
requestStream.Close()
|
||||
return
|
||||
}
|
||||
if headersFrame.Type != h3.FRAME_HEADERS {
|
||||
// log.Println("request stream got not HeadersFrame")
|
||||
cancelFunction()
|
||||
requestStream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
decoder := qpack.NewDecoder(nil)
|
||||
hfs, err := decoder.DecodeFull(headersFrame.Data)
|
||||
if err != nil {
|
||||
// log.Printf("request stream decoder err: %v", err)
|
||||
cancelFunction()
|
||||
requestStream.Close()
|
||||
return
|
||||
}
|
||||
req, protocol, err := h3.RequestFromHeaders(hfs)
|
||||
if err != nil {
|
||||
cancelFunction()
|
||||
requestStream.Close()
|
||||
return
|
||||
}
|
||||
req.RemoteAddr = sess.RemoteAddr().String()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
rw := h3.NewResponseWriter(requestStream)
|
||||
rw.Header().Add("sec-webtransport-http3-draft", "draft02")
|
||||
req.Body = &Session{Stream: requestStream, Session: sess, ClientControlStream: clientControlStream, ServerControlStream: serverControlStream, responseWriter: rw, context: ctx, cancel: cancelFunction}
|
||||
|
||||
if protocol != "webtransport" || !s.validateOrigin(req.Header.Get("origin")) {
|
||||
req.Body.(*Session).RejectSession(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Drain request stream - this is so that we can catch the EOF and shut down cleanly when the client closes the transport
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, 1024)
|
||||
_, err := requestStream.Read(buf)
|
||||
if err != nil {
|
||||
cancelFunction()
|
||||
requestStream.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func (s *Server) generateTLSConfig() *tls.Config {
|
||||
var cert tls.Certificate
|
||||
var err error
|
||||
|
||||
if s.TLSCert.Path != "" && s.TLSKey.Path != "" {
|
||||
cert, err = tls.LoadX509KeyPair(s.TLSCert.Path, s.TLSKey.Path)
|
||||
} else {
|
||||
cert, err = tls.X509KeyPair(s.TLSCert.Data, s.TLSKey.Data)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
NextProtos: []string{"h3", "h3-32", "h3-31", "h3-30", "h3-29"},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) validateOrigin(origin string) bool {
|
||||
// No origin specified - everything is allowed
|
||||
if s.AllowedOrigins == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Enforce allowed origins
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, b := range s.AllowedOrigins {
|
||||
if b == u.Host {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ReceiveStream wraps a quic.ReceiveStream providing a unidirectional WebTransport client->server stream, including a Read function.
|
||||
type ReceiveStream struct {
|
||||
quic.ReceiveStream
|
||||
readHeaderBeforeData bool
|
||||
headerRead bool
|
||||
requestSessionID uint64
|
||||
}
|
||||
|
||||
// SendStream wraps a quic.SendStream providing a unidirectional WebTransport server->client stream, including a Write function.
|
||||
type SendStream struct {
|
||||
quic.SendStream
|
||||
writeHeaderBeforeData bool
|
||||
headerWritten bool
|
||||
requestSessionID uint64
|
||||
}
|
||||
|
||||
// Stream wraps a quic.Stream providing a bidirectional server<->client stream, including Read and Write functions.
|
||||
type Stream quic.Stream
|
||||
|
||||
// Read reads up to len(p) bytes from a WebTransport unidirectional stream, returning the actual number of bytes read.
|
||||
func (s *ReceiveStream) Read(p []byte) (int, error) {
|
||||
if s.readHeaderBeforeData && !s.headerRead {
|
||||
// Unidirectional stream - so we need to read stream header before first data read
|
||||
|
||||
streamHeader := h3.StreamHeader{}
|
||||
if err := streamHeader.Read(s.ReceiveStream); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if streamHeader.Type != h3.STREAM_WEBTRANSPORT_UNI_STREAM {
|
||||
return 0, fmt.Errorf("unidirectional stream received with the wrong stream type")
|
||||
}
|
||||
s.requestSessionID = streamHeader.ID
|
||||
s.headerRead = true
|
||||
}
|
||||
return s.ReceiveStream.Read(p)
|
||||
}
|
||||
|
||||
// Write writes up to len(p) bytes to a WebTransport unidirectional stream, returning the actual number of bytes written.
|
||||
func (s *SendStream) Write(p []byte) (int, error) {
|
||||
if s.writeHeaderBeforeData && !s.headerWritten {
|
||||
// Unidirectional stream - so we need to write stream header before first data write
|
||||
buf := &bytes.Buffer{}
|
||||
quicvarint.Write(buf, h3.STREAM_WEBTRANSPORT_UNI_STREAM)
|
||||
quicvarint.Write(buf, s.requestSessionID)
|
||||
if _, err := s.SendStream.Write(buf.Bytes()); err != nil {
|
||||
s.Close()
|
||||
return 0, err
|
||||
}
|
||||
s.headerWritten = true
|
||||
}
|
||||
return s.SendStream.Write(p)
|
||||
}
|
||||
|
||||
// Session is a WebTransport session (and the Body of a WebTransport http.Request) wrapping the request stream (a quic.Stream), the two control streams and a quic.Session.
|
||||
type Session struct {
|
||||
quic.Stream
|
||||
Session quic.Connection
|
||||
ClientControlStream quic.ReceiveStream
|
||||
ServerControlStream quic.SendStream
|
||||
responseWriter *h3.ResponseWriter
|
||||
context context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Context returns the context for the WebTransport session.
|
||||
func (s *Session) Context() context.Context {
|
||||
return s.context
|
||||
}
|
||||
|
||||
// AcceptSession accepts an incoming WebTransport session. Call it in your http.HandleFunc.
|
||||
func (s *Session) AcceptSession() {
|
||||
r := s.responseWriter
|
||||
r.WriteHeader(http.StatusOK)
|
||||
r.Flush()
|
||||
}
|
||||
|
||||
// AcceptSession rejects an incoming WebTransport session, returning the supplied HTML error code to the client. Call it in your http.HandleFunc.
|
||||
func (s *Session) RejectSession(errorCode int) {
|
||||
r := s.responseWriter
|
||||
r.WriteHeader(errorCode)
|
||||
r.Flush()
|
||||
s.CloseSession()
|
||||
}
|
||||
|
||||
// ReceiveMessage returns a datagram received from a WebTransport session, blocking if necessary until one is available. Supply your own context, or use the WebTransport
|
||||
// session's Context() so that ending the WebTransport session automatically cancels this call. Note that datagrams are unreliable - depending on network conditions,
|
||||
// datagrams sent by the client may never be received by the server.
|
||||
func (s *Session) ReceiveMessage(ctx context.Context) ([]byte, error) {
|
||||
resultChannel := make(chan receiveMessageResult)
|
||||
|
||||
go func() {
|
||||
msg, err := s.Session.ReceiveMessage()
|
||||
resultChannel <- receiveMessageResult{msg: msg, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChannel:
|
||||
if result.err != nil {
|
||||
return nil, result.err
|
||||
}
|
||||
|
||||
datastream := bytes.NewReader(result.msg)
|
||||
quarterStreamId, err := quicvarint.Read(datastream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.msg[quicvarint.Len(quarterStreamId):], nil
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("WebTransport stream closed")
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessage sends a datagram over a WebTransport session. Supply your own context, or use the WebTransport
|
||||
// session's Context() so that ending the WebTransport session automatically cancels this call. Note that datagrams are unreliable - depending on network conditions,
|
||||
// datagrams sent by the server may never be received by the client.
|
||||
func (s *Session) SendMessage(msg []byte) error {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
// "Quarter Stream ID" (!) of associated request stream, as per https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram
|
||||
quicvarint.Write(buf, uint64(s.StreamID()/4))
|
||||
buf.Write(msg)
|
||||
return s.Session.SendMessage(buf.Bytes())
|
||||
}
|
||||
|
||||
// AcceptStream accepts an incoming (that is, client-initated) bidirectional stream, blocking if necessary until one is available. Supply your own context, or use the WebTransport
|
||||
// session's Context() so that ending the WebTransport session automatically cancels this call.
|
||||
func (s *Session) AcceptStream() (Stream, error) {
|
||||
stream, err := s.Session.AcceptStream(s.context)
|
||||
if err != nil {
|
||||
return stream, err
|
||||
}
|
||||
|
||||
streamFrame := h3.Frame{}
|
||||
err = streamFrame.Read(stream)
|
||||
|
||||
return stream, err
|
||||
}
|
||||
|
||||
// AcceptStream accepts an incoming (that is, client-initated) unidirectional stream, blocking if necessary until one is available. Supply your own context, or use the WebTransport
|
||||
// session's Context() so that ending the WebTransport session automatically cancels this call.
|
||||
func (s *Session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
|
||||
stream, err := s.Session.AcceptUniStream(ctx)
|
||||
return ReceiveStream{ReceiveStream: stream, readHeaderBeforeData: true, headerRead: false}, err
|
||||
}
|
||||
|
||||
func (s *Session) internalOpenStream(ctx *context.Context, sync bool) (Stream, error) {
|
||||
var stream quic.Stream
|
||||
var err error
|
||||
|
||||
if sync {
|
||||
stream, err = s.Session.OpenStreamSync(*ctx)
|
||||
} else {
|
||||
stream, err = s.Session.OpenStream()
|
||||
}
|
||||
if err == nil {
|
||||
// Write frame header
|
||||
buf := &bytes.Buffer{}
|
||||
quicvarint.Write(buf, h3.FRAME_WEBTRANSPORT_STREAM)
|
||||
quicvarint.Write(buf, uint64(s.StreamID()))
|
||||
if _, err := stream.Write(buf.Bytes()); err != nil {
|
||||
stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return stream, err
|
||||
}
|
||||
|
||||
func (s *Session) internalOpenUniStream(ctx *context.Context, sync bool) (SendStream, error) {
|
||||
var stream quic.SendStream
|
||||
var err error
|
||||
|
||||
if sync {
|
||||
stream, err = s.Session.OpenUniStreamSync(*ctx)
|
||||
} else {
|
||||
stream, err = s.Session.OpenUniStream()
|
||||
}
|
||||
return SendStream{SendStream: stream, writeHeaderBeforeData: true, headerWritten: false, requestSessionID: uint64(s.StreamID())}, err
|
||||
}
|
||||
|
||||
// OpenStream creates an outgoing (that is, server-initiated) bidirectional stream. It returns immediately.
|
||||
func (s *Session) OpenStream() (Stream, error) {
|
||||
return s.internalOpenStream(nil, false)
|
||||
}
|
||||
|
||||
// OpenStream creates an outgoing (that is, server-initiated) bidirectional stream. It generally returns immediately, but if the session's maximum number of streams
|
||||
// has been exceeded, it will block until a slot is available. Supply your own context, or use the WebTransport
|
||||
// session's Context() so that ending the WebTransport session automatically cancels this call.
|
||||
func (s *Session) OpenStreamSync(ctx context.Context) (Stream, error) {
|
||||
return s.internalOpenStream(&ctx, true)
|
||||
}
|
||||
|
||||
// OpenUniStream creates an outgoing (that is, server-initiated) bidirectional stream. It returns immediately.
|
||||
func (s *Session) OpenUniStream() (SendStream, error) {
|
||||
return s.internalOpenUniStream(nil, false)
|
||||
}
|
||||
|
||||
// OpenUniStreamSync creates an outgoing (that is, server-initiated) unidirectional stream. It generally returns immediately, but if the session's maximum number of streams
|
||||
// has been exceeded, it will block until a slot is available. Supply your own context, or use the WebTransport
|
||||
// session's Context() so that ending the WebTransport session automatically cancels this call.
|
||||
func (s *Session) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
|
||||
return s.internalOpenUniStream(&ctx, true)
|
||||
}
|
||||
|
||||
// CloseSession cleanly closes a WebTransport session. All active streams are cancelled before terminating the session.
|
||||
func (s *Session) CloseSession() {
|
||||
s.cancel()
|
||||
s.Close()
|
||||
}
|
||||
|
||||
// CloseWithError closes a WebTransport session with a supplied error code and string.
|
||||
func (s *Session) CloseWithError(code quic.ApplicationErrorCode, str string) {
|
||||
s.Session.CloseWithError(code, str)
|
||||
}
|
||||
Reference in New Issue
Block a user