commit 0492579cf2ff56614284391ab56a5d84d16bd247 Author: Jason Coene Date: Wed Jun 5 22:38:33 2013 -0500 It's sane - check it in! diff --git a/Guardfile b/Guardfile new file mode 100644 index 0000000..8a7d3bf --- /dev/null +++ b/Guardfile @@ -0,0 +1,5 @@ +guard :shell do + watch /^(.*\.go)/ do |m| + system "make" + end +end diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6054544 --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ +default: + go test + +test: default diff --git a/chunkstream.go b/chunkstream.go new file mode 100644 index 0000000..6e6fd76 --- /dev/null +++ b/chunkstream.go @@ -0,0 +1,100 @@ +package rtmp + +type OutboundChunkStream struct { + Id uint32 + lastHeader *Header + lastOutAbsoluteTimestamp uint32 + lastInAbsoluteTimestamp uint32 + startAtTimestamp uint32 +} + +type InboundChunkStream struct { + Id uint32 + lastHeader *Header + lastOutAbsoluteTimestamp uint32 + lastInAbsoluteTimestamp uint32 + currentMessage *Message +} + +func NewOutboundChunkStream(id uint32) *OutboundChunkStream { + return &OutboundChunkStream { + Id: id, + } +} + +func NewInboundChunkStream(id uint32) *InboundChunkStream { + return &InboundChunkStream { + Id: id, + } +} + +func (cs *OutboundChunkStream) NewOutboundChunkStream(m *Message) *Header { + h := &Header { + ChunkStreamId: cs.Id, + MessageLength: uint32(m.Buffer.Len()), + MessageTypeId: m.Type, + MessageStreamId: m.StreamId, + } + + ts := m.Timestamp + if ts == TIMESTAMP_AUTO { + ts = cs.GetTimestamp() + m.Timestamp = ts + m.AbsoluteTimestamp = ts + } + + deltaTimestamp := uint32(0) + if cs.lastOutAbsoluteTimestamp < m.Timestamp { + deltaTimestamp = m.Timestamp - cs.lastOutAbsoluteTimestamp + } + + if cs.lastHeader == nil { + h.Format = HEADER_FORMAT_FULL + h.Timestamp = ts + } else { + if h.MessageStreamId == cs.lastHeader.MessageStreamId { + if h.MessageTypeId == cs.lastHeader.MessageTypeId && h.MessageLength == cs.lastHeader.MessageLength { + switch cs.lastHeader.Format { + case HEADER_FORMAT_FULL: + h.Format = HEADER_FORMAT_SAME_LENGTH_AND_STREAM + h.Timestamp = deltaTimestamp + case HEADER_FORMAT_SAME_STREAM: + fallthrough + case HEADER_FORMAT_SAME_LENGTH_AND_STREAM: + fallthrough + case HEADER_FORMAT_CONTINUATION: + if cs.lastHeader.Timestamp == deltaTimestamp { + h.Format = HEADER_FORMAT_CONTINUATION + } else { + h.Format = HEADER_FORMAT_SAME_LENGTH_AND_STREAM + h.Timestamp = deltaTimestamp + } + } + } else { + h.Format = HEADER_FORMAT_SAME_STREAM + h.Timestamp = ts + } + } + } + + if h.Timestamp >= TIMESTAMP_EXTENDED { + h.ExtendedTimestamp = m.Timestamp + h.Timestamp = TIMESTAMP_EXTENDED + } else { + h.ExtendedTimestamp = 0 + } + + cs.lastHeader = h + cs.lastOutAbsoluteTimestamp = ts + + return h +} + +func (cs *OutboundChunkStream) GetTimestamp() uint32 { + if cs.startAtTimestamp == uint32(0) { + cs.startAtTimestamp = GetCurrentTimestamp() + return uint32(0) + } + + return GetCurrentTimestamp() - cs.startAtTimestamp +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..4c3c63b --- /dev/null +++ b/client.go @@ -0,0 +1,250 @@ +package rtmp + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/url" +) + +type ClientHandler interface { + OnConnect() + OnDisconnect() + OnReceive(message *Message) +} + +type Client struct { + url string + + handler ClientHandler + connected bool + + conn net.Conn + reader *bufio.Reader + writer *bufio.Writer + + outBytes uint32 + outMessages chan *Message + outWindowSize uint32 + outChunkSize uint32 + outChunkStreams map[uint32]*OutboundChunkStream + + inBytes uint32 + inMessages chan *Message + inNotify chan uint8 + inWindowSize uint32 + inChunkSize uint32 + inChunkStreams map[uint32]*InboundChunkStream + + lastTid uint32 +} + +func NewClient(url string) (*Client, error) { + c := &Client{ + url: url, + + connected: false, + + outMessages: make(chan *Message), + outChunkSize: DEFAULT_CHUNK_SIZE, + outWindowSize: DEFAULT_WINDOW_SIZE, + outChunkStreams: make(map[uint32]*OutboundChunkStream), + + inMessages: make(chan *Message, 100), + inChunkSize: DEFAULT_CHUNK_SIZE, + inWindowSize: DEFAULT_WINDOW_SIZE, + inChunkStreams: make(map[uint32]*InboundChunkStream), + } + + err := c.Connect() + if err != nil { + return c, err + } + + return c, err +} + +func (c *Client) Connect() (err error) { + url, err := url.Parse(c.url) + if err != nil { + return err + } + + switch url.Scheme { + case "rtmp": + c.conn, err = net.Dial("tcp", url.Host) + case "rtmps": + config := &tls.Config{InsecureSkipVerify: true} + c.conn, err = tls.Dial("tcp", url.Host, config) + default: + return errors.New(fmt.Sprintf("Unsupported scheme: %s", url.Scheme)) + } + + c.reader = bufio.NewReader(c.conn) + c.writer = bufio.NewWriter(c.conn) + + err = c.handshake() + if err != nil { + return err + } + + c.connected = true + + go c.dispatchLoop() + go c.receiveLoop() + go c.sendLoop() + + log.Info("connected to %s", c.url) + + return nil +} + +func (c *Client) Disconnect() { + c.connected = false + c.conn.Close() + + log.Info("disconnected from %s", c.url, c.outBytes, c.inBytes) +} + +func (c *Client) dispatchLoop() { + for { + m := <- c.inMessages + + switch m.ChunkStreamId { + case CHUNK_STREAM_ID_PROTOCOL: + log.Debug("dispatch protocol message") + case CHUNK_STREAM_ID_COMMAND: + log.Debug("dispatch command message") + } + } +} + +func (c *Client) sendLoop() { + for { + m := <- c.outMessages + log.Debug("send message %+v", m) + } +} + +func (c *Client) receiveLoop() { + for { + // Read the next header from the connection + h, err := ReadHeader(c.reader) + if err != nil { + if c.connected { + log.Warn("unable to receive next header while connected") + c.Disconnect() + } + return + } + + // Determine whether or not we already have a chunk stream + // allocated for this ID. If we don't, create one. + var cs *InboundChunkStream = c.inChunkStreams[h.ChunkStreamId] + if cs == nil { + cs = NewInboundChunkStream(h.ChunkStreamId) + c.inChunkStreams[h.ChunkStreamId] = cs + } + + var ts uint32 + var m *Message + + if (cs.lastHeader == nil) && (h.Format != HEADER_FORMAT_FULL) { + log.Warn("unable to find previous header on chunk stream %d", h.ChunkStreamId) + c.Disconnect() + return + } + + switch h.Format { + case HEADER_FORMAT_FULL: + // If it's an entirely new header, replace the reference in + // the chunk stream and set the working timestamp from + // the header. + cs.lastHeader = &h + ts = h.Timestamp + + case HEADER_FORMAT_SAME_STREAM: + // If it's the same stream, use the last message stream id, + // but otherwise use values from the header. + h.MessageStreamId = cs.lastHeader.MessageStreamId + cs.lastHeader = &h + ts = cs.lastInAbsoluteTimestamp + h.Timestamp + + case HEADER_FORMAT_SAME_LENGTH_AND_STREAM: + // If it's the same length and stream, copy values from the + // last header and replace it. + h.MessageStreamId = cs.lastHeader.MessageStreamId + h.MessageLength = cs.lastHeader.MessageLength + h.MessageTypeId = cs.lastHeader.MessageTypeId + cs.lastHeader = &h + ts = cs.lastInAbsoluteTimestamp + h.Timestamp + + case HEADER_FORMAT_CONTINUATION: + // A full continuation of the previous stream. Copy all values. + h.MessageStreamId = cs.lastHeader.MessageStreamId + h.MessageLength = cs.lastHeader.MessageLength + h.MessageTypeId = cs.lastHeader.MessageTypeId + h.Timestamp = cs.lastHeader.Timestamp + ts = cs.lastInAbsoluteTimestamp + cs.lastHeader.Timestamp + + // If there's a message already started, use it. + if cs.currentMessage != nil { + m = cs.currentMessage + } + } + + if m == nil { + m = &Message{ + Type: h.MessageTypeId, + ChunkStreamId: h.ChunkStreamId, + StreamId: h.MessageStreamId, + Timestamp: h.CalculateTimestamp(), + AbsoluteTimestamp: ts, + Length: h.MessageLength, + Buffer: new(bytes.Buffer), + } + } + + cs.lastInAbsoluteTimestamp = ts + + rs := m.RemainingBytes() + if rs > c.inChunkSize { + rs = c.inChunkSize + } + + _, err = io.CopyN(m.Buffer, c.reader, int64(rs)) + if err != nil { + if c.connected { + log.Warn("unable to copy %d message bytes from buffer", rs) + c.Disconnect() + } + + return + } + + if m.RemainingBytes() == 0 { + cs.currentMessage = nil + c.inMessages <- m + } else { + cs.currentMessage = m + } + } +} + +func (c *Client) Read(p []byte) (n int, err error) { + n, err = c.conn.Read(p) + c.inBytes += uint32(n) + log.Debug("read %d", n) + return n, err +} + +func (c *Client) Write(p []byte) (n int, err error) { + n, err = c.conn.Write(p) + c.outBytes += uint32(n) + log.Debug("read %d", n) + return n, err +} diff --git a/const.go b/const.go new file mode 100644 index 0000000..eb8efac --- /dev/null +++ b/const.go @@ -0,0 +1,48 @@ +package rtmp + +const ( + TIMESTAMP_MAX = uint32(2000000000) + TIMESTAMP_AUTO = uint32(0) + TIMESTAMP_EXTENDED = 0xFFFFFF +) + +const ( + CHUNK_STREAM_ID_PROTOCOL = uint32(2) + CHUNK_STREAM_ID_COMMAND = uint32(3) + CHUNK_STREAM_ID_USER_CONTROL = uint32(4) +) + +const ( + HEADER_FORMAT_FULL = 0x00 + HEADER_FORMAT_SAME_STREAM = 0x01 + HEADER_FORMAT_SAME_LENGTH_AND_STREAM = 0x02 + HEADER_FORMAT_CONTINUATION = 0x03 +) + +const ( + MESSAGE_TYPE_NONE = 0x00 + MESSAGE_TYPE_CHUNK_SIZE = 0x01 + MESSAGE_TYPE_ABORT = 0x02 + MESSAGE_TYPE_ACK = 0x03 + MESSAGE_TYPE_PING = 0x04 + MESSAGE_TYPE_ACK_SIZE = 0x05 + MESSAGE_TYPE_BANDWIDTH = 0x06 + MESSAGE_TYPE_AUDIO = 0x08 + MESSAGE_TYPE_VIDEO = 0x09 + MESSAGE_TYPE_FLEX = 0x0F + MESSAGE_TYPE_AMF3_SHARED_OBJECT = 0x10 + MESSAGE_TYPE_AMF3 = 0x11 + MESSAGE_TYPE_INVOKE = 0x12 + MESSAGE_TYPE_AMF0_SHARED_OBJECT = 0x13 + MESSAGE_TYPE_AMF0 = 0x14 + MESSAGE_TYPE_FLV = 0x16 +) + +const ( + MESSAGE_DISPATCH_QUEUE_LENGTH = 100 +) + +const ( + DEFAULT_CHUNK_SIZE = uint32(128) + DEFAULT_WINDOW_SIZE = uint32(2500000) +) diff --git a/handshake.go b/handshake.go new file mode 100644 index 0000000..c322915 --- /dev/null +++ b/handshake.go @@ -0,0 +1,60 @@ +package rtmp + +import ( + "bytes" + "crypto/rand" + "errors" +) + +func (c *Client) handshake() error { + C0 := []byte{0x03} + C1 := make([]byte, 1536) + S0 := make([]byte, 1) + S1 := make([]byte, 1536) + S2 := make([]byte, 1536) + + rand.Read(C1) + for i := 0; i < 8; i++ { + C1[i] = 0x00 + } + + _, err := c.Write(C0) + if err != nil { + return err + } + + _, err = c.Write(C1) + if err != nil { + return err + } + + _, err = c.Read(S0) + if err != nil { + return err + } + + if bytes.Equal(C0, S0) != true { + return errors.New("Handshake failed: version mismatch") + } + + _, err = c.Read(S1) + if err != nil { + return err + } + + _, err = c.Write(S1) + if err != nil { + return err + } + + _, err = c.Read(S2) + if err != nil { + return err + } + + if bytes.Equal(C1, S2) != true { + return errors.New("Handshake failed: challenge mismatch") + } + + return nil +} diff --git a/header.go b/header.go new file mode 100644 index 0000000..bc68820 --- /dev/null +++ b/header.go @@ -0,0 +1,119 @@ +package rtmp + +import ( + "encoding/binary" +) + +type Header struct { + Format uint8 + ChunkStreamId uint32 + MessageLength uint32 + MessageTypeId uint8 + MessageStreamId uint32 + Timestamp uint32 + ExtendedTimestamp uint32 +} + +func NewHeader() *Header { + return &Header{} +} + +func ReadHeader(r Reader) (Header, error) { + h := *NewHeader() + u8 := make([]byte, 1) + u16 := make([]byte, 2) + u32 := make([]byte, 4) + + // The first byte we read from the header will indicate the + // format of the packet and chunk stream id + _, err := r.Read(u8) + if err != nil { + return h, err + } + + // Determine the packet format from the byte + h.Format = (u8[0] & 0xC0) >> 6 + + // Determine Chunk Stream ID using the remainder of the byte + h.ChunkStreamId = uint32(u8[0] & 0x3F) + + switch h.ChunkStreamId { + case 0: + // A Chunk Stream ID of 0 indicates that the real value + // is between 64-319, which is reached by adding 64 to the + // next byte. + _, err = r.Read(u8) + if err != nil { + return h, err + } + h.ChunkStreamId = uint32(64) + uint32(u8[0]) + + case 1: + // A Chunk Stream ID of 1 indicates that the real value + // is between 64-65599 and can be reached by adding 64 to + // the next byte and then multiplying the one after it + // by 256. + _, err = r.Read(u16) + if err != nil { + return h, err + } + h.ChunkStreamId = uint32(u16[0]) + (256 * uint32(u16[1])) + } + + // If the header is full, same length, or same length + // and stream, then we only need to extract the timestamp. + if h.Format <= HEADER_FORMAT_SAME_LENGTH_AND_STREAM { + _, err = r.Read(u32[1:]) + if err != nil { + return h, err + } + h.Timestamp = binary.BigEndian.Uint32(u32) + } + + // If the header is full or same stream, then we also + // need to extract the message size and message type. + if h.Format <= HEADER_FORMAT_SAME_STREAM { + _, err = r.Read(u32[1:]) + if err != nil { + return h, err + } + h.MessageLength = binary.BigEndian.Uint32(u32) + + _, err = r.Read(u8) + if err != nil { + return h, err + } + h.MessageTypeId = uint8(u8[0]) + } + + // If the header is full, we also need to extract + // the message stream id. + if h.Format <= HEADER_FORMAT_FULL { + _, err = r.Read(u32) + if err != nil { + return h, err + } + h.MessageStreamId = binary.LittleEndian.Uint32(u32) + } + + // If the header has an extended timestamp, we need to + // extract that as well. + if h.Timestamp == TIMESTAMP_EXTENDED { + _, err = r.Read(u32) + if err != nil { + return h, err + } + + h.ExtendedTimestamp = binary.BigEndian.Uint32(u32) + } + + return h, nil +} + +func (h *Header) CalculateTimestamp() uint32 { + if h.Timestamp >= TIMESTAMP_MAX { + return h.ExtendedTimestamp + } + + return h.Timestamp +} diff --git a/message.go b/message.go new file mode 100644 index 0000000..f6f8306 --- /dev/null +++ b/message.go @@ -0,0 +1,23 @@ +package rtmp + +import ( + "bytes" +) + +type Message struct { + Type uint8 + ChunkStreamId uint32 + StreamId uint32 + Timestamp uint32 + AbsoluteTimestamp uint32 + Length uint32 + Buffer *bytes.Buffer +} + +func (m *Message) RemainingBytes() uint32 { + if m.Buffer == nil { + return m.Length + } + + return m.Length - uint32(m.Buffer.Len()) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..dea75f9 --- /dev/null +++ b/util.go @@ -0,0 +1,12 @@ +package rtmp + +import ( + "time" + "github.com/elobuff/gologger" +) + +var log logger.Logger = *logger.NewLogger(logger.LOG_LEVEL_DEBUG, "rtmp") + +func GetCurrentTimestamp() uint32 { + return uint32(time.Now().UnixNano()/int64(1000000)) % TIMESTAMP_MAX +}