fileshare: a better downloader api

This commit is contained in:
rkonfj
2024-06-23 16:04:56 +08:00
parent 8e35276ffe
commit 4ed0dba772
6 changed files with 182 additions and 105 deletions

View File

@@ -1,9 +1,12 @@
package download
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/signal"
@@ -36,7 +39,7 @@ func execute(cmd *cobra.Command, args []string) error {
}
slog.SetLogLoggerLevel(slog.Level(verbose))
downloader := fileshare.Downloader{UDPPort: 29879, ProgressBar: createBar}
downloader := fileshare.Downloader{ListenUDPPort: 29879}
downloader.Server, err = cmd.Flags().GetString("server")
if err != nil {
@@ -55,8 +58,51 @@ func execute(cmd *cobra.Command, args []string) error {
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
return downloader.Request(ctx, args[0], readFile)
}
return downloader.Download(ctx, args[0])
func readFile(fh *fileshare.FileHandle) error {
f, err := os.OpenFile(fh.Filename, os.O_RDWR, 0666)
if err != nil {
f, err = os.Create(fh.Filename)
if err != nil {
return err
}
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
partSize := stat.Size()
sha256Checksum := sha256.New()
if partSize > 0 {
fmt.Println("download resuming")
}
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
return err
}
if err := fh.Handshake(uint32(partSize), sha256Checksum.Sum(nil)); err != nil {
return err
}
r, fileSize, _ := fh.File()
bar := createBar(int64(fileSize), fh.Filename)
bar.Add(int(partSize))
if _, err = io.Copy(io.MultiWriter(f, bar, sha256Checksum), r); err != nil {
return fmt.Errorf("download file falied: %w", err)
}
checksum, err := fh.Sha256()
if err != nil {
return err
}
recvSum := sha256Checksum.Sum(nil)
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
if !bytes.Equal(checksum, recvSum) {
return fmt.Errorf("download file failed: checksum mismatched")
}
fmt.Printf("sha256: %x\n", checksum)
return nil
}
func createBar(total int64, desc string) fileshare.ProgressBar {

View File

@@ -37,7 +37,7 @@ func execute(cmd *cobra.Command, args []string) error {
}
slog.SetLogLoggerLevel(slog.Level(verbose))
fileManager := fileshare.FileManager{UDPPort: 29878, ProgressBar: createBar}
fileManager := fileshare.FileManager{ListenUDPPort: 29878, ProgressBar: createBar}
if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil {
return err

49
fileshare/README.md Normal file
View File

@@ -0,0 +1,49 @@
# fileshare
A p2p file sharing library
### Example
#### download
```go
downloader := &fileshare.Downloader{
Server: "wss://synf.in/pg",
ListenUDPPort: 29999,
}
read := func(fh *fileshare.FileHandle) error {
// handshake (can set the offset to facilitate breakpoint resuming)
if err := fh.Handshake(0, nil); err != nil {
return err
}
reader, fileSize, err := fh.File()
if err != nil {
return err
}
f, err := os.Create(fh.Filename)
if err != nil {
return err
}
if err = io.Copy(io.MultiWriter(f, sum), reader); err != nil {
return err
}
peerSum, err := fh.Sha256() // file checksum from peer
if err != nil {
return err
}
if !bytes.Equal(sum.Sum(nil), peerSum) { // assert that local and remote are consistent
return errors.New("transfer error")
}
return nil
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
err := downloader.Request(ctx, "pg://DJX2csRurJ3DvKeh63JebVHFDqVhnFjckdVhToAAiPYf/0/my-show.pptx", read)
if err != nil {
panic(err)
}
```

View File

@@ -1,6 +1,7 @@
package fileshare
import (
"cmp"
"fmt"
"io"
"net"
@@ -22,15 +23,11 @@ func (pn *PublicNetwork) ListenPacket(udpPort int) (net.PacketConn, error) {
if err != nil {
return nil, fmt.Errorf("invalid peermap URL: %w", err)
}
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{
Network: pn.Name,
Secret: pn.Name,
})
network := cmp.Or(pn.Name, "pubnet")
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{Network: network, Secret: network})
if err != nil {
return nil, fmt.Errorf("create peermap failed: %w", err)
}
return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort))
}

View File

