package rtsp import ( "fmt" "github.com/lkmio/avformat/utils" "github.com/lkmio/lkm/log" "github.com/lkmio/lkm/stream" "net/http" "net/textproto" "net/url" "reflect" "strconv" "strings" "time" ) type Request struct { session *session sourceId string method string url *url.URL headers textproto.MIMEHeader } // Handler 处理RTSP各个请求消息 type Handler interface { // Process 路由请求给具体的handler, 并发送响应 Process(session *session, method string, url_ *url.URL, headers textproto.MIMEHeader) error OnOptions(request Request) (*http.Response, []byte, error) // OnDescribe 获取spd OnDescribe(request Request) (*http.Response, []byte, error) // OnSetup 订阅track OnSetup(request Request) (*http.Response, []byte, error) // OnPlay 请求播放 OnPlay(request Request) (*http.Response, []byte, error) // OnTeardown 结束播放 OnTeardown(request Request) (*http.Response, []byte, error) OnPause(request Request) (*http.Response, []byte, error) OnGetParameter(request Request) (*http.Response, []byte, error) OnSetParameter(request Request) (*http.Response, []byte, error) OnRedirect(request Request) (*http.Response, []byte, error) // OnRecord 推流 OnRecord(request Request) (*http.Response, []byte, error) } type handler struct { methods map[string]reflect.Value password string publicHeader string } func (h handler) Process(session *session, method string, url_ *url.URL, headers textproto.MIMEHeader) error { m, ok := h.methods[method] if !ok { return fmt.Errorf("the method %s is not implmented", method) } //确保拉流要经过授权 state, ok := method2StateMap[method] if ok && state > SessionStateSetup && session.sink == nil { return fmt.Errorf("please establish a session first") } source, _ := stream.Path2SourceID(url_.Path, "") // 反射调用各个处理函数 results := m.Call([]reflect.Value{ reflect.ValueOf(&h), reflect.ValueOf(Request{session, source, method, url_, headers}), }) err, _ := results[2].Interface().(error) if err != nil { return err } response := results[0].Interface().(*http.Response) if ok { session.state = state } if response == nil { return nil } body := results[1].Bytes() err = session.response(response, body) return err } func (h handler) OnOptions(request Request) (*http.Response, []byte, error) { rep := NewOKResponse(request.headers.Get("Cseq")) rep.Header.Set("Public", h.publicHeader) return rep, nil, nil } func (h handler) OnDescribe(request Request) (*http.Response, []byte, error) { var err error var response *http.Response var body []byte // 校验密码 if h.password != "" { var ok bool authorization := request.headers.Get("Authorization") if authorization != "" { params, err := parseAuthParams(authorization) ok = err == nil && DoAuthenticatePlainTextPassword(params, h.password) } if !ok { response401 := NewResponse(http.StatusUnauthorized, request.headers.Get("Cseq")) response401.Header.Set("WWW-Authenticate", generateAuthHeader("lkm")) return response401, nil, nil } } sinkId := stream.NetAddr2SinkID(request.session.conn.RemoteAddr()) sink := NewSink(sinkId, request.sourceId, request.session.conn, func(sdp string) { // 响应sdp回调 response = NewOKResponse(request.headers.Get("Cseq")) response.Header.Set("Content-Type", "application/sdp") request.session.response(response, []byte(sdp)) }) ok := stream.SubscribeStreamWithOptions(sink, request.url.Query(), false, false) if utils.HookStateOK != ok { return nil, nil, fmt.Errorf("hook failed. code: %d", ok) } request.session.sink = sink.(*Sink) return nil, body, err } func (h handler) OnSetup(request Request) (*http.Response, []byte, error) { var response *http.Response // 修复解析拉流携带的参数失败问题 params := strings.ReplaceAll(request.url.RawQuery, "/?", "&") query, err := url.ParseQuery(params) if err != nil { return nil, nil, err } track := query.Get("track") index, err := strconv.Atoi(track) if err != nil { return nil, nil, err } transportHeader := request.headers.Get("Transport") if transportHeader == "" { return nil, nil, fmt.Errorf("not find transport header") } split := strings.Split(transportHeader, ";") if len(split) < 3 { return nil, nil, fmt.Errorf("failed to parsing TRANSPORT header:%s", transportHeader) } tcp := "RTP/AVP" != split[0] && "RTP/AVP/UDP" != split[0] if !tcp { for _, value := range split { if !strings.HasPrefix(value, "client_port=") { continue } pairPort := strings.Split(value[len("client_port="):], "-") if len(pairPort) != 2 { return nil, nil, fmt.Errorf("failed to parsing client_port:%s", value) } port, err := strconv.Atoi(pairPort[0]) if err != nil { return nil, nil, err } _ = port port2, err := strconv.Atoi(pairPort[1]) if err != nil { return nil, nil, err } _ = port2 log.Sugar.Debugf("client port:%d-%d", port, port2) } } ssrc := 0xFFFFFFFF rtpPort, rtcpPort, err := request.session.sink.AddSender(index, tcp, uint32(ssrc)) if err != nil { return nil, nil, err } responseHeader := transportHeader if tcp { // 修改interleaved为实际的stream index responseHeader += ";interleaved=" + fmt.Sprintf("%d-%d", index, index) } else { responseHeader += ";server_port=" + fmt.Sprintf("%d-%d", rtpPort, rtcpPort) } responseHeader += ";ssrc=" + strconv.FormatInt(int64(ssrc), 16) response = NewOKResponse(request.headers.Get("Cseq")) response.Header.Set("Transport", responseHeader) response.Header.Set("Session", request.session.sessionId) return response, nil, nil } func (h handler) OnPlay(request Request) (*http.Response, []byte, error) { response := NewOKResponse(request.headers.Get("Cseq")) response.Header.Set("Date", time.Now().Format("Mon, 02 Jan 2006 15:04:05 GMT")) if sessionHeader := request.headers.Get("Session"); sessionHeader != "" { response.Header.Set("Session", sessionHeader) } if rangeV := request.headers.Get("Range"); rangeV != "" { response.Header.Set("Range", rangeV) } sink := request.session.sink sink.SetReady(true) source := stream.SourceManager.Find(sink.GetSourceID()) if source == nil { return nil, nil, fmt.Errorf("Source with ID %s does not exist.", request.sourceId) } source.GetTransStreamPublisher().AddSink(sink) // RTP-Info: url=rtsp://192.168.2.110:8554/hls/mystream/trackID=0;seq=21592;rtptime=4586400,url=rtsp://192.168.2.110:8554/hls/mystream/trackID=1;seq=403;rtptime=412672\r\n //info := <-sink.onPlayResponse //response.Header.Set("RTP-Info", fmt.Sprintf("url=%s;seq=%d;rtptime=%d", "rtsp://192.168.2.119:554/hls/mystream/?track=0", info[0], info[1])) return response, nil, nil } func (h handler) OnTeardown(request Request) (*http.Response, []byte, error) { response := NewOKResponse(request.headers.Get("Cseq")) return response, nil, nil } func (h handler) OnPause(request Request) (*http.Response, []byte, error) { response := NewOKResponse(request.headers.Get("Cseq")) return response, nil, nil } func NewHandler(password string) *handler { h := handler{ methods: make(map[string]reflect.Value, 10), password: password, } //反射获取所有成员函数, 映射对应的RTSP请求方法 t := reflect.TypeOf(&h) numMethod := t.NumMethod() headers := make([]string, 0, 10) for i := 0; i < numMethod; i++ { method := t.Method(i) if !strings.HasPrefix(method.Name, "On") { continue } //确保函数名和RTSP标准的请求方法保持一致 methodName := strings.ToUpper(method.Name[2:]) h.methods[methodName] = method.Func headers = append(headers, methodName) } h.publicHeader = strings.Join(headers, ",") return &h }