Files
gortsplib/pkg/mikey/header.go

144 lines
2.8 KiB
Go

package mikey
import "fmt"
func boolToUint8(v bool) uint8 {
if v {
return 1
}
return 0
}
// DataType is a message data type.
type DataType uint8
// RFC3830, Table 6.1.a
const (
DataTypeInitiatorPSK DataType = 0
)
// CSIDMapType is a CS ID map type.
type CSIDMapType uint8
// RFC3830, Table 6.1.d
const (
CSIDMapTypeSRTPID CSIDMapType = 0
)
// SRTPIDEntry is an entry of a SRTP-ID map.
type SRTPIDEntry struct {
PolicyNo uint8
SSRC uint32
ROC uint32
}
// Header is a MIKEY header.
type Header struct {
Version uint8
DataType DataType
V bool
PRFFunc uint8
CSBID uint32
CSIDMapType CSIDMapType
CSIDMapInfo []SRTPIDEntry
}
func (h *Header) unmarshal(buf []byte) (int, payloadType, error) {
if len(buf) < 10 {
return 0, 0, fmt.Errorf("header too short")
}
n := 0
h.Version = buf[n]
n++
if h.Version != 1 {
return 0, 0, fmt.Errorf("unsupported version: %v", h.Version)
}
h.DataType = DataType(buf[n])
n++
if h.DataType != DataTypeInitiatorPSK {
return 0, 0, fmt.Errorf("unsupported data type: %v", h.DataType)
}
nextPayload := payloadType(buf[n])
n++
h.V = (buf[n] >> 7) != 0
h.PRFFunc = buf[n] & 0b01111111
n++
if h.V {
return 0, 0, fmt.Errorf("unsupported V: %v", h.V)
}
if h.PRFFunc != 0 {
return 0, 0, fmt.Errorf("unsupported PRFFunc: %v", h.PRFFunc)
}
h.CSBID = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3])
n += 4
numCS := buf[n]
n++
h.CSIDMapType = CSIDMapType(buf[n])
n++
if h.CSIDMapType != CSIDMapTypeSRTPID {
return 0, 0, fmt.Errorf("unsupported map type: %d", h.CSIDMapType)
}
if len(buf[n:]) < (int(numCS) * 9) {
return 0, 0, fmt.Errorf("header too short")
}
h.CSIDMapInfo = make([]SRTPIDEntry, numCS)
for i := range numCS {
h.CSIDMapInfo[i].PolicyNo = buf[n]
n++
h.CSIDMapInfo[i].SSRC = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3])
n += 4
h.CSIDMapInfo[i].ROC = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3])
n += 4
}
return n, nextPayload, nil
}
func (h *Header) marshalSize() int {
return 10 + len(h.CSIDMapInfo)*9
}
func (h *Header) marshalTo(buf []byte, nextPayload payloadType) (int, error) {
buf[0] = h.Version
buf[1] = byte(h.DataType)
buf[2] = byte(nextPayload)
buf[3] = boolToUint8(h.V)<<7 | h.PRFFunc
buf[4] = byte(h.CSBID >> 24)
buf[5] = byte(h.CSBID >> 16)
buf[6] = byte(h.CSBID >> 8)
buf[7] = byte(h.CSBID)
buf[8] = byte(len(h.CSIDMapInfo))
buf[9] = byte(h.CSIDMapType)
n := 10
for _, mi := range h.CSIDMapInfo {
buf[n] = mi.PolicyNo
buf[n+1] = byte(mi.SSRC >> 24)
buf[n+2] = byte(mi.SSRC >> 16)
buf[n+3] = byte(mi.SSRC >> 8)
buf[n+4] = byte(mi.SSRC)
buf[n+5] = byte(mi.ROC >> 24)
buf[n+6] = byte(mi.ROC >> 16)
buf[n+7] = byte(mi.ROC >> 8)
buf[n+8] = byte(mi.ROC)
n += 9
}
return n, nil
}