fileshare: a better downloader api

This commit is contained in:
rkonfj
2024-06-23 16:04:56 +08:00
parent 8e35276ffe
commit 4ed0dba772
6 changed files with 182 additions and 105 deletions

View File

@@ -1,9 +1,12 @@
package download package download
import ( import (
"bytes"
"context" "context"
"crypto/sha256"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"os" "os"
"os/signal" "os/signal"
@@ -36,7 +39,7 @@ func execute(cmd *cobra.Command, args []string) error {
} }
slog.SetLogLoggerLevel(slog.Level(verbose)) slog.SetLogLoggerLevel(slog.Level(verbose))
downloader := fileshare.Downloader{UDPPort: 29879, ProgressBar: createBar} downloader := fileshare.Downloader{ListenUDPPort: 29879}
downloader.Server, err = cmd.Flags().GetString("server") downloader.Server, err = cmd.Flags().GetString("server")
if err != nil { if err != nil {
@@ -55,8 +58,51 @@ func execute(cmd *cobra.Command, args []string) error {
} }
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel() defer cancel()
return downloader.Request(ctx, args[0], readFile)
}
return downloader.Download(ctx, args[0]) func readFile(fh *fileshare.FileHandle) error {
f, err := os.OpenFile(fh.Filename, os.O_RDWR, 0666)
if err != nil {
f, err = os.Create(fh.Filename)
if err != nil {
return err
}
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
partSize := stat.Size()
sha256Checksum := sha256.New()
if partSize > 0 {
fmt.Println("download resuming")
}
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
return err
}
if err := fh.Handshake(uint32(partSize), sha256Checksum.Sum(nil)); err != nil {
return err
}
r, fileSize, _ := fh.File()
bar := createBar(int64(fileSize), fh.Filename)
bar.Add(int(partSize))
if _, err = io.Copy(io.MultiWriter(f, bar, sha256Checksum), r); err != nil {
return fmt.Errorf("download file falied: %w", err)
}
checksum, err := fh.Sha256()
if err != nil {
return err
}
recvSum := sha256Checksum.Sum(nil)
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
if !bytes.Equal(checksum, recvSum) {
return fmt.Errorf("download file failed: checksum mismatched")
}
fmt.Printf("sha256: %x\n", checksum)
return nil
} }
func createBar(total int64, desc string) fileshare.ProgressBar { func createBar(total int64, desc string) fileshare.ProgressBar {

View File

@@ -37,7 +37,7 @@ func execute(cmd *cobra.Command, args []string) error {
} }
slog.SetLogLoggerLevel(slog.Level(verbose)) slog.SetLogLoggerLevel(slog.Level(verbose))
fileManager := fileshare.FileManager{UDPPort: 29878, ProgressBar: createBar} fileManager := fileshare.FileManager{ListenUDPPort: 29878, ProgressBar: createBar}
if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil { if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil {
return err return err

49
fileshare/README.md Normal file
View File

@@ -0,0 +1,49 @@
# fileshare
A p2p file sharing library
### Example
#### download
```go
downloader := &fileshare.Downloader{
Server: "wss://synf.in/pg",
ListenUDPPort: 29999,
}
read := func(fh *fileshare.FileHandle) error {
// handshake (can set the offset to facilitate breakpoint resuming)
if err := fh.Handshake(0, nil); err != nil {
return err
}
reader, fileSize, err := fh.File()
if err != nil {
return err
}
f, err := os.Create(fh.Filename)
if err != nil {
return err
}
if err = io.Copy(io.MultiWriter(f, sum), reader); err != nil {
return err
}
peerSum, err := fh.Sha256() // file checksum from peer
if err != nil {
return err
}
if !bytes.Equal(sum.Sum(nil), peerSum) { // assert that local and remote are consistent
return errors.New("transfer error")
}
return nil
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
err := downloader.Request(ctx, "pg://DJX2csRurJ3DvKeh63JebVHFDqVhnFjckdVhToAAiPYf/0/my-show.pptx", read)
if err != nil {
panic(err)
}
```

View File

@@ -1,6 +1,7 @@
package fileshare package fileshare
import ( import (
"cmp"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -22,15 +23,11 @@ func (pn *PublicNetwork) ListenPacket(udpPort int) (net.PacketConn, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid peermap URL: %w", err) return nil, fmt.Errorf("invalid peermap URL: %w", err)
} }
network := cmp.Or(pn.Name, "pubnet")
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{ pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{Network: network, Secret: network})
Network: pn.Name,
Secret: pn.Name,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("create peermap failed: %w", err) return nil, fmt.Errorf("create peermap failed: %w", err)
} }
return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort)) return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort))
} }

