fileshare: a better downloader api

This commit is contained in:
rkonfj
2024-06-23 16:04:56 +08:00
parent 8e35276ffe
commit 4ed0dba772
6 changed files with 182 additions and 105 deletions

View File

@@ -1,9 +1,12 @@
package download
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/signal"
@@ -36,7 +39,7 @@ func execute(cmd *cobra.Command, args []string) error {
}
slog.SetLogLoggerLevel(slog.Level(verbose))
downloader := fileshare.Downloader{UDPPort: 29879, ProgressBar: createBar}
downloader := fileshare.Downloader{ListenUDPPort: 29879}
downloader.Server, err = cmd.Flags().GetString("server")
if err != nil {
@@ -55,8 +58,51 @@ func execute(cmd *cobra.Command, args []string) error {
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
return downloader.Request(ctx, args[0], readFile)
}
return downloader.Download(ctx, args[0])
func readFile(fh *fileshare.FileHandle) error {
f, err := os.OpenFile(fh.Filename, os.O_RDWR, 0666)
if err != nil {
f, err = os.Create(fh.Filename)
if err != nil {
return err
}
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
partSize := stat.Size()
sha256Checksum := sha256.New()
if partSize > 0 {
fmt.Println("download resuming")
}
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
return err
}
if err := fh.Handshake(uint32(partSize), sha256Checksum.Sum(nil)); err != nil {
return err
}
r, fileSize, _ := fh.File()
bar := createBar(int64(fileSize), fh.Filename)
bar.Add(int(partSize))
if _, err = io.Copy(io.MultiWriter(f, bar, sha256Checksum), r); err != nil {
return fmt.Errorf("download file falied: %w", err)
}
checksum, err := fh.Sha256()
if err != nil {
return err
}
recvSum := sha256Checksum.Sum(nil)
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
if !bytes.Equal(checksum, recvSum) {
return fmt.Errorf("download file failed: checksum mismatched")
}
fmt.Printf("sha256: %x\n", checksum)
return nil
}
func createBar(total int64, desc string) fileshare.ProgressBar {

View File

@@ -37,7 +37,7 @@ func execute(cmd *cobra.Command, args []string) error {
}
slog.SetLogLoggerLevel(slog.Level(verbose))
fileManager := fileshare.FileManager{UDPPort: 29878, ProgressBar: createBar}
fileManager := fileshare.FileManager{ListenUDPPort: 29878, ProgressBar: createBar}
if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil {
return err

49
fileshare/README.md Normal file
View File

@@ -0,0 +1,49 @@
# fileshare
A p2p file sharing library
### Example
#### download
```go
downloader := &fileshare.Downloader{
Server: "wss://synf.in/pg",
ListenUDPPort: 29999,
}
read := func(fh *fileshare.FileHandle) error {
// handshake (can set the offset to facilitate breakpoint resuming)
if err := fh.Handshake(0, nil); err != nil {
return err
}
reader, fileSize, err := fh.File()
if err != nil {
return err
}
f, err := os.Create(fh.Filename)
if err != nil {
return err
}
if err = io.Copy(io.MultiWriter(f, sum), reader); err != nil {
return err
}
peerSum, err := fh.Sha256() // file checksum from peer
if err != nil {
return err
}
if !bytes.Equal(sum.Sum(nil), peerSum) { // assert that local and remote are consistent
return errors.New("transfer error")
}
return nil
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
err := downloader.Request(ctx, "pg://DJX2csRurJ3DvKeh63JebVHFDqVhnFjckdVhToAAiPYf/0/my-show.pptx", read)
if err != nil {
panic(err)
}
```

View File

@@ -1,6 +1,7 @@
package fileshare
import (
"cmp"
"fmt"
"io"
"net"
@@ -22,15 +23,11 @@ func (pn *PublicNetwork) ListenPacket(udpPort int) (net.PacketConn, error) {
if err != nil {
return nil, fmt.Errorf("invalid peermap URL: %w", err)
}
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{
Network: pn.Name,
Secret: pn.Name,
})
network := cmp.Or(pn.Name, "pubnet")
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{Network: network, Secret: network})
if err != nil {
return nil, fmt.Errorf("create peermap failed: %w", err)
}
return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort))
}

View File