@@ -1,17 +1,13 @@
package fileshare
import (
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/url"
"os"
"path"
"strconv"
"strings"
@@ -20,22 +16,79 @@ import (
"github.com/rkonfj/peerguard/rdt"
)
type Downloader struct {
Network string
Server string
PrivateKey string
UDPPort int
ProgressBar func(total int64, desc string) ProgressBar
type FileHandle struct {
Filename string
c net.Conn
index uint16
fSize uint32
f io.Reader
}
func (d *Downloader) Download(ctx context.Context, shareURL string) error {
func (h *FileHandle) Handshake(offset uint32, sha256Checksum []byte) error {
_, err := h.c.Write(buildGet(h.index, offset, sha256Checksum))
if err != nil {
return err
}
header := make([]byte, 5)
_, err = io.ReadFull(h.c, header)
if err != nil {
return err
}
switch header[0] {
case 0:
case 20:
case 1:
return errors.New("bad request. maybe the version is lower than peer")
case 2:
return errors.New("file not found")
case 4:
return errors.New("download file size is less than local file")
case 5:
return errors.New("local file is not part of the file to be downloaded")
default:
return errors.New("invalid protocol header")
}
if offset > 0 && header[0] != 20 {
return errors.New("sha256 checksum non matched for [0, offset)")
}
h.fSize = binary.BigEndian.Uint32(header[1:])
h.f = io.LimitReader(h.c, int64(h.fSize-offset))
return nil
}
func (h *FileHandle) File() (io.Reader, uint32, error) {
if h.f == nil {
return nil, 0, errors.New("handshake first")
}
return h.f, h.fSize, nil
}
func (h *FileHandle) Sha256() ([]byte, error) {
checksum := make([]byte, 32)
if _, err := io.ReadFull(h.c, checksum); err != nil {
return nil, fmt.Errorf("read checksum failed: %w", err)
}
return checksum, nil
}
type Read func(f *FileHandle) error
type Downloader struct {
Network string
Server string
PrivateKey string
ListenUDPPort int
}
func (d *Downloader) Request(ctx context.Context, shareURL string, read Read) error {
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
packetConn, err := pnet.ListenPacket(d.UDPPort)
packetConn, err := pnet.ListenPacket(d.ListenUDPPort)
if err != nil {
return fmt.Errorf("listen p2p packet failed: %w", err)
}
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.UDPPort+100)))
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.ListenUDPPort+100)))
if err != nil {
return fmt.Errorf("listen rdt: %w", err)
}
@@ -66,77 +119,12 @@ func (d *Downloader) Download(ctx context.Context, shareURL string) error {
conn.Write(buildClose())
conn.Close()
}()
return d.download(conn, uint16(index), fn)
}
func (d *Downloader) download(conn net.Conn, index uint16, filename string) error {
f, err := os.OpenFile(filename, os.O_RDWR, 0666)
if err != nil {
f, err = os.Create(filename)
if err != nil {
return err
}
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
partSize := stat.Size()
sha256Checksum := sha256.New()
if partSize > 0 {
fmt.Println("download resuming")
}
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
return err
}
_, err = conn.Write(buildGet(uint16(index), uint32(stat.Size()), sha256Checksum.Sum(nil)))
if err != nil {
return err
}
header := make([]byte, 5)
_, err = io.ReadFull(conn, header)
if err != nil {
return err
}
switch header[0] {
case 0:
case 20:
case 1:
return errors.New("bad request. maybe the version is lower than peer")
case 2:
return errors.New("file not found")
case 4:
return errors.New("download file size is less than local file")
default:
return errors.New("invalid protocol header")
}
fileSize := binary.BigEndian.Uint32(header[1:])
var bar ProgressBar = NopProgress{}
if d.ProgressBar != nil {
bar = d.ProgressBar(int64(fileSize), filename)
}
bar.Add(int(stat.Size()))
defer conn.Write(buildClose())
_, err = io.CopyN(io.MultiWriter(f, bar, sha256Checksum), conn, int64(fileSize-uint32(partSize)))
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("download file falied: %w", err)
}
checksum := make([]byte, 32)
if _, err = io.ReadFull(conn, checksum); err != nil {
return fmt.Errorf("read checksum failed: %w", err)
}
recvSum := sha256Checksum.Sum(nil)
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
if !bytes.Equal(checksum, recvSum) {
return fmt.Errorf("download file failed: checksum mismatched")
}
fmt.Printf("sha256: %x\n", checksum)
return nil
return read(&FileHandle{
Filename: fn,
c: conn,
index: uint16(index),
})
}
func buildGet(index uint16, partSize uint32, checksum []byte) []byte {

View File

@@ -21,11 +21,11 @@ import (
)
type FileManager struct {
Network string
Server string
PrivateKey string
UDPPort int
ProgressBar func(total int64, desc string) ProgressBar
Network string
Server string
PrivateKey string
ListenUDPPort int
ProgressBar func(total int64, desc string) ProgressBar
mutex sync.RWMutex
index int
@@ -36,12 +36,12 @@ type FileManager struct {
func (m *FileManager) ListenNetwork() (net.Listener, error) {
pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey}
packetConn, err := pnet.ListenPacket(m.UDPPort)
packetConn, err := pnet.ListenPacket(m.ListenUDPPort)
if err != nil {
return nil, fmt.Errorf("listen p2p packet failed: %w", err)
}
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.UDPPort+100)))
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.ListenUDPPort+100)))
if err != nil {
return nil, fmt.Errorf("listen rdt: %w", err)
}
@@ -154,12 +154,9 @@ func (m *FileManager) handleRequest(peerID string, conn net.Conn) {
}
io.CopyN(sha256Checksum, f, int64(partSize))
if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) {
if _, err = f.Seek(0, io.SeekStart); err != nil {
conn.Write(buildErr(10))
slog.Error("SeekToStart", "err", err)
return
}
sha256Checksum = sha256.New()
conn.Write(buildErr(5)) // not part of file
slog.Error("Request not part of file", "file", f.Name())
return
}
}