From 4ed0dba77219f5a54bb2a06264c5b1d6d611a142 Mon Sep 17 00:00:00 2001 From: rkonfj Date: Sun, 23 Jun 2024 16:04:56 +0800 Subject: [PATCH] fileshare: a better downloader api --- cmd/pgcli/download/download.go | 50 ++++++++++- cmd/pgcli/share/share.go | 2 +- fileshare/README.md | 49 +++++++++++ fileshare/common.go | 9 +- fileshare/downloader.go | 154 +++++++++++++++------------------ fileshare/filemanager.go | 23 +++-- 6 files changed, 182 insertions(+), 105 deletions(-) create mode 100644 fileshare/README.md diff --git a/cmd/pgcli/download/download.go b/cmd/pgcli/download/download.go index c72e420..c5b4f44 100644 --- a/cmd/pgcli/download/download.go +++ b/cmd/pgcli/download/download.go @@ -1,9 +1,12 @@ package download import ( + "bytes" "context" + "crypto/sha256" "errors" "fmt" + "io" "log/slog" "os" "os/signal" @@ -36,7 +39,7 @@ func execute(cmd *cobra.Command, args []string) error { } slog.SetLogLoggerLevel(slog.Level(verbose)) - downloader := fileshare.Downloader{UDPPort: 29879, ProgressBar: createBar} + downloader := fileshare.Downloader{ListenUDPPort: 29879} downloader.Server, err = cmd.Flags().GetString("server") if err != nil { @@ -55,8 +58,51 @@ func execute(cmd *cobra.Command, args []string) error { } ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() + return downloader.Request(ctx, args[0], readFile) +} - return downloader.Download(ctx, args[0]) +func readFile(fh *fileshare.FileHandle) error { + f, err := os.OpenFile(fh.Filename, os.O_RDWR, 0666) + if err != nil { + f, err = os.Create(fh.Filename) + if err != nil { + return err + } + } + defer f.Close() + stat, err := f.Stat() + if err != nil { + return err + } + partSize := stat.Size() + + sha256Checksum := sha256.New() + if partSize > 0 { + fmt.Println("download resuming") + } + if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil { + return err + } + if err := fh.Handshake(uint32(partSize), sha256Checksum.Sum(nil)); err != nil { + return err + } + r, fileSize, _ := fh.File() + bar := createBar(int64(fileSize), fh.Filename) + bar.Add(int(partSize)) + if _, err = io.Copy(io.MultiWriter(f, bar, sha256Checksum), r); err != nil { + return fmt.Errorf("download file falied: %w", err) + } + checksum, err := fh.Sha256() + if err != nil { + return err + } + recvSum := sha256Checksum.Sum(nil) + slog.Debug("Checksum", "recv", recvSum, "send", checksum) + if !bytes.Equal(checksum, recvSum) { + return fmt.Errorf("download file failed: checksum mismatched") + } + fmt.Printf("sha256: %x\n", checksum) + return nil } func createBar(total int64, desc string) fileshare.ProgressBar { diff --git a/cmd/pgcli/share/share.go b/cmd/pgcli/share/share.go index b18e08f..8539a5c 100644 --- a/cmd/pgcli/share/share.go +++ b/cmd/pgcli/share/share.go @@ -37,7 +37,7 @@ func execute(cmd *cobra.Command, args []string) error { } slog.SetLogLoggerLevel(slog.Level(verbose)) - fileManager := fileshare.FileManager{UDPPort: 29878, ProgressBar: createBar} + fileManager := fileshare.FileManager{ListenUDPPort: 29878, ProgressBar: createBar} if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil { return err diff --git a/fileshare/README.md b/fileshare/README.md new file mode 100644 index 0000000..a89ca33 --- /dev/null +++ b/fileshare/README.md @@ -0,0 +1,49 @@ +# fileshare + +A p2p file sharing library + +### Example + +#### download +```go +downloader := &fileshare.Downloader{ + Server: "wss://synf.in/pg", + ListenUDPPort: 29999, +} + +read := func(fh *fileshare.FileHandle) error { + // handshake (can set the offset to facilitate breakpoint resuming) + if err := fh.Handshake(0, nil); err != nil { + return err + } + reader, fileSize, err := fh.File() + if err != nil { + return err + } + f, err := os.Create(fh.Filename) + if err != nil { + return err + } + + if err = io.Copy(io.MultiWriter(f, sum), reader); err != nil { + return err + } + + peerSum, err := fh.Sha256() // file checksum from peer + if err != nil { + return err + } + + if !bytes.Equal(sum.Sum(nil), peerSum) { // assert that local and remote are consistent + return errors.New("transfer error") + } + return nil +} + +ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) +defer cancel() +err := downloader.Request(ctx, "pg://DJX2csRurJ3DvKeh63JebVHFDqVhnFjckdVhToAAiPYf/0/my-show.pptx", read) +if err != nil { + panic(err) +} +``` \ No newline at end of file diff --git a/fileshare/common.go b/fileshare/common.go index e92f782..bba1ee1 100644 --- a/fileshare/common.go +++ b/fileshare/common.go @@ -1,6 +1,7 @@ package fileshare import ( + "cmp" "fmt" "io" "net" @@ -22,15 +23,11 @@ func (pn *PublicNetwork) ListenPacket(udpPort int) (net.PacketConn, error) { if err != nil { return nil, fmt.Errorf("invalid peermap URL: %w", err) } - - pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{ - Network: pn.Name, - Secret: pn.Name, - }) + network := cmp.Or(pn.Name, "pubnet") + pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{Network: network, Secret: network}) if err != nil { return nil, fmt.Errorf("create peermap failed: %w", err) } - return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort)) } diff --git a/fileshare/downloader.go b/fileshare/downloader.go index 4e50c43..2364477 100644 --- a/fileshare/downloader.go +++ b/fileshare/downloader.go @@ -1,17 +1,13 @@ package fileshare import ( - "bytes" "context" - "crypto/sha256" "encoding/binary" "errors" "fmt" "io" - "log/slog" "net" "net/url" - "os" "path" "strconv" "strings" @@ -20,22 +16,79 @@ import ( "github.com/rkonfj/peerguard/rdt" ) -type Downloader struct { - Network string - Server string - PrivateKey string - UDPPort int - ProgressBar func(total int64, desc string) ProgressBar +type FileHandle struct { + Filename string + + c net.Conn + index uint16 + fSize uint32 + f io.Reader } -func (d *Downloader) Download(ctx context.Context, shareURL string) error { +func (h *FileHandle) Handshake(offset uint32, sha256Checksum []byte) error { + _, err := h.c.Write(buildGet(h.index, offset, sha256Checksum)) + if err != nil { + return err + } + header := make([]byte, 5) + _, err = io.ReadFull(h.c, header) + if err != nil { + return err + } + switch header[0] { + case 0: + case 20: + case 1: + return errors.New("bad request. maybe the version is lower than peer") + case 2: + return errors.New("file not found") + case 4: + return errors.New("download file size is less than local file") + case 5: + return errors.New("local file is not part of the file to be downloaded") + default: + return errors.New("invalid protocol header") + } + if offset > 0 && header[0] != 20 { + return errors.New("sha256 checksum non matched for [0, offset)") + } + h.fSize = binary.BigEndian.Uint32(header[1:]) + h.f = io.LimitReader(h.c, int64(h.fSize-offset)) + return nil +} + +func (h *FileHandle) File() (io.Reader, uint32, error) { + if h.f == nil { + return nil, 0, errors.New("handshake first") + } + return h.f, h.fSize, nil +} + +func (h *FileHandle) Sha256() ([]byte, error) { + checksum := make([]byte, 32) + if _, err := io.ReadFull(h.c, checksum); err != nil { + return nil, fmt.Errorf("read checksum failed: %w", err) + } + return checksum, nil +} + +type Read func(f *FileHandle) error + +type Downloader struct { + Network string + Server string + PrivateKey string + ListenUDPPort int +} + +func (d *Downloader) Request(ctx context.Context, shareURL string, read Read) error { pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey} - packetConn, err := pnet.ListenPacket(d.UDPPort) + packetConn, err := pnet.ListenPacket(d.ListenUDPPort) if err != nil { return fmt.Errorf("listen p2p packet failed: %w", err) } - listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.UDPPort+100))) + listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.ListenUDPPort+100))) if err != nil { return fmt.Errorf("listen rdt: %w", err) } @@ -66,77 +119,12 @@ func (d *Downloader) Download(ctx context.Context, shareURL string) error { conn.Write(buildClose()) conn.Close() }() - return d.download(conn, uint16(index), fn) -} - -func (d *Downloader) download(conn net.Conn, index uint16, filename string) error { - f, err := os.OpenFile(filename, os.O_RDWR, 0666) - if err != nil { - f, err = os.Create(filename) - if err != nil { - return err - } - } - defer f.Close() - stat, err := f.Stat() - if err != nil { - return err - } - partSize := stat.Size() - - sha256Checksum := sha256.New() - if partSize > 0 { - fmt.Println("download resuming") - } - - if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil { - return err - } - _, err = conn.Write(buildGet(uint16(index), uint32(stat.Size()), sha256Checksum.Sum(nil))) - if err != nil { - return err - } - header := make([]byte, 5) - _, err = io.ReadFull(conn, header) - if err != nil { - return err - } - switch header[0] { - case 0: - case 20: - case 1: - return errors.New("bad request. maybe the version is lower than peer") - case 2: - return errors.New("file not found") - case 4: - return errors.New("download file size is less than local file") - default: - return errors.New("invalid protocol header") - } - - fileSize := binary.BigEndian.Uint32(header[1:]) - var bar ProgressBar = NopProgress{} - if d.ProgressBar != nil { - bar = d.ProgressBar(int64(fileSize), filename) - } - bar.Add(int(stat.Size())) defer conn.Write(buildClose()) - - _, err = io.CopyN(io.MultiWriter(f, bar, sha256Checksum), conn, int64(fileSize-uint32(partSize))) - if err != nil && !errors.Is(err, io.EOF) { - return fmt.Errorf("download file falied: %w", err) - } - checksum := make([]byte, 32) - if _, err = io.ReadFull(conn, checksum); err != nil { - return fmt.Errorf("read checksum failed: %w", err) - } - recvSum := sha256Checksum.Sum(nil) - slog.Debug("Checksum", "recv", recvSum, "send", checksum) - if !bytes.Equal(checksum, recvSum) { - return fmt.Errorf("download file failed: checksum mismatched") - } - fmt.Printf("sha256: %x\n", checksum) - return nil + return read(&FileHandle{ + Filename: fn, + c: conn, + index: uint16(index), + }) } func buildGet(index uint16, partSize uint32, checksum []byte) []byte { diff --git a/fileshare/filemanager.go b/fileshare/filemanager.go index 440e059..93550be 100644 --- a/fileshare/filemanager.go +++ b/fileshare/filemanager.go @@ -21,11 +21,11 @@ import ( ) type FileManager struct { - Network string - Server string - PrivateKey string - UDPPort int - ProgressBar func(total int64, desc string) ProgressBar + Network string + Server string + PrivateKey string + ListenUDPPort int + ProgressBar func(total int64, desc string) ProgressBar mutex sync.RWMutex index int @@ -36,12 +36,12 @@ type FileManager struct { func (m *FileManager) ListenNetwork() (net.Listener, error) { pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey} - packetConn, err := pnet.ListenPacket(m.UDPPort) + packetConn, err := pnet.ListenPacket(m.ListenUDPPort) if err != nil { return nil, fmt.Errorf("listen p2p packet failed: %w", err) } - listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.UDPPort+100))) + listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.ListenUDPPort+100))) if err != nil { return nil, fmt.Errorf("listen rdt: %w", err) } @@ -154,12 +154,9 @@ func (m *FileManager) handleRequest(peerID string, conn net.Conn) { } io.CopyN(sha256Checksum, f, int64(partSize)) if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) { - if _, err = f.Seek(0, io.SeekStart); err != nil { - conn.Write(buildErr(10)) - slog.Error("SeekToStart", "err", err) - return - } - sha256Checksum = sha256.New() + conn.Write(buildErr(5)) // not part of file + slog.Error("Request not part of file", "file", f.Name()) + return } }