@@ -1,17 +1,13 @@
package fileshare
import (
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/url"
"os"
"path"
"strconv"
"strings"
@@ -20,22 +16,79 @@ import (
"github.com/rkonfj/peerguard/rdt"
)
type FileHandle struct {
Filename string
c net.Conn
index uint16
fSize uint32
f io.Reader
}
func (h *FileHandle) Handshake(offset uint32, sha256Checksum []byte) error {
_, err := h.c.Write(buildGet(h.index, offset, sha256Checksum))
if err != nil {
return err
}
header := make([]byte, 5)
_, err = io.ReadFull(h.c, header)
if err != nil {
return err
}
switch header[0] {
case 0:
case 20:
case 1:
return errors.New("bad request. maybe the version is lower than peer")
case 2:
return errors.New("file not found")
case 4:
return errors.New("download file size is less than local file")
case 5:
return errors.New("local file is not part of the file to be downloaded")
default:
return errors.New("invalid protocol header")
}
if offset > 0 && header[0] != 20 {
return errors.New("sha256 checksum non matched for [0, offset)")
}
h.fSize = binary.BigEndian.Uint32(header[1:])
h.f = io.LimitReader(h.c, int64(h.fSize-offset))
return nil
}
func (h *FileHandle) File() (io.Reader, uint32, error) {
if h.f == nil {
return nil, 0, errors.New("handshake first")
}
return h.f, h.fSize, nil
}
func (h *FileHandle) Sha256() ([]byte, error) {
checksum := make([]byte, 32)
if _, err := io.ReadFull(h.c, checksum); err != nil {
return nil, fmt.Errorf("read checksum failed: %w", err)
}
return checksum, nil
}
type Read func(f *FileHandle) error
type Downloader struct {
Network string
Server string
PrivateKey string
UDPPort int
ProgressBar func(total int64, desc string) ProgressBar
ListenUDPPort int
}
func (d *Downloader) Download(ctx context.Context, shareURL string) error {
func (d *Downloader) Request(ctx context.Context, shareURL string, read Read) error {
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
packetConn, err := pnet.ListenPacket(d.UDPPort)
packetConn, err := pnet.ListenPacket(d.ListenUDPPort)
if err != nil {
return fmt.Errorf("listen p2p packet failed: %w", err)
}
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.UDPPort+100)))
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.ListenUDPPort+100)))
if err != nil {
return fmt.Errorf("listen rdt: %w", err)
}
@@ -66,77 +119,12 @@ func (d *Downloader) Download(ctx context.Context, shareURL string) error {
conn.Write(buildClose())
conn.Close()
}()
return d.download(conn, uint16(index), fn)
}
func (d *Downloader) download(conn net.Conn, index uint16, filename string) error {
f, err := os.OpenFile(filename, os.O_RDWR, 0666)
if err != nil {
f, err = os.Create(filename)
if err != nil {
return err
}
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
partSize := stat.Size()
sha256Checksum := sha256.New()
if partSize > 0 {
fmt.Println("download resuming")
}
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
return err
}
_, err = conn.Write(buildGet(uint16(index), uint32(stat.Size()), sha256Checksum.Sum(nil)))
if err != nil {
return err
}
header := make([]byte, 5)
_, err = io.ReadFull(conn, header)
if err != nil {
return err
}
switch header[0] {
case 0:
case 20:
case 1:
return errors.New("bad request. maybe the version is lower than peer")
case 2:
return errors.New("file not found")
case 4:
return errors.New("download file size is less than local file")
default:
return errors.New("invalid protocol header")
}
fileSize := binary.BigEndian.Uint32(header[1:])
var bar ProgressBar = NopProgress{}
if d.ProgressBar != nil {
bar = d.ProgressBar(int64(fileSize), filename)
}
bar.Add(int(stat.Size()))
defer conn.Write(buildClose())
_, err = io.CopyN(io.MultiWriter(f, bar, sha256Checksum), conn, int64(fileSize-uint32(partSize)))
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("download file falied: %w", err)
}
checksum := make([]byte, 32)
if _, err = io.ReadFull(conn, checksum); err != nil {
return fmt.Errorf("read checksum failed: %w", err)
}
recvSum := sha256Checksum.Sum(nil)
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
if !bytes.Equal(checksum, recvSum) {
return fmt.Errorf("download file failed: checksum mismatched")
}
fmt.Printf("sha256: %x\n", checksum)
return nil
return read(&FileHandle{
Filename: fn,
c: conn,
index: uint16(index),
})
}
func buildGet(index uint16, partSize uint32, checksum []byte) []byte {

View File

@@ -24,7 +24,7 @@ type FileManager struct {
Network string
Server string
PrivateKey string
UDPPort int
ListenUDPPort int
ProgressBar func(total int64, desc string) ProgressBar
mutex sync.RWMutex
@@ -36,12 +36,12 @@ type FileManager struct {
func (m *FileManager) ListenNetwork() (net.Listener, error) {
pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey}
packetConn, err := pnet.ListenPacket(m.UDPPort)
packetConn, err := pnet.ListenPacket(m.ListenUDPPort)
if err != nil {
return nil, fmt.Errorf("listen p2p packet failed: %w", err)
}
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.UDPPort+100)))
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.ListenUDPPort+100)))
if err != nil {
return nil, fmt.Errorf("listen rdt: %w", err)
}
@@ -154,13 +154,10 @@ func (m *FileManager) handleRequest(peerID string, conn net.Conn) {
}
io.CopyN(sha256Checksum, f, int64(partSize))
if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) {
if _, err = f.Seek(0, io.SeekStart); err != nil {
conn.Write(buildErr(10))
slog.Error("SeekToStart", "err", err)
conn.Write(buildErr(5)) // not part of file
slog.Error("Request not part of file", "file", f.Name())
return
}
sha256Checksum = sha256.New()
}
}
pos, err := f.Seek(0, io.SeekCurrent)