View File

@@ -1,17 +1,13 @@
package fileshare package fileshare
import ( import (
"bytes"
"context" "context"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net" "net"
"net/url" "net/url"
"os"
"path" "path"
"strconv" "strconv"
"strings" "strings"
@@ -20,22 +16,79 @@ import (
"github.com/rkonfj/peerguard/rdt" "github.com/rkonfj/peerguard/rdt"
) )
type FileHandle struct {
Filename string
c net.Conn
index uint16
fSize uint32
f io.Reader
}
func (h *FileHandle) Handshake(offset uint32, sha256Checksum []byte) error {
_, err := h.c.Write(buildGet(h.index, offset, sha256Checksum))
if err != nil {
return err
}
header := make([]byte, 5)
_, err = io.ReadFull(h.c, header)
if err != nil {
return err
}
switch header[0] {
case 0:
case 20:
case 1:
return errors.New("bad request. maybe the version is lower than peer")
case 2:
return errors.New("file not found")
case 4:
return errors.New("download file size is less than local file")
case 5:
return errors.New("local file is not part of the file to be downloaded")
default:
return errors.New("invalid protocol header")
}
if offset > 0 && header[0] != 20 {
return errors.New("sha256 checksum non matched for [0, offset)")
}
h.fSize = binary.BigEndian.Uint32(header[1:])
h.f = io.LimitReader(h.c, int64(h.fSize-offset))
return nil
}
func (h *FileHandle) File() (io.Reader, uint32, error) {
if h.f == nil {
return nil, 0, errors.New("handshake first")
}
return h.f, h.fSize, nil
}
func (h *FileHandle) Sha256() ([]byte, error) {
checksum := make([]byte, 32)
if _, err := io.ReadFull(h.c, checksum); err != nil {
return nil, fmt.Errorf("read checksum failed: %w", err)
}
return checksum, nil
}
type Read func(f *FileHandle) error
type Downloader struct { type Downloader struct {
Network string Network string
Server string Server string
PrivateKey string PrivateKey string
UDPPort int ListenUDPPort int
ProgressBar func(total int64, desc string) ProgressBar
} }
func (d *Downloader) Download(ctx context.Context, shareURL string) error { func (d *Downloader) Request(ctx context.Context, shareURL string, read Read) error {
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey} pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
packetConn, err := pnet.ListenPacket(d.UDPPort) packetConn, err := pnet.ListenPacket(d.ListenUDPPort)
if err != nil { if err != nil {
return fmt.Errorf("listen p2p packet failed: %w", err) return fmt.Errorf("listen p2p packet failed: %w", err)
} }
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.UDPPort+100))) listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.ListenUDPPort+100)))
if err != nil { if err != nil {
return fmt.Errorf("listen rdt: %w", err) return fmt.Errorf("listen rdt: %w", err)
} }
@@ -66,77 +119,12 @@ func (d *Downloader) Download(ctx context.Context, shareURL string) error {
conn.Write(buildClose()) conn.Write(buildClose())
conn.Close() conn.Close()
}() }()
return d.download(conn, uint16(index), fn)
}
func (d *Downloader) download(conn net.Conn, index uint16, filename string) error {
f, err := os.OpenFile(filename, os.O_RDWR, 0666)
if err != nil {
f, err = os.Create(filename)
if err != nil {
return err
}
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
partSize := stat.Size()
sha256Checksum := sha256.New()
if partSize > 0 {
fmt.Println("download resuming")
}
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
return err
}
_, err = conn.Write(buildGet(uint16(index), uint32(stat.Size()), sha256Checksum.Sum(nil)))
if err != nil {
return err
}
header := make([]byte, 5)
_, err = io.ReadFull(conn, header)
if err != nil {
return err
}
switch header[0] {
case 0:
case 20:
case 1:
return errors.New("bad request. maybe the version is lower than peer")
case 2:
return errors.New("file not found")
case 4:
return errors.New("download file size is less than local file")
default:
return errors.New("invalid protocol header")
}
fileSize := binary.BigEndian.Uint32(header[1:])
var bar ProgressBar = NopProgress{}
if d.ProgressBar != nil {
bar = d.ProgressBar(int64(fileSize), filename)
}
bar.Add(int(stat.Size()))
defer conn.Write(buildClose()) defer conn.Write(buildClose())
return read(&FileHandle{
_, err = io.CopyN(io.MultiWriter(f, bar, sha256Checksum), conn, int64(fileSize-uint32(partSize))) Filename: fn,
if err != nil && !errors.Is(err, io.EOF) { c: conn,
return fmt.Errorf("download file falied: %w", err) index: uint16(index),
} })
checksum := make([]byte, 32)
if _, err = io.ReadFull(conn, checksum); err != nil {
return fmt.Errorf("read checksum failed: %w", err)
}
recvSum := sha256Checksum.Sum(nil)
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
if !bytes.Equal(checksum, recvSum) {
return fmt.Errorf("download file failed: checksum mismatched")
}
fmt.Printf("sha256: %x\n", checksum)
return nil
} }
func buildGet(index uint16, partSize uint32, checksum []byte) []byte { func buildGet(index uint16, partSize uint32, checksum []byte) []byte {

View File

@@ -24,7 +24,7 @@ type FileManager struct {
Network string Network string
Server string Server string
PrivateKey string PrivateKey string
UDPPort int ListenUDPPort int
ProgressBar func(total int64, desc string) ProgressBar ProgressBar func(total int64, desc string) ProgressBar
mutex sync.RWMutex mutex sync.RWMutex
@@ -36,12 +36,12 @@ type FileManager struct {
func (m *FileManager) ListenNetwork() (net.Listener, error) { func (m *FileManager) ListenNetwork() (net.Listener, error) {
pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey} pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey}
packetConn, err := pnet.ListenPacket(m.UDPPort) packetConn, err := pnet.ListenPacket(m.ListenUDPPort)
if err != nil { if err != nil {
return nil, fmt.Errorf("listen p2p packet failed: %w", err) return nil, fmt.Errorf("listen p2p packet failed: %w", err)
} }
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.UDPPort+100))) listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.ListenUDPPort+100)))
if err != nil { if err != nil {
return nil, fmt.Errorf("listen rdt: %w", err) return nil, fmt.Errorf("listen rdt: %w", err)
} }
@@ -154,13 +154,10 @@ func (m *FileManager) handleRequest(peerID string, conn net.Conn) {
} }
io.CopyN(sha256Checksum, f, int64(partSize)) io.CopyN(sha256Checksum, f, int64(partSize))
if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) { if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) {
if _, err = f.Seek(0, io.SeekStart); err != nil { conn.Write(buildErr(5)) // not part of file
conn.Write(buildErr(10)) slog.Error("Request not part of file", "file", f.Name())
slog.Error("SeekToStart", "err", err)
return return
} }
sha256Checksum = sha256.New()
}
} }
pos, err := f.Seek(0, io.SeekCurrent) pos, err := f.Seek(0, io.SeekCurrent)