diff --git a/api.go b/api.go index d786d5d..676a461 100644 --- a/api.go +++ b/api.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "m7s.live/v5/pkg/db" "m7s.live/v5/pkg/task" myip "github.com/husanpao/ip" @@ -34,6 +35,11 @@ import ( var localIP string var empty = &emptypb.Empty{} +func init() { + // Add auto-migration for User model + db.AutoMigrations = append(db.AutoMigrations, &db.User{}) +} + func (s *Server) SysInfo(context.Context, *emptypb.Empty) (res *pb.SysInfoResponse, err error) { if localIP == "" { localIP = myip.LocalIP() @@ -121,21 +127,22 @@ func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err tmp, _ := json.Marshal(pub.GetDescriptions()) res = &pb.StreamInfoResponse{ Data: &pb.StreamInfo{ - Meta: string(tmp), - Path: pub.StreamPath, - State: int32(pub.State), - StartTime: timestamppb.New(pub.StartTime), - Subscribers: int32(pub.Subscribers.Length), - PluginName: pub.Plugin.Meta.Name, - Type: pub.Type, - Speed: float32(pub.Speed), - StopOnIdle: pub.DelayCloseTimeout > 0, - IsPaused: pub.Paused != nil, - Gop: int32(pub.GOP), - BufferTime: durationpb.New(pub.BufferTime), + Meta: string(tmp), + Path: pub.StreamPath, + State: int32(pub.State), + StartTime: timestamppb.New(pub.StartTime), + // Subscribers: int32(pub.Subscribers.Length), + PluginName: pub.Plugin.Meta.Name, + Type: pub.Type, + Speed: float32(pub.Speed), + StopOnIdle: pub.DelayCloseTimeout > 0, + IsPaused: pub.Paused != nil, + Gop: int32(pub.GOP), + BufferTime: durationpb.New(pub.BufferTime), }, } var audioBpsOut, videoBpsOut uint32 + var serverSubCount int32 for sub := range pub.Subscribers.Range { if sub.AudioReader != nil { audioBpsOut += sub.AudioReader.BPS @@ -143,7 +150,11 @@ func (s *Server) getStreamInfo(pub *Publisher) (res *pb.StreamInfoResponse, err if sub.VideoReader != nil { videoBpsOut += sub.VideoReader.BPS } + if sub.Type == SubscribeTypeServer { + serverSubCount++ + } } + res.Data.Subscribers = serverSubCount if t := pub.AudioTrack.AVTrack; t != nil { if t.ICodecCtx != nil { res.Data.AudioTrack = &pb.AudioTrackInfo{ diff --git a/example/default/config.yaml b/example/default/config.yaml index cc4327f..cae0680 100644 --- a/example/default/config.yaml +++ b/example/default/config.yaml @@ -1,5 +1,6 @@ global: loglevel: debug + enablelogin: true # db: # dbtype: mysql # dsn: root:Monibuca#!4@tcp(sh-cynosdbmysql-grp-kxt43lv6.sql.tencentcdb.com:28520)/lkm7s_v5?parseTime=true diff --git a/go.mod b/go.mod index 15b0ab1..0263601 100644 --- a/go.mod +++ b/go.mod @@ -64,6 +64,7 @@ require ( github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect diff --git a/go.sum b/go.sum index aa39dd7..a2edba5 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm github.com/gobwas/ws v1.3.2 h1:zlnbNHxumkRvfPWgfXu8RBwyNR1x8wh9cf5PTOCqs9Q= github.com/gobwas/ws v1.3.2/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/pb/auth.pb.go b/pb/auth.pb.go new file mode 100644 index 0000000..b4151ce --- /dev/null +++ b/pb/auth.pb.go @@ -0,0 +1,686 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.19.1 +// source: auth.proto + +package pb + +import ( + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type LoginRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` +} + +func (x *LoginRequest) Reset() { + *x = LoginRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LoginRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LoginRequest) ProtoMessage() {} + +func (x *LoginRequest) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead. +func (*LoginRequest) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{0} +} + +func (x *LoginRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *LoginRequest) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +type LoginSuccess struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + UserInfo *UserInfo `protobuf:"bytes,2,opt,name=userInfo,proto3" json:"userInfo,omitempty"` +} + +func (x *LoginSuccess) Reset() { + *x = LoginSuccess{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LoginSuccess) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LoginSuccess) ProtoMessage() {} + +func (x *LoginSuccess) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LoginSuccess.ProtoReflect.Descriptor instead. +func (*LoginSuccess) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{1} +} + +func (x *LoginSuccess) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +func (x *LoginSuccess) GetUserInfo() *UserInfo { + if x != nil { + return x.UserInfo + } + return nil +} + +type LoginResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + Data *LoginSuccess `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *LoginResponse) Reset() { + *x = LoginResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LoginResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LoginResponse) ProtoMessage() {} + +func (x *LoginResponse) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead. +func (*LoginResponse) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{2} +} + +func (x *LoginResponse) GetCode() int32 { + if x != nil { + return x.Code + } + return 0 +} + +func (x *LoginResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *LoginResponse) GetData() *LoginSuccess { + if x != nil { + return x.Data + } + return nil +} + +type LogoutRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` +} + +func (x *LogoutRequest) Reset() { + *x = LogoutRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LogoutRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogoutRequest) ProtoMessage() {} + +func (x *LogoutRequest) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. +func (*LogoutRequest) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{3} +} + +func (x *LogoutRequest) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +type LogoutResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *LogoutResponse) Reset() { + *x = LogoutResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LogoutResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogoutResponse) ProtoMessage() {} + +func (x *LogoutResponse) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. +func (*LogoutResponse) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{4} +} + +func (x *LogoutResponse) GetCode() int32 { + if x != nil { + return x.Code + } + return 0 +} + +func (x *LogoutResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type UserInfoRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` +} + +func (x *UserInfoRequest) Reset() { + *x = UserInfoRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UserInfoRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UserInfoRequest) ProtoMessage() {} + +func (x *UserInfoRequest) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UserInfoRequest.ProtoReflect.Descriptor instead. +func (*UserInfoRequest) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{5} +} + +func (x *UserInfoRequest) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +type UserInfo struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + ExpiresAt int64 `protobuf:"varint,2,opt,name=expires_at,json=expiresAt,proto3" json:"expires_at,omitempty"` // Token expiration timestamp +} + +func (x *UserInfo) Reset() { + *x = UserInfo{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UserInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UserInfo) ProtoMessage() {} + +func (x *UserInfo) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UserInfo.ProtoReflect.Descriptor instead. +func (*UserInfo) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{6} +} + +func (x *UserInfo) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *UserInfo) GetExpiresAt() int64 { + if x != nil { + return x.ExpiresAt + } + return 0 +} + +type UserInfoResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + Data *UserInfo `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *UserInfoResponse) Reset() { + *x = UserInfoResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_auth_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UserInfoResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UserInfoResponse) ProtoMessage() {} + +func (x *UserInfoResponse) ProtoReflect() protoreflect.Message { + mi := &file_auth_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UserInfoResponse.ProtoReflect.Descriptor instead. +func (*UserInfoResponse) Descriptor() ([]byte, []int) { + return file_auth_proto_rawDescGZIP(), []int{7} +} + +func (x *UserInfoResponse) GetCode() int32 { + if x != nil { + return x.Code + } + return 0 +} + +func (x *UserInfoResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *UserInfoResponse) GetData() *UserInfo { + if x != nil { + return x.Data + } + return nil +} + +var File_auth_proto protoreflect.FileDescriptor + +var file_auth_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, + 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, + 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x46, + 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, + 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, + 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, + 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x4e, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x53, + 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x28, 0x0a, 0x08, + 0x75, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0c, + 0x2e, 0x70, 0x62, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x63, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x24, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x70, 0x62, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x53, 0x75, + 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0x25, 0x0a, 0x0d, 0x4c, + 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, + 0x65, 0x6e, 0x22, 0x3e, 0x0a, 0x0e, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x22, 0x27, 0x0a, 0x0f, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x45, 0x0a, 0x08, 0x55, + 0x73, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, + 0x61, 0x6d, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x5f, 0x61, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, + 0x41, 0x74, 0x22, 0x62, 0x0a, 0x10, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x12, 0x20, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x0c, 0x2e, 0x70, 0x62, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x32, 0xf4, 0x01, 0x0a, 0x04, 0x41, 0x75, 0x74, 0x68, 0x12, + 0x48, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x10, 0x2e, 0x70, 0x62, 0x2e, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x11, 0x2e, 0x70, 0x62, 0x2e, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x82, + 0xd3, 0xe4, 0x93, 0x02, 0x14, 0x22, 0x0f, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x75, 0x74, 0x68, + 0x2f, 0x6c, 0x6f, 0x67, 0x69, 0x6e, 0x3a, 0x01, 0x2a, 0x12, 0x4c, 0x0a, 0x06, 0x4c, 0x6f, 0x67, + 0x6f, 0x75, 0x74, 0x12, 0x11, 0x2e, 0x70, 0x62, 0x2e, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x70, 0x62, 0x2e, 0x4c, 0x6f, 0x67, 0x6f, + 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x82, 0xd3, 0xe4, 0x93, + 0x02, 0x15, 0x22, 0x10, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f, 0x6c, 0x6f, + 0x67, 0x6f, 0x75, 0x74, 0x3a, 0x01, 0x2a, 0x12, 0x54, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x55, 0x73, + 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x13, 0x2e, 0x70, 0x62, 0x2e, 0x55, 0x73, 0x65, 0x72, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x70, 0x62, + 0x2e, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x1a, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x14, 0x12, 0x12, 0x2f, 0x61, 0x70, 0x69, 0x2f, + 0x61, 0x75, 0x74, 0x68, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x69, 0x6e, 0x66, 0x6f, 0x42, 0x10, 0x5a, + 0x0e, 0x6d, 0x37, 0x73, 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x2f, 0x76, 0x35, 0x2f, 0x70, 0x62, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_auth_proto_rawDescOnce sync.Once + file_auth_proto_rawDescData = file_auth_proto_rawDesc +) + +func file_auth_proto_rawDescGZIP() []byte { + file_auth_proto_rawDescOnce.Do(func() { + file_auth_proto_rawDescData = protoimpl.X.CompressGZIP(file_auth_proto_rawDescData) + }) + return file_auth_proto_rawDescData +} + +var file_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_auth_proto_goTypes = []interface{}{ + (*LoginRequest)(nil), // 0: pb.LoginRequest + (*LoginSuccess)(nil), // 1: pb.LoginSuccess + (*LoginResponse)(nil), // 2: pb.LoginResponse + (*LogoutRequest)(nil), // 3: pb.LogoutRequest + (*LogoutResponse)(nil), // 4: pb.LogoutResponse + (*UserInfoRequest)(nil), // 5: pb.UserInfoRequest + (*UserInfo)(nil), // 6: pb.UserInfo + (*UserInfoResponse)(nil), // 7: pb.UserInfoResponse +} +var file_auth_proto_depIdxs = []int32{ + 6, // 0: pb.LoginSuccess.userInfo:type_name -> pb.UserInfo + 1, // 1: pb.LoginResponse.data:type_name -> pb.LoginSuccess + 6, // 2: pb.UserInfoResponse.data:type_name -> pb.UserInfo + 0, // 3: pb.Auth.Login:input_type -> pb.LoginRequest + 3, // 4: pb.Auth.Logout:input_type -> pb.LogoutRequest + 5, // 5: pb.Auth.GetUserInfo:input_type -> pb.UserInfoRequest + 2, // 6: pb.Auth.Login:output_type -> pb.LoginResponse + 4, // 7: pb.Auth.Logout:output_type -> pb.LogoutResponse + 7, // 8: pb.Auth.GetUserInfo:output_type -> pb.UserInfoResponse + 6, // [6:9] is the sub-list for method output_type + 3, // [3:6] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_auth_proto_init() } +func file_auth_proto_init() { + if File_auth_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_auth_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LoginRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_auth_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LoginSuccess); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_auth_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LoginResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_auth_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LogoutRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_auth_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LogoutResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_auth_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UserInfoRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_auth_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UserInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_auth_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UserInfoResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_auth_proto_rawDesc, + NumEnums: 0, + NumMessages: 8, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_auth_proto_goTypes, + DependencyIndexes: file_auth_proto_depIdxs, + MessageInfos: file_auth_proto_msgTypes, + }.Build() + File_auth_proto = out.File + file_auth_proto_rawDesc = nil + file_auth_proto_goTypes = nil + file_auth_proto_depIdxs = nil +} diff --git a/pb/auth.pb.gw.go b/pb/auth.pb.gw.go new file mode 100644 index 0000000..e7f75d6 --- /dev/null +++ b/pb/auth.pb.gw.go @@ -0,0 +1,327 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: auth.proto + +/* +Package pb is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package pb + +import ( + "context" + "io" + "net/http" + + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// Suppress "imported and not used" errors +var _ codes.Code +var _ io.Reader +var _ status.Status +var _ = runtime.String +var _ = utilities.NewDoubleArray +var _ = metadata.Join + +func request_Auth_Login_0(ctx context.Context, marshaler runtime.Marshaler, client AuthClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq LoginRequest + var metadata runtime.ServerMetadata + + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Login(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_Auth_Login_0(ctx context.Context, marshaler runtime.Marshaler, server AuthServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq LoginRequest + var metadata runtime.ServerMetadata + + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Login(ctx, &protoReq) + return msg, metadata, err + +} + +func request_Auth_Logout_0(ctx context.Context, marshaler runtime.Marshaler, client AuthClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq LogoutRequest + var metadata runtime.ServerMetadata + + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Logout(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_Auth_Logout_0(ctx context.Context, marshaler runtime.Marshaler, server AuthServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq LogoutRequest + var metadata runtime.ServerMetadata + + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Logout(ctx, &protoReq) + return msg, metadata, err + +} + +var ( + filter_Auth_GetUserInfo_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +) + +func request_Auth_GetUserInfo_0(ctx context.Context, marshaler runtime.Marshaler, client AuthClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq UserInfoRequest + var metadata runtime.ServerMetadata + + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_Auth_GetUserInfo_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.GetUserInfo(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_Auth_GetUserInfo_0(ctx context.Context, marshaler runtime.Marshaler, server AuthServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq UserInfoRequest + var metadata runtime.ServerMetadata + + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_Auth_GetUserInfo_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.GetUserInfo(ctx, &protoReq) + return msg, metadata, err + +} + +// RegisterAuthHandlerServer registers the http handlers for service Auth to "mux". +// UnaryRPC :call AuthServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterAuthHandlerFromEndpoint instead. +func RegisterAuthHandlerServer(ctx context.Context, mux *runtime.ServeMux, server AuthServer) error { + + mux.Handle("POST", pattern_Auth_Login_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + var err error + var annotatedContext context.Context + annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/pb.Auth/Login", runtime.WithHTTPPathPattern("/api/auth/login")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_Auth_Login_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + + forward_Auth_Login_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_Auth_Logout_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + var err error + var annotatedContext context.Context + annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/pb.Auth/Logout", runtime.WithHTTPPathPattern("/api/auth/logout")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_Auth_Logout_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + + forward_Auth_Logout_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_Auth_GetUserInfo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + var err error + var annotatedContext context.Context + annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/pb.Auth/GetUserInfo", runtime.WithHTTPPathPattern("/api/auth/userinfo")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_Auth_GetUserInfo_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + + forward_Auth_GetUserInfo_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +// RegisterAuthHandlerFromEndpoint is same as RegisterAuthHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterAuthHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.DialContext(ctx, endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterAuthHandler(ctx, mux, conn) +} + +// RegisterAuthHandler registers the http handlers for service Auth to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterAuthHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterAuthHandlerClient(ctx, mux, NewAuthClient(conn)) +} + +// RegisterAuthHandlerClient registers the http handlers for service Auth +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "AuthClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "AuthClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "AuthClient" to call the correct interceptors. +func RegisterAuthHandlerClient(ctx context.Context, mux *runtime.ServeMux, client AuthClient) error { + + mux.Handle("POST", pattern_Auth_Login_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + var err error + var annotatedContext context.Context + annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/pb.Auth/Login", runtime.WithHTTPPathPattern("/api/auth/login")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Auth_Login_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + + forward_Auth_Login_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_Auth_Logout_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + var err error + var annotatedContext context.Context + annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/pb.Auth/Logout", runtime.WithHTTPPathPattern("/api/auth/logout")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Auth_Logout_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + + forward_Auth_Logout_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_Auth_GetUserInfo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + var err error + var annotatedContext context.Context + annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/pb.Auth/GetUserInfo", runtime.WithHTTPPathPattern("/api/auth/userinfo")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Auth_GetUserInfo_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + + forward_Auth_GetUserInfo_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_Auth_Login_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "auth", "login"}, "")) + + pattern_Auth_Logout_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "auth", "logout"}, "")) + + pattern_Auth_GetUserInfo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "auth", "userinfo"}, "")) +) + +var ( + forward_Auth_Login_0 = runtime.ForwardResponseMessage + + forward_Auth_Logout_0 = runtime.ForwardResponseMessage + + forward_Auth_GetUserInfo_0 = runtime.ForwardResponseMessage +) diff --git a/pb/auth.proto b/pb/auth.proto new file mode 100644 index 0000000..6e103b5 --- /dev/null +++ b/pb/auth.proto @@ -0,0 +1,65 @@ +syntax = "proto3"; +package pb; +option go_package = "m7s.live/v5/pb"; + +import "google/api/annotations.proto"; + +message LoginRequest { + string username = 1; + string password = 2; +} + +message LoginSuccess { + string token = 1; + UserInfo userInfo = 2; +} + +message LoginResponse { + int32 code = 1; + string message = 2; + LoginSuccess data = 3; +} + +message LogoutRequest { + string token = 1; +} + +message LogoutResponse { + int32 code = 1; + string message = 2; +} + +message UserInfoRequest { + string token = 1; +} + +message UserInfo { + string username = 1; + int64 expires_at = 2; // Token expiration timestamp +} + +message UserInfoResponse { + int32 code = 1; + string message = 2; + UserInfo data = 3; +} + +service Auth { + rpc Login(LoginRequest) returns (LoginResponse) { + option (google.api.http) = { + post: "/api/auth/login" + body: "*" + }; + } + rpc Logout(LogoutRequest) returns (LogoutResponse) { + option (google.api.http) = { + post: "/api/auth/logout" + body: "*" + }; + } + rpc GetUserInfo(UserInfoRequest) returns (UserInfoResponse) { + option (google.api.http) = { + get: "/api/auth/userinfo" + }; + } +} \ No newline at end of file diff --git a/pb/auth_grpc.pb.go b/pb/auth_grpc.pb.go new file mode 100644 index 0000000..b7f0e25 --- /dev/null +++ b/pb/auth_grpc.pb.go @@ -0,0 +1,177 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.19.1 +// source: auth.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// AuthClient is the client API for Auth service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type AuthClient interface { + Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) + Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) + GetUserInfo(ctx context.Context, in *UserInfoRequest, opts ...grpc.CallOption) (*UserInfoResponse, error) +} + +type authClient struct { + cc grpc.ClientConnInterface +} + +func NewAuthClient(cc grpc.ClientConnInterface) AuthClient { + return &authClient{cc} +} + +func (c *authClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) { + out := new(LoginResponse) + err := c.cc.Invoke(ctx, "/pb.Auth/Login", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *authClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) { + out := new(LogoutResponse) + err := c.cc.Invoke(ctx, "/pb.Auth/Logout", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *authClient) GetUserInfo(ctx context.Context, in *UserInfoRequest, opts ...grpc.CallOption) (*UserInfoResponse, error) { + out := new(UserInfoResponse) + err := c.cc.Invoke(ctx, "/pb.Auth/GetUserInfo", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// AuthServer is the server API for Auth service. +// All implementations must embed UnimplementedAuthServer +// for forward compatibility +type AuthServer interface { + Login(context.Context, *LoginRequest) (*LoginResponse, error) + Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) + GetUserInfo(context.Context, *UserInfoRequest) (*UserInfoResponse, error) + mustEmbedUnimplementedAuthServer() +} + +// UnimplementedAuthServer must be embedded to have forward compatible implementations. +type UnimplementedAuthServer struct { +} + +func (UnimplementedAuthServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") +} +func (UnimplementedAuthServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented") +} +func (UnimplementedAuthServer) GetUserInfo(context.Context, *UserInfoRequest) (*UserInfoResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetUserInfo not implemented") +} +func (UnimplementedAuthServer) mustEmbedUnimplementedAuthServer() {} + +// UnsafeAuthServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to AuthServer will +// result in compilation errors. +type UnsafeAuthServer interface { + mustEmbedUnimplementedAuthServer() +} + +func RegisterAuthServer(s grpc.ServiceRegistrar, srv AuthServer) { + s.RegisterService(&Auth_ServiceDesc, srv) +} + +func _Auth_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(LoginRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AuthServer).Login(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pb.Auth/Login", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AuthServer).Login(ctx, req.(*LoginRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Auth_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(LogoutRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AuthServer).Logout(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pb.Auth/Logout", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AuthServer).Logout(ctx, req.(*LogoutRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Auth_GetUserInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(UserInfoRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AuthServer).GetUserInfo(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pb.Auth/GetUserInfo", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AuthServer).GetUserInfo(ctx, req.(*UserInfoRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Auth_ServiceDesc is the grpc.ServiceDesc for Auth service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Auth_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "pb.Auth", + HandlerType: (*AuthServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Login", + Handler: _Auth_Login_Handler, + }, + { + MethodName: "Logout", + Handler: _Auth_Logout_Handler, + }, + { + MethodName: "GetUserInfo", + Handler: _Auth_GetUserInfo_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "auth.proto", +} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 0000000..6be5818 --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,57 @@ +package auth + +import ( + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + jwtSecret = []byte("m7s_secret_key") // In production, this should be properly configured + tokenTTL = 24 * time.Hour +) + +// JWTClaims represents the JWT claims +type JWTClaims struct { + Username string `json:"username"` +} + +// TokenValidator is an interface for token validation +type TokenValidator interface { + ValidateToken(tokenString string) (*JWTClaims, error) +} + +// GenerateToken generates a new JWT token for a user +func GenerateToken(username string) (string, error) { + claims := jwt.RegisteredClaims{ + Subject: username, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(tokenTTL)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(jwtSecret) +} + +// ValidateJWT validates a JWT token and returns the claims +func ValidateJWT(tokenString string) (*JWTClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return jwtSecret, nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid { + return &JWTClaims{Username: claims.Subject}, nil + } + + return nil, errors.New("invalid token") +} diff --git a/pkg/auth/middleware.go b/pkg/auth/middleware.go new file mode 100644 index 0000000..527d8d6 --- /dev/null +++ b/pkg/auth/middleware.go @@ -0,0 +1,38 @@ +package auth + +import ( + "context" + "net/http" + "strings" +) + +// Middleware creates a new middleware for HTTP authentication +func Middleware(validator TokenValidator) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip auth for login endpoint + if r.URL.Path == "/api/auth/login" { + next.ServeHTTP(w, r) + return + } + + // Get token from Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "missing authorization header", http.StatusUnauthorized) + return + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + claims, err := validator.ValidateToken(tokenString) + if err != nil { + http.Error(w, "invalid token", http.StatusUnauthorized) + return + } + + // Add claims to context + ctx := context.WithValue(r.Context(), "claims", claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/pkg/config/http.go b/pkg/config/http.go index 0370f81..941843f 100644 --- a/pkg/config/http.go +++ b/pkg/config/http.go @@ -5,9 +5,10 @@ import ( "crypto/subtle" "crypto/tls" "log/slog" - "m7s.live/v5/pkg/task" "net/http" + "m7s.live/v5/pkg/task" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "time" @@ -45,7 +46,8 @@ func (config *HTTP) GetHandler() http.Handler { return config.mux } -func (config *HTTP) GetHttpMux() *http.ServeMux { +func (config *HTTP) CreateHttpMux() *http.ServeMux { + config.mux = http.NewServeMux() return config.mux } diff --git a/pkg/db/db.go b/pkg/db/db.go new file mode 100644 index 0000000..a85a46c --- /dev/null +++ b/pkg/db/db.go @@ -0,0 +1,4 @@ +package db + +// AutoMigrations is a slice of models that need to be auto-migrated +var AutoMigrations []interface{} diff --git a/pkg/db/user.go b/pkg/db/user.go new file mode 100644 index 0000000..358b288 --- /dev/null +++ b/pkg/db/user.go @@ -0,0 +1,36 @@ +package db + +import ( + "time" + + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +// User represents a user in the system +type User struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` + Username string `gorm:"uniqueIndex;size:64"` + Password string `gorm:"size:60"` // bcrypt hash + Role string `gorm:"size:20;default:'user'"` // admin or user + LastLogin time.Time +} + +// BeforeCreate hook to hash password before saving +func (u *User) BeforeCreate(tx *gorm.DB) error { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost) + if err != nil { + return err + } + u.Password = string(hashedPassword) + return nil +} + +// CheckPassword verifies if the provided password matches the hash +func (u *User) CheckPassword(password string) bool { + err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) + return err == nil +} diff --git a/pkg/error.go b/pkg/error.go index 80b6fa7..64216a9 100644 --- a/pkg/error.go +++ b/pkg/error.go @@ -4,6 +4,7 @@ import "errors" var ( ErrNotFound = errors.New("not found") + ErrDisabled = errors.New("disabled") ErrStreamExist = errors.New("stream exist") ErrRecordExists = errors.New("record exists") ErrKick = errors.New("kick") @@ -24,4 +25,5 @@ var ( ErrRecordSamePath = errors.New("record same path") ErrTransformSame = errors.New("transform same") ErrNotListen = errors.New("not listen") + ErrInvalidCredentials = errors.New("invalid credentials") ) diff --git a/scripts/protoc.sh b/scripts/protoc.sh index 9d445f5..1ec06d1 100755 --- a/scripts/protoc.sh +++ b/scripts/protoc.sh @@ -3,14 +3,16 @@ if [ $# -eq 0 ]; then cd pb # Run the global protoc command when no argument provided - protoc -I. \ - --go_out=. \ - --go_opt=paths=source_relative \ - --go-grpc_out=. \ - --go-grpc_opt=paths=source_relative \ - --grpc-gateway_out=. \ - --grpc-gateway_opt=paths=source_relative \ - "global.proto" + for proto in *.proto; do + protoc -I. \ + --go_out=. \ + --go_opt=paths=source_relative \ + --go-grpc_out=. \ + --go-grpc_opt=paths=source_relative \ + --grpc-gateway_out=. \ + --grpc-gateway_opt=paths=source_relative \ + "$proto" + done # Check if the command was successful if [ $? -eq 0 ]; then @@ -23,15 +25,17 @@ else name=$1 cd plugin/${name}/pb # Run the protoc command for plugin - protoc -I. \ - -I"../../../pb" \ - --go_out=. \ - --go_opt=paths=source_relative \ - --go-grpc_out=. \ - --go-grpc_opt=paths=source_relative \ - --grpc-gateway_out=. \ - --grpc-gateway_opt=paths=source_relative \ - "${name}.proto" + for proto in *.proto; do + protoc -I. \ + -I"../../../pb" \ + --go_out=. \ + --go_opt=paths=source_relative \ + --go-grpc_out=. \ + --go-grpc_opt=paths=source_relative \ + --grpc-gateway_out=. \ + --grpc-gateway_opt=paths=source_relative \ + "$proto" + done # Check if the command was successful if [ $? -eq 0 ]; then diff --git a/server.go b/server.go index ea206f8..2132110 100644 --- a/server.go +++ b/server.go @@ -17,9 +17,9 @@ import ( "github.com/shirou/gopsutil/v4/cpu" "google.golang.org/protobuf/proto" - "m7s.live/v5/pkg/task" - + "m7s.live/v5/pkg" "m7s.live/v5/pkg/config" + "m7s.live/v5/pkg/task" sysruntime "runtime" @@ -30,10 +30,12 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "gopkg.in/yaml.v3" "gorm.io/gorm" "m7s.live/v5/pb" . "m7s.live/v5/pkg" + "m7s.live/v5/pkg/auth" "m7s.live/v5/pkg/db" "m7s.live/v5/pkg/util" ) @@ -53,13 +55,13 @@ var ( type ( ServerConfig struct { - EnableSubEvent bool `default:"true" desc:"启用订阅事件,禁用可以提高性能"` //启用订阅事件,禁用可以提高性能 - SettingDir string `default:".m7s" desc:""` - FatalDir string `default:"fatal" desc:""` - PulseInterval time.Duration `default:"5s" desc:"心跳事件间隔"` //心跳事件间隔 - DisableAll bool `default:"false" desc:"禁用所有插件"` //禁用所有插件 - StreamAlias map[config.Regexp]string `desc:"流别名"` - PullProxy []*PullProxy + SettingDir string `default:".m7s" desc:""` + FatalDir string `default:"fatal" desc:""` + PulseInterval time.Duration `default:"5s" desc:"心跳事件间隔"` //心跳事件间隔 + DisableAll bool `default:"false" desc:"禁用所有插件"` //禁用所有插件 + StreamAlias map[config.Regexp]string `desc:"流别名"` + PullProxy []*PullProxy + EnableLogin bool `default:"false" desc:"启用登录机制"` //启用登录机制 } WaitStream struct { StreamPath string @@ -67,7 +69,9 @@ type ( } Server struct { pb.UnimplementedApiServer + pb.UnimplementedAuthServer Plugin + ServerConfig Plugins util.Collection[string, *Plugin] Streams task.Manager[string, *Publisher] @@ -191,18 +195,29 @@ func (s *Server) Start() (err error) { s.LogHandler.Add(defaultLogHandler) s.Logger = slog.New(&s.LogHandler).With("server", s.ID) s.Waiting.Logger = s.Logger - mux := runtime.NewServeMux(runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}), runtime.WithForwardResponseOption(func(ctx context.Context, w http.ResponseWriter, m proto.Message) error { - header := w.Header() - header.Set("Access-Control-Allow-Credentials", "true") - header.Set("Cross-Origin-Resource-Policy", "cross-origin") - header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token") - header.Set("Access-Control-Allow-Private-Network", "true") - header.Set("Access-Control-Allow-Origin", "*") - return nil - }), runtime.WithRoutingErrorHandler(func(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w http.ResponseWriter, r *http.Request, _ int) { - httpConf.GetHttpMux().ServeHTTP(w, r) - })) + + var httpMux http.Handler = httpConf.CreateHttpMux() + if s.ServerConfig.EnableLogin { + httpMux = auth.Middleware(s)(httpMux) + } + mux := runtime.NewServeMux( + runtime.WithMarshalerOption("text/plain", &pb.TextPlain{}), + runtime.WithForwardResponseOption(func(ctx context.Context, w http.ResponseWriter, m proto.Message) error { + header := w.Header() + header.Set("Access-Control-Allow-Credentials", "true") + header.Set("Cross-Origin-Resource-Policy", "cross-origin") + header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token,Authorization") + header.Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS") + header.Set("Access-Control-Allow-Private-Network", "true") + header.Set("Access-Control-Allow-Origin", "*") + return nil + }), + runtime.WithRoutingErrorHandler(func(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w http.ResponseWriter, r *http.Request, _ int) { + httpMux.ServeHTTP(w, r) + }), + ) httpConf.SetMux(mux) + var cg RawConfig var configYaml []byte switch v := s.conf.(type) { @@ -243,6 +258,7 @@ func (s *Server) Start() (err error) { "/api/videotrack/sse/{streamPath...}": s.api_VideoTrack_SSE, "/api/audiotrack/sse/{streamPath...}": s.api_AudioTrack_SSE, }) + if s.config.DSN != "" { if factory, ok := db.Factory[s.config.DBType]; ok { s.DB, err = gorm.Open(factory(s.config.DSN), &gorm.Config{}) @@ -250,19 +266,43 @@ func (s *Server) Start() (err error) { s.Error("failed to connect database", "error", err, "dsn", s.config.DSN, "type", s.config.DBType) return } + // Auto-migrate the User model + if err = s.DB.AutoMigrate(&db.User{}); err != nil { + s.Error("failed to auto-migrate User model", "error", err) + return + } + // Create default admin user if not exists + var count int64 + s.DB.Model(&db.User{}).Count(&count) + if count == 0 { + adminUser := &db.User{ + Username: "admin", + Password: "admin", + Role: "admin", + } + if err = s.DB.Create(adminUser).Error; err != nil { + s.Error("failed to create default admin user", "error", err) + return + } + } } } + if httpConf.ListenAddrTLS != "" { s.AddDependTask(httpConf.CreateHTTPSWork(s.Logger)) } if httpConf.ListenAddr != "" { s.AddDependTask(httpConf.CreateHTTPWork(s.Logger)) } + var grpcServer *GRPCServer if tcpConf.ListenAddr != "" { var opts []grpc.ServerOption + // Add the auth interceptor + opts = append(opts, grpc.UnaryInterceptor(s.AuthInterceptor())) s.grpcServer = grpc.NewServer(opts...) pb.RegisterApiServer(s.grpcServer, s) + pb.RegisterAuthServer(s.grpcServer, s) s.grpcClientConn, err = grpc.DialContext(s.Context, tcpConf.ListenAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -270,7 +310,11 @@ func (s *Server) Start() (err error) { return } if err = pb.RegisterApiHandler(s.Context, mux, s.grpcClientConn); err != nil { - s.Error("register handler faild", "error", err) + s.Error("register handler failed", "error", err) + return + } + if err = pb.RegisterAuthHandler(s.Context, mux, s.grpcClientConn); err != nil { + s.Error("register auth handler failed", "error", err) return } grpcServer = &GRPCServer{s: s, tcpTask: tcpConf.CreateTCPWork(s.Logger, nil)} @@ -279,6 +323,7 @@ func (s *Server) Start() (err error) { return } } + s.AddTask(&s.Records) s.AddTask(&s.Streams) s.AddTask(&s.Pulls) @@ -450,3 +495,135 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprintf(w, "%s\n", api) } } + +// ValidateToken implements auth.TokenValidator +func (s *Server) ValidateToken(tokenString string) (*auth.JWTClaims, error) { + if !s.ServerConfig.EnableLogin { + return &auth.JWTClaims{Username: "anonymous"}, nil + } + return auth.ValidateJWT(tokenString) +} + +// Login implements the Login RPC method +func (s *Server) Login(ctx context.Context, req *pb.LoginRequest) (res *pb.LoginResponse, err error) { + res = &pb.LoginResponse{} + if !s.ServerConfig.EnableLogin { + err = pkg.ErrDisabled + return + } + if s.DB == nil { + err = pkg.ErrNoDB + return + } + var user db.User + if err = s.DB.Where("username = ?", req.Username).First(&user).Error; err != nil { + return + } + + if !user.CheckPassword(req.Password) { + err = pkg.ErrInvalidCredentials + return + } + + // Generate JWT token + var tokenString string + tokenString, err = auth.GenerateToken(user.Username) + if err != nil { + return + } + + // Update last login time + s.DB.Model(&user).Update("last_login", time.Now()) + res.Data = &pb.LoginSuccess{ + Token: tokenString, + UserInfo: &pb.UserInfo{ + Username: user.Username, + ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), + }, + } + return +} + +// Logout implements the Logout RPC method +func (s *Server) Logout(ctx context.Context, req *pb.LogoutRequest) (res *pb.LogoutResponse, err error) { + if !s.ServerConfig.EnableLogin { + err = pkg.ErrDisabled + return + } + // In a more complex system, you might want to maintain a blacklist of logged-out tokens + // For now, we'll just return success as JWT tokens are stateless + res = &pb.LogoutResponse{Code: 0, Message: "success"} + return +} + +// GetUserInfo implements the GetUserInfo RPC method +func (s *Server) GetUserInfo(ctx context.Context, req *pb.UserInfoRequest) (res *pb.UserInfoResponse, err error) { + if !s.ServerConfig.EnableLogin { + res = &pb.UserInfoResponse{ + Code: 0, + Message: "success", + Data: &pb.UserInfo{ + Username: "anonymous", + ExpiresAt: time.Now().Add(24 * time.Hour).Unix(), + }, + } + return + } + res = &pb.UserInfoResponse{} + claims, err := s.ValidateToken(req.Token) + if err != nil { + err = pkg.ErrInvalidCredentials + return + } + + var user db.User + if err = s.DB.Where("username = ?", claims.Username).First(&user).Error; err != nil { + return + } + + // Token is valid for 24 hours from now + expiresAt := time.Now().Add(24 * time.Hour).Unix() + + return &pb.UserInfoResponse{ + Code: 0, + Message: "success", + Data: &pb.UserInfo{ + Username: user.Username, + ExpiresAt: expiresAt, + }, + }, nil +} + +// AuthInterceptor creates a new unary interceptor for authentication +func (s *Server) AuthInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if !s.ServerConfig.EnableLogin { + return handler(ctx, req) + } + + // Skip auth for login endpoint + if info.FullMethod == "/pb.Auth/Login" { + return handler(ctx, req) + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("missing metadata") + } + + authHeader := md.Get("authorization") + if len(authHeader) == 0 { + return nil, errors.New("missing authorization header") + } + + tokenString := strings.TrimPrefix(authHeader[0], "Bearer ") + claims, err := s.ValidateToken(tokenString) + if err != nil { + return nil, errors.New("invalid token") + } + + // Add claims to context + newCtx := context.WithValue(ctx, "claims", claims) + return handler(newCtx, req) + } +}