diff --git a/internal/streams/play.go b/internal/streams/play.go index d72c5e0c..1f8c4ade 100644 --- a/internal/streams/play.go +++ b/internal/streams/play.go @@ -7,7 +7,7 @@ import ( "github.com/AlexxIT/go2rtc/pkg/core" ) -func (s *Stream) Play(source string) error { +func (s *Stream) Play(urlOrProd any) error { s.mu.Lock() for _, producer := range s.producers { if producer.state == stateInternal && producer.conn != nil { @@ -16,12 +16,18 @@ func (s *Stream) Play(source string) error { } s.mu.Unlock() - if source == "" { - return nil - } - + var source string var src core.Producer + switch urlOrProd.(type) { + case string: + if source = urlOrProd.(string); source == "" { + return nil + } + case core.Producer: + src = urlOrProd.(core.Producer) + } + for _, producer := range s.producers { if producer.conn == nil { continue diff --git a/internal/wyoming/wyoming.go b/internal/wyoming/wyoming.go new file mode 100644 index 00000000..41ae27c3 --- /dev/null +++ b/internal/wyoming/wyoming.go @@ -0,0 +1,70 @@ +package wyoming + +import ( + "net" + + "github.com/AlexxIT/go2rtc/internal/app" + "github.com/AlexxIT/go2rtc/internal/streams" + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/AlexxIT/go2rtc/pkg/wyoming" + "github.com/rs/zerolog" +) + +func Init() { + streams.HandleFunc("wyoming", wyoming.Dial) + + // server + var cfg struct { + Mod map[string]struct { + Listen string `yaml:"listen"` + Name string `yaml:"name"` + WakeURI string `yaml:"wake_uri"` + VADThreshold float32 `yaml:"vad_threshold"` + } `yaml:"wyoming"` + } + app.LoadConfig(&cfg) + + log = app.GetLogger("wyoming") + + for name, cfg := range cfg.Mod { + stream := streams.Get(name) + if stream == nil { + log.Warn().Msgf("[wyoming] missing stream: %s", name) + continue + } + + ln, err := net.Listen("tcp", cfg.Listen) + if err != nil { + log.Warn().Msgf("[wyoming] listen error: %s", err) + continue + } + + if cfg.Name == "" { + cfg.Name = name + } + + srv := wyoming.Server{ + Name: cfg.Name, + VADThreshold: int16(1000 * cfg.VADThreshold), // 1.0 => 1000 + WakeURI: cfg.WakeURI, + MicHandler: func(cons core.Consumer) error { + if err := stream.AddConsumer(cons); err != nil { + return err + } + // not best solution + if i, ok := cons.(interface{ OnClose(func()) }); ok { + i.OnClose(func() { + stream.RemoveConsumer(cons) + }) + } + return nil + }, + SndHandler: func(prod core.Producer) error { + return stream.Play(prod) + }, + } + go srv.Serve(ln) + } +} + +var log zerolog.Logger diff --git a/main.go b/main.go index f8aba89e..295de219 100644 --- a/main.go +++ b/main.go @@ -38,6 +38,7 @@ import ( "github.com/AlexxIT/go2rtc/internal/v4l2" "github.com/AlexxIT/go2rtc/internal/webrtc" "github.com/AlexxIT/go2rtc/internal/webtorrent" + "github.com/AlexxIT/go2rtc/internal/wyoming" "github.com/AlexxIT/go2rtc/pkg/shell" ) @@ -69,6 +70,7 @@ func main() { hass.Init() // hass source, Hass API server onvif.Init() // onvif source, ONVIF API server webtorrent.Init() // webtorrent source, WebTorrent module + wyoming.Init() // 5. Other sources diff --git a/pkg/core/codec.go b/pkg/core/codec.go index c7791df9..ba0c656a 100644 --- a/pkg/core/codec.go +++ b/pkg/core/codec.go @@ -277,7 +277,7 @@ func ParseCodecString(s string) *Codec { codec.ClockRate = uint32(Atoi(ss[1])) } if len(ss) >= 3 { - codec.Channels = uint16(Atoi(ss[1])) + codec.Channels = uint8(Atoi(ss[1])) } return &codec diff --git a/pkg/pcm/pcm.go b/pkg/pcm/pcm.go index 6872c503..bf54a6cf 100644 --- a/pkg/pcm/pcm.go +++ b/pkg/pcm/pcm.go @@ -185,3 +185,23 @@ func Transcode(dst, src *core.Codec) func([]byte) []byte { return writer(samples) } } + +func ConsumerCodecs() []*core.Codec { + return []*core.Codec{ + {Name: core.CodecPCML}, + {Name: core.CodecPCM}, + {Name: core.CodecPCMA}, + {Name: core.CodecPCMU}, + } +} + +func ProducerCodecs() []*core.Codec { + return []*core.Codec{ + {Name: core.CodecPCML, ClockRate: 16000}, + {Name: core.CodecPCM, ClockRate: 16000}, + {Name: core.CodecPCML, ClockRate: 8000}, + {Name: core.CodecPCM, ClockRate: 8000}, + {Name: core.CodecPCMA, ClockRate: 8000}, + {Name: core.CodecPCMU, ClockRate: 8000}, + } +} diff --git a/pkg/pcm/s16le/s16le.go b/pkg/pcm/s16le/s16le.go new file mode 100644 index 00000000..acd2d4fc --- /dev/null +++ b/pkg/pcm/s16le/s16le.go @@ -0,0 +1,42 @@ +package s16le + +func PeaksRMS(b []byte) int16 { + // RMS of sine wave = peak / sqrt2 + // https://en.wikipedia.org/wiki/Root_mean_square + // https://www.youtube.com/watch?v=MUDkL4KZi0I + var peaks int32 + var peaksSum int32 + var prevSample int16 + var prevUp bool + + var i int + for n := len(b); i < n; { + lo := b[i] + i++ + hi := b[i] + i++ + + sample := int16(hi)<<8 | int16(lo) + up := sample >= prevSample + + if i >= 4 { + if up != prevUp { + if prevSample >= 0 { + peaksSum += int32(prevSample) + } else { + peaksSum -= int32(prevSample) + } + peaks++ + } + } + + prevSample = sample + prevUp = up + } + + if peaks == 0 { + return 0 + } + + return int16(peaksSum / peaks) +} diff --git a/pkg/wyoming/api.go b/pkg/wyoming/api.go new file mode 100644 index 00000000..59de747c --- /dev/null +++ b/pkg/wyoming/api.go @@ -0,0 +1,98 @@ +package wyoming + +import ( + "bufio" + "encoding/json" + "io" + "net" + + "github.com/AlexxIT/go2rtc/pkg/core" +) + +type API struct { + conn net.Conn + rd *bufio.Reader +} + +func DialAPI(address string) (*API, error) { + conn, err := net.DialTimeout("tcp", address, core.ConnDialTimeout) + if err != nil { + return nil, err + } + + return NewAPI(conn), nil +} + +const Version = "1.5.4" + +func NewAPI(conn net.Conn) *API { + return &API{conn: conn, rd: bufio.NewReader(conn)} +} + +func (w *API) WriteEvent(evt *Event) (err error) { + hdr := EventHeader{ + Type: evt.Type, + Version: Version, + DataLength: len(evt.Data), + PayloadLength: len(evt.Payload), + } + + buf, err := json.Marshal(hdr) + if err != nil { + return err + } + + buf = append(buf, '\n') + buf = append(buf, evt.Data...) + buf = append(buf, evt.Payload...) + + _, err = w.conn.Write(buf) + return err +} + +func (w *API) ReadEvent() (*Event, error) { + data, err := w.rd.ReadBytes('\n') + if err != nil { + return nil, err + } + + var hdr EventHeader + if err = json.Unmarshal(data, &hdr); err != nil { + return nil, err + } + + evt := Event{Type: hdr.Type} + + if hdr.DataLength > 0 { + evt.Data = make([]byte, hdr.DataLength) + if _, err = io.ReadFull(w.rd, evt.Data); err != nil { + return nil, err + } + } + + if hdr.PayloadLength > 0 { + evt.Payload = make([]byte, hdr.PayloadLength) + if _, err = io.ReadFull(w.rd, evt.Payload); err != nil { + return nil, err + } + } + + return &evt, nil +} + +func (w *API) Close() error { + return w.conn.Close() +} + +type Event struct { + Type string + Data []byte + Payload []byte +} + +type EventHeader struct { + Type string `json:"type"` + Version string `json:"version"` + DataLength int `json:"data_length,omitempty"` + PayloadLength int `json:"payload_length,omitempty"` +} diff --git a/pkg/wyoming/backchannel.go b/pkg/wyoming/backchannel.go new file mode 100644 index 00000000..8760789e --- /dev/null +++ b/pkg/wyoming/backchannel.go @@ -0,0 +1,42 @@ +package wyoming + +import ( + "fmt" + "time" + + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/pion/rtp" +) + +type Backchannel struct { + core.Connection + api *API +} + +func (b *Backchannel) GetTrack(media *core.Media, codec *core.Codec) (*core.Receiver, error) { + return nil, core.ErrCantGetTrack +} + +func (b *Backchannel) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiver) error { + sender := core.NewSender(media, codec) + sender.Handler = func(pkt *rtp.Packet) { + ts := time.Now().Nanosecond() + evt := &Event{ + Type: "audio-chunk", + Data: []byte(fmt.Sprintf(`{"rate":16000,"width":2,"channels":1,"timestamp":%d}`, ts)), + Payload: pkt.Payload, + } + _ = b.api.WriteEvent(evt) + } + sender.HandleRTP(track) + b.Senders = append(b.Senders, sender) + return nil +} + +func (b *Backchannel) Start() error { + for { + if _, err := b.api.ReadEvent(); err != nil { + return err + } + } +} diff --git a/pkg/wyoming/producer.go b/pkg/wyoming/producer.go new file mode 100644 index 00000000..9cd6abb6 --- /dev/null +++ b/pkg/wyoming/producer.go @@ -0,0 +1,43 @@ +package wyoming + +import ( + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/pion/rtp" +) + +type Producer struct { + core.Connection + api *API +} + +func (p *Producer) Start() error { + var seq uint16 + var ts uint32 + + for { + evt, err := p.api.ReadEvent() + if err != nil { + return err + } + + if evt.Type != "audio-chunk" { + continue + } + + p.Recv += len(evt.Payload) + + pkt := &core.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + SequenceNumber: seq, + Timestamp: ts, + }, + Payload: evt.Payload, + } + p.Receivers[0].WriteRTP(pkt) + + seq++ + ts += uint32(len(evt.Payload) / 2) + } +} diff --git a/pkg/wyoming/satellite.go b/pkg/wyoming/satellite.go new file mode 100644 index 00000000..14bded32 --- /dev/null +++ b/pkg/wyoming/satellite.go @@ -0,0 +1,395 @@ +package wyoming + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/AlexxIT/go2rtc/pkg/pcm" + "github.com/AlexxIT/go2rtc/pkg/pcm/s16le" + "github.com/pion/rtp" +) + +type Server struct { + Name string + + VADThreshold int16 + WakeURI string + + MicHandler func(cons core.Consumer) error + SndHandler func(prod core.Producer) error +} + +func (s *Server) Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if err != nil { + return err + } + + go s.Handle(conn) + } +} + +func (s *Server) Handle(conn net.Conn) error { + api := NewAPI(conn) + sat := newSatellite(api, s) + defer sat.Close() + + //log.Debug().Msgf("[wyoming] new client: %s", conn.RemoteAddr()) + + var snd []byte + + for { + evt, err := api.ReadEvent() + if err != nil { + return err + } + + //log.Printf("%s %s %d", evt.Type, evt.Data, len(evt.Payload)) + + switch evt.Type { + case "ping": // {"text": null} + _ = api.WriteEvent(&Event{Type: "pong", Data: evt.Data}) + case "describe": + // {"asr": [], "tts": [], "handle": [], "intent": [], "wake": [], "satellite": {"name": "my satellite", "attribution": {"name": "", "url": ""}, "installed": true, "description": "my satellite", "version": "1.4.1", "area": null, "snd_format": null}} + data := fmt.Sprintf(`{"satellite":{"name":%q,"attribution":{"name":"go2rtc","url":"https://github.com/AlexxIT/go2rtc"},"installed":true}}`, s.Name) + _ = api.WriteEvent(&Event{Type: "info", Data: []byte(data)}) + case "run-satellite": + if err = sat.run(); err != nil { + return err + } + case "pause-satellite": + sat.pause() + case "detect": // WAKE_WORD_START {"names": null} + case "detection": // WAKE_WORD_END {"name": "ok_nabu_v0.1", "timestamp": 17580, "speaker": null} + case "transcribe": // STT_START {"language": "en"} + case "voice-started": // STT_VAD_START {"timestamp": 1160} + case "voice-stopped": // STT_VAD_END {"timestamp": 2470} + sat.idle() + case "transcript": // STT_END {"text": "how are you"} + case "synthesize": // TTS_START {"text": "Sorry, I couldn't understand that", "voice": {"language": "en"}} + case "audio-start": // TTS_END {"rate": 22050, "width": 2, "channels": 1, "timestamp": 0} + snd = snd[:0] + case "audio-chunk": // {"rate": 22050, "width": 2, "channels": 1, "timestamp": 0} + snd = append(snd, evt.Payload...) + case "audio-stop": // {"timestamp": 2.880000000000002} + sat.respond(snd) + case "error": + sat.start() + } + } +} + +// states like Home Assistant +const ( + stateUnavailable = iota + stateIdle + stateWaitVAD // aka wait VAD + stateWaitWakeWord + stateStreaming +) + +type satellite struct { + api *API + srv *Server + + state uint8 + mu sync.Mutex + + timestamp int + + mic *micConsumer + wake *WakeWord +} + +func newSatellite(api *API, srv *Server) *satellite { + sat := &satellite{api: api, srv: srv} + return sat +} + +func (s *satellite) Close() error { + s.pause() + return s.api.Close() +} + +func (s *satellite) run() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != stateUnavailable { + return errors.New("wyoming: wrong satellite state") + } + + s.mic = newMicConsumer(s.onMicChunk) + s.mic.RemoteAddr = s.api.conn.RemoteAddr().String() + + if err := s.srv.MicHandler(s.mic); err != nil { + return err + } + + s.state = stateIdle + go s.start() + + return nil +} + +func (s *satellite) pause() { + s.mu.Lock() + + s.state = stateUnavailable + if s.mic != nil { + if s.mic.onClose != nil { + s.mic.onClose() + } + _ = s.mic.Stop() + s.mic = nil + } + if s.wake != nil { + _ = s.wake.Close() + s.wake = nil + } + + s.mu.Unlock() +} + +func (s *satellite) start() { + s.mu.Lock() + + if s.state != stateUnavailable { + s.state = stateWaitVAD + } + + s.mu.Unlock() +} + +func (s *satellite) idle() { + s.mu.Lock() + + if s.state != stateUnavailable { + s.state = stateIdle + } + + s.mu.Unlock() +} + +const wakeTimeout = 5 * 2 * 16000 // 5 seconds + +func (s *satellite) onMicChunk(chunk []byte) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state == stateIdle { + return + } + + if s.state == stateWaitVAD { + // tests show that values over 1000 are most likely speech + if s.srv.VADThreshold == 0 || s16le.PeaksRMS(chunk) > s.srv.VADThreshold { + if s.wake == nil && s.srv.WakeURI != "" { + s.wake, _ = DialWakeWord(s.srv.WakeURI) + } + if s.wake == nil { + // some problems with wake word - redirect to HA + evt := &Event{ + Type: "run-pipeline", + Data: []byte(`{"start_stage":"wake","end_stage":"tts","restart_on_end":false}`), + } + if err := s.api.WriteEvent(evt); err != nil { + return + } + s.state = stateStreaming + } else { + s.state = stateWaitWakeWord + } + s.timestamp = 0 + } + } + + if s.state == stateWaitWakeWord { + if s.wake.Detection != "" { + // check if wake word detected + evt := &Event{ + Type: "run-pipeline", + Data: []byte(`{"start_stage":"asr","end_stage":"tts","restart_on_end":false}`), + } + _ = s.api.WriteEvent(evt) + s.state = stateStreaming + s.timestamp = 0 + } else if err := s.wake.WriteChunk(chunk); err != nil { + // wake word service failed + s.state = stateWaitVAD + _ = s.wake.Close() + s.wake = nil + } else if s.timestamp > wakeTimeout { + // wake word detection timeout + s.state = stateWaitVAD + } + } else if s.wake != nil { + _ = s.wake.Close() + s.wake = nil + } + + if s.state == stateStreaming { + data := fmt.Sprintf(`{"rate":16000,"width":2,"channels":1,"timestamp":%d}`, s.timestamp) + evt := &Event{Type: "audio-chunk", Data: []byte(data), Payload: chunk} + _ = s.api.WriteEvent(evt) + } + + s.timestamp += len(chunk) / 2 +} + +func (s *satellite) respond(data []byte) { + prod := newSndProducer(data, func() { + _ = s.api.WriteEvent(&Event{Type: "played"}) + s.start() + }) + if err := s.srv.SndHandler(prod); err != nil { + prod.onClose() + } +} + +type micConsumer struct { + core.Connection + onData func(chunk []byte) + onClose func() +} + +func newMicConsumer(onData func(chunk []byte)) *micConsumer { + medias := []*core.Media{ + { + Kind: core.KindAudio, + Direction: core.DirectionSendonly, + Codecs: pcm.ConsumerCodecs(), + }, + } + + return &micConsumer{ + Connection: core.Connection{ + ID: core.NewID(), + FormatName: "wyoming", + Protocol: "tcp", + Medias: medias, + }, + onData: onData, + } +} + +func (c *micConsumer) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiver) error { + src := track.Codec + dst := &core.Codec{ + Name: core.CodecPCML, + ClockRate: 16000, + Channels: 1, + } + sender := core.NewSender(media, dst) + sender.Handler = pcm.TranscodeHandler(dst, src, + repack(func(packet *core.Packet) { + c.onData(packet.Payload) + }), + ) + sender.HandleRTP(track) + c.Senders = append(c.Senders, sender) + return nil +} + +type sndProducer struct { + core.Connection + data []byte + onClose func() +} + +func newSndProducer(data []byte, onClose func()) *sndProducer { + medias := []*core.Media{ + { + Kind: core.KindAudio, + Direction: core.DirectionRecvonly, + Codecs: pcm.ProducerCodecs(), + }, + } + + return &sndProducer{ + core.Connection{ + ID: core.NewID(), + FormatName: "wyoming", + Protocol: "tcp", + Medias: medias, + }, + data, + onClose, + } +} + +func (s *sndProducer) Start() error { + if len(s.Receivers) == 0 { + return nil + } + + var pts time.Duration + var seq uint16 + + t0 := time.Now() + + src := &core.Codec{Name: core.CodecPCML, ClockRate: 22050} + dst := s.Receivers[0].Codec + f := pcm.Transcode(dst, src) + + bps := uint32(pcm.BytesPerFrame(dst)) + + chunkBytes := int(2 * src.ClockRate / 50) // 20ms + + for { + n := len(s.data) + if n == 0 { + break + } + if chunkBytes > n { + chunkBytes = n + } + + pkt := &core.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + SequenceNumber: seq, + Timestamp: uint32(s.Recv/2) * bps, + }, + Payload: f(s.data[:chunkBytes]), + } + + if d := pts - time.Since(t0); d > 0 { + time.Sleep(d) + } + + s.Receivers[0].WriteRTP(pkt) + + s.Recv += chunkBytes + s.data = s.data[chunkBytes:] + + pts += 10 * time.Millisecond + seq++ + } + + s.onClose() + + return nil +} + +func repack(handler core.HandlerFunc) core.HandlerFunc { + const PacketSize = 2 * 16000 / 50 // 20ms + + var buf []byte + + return func(pkt *rtp.Packet) { + buf = append(buf, pkt.Payload...) + + for len(buf) >= PacketSize { + pkt = &core.Packet{Payload: buf[:PacketSize]} + buf = buf[PacketSize:] + handler(pkt) + } + } +} diff --git a/pkg/wyoming/wakeword.go b/pkg/wyoming/wakeword.go new file mode 100644 index 00000000..3603e22a --- /dev/null +++ b/pkg/wyoming/wakeword.go @@ -0,0 +1,120 @@ +package wyoming + +import ( + "encoding/json" + "fmt" + "net/url" +) + +type WakeWord struct { + *API + names []string + send int + + Detection string +} + +func DialWakeWord(rawURL string) (*WakeWord, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + + api, err := DialAPI(u.Host) + if err != nil { + return nil, err + } + + names := u.Query()["name"] + if len(names) == 0 { + names = []string{"ok_nabu_v0.1"} + } + + wake := &WakeWord{API: api, names: names} + if err = wake.Start(); err != nil { + _ = wake.Close() + return nil, err + } + + go wake.handle() + return wake, nil +} + +func (w *WakeWord) handle() { + defer w.Close() + + for { + evt, err := w.ReadEvent() + if err != nil { + return + } + + if evt.Type == "detection" { + var data struct { + Name string `json:"name"` + } + if err = json.Unmarshal(evt.Data, &data); err != nil { + return + } + w.Detection = data.Name + } + } +} + +//func (w *WakeWord) Describe() error { +// if err := w.WriteEvent(&Event{Type: "describe"}); err != nil { +// return err +// } +// +// evt, err := w.ReadEvent() +// if err != nil { +// return err +// } +// +// var info struct { +// Wake []struct { +// Models []struct { +// Name string `json:"name"` +// } `json:"models"` +// } `json:"wake"` +// } +// if err = json.Unmarshal(evt.Data, &info); err != nil { +// return err +// } +// +// return nil +//} + +func (w *WakeWord) Start() error { + msg := struct { + Names []string `json:"names"` + }{ + Names: w.names, + } + data, err := json.Marshal(msg) + if err != nil { + return err + } + evt := &Event{Type: "detect", Data: data} + if err := w.WriteEvent(evt); err != nil { + return err + } + + evt = &Event{Type: "audio-start", Data: audioData(0)} + return w.WriteEvent(evt) +} + +func (w *WakeWord) Close() error { + return w.conn.Close() +} + +func (w *WakeWord) WriteChunk(payload []byte) error { + evt := &Event{Type: "audio-chunk", Data: audioData(w.send), Payload: payload} + w.send += len(payload) + return w.WriteEvent(evt) +} + +func audioData(send int) []byte { + // timestamp in ms = send / 2 * 1000 / 16000 = send / 32 + return []byte(fmt.Sprintf(`{"rate":16000,"width":2,"channels":1,"timestamp":%d}`, send/32)) +} diff --git a/pkg/wyoming/wyoming.go b/pkg/wyoming/wyoming.go new file mode 100644 index 00000000..96d1dc5e --- /dev/null +++ b/pkg/wyoming/wyoming.go @@ -0,0 +1,42 @@ +package wyoming + +import ( + "net" + "net/url" + + "github.com/AlexxIT/go2rtc/pkg/core" +) + +func Dial(rawURL string) (core.Producer, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + + conn, err := net.DialTimeout("tcp", u.Host, core.ConnDialTimeout) + if err != nil { + return nil, err + } + + cc := core.Connection{ + ID: core.NewID(), + FormatName: "wyoming", + Medias: []*core.Media{ + { + Kind: core.KindAudio, + Codecs: []*core.Codec{ + {Name: core.CodecPCML, ClockRate: 16000}, + }, + }, + }, + Transport: conn, + } + + if u.Query().Get("backchannel") != "1" { + cc.Medias[0].Direction = core.DirectionRecvonly + return &Producer{cc, NewAPI(conn)}, nil + } else { + cc.Medias[0].Direction = core.DirectionSendonly + return &Backchannel{cc, NewAPI(conn)}, nil + } +}