Files
rpcx/serverplugin/file_transfer.go
smallnest c13fa98bc4 remove v5
2020-04-28 18:20:33 +08:00

222 lines
4.9 KiB
Go

package serverplugin
import (
"context"
"crypto/rand"
"io"
"net"
"sync"
"time"
"github.com/hashicorp/golang-lru"
"github.com/smallnest/rpcx/log"
"github.com/smallnest/rpcx/server"
)
var (
SendFileServiceName = "_filetransfer"
)
// FileTransferHandler handles uploading file. Must close the connection after it finished.
type FileTransferHandler func(conn net.Conn, args *FileTransferArgs)
// DownloadFileHandler handles downloading file. Must close the connection after it finished.
type DownloadFileHandler func(conn net.Conn, args *DownloadFileArgs)
// FileTransferArgs args from clients.
type FileTransferArgs struct {
FileName string `json:"file_name,omitempty"`
FileSize int64 `json:"file_size,omitempty"`
Meta map[string]string `json:"meta,omitempty"`
}
// FileTransferReply response to token and addr to clients.
type FileTransferReply struct {
Token []byte `json:"token,omitempty"`
Addr string `json:"addr,omitempty"`
}
// DownloadFileArgs args from clients.
type DownloadFileArgs struct {
FileName string `json:"file_name,omitempty"`
}
type tokenInfo struct {
token []byte
args *FileTransferArgs
}
type downloadTokenInfo struct {
token []byte
args *DownloadFileArgs
}
// FileTransfer support transfer files from clients.
// It registers a file transfer service and listens a on the given port.
// Clients will invokes this service to get the token and send the token and the file to this port.
type FileTransfer struct {
Addr string
handler FileTransferHandler
downloadFileHandler DownloadFileHandler
cachedTokens *lru.Cache
service *FileTransferService
startOnce sync.Once
done chan struct{}
}
type FileTransferService struct {
FileTransfer *FileTransfer
}
// NewFileTransfer creates a FileTransfer with given parameters.
func NewFileTransfer(addr string, handler FileTransferHandler, downloadFileHandler DownloadFileHandler, waitNum int) *FileTransfer {
cachedTokens, _ := lru.New(waitNum)
fi := &FileTransfer{
Addr: addr,
handler: handler,
downloadFileHandler: downloadFileHandler,
cachedTokens: cachedTokens,
}
fi.service = &FileTransferService{
FileTransfer: fi,
}
return fi
}
// RegisterFileTransfer register filetransfer service into the server.
func RegisterFileTransfer(s *server.Server, fileTransfer *FileTransfer) {
fileTransfer.Start()
s.RegisterName(SendFileServiceName, fileTransfer.service, "")
}
func (s *FileTransferService) TransferFile(ctx context.Context, args *FileTransferArgs, reply *FileTransferReply) error {
token := make([]byte, 32)
_, err := rand.Read(token)
if err != nil {
return err
}
*reply = FileTransferReply{
Token: token,
Addr: s.FileTransfer.Addr,
}
s.FileTransfer.cachedTokens.Add(string(token), &tokenInfo{token, args})
return nil
}
func (s *FileTransferService) DownloadFile(ctx context.Context, args *DownloadFileArgs, reply *FileTransferReply) error {
token := make([]byte, 32)
_, err := rand.Read(token)
if err != nil {
return err
}
*reply = FileTransferReply{
Token: token,
Addr: s.FileTransfer.Addr,
}
s.FileTransfer.cachedTokens.Add(string(token), &downloadTokenInfo{token, args})
return nil
}
func (s *FileTransfer) Start() error {
s.startOnce.Do(func() {
go s.start()
})
return nil
}
func (s *FileTransfer) start() error {
ln, err := net.Listen("tcp", s.Addr)
if err != nil {
return err
}
var tempDelay time.Duration
for {
select {
case <-s.done:
return nil
default:
conn, e := ln.Accept()
if e != nil {
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Errorf("filetransfer: accept error: %v; retrying in %v", e, tempDelay)
time.Sleep(tempDelay)
continue
}
return e
}
tempDelay = 0
if tc, ok := conn.(*net.TCPConn); ok {
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
tc.SetLinger(10)
}
token := make([]byte, 32)
_, err := io.ReadFull(conn, token)
if err != nil {
conn.Close()
log.Errorf("failed to read token from %s", conn.RemoteAddr().String())
continue
}
tokenStr := string(token)
info, ok := s.cachedTokens.Get(tokenStr)
if !ok {
conn.Close()
log.Errorf("failed to read token from %s", conn.RemoteAddr().String())
continue
}
s.cachedTokens.Remove(tokenStr)
switch ti := info.(type) {
case *tokenInfo:
if s.handler == nil {
conn.Close()
continue
}
go s.handler(conn, ti.args)
case *downloadTokenInfo:
if s.downloadFileHandler == nil {
conn.Close()
continue
}
go s.downloadFileHandler(conn, ti.args)
default:
conn.Close()
}
}
}
}
func (s *FileTransfer) Stop() error {
close(s.done)
return nil
}