mirror of
https://github.com/sigcn/pg.git
synced 2025-09-26 22:05:50 +08:00
fileshare: a better downloader api
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
package download
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -36,7 +39,7 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
slog.SetLogLoggerLevel(slog.Level(verbose))
|
||||
|
||||
downloader := fileshare.Downloader{UDPPort: 29879, ProgressBar: createBar}
|
||||
downloader := fileshare.Downloader{ListenUDPPort: 29879}
|
||||
|
||||
downloader.Server, err = cmd.Flags().GetString("server")
|
||||
if err != nil {
|
||||
@@ -55,8 +58,51 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
return downloader.Request(ctx, args[0], readFile)
|
||||
}
|
||||
|
||||
return downloader.Download(ctx, args[0])
|
||||
func readFile(fh *fileshare.FileHandle) error {
|
||||
f, err := os.OpenFile(fh.Filename, os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
f, err = os.Create(fh.Filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
defer f.Close()
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
partSize := stat.Size()
|
||||
|
||||
sha256Checksum := sha256.New()
|
||||
if partSize > 0 {
|
||||
fmt.Println("download resuming")
|
||||
}
|
||||
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fh.Handshake(uint32(partSize), sha256Checksum.Sum(nil)); err != nil {
|
||||
return err
|
||||
}
|
||||
r, fileSize, _ := fh.File()
|
||||
bar := createBar(int64(fileSize), fh.Filename)
|
||||
bar.Add(int(partSize))
|
||||
if _, err = io.Copy(io.MultiWriter(f, bar, sha256Checksum), r); err != nil {
|
||||
return fmt.Errorf("download file falied: %w", err)
|
||||
}
|
||||
checksum, err := fh.Sha256()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
recvSum := sha256Checksum.Sum(nil)
|
||||
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
|
||||
if !bytes.Equal(checksum, recvSum) {
|
||||
return fmt.Errorf("download file failed: checksum mismatched")
|
||||
}
|
||||
fmt.Printf("sha256: %x\n", checksum)
|
||||
return nil
|
||||
}
|
||||
|
||||
func createBar(total int64, desc string) fileshare.ProgressBar {
|
||||
|
@@ -37,7 +37,7 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
slog.SetLogLoggerLevel(slog.Level(verbose))
|
||||
|
||||
fileManager := fileshare.FileManager{UDPPort: 29878, ProgressBar: createBar}
|
||||
fileManager := fileshare.FileManager{ListenUDPPort: 29878, ProgressBar: createBar}
|
||||
|
||||
if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil {
|
||||
return err
|
||||
|
49
fileshare/README.md
Normal file
49
fileshare/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# fileshare
|
||||
|
||||
A p2p file sharing library
|
||||
|
||||
### Example
|
||||
|
||||
#### download
|
||||
```go
|
||||
downloader := &fileshare.Downloader{
|
||||
Server: "wss://synf.in/pg",
|
||||
ListenUDPPort: 29999,
|
||||
}
|
||||
|
||||
read := func(fh *fileshare.FileHandle) error {
|
||||
// handshake (can set the offset to facilitate breakpoint resuming)
|
||||
if err := fh.Handshake(0, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
reader, fileSize, err := fh.File()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := os.Create(fh.Filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = io.Copy(io.MultiWriter(f, sum), reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerSum, err := fh.Sha256() // file checksum from peer
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(sum.Sum(nil), peerSum) { // assert that local and remote are consistent
|
||||
return errors.New("transfer error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
err := downloader.Request(ctx, "pg://DJX2csRurJ3DvKeh63JebVHFDqVhnFjckdVhToAAiPYf/0/my-show.pptx", read)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
@@ -1,6 +1,7 @@
|
||||
package fileshare
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -22,15 +23,11 @@ func (pn *PublicNetwork) ListenPacket(udpPort int) (net.PacketConn, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid peermap URL: %w", err)
|
||||
}
|
||||
|
||||
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{
|
||||
Network: pn.Name,
|
||||
Secret: pn.Name,
|
||||
})
|
||||
network := cmp.Or(pn.Name, "pubnet")
|
||||
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{Network: network, Secret: network})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create peermap failed: %w", err)
|
||||
}
|
||||
|
||||
return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort))
|
||||
}
|
||||
|
||||
|
@@ -1,17 +1,13 @@
|
||||
package fileshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -20,22 +16,79 @@ import (
|
||||
"github.com/rkonfj/peerguard/rdt"
|
||||
)
|
||||
|
||||
type Downloader struct {
|
||||
Network string
|
||||
Server string
|
||||
PrivateKey string
|
||||
UDPPort int
|
||||
ProgressBar func(total int64, desc string) ProgressBar
|
||||
type FileHandle struct {
|
||||
Filename string
|
||||
|
||||
c net.Conn
|
||||
index uint16
|
||||
fSize uint32
|
||||
f io.Reader
|
||||
}
|
||||
|
||||
func (d *Downloader) Download(ctx context.Context, shareURL string) error {
|
||||
func (h *FileHandle) Handshake(offset uint32, sha256Checksum []byte) error {
|
||||
_, err := h.c.Write(buildGet(h.index, offset, sha256Checksum))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header := make([]byte, 5)
|
||||
_, err = io.ReadFull(h.c, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch header[0] {
|
||||
case 0:
|
||||
case 20:
|
||||
case 1:
|
||||
return errors.New("bad request. maybe the version is lower than peer")
|
||||
case 2:
|
||||
return errors.New("file not found")
|
||||
case 4:
|
||||
return errors.New("download file size is less than local file")
|
||||
case 5:
|
||||
return errors.New("local file is not part of the file to be downloaded")
|
||||
default:
|
||||
return errors.New("invalid protocol header")
|
||||
}
|
||||
if offset > 0 && header[0] != 20 {
|
||||
return errors.New("sha256 checksum non matched for [0, offset)")
|
||||
}
|
||||
h.fSize = binary.BigEndian.Uint32(header[1:])
|
||||
h.f = io.LimitReader(h.c, int64(h.fSize-offset))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *FileHandle) File() (io.Reader, uint32, error) {
|
||||
if h.f == nil {
|
||||
return nil, 0, errors.New("handshake first")
|
||||
}
|
||||
return h.f, h.fSize, nil
|
||||
}
|
||||
|
||||
func (h *FileHandle) Sha256() ([]byte, error) {
|
||||
checksum := make([]byte, 32)
|
||||
if _, err := io.ReadFull(h.c, checksum); err != nil {
|
||||
return nil, fmt.Errorf("read checksum failed: %w", err)
|
||||
}
|
||||
return checksum, nil
|
||||
}
|
||||
|
||||
type Read func(f *FileHandle) error
|
||||
|
||||
type Downloader struct {
|
||||
Network string
|
||||
Server string
|
||||
PrivateKey string
|
||||
ListenUDPPort int
|
||||
}
|
||||
|
||||
func (d *Downloader) Request(ctx context.Context, shareURL string, read Read) error {
|
||||
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
|
||||
packetConn, err := pnet.ListenPacket(d.UDPPort)
|
||||
packetConn, err := pnet.ListenPacket(d.ListenUDPPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen p2p packet failed: %w", err)
|
||||
}
|
||||
|
||||
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.UDPPort+100)))
|
||||
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.ListenUDPPort+100)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen rdt: %w", err)
|
||||
}
|
||||
@@ -66,77 +119,12 @@ func (d *Downloader) Download(ctx context.Context, shareURL string) error {
|
||||
conn.Write(buildClose())
|
||||
conn.Close()
|
||||
}()
|
||||
return d.download(conn, uint16(index), fn)
|
||||
}
|
||||
|
||||
func (d *Downloader) download(conn net.Conn, index uint16, filename string) error {
|
||||
f, err := os.OpenFile(filename, os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
f, err = os.Create(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
defer f.Close()
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
partSize := stat.Size()
|
||||
|
||||
sha256Checksum := sha256.New()
|
||||
if partSize > 0 {
|
||||
fmt.Println("download resuming")
|
||||
}
|
||||
|
||||
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Write(buildGet(uint16(index), uint32(stat.Size()), sha256Checksum.Sum(nil)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header := make([]byte, 5)
|
||||
_, err = io.ReadFull(conn, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch header[0] {
|
||||
case 0:
|
||||
case 20:
|
||||
case 1:
|
||||
return errors.New("bad request. maybe the version is lower than peer")
|
||||
case 2:
|
||||
return errors.New("file not found")
|
||||
case 4:
|
||||
return errors.New("download file size is less than local file")
|
||||
default:
|
||||
return errors.New("invalid protocol header")
|
||||
}
|
||||
|
||||
fileSize := binary.BigEndian.Uint32(header[1:])
|
||||
var bar ProgressBar = NopProgress{}
|
||||
if d.ProgressBar != nil {
|
||||
bar = d.ProgressBar(int64(fileSize), filename)
|
||||
}
|
||||
bar.Add(int(stat.Size()))
|
||||
defer conn.Write(buildClose())
|
||||
|
||||
_, err = io.CopyN(io.MultiWriter(f, bar, sha256Checksum), conn, int64(fileSize-uint32(partSize)))
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("download file falied: %w", err)
|
||||
}
|
||||
checksum := make([]byte, 32)
|
||||
if _, err = io.ReadFull(conn, checksum); err != nil {
|
||||
return fmt.Errorf("read checksum failed: %w", err)
|
||||
}
|
||||
recvSum := sha256Checksum.Sum(nil)
|
||||
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
|
||||
if !bytes.Equal(checksum, recvSum) {
|
||||
return fmt.Errorf("download file failed: checksum mismatched")
|
||||
}
|
||||
fmt.Printf("sha256: %x\n", checksum)
|
||||
return nil
|
||||
return read(&FileHandle{
|
||||
Filename: fn,
|
||||
c: conn,
|
||||
index: uint16(index),
|
||||
})
|
||||
}
|
||||
|
||||
func buildGet(index uint16, partSize uint32, checksum []byte) []byte {
|
||||
|
@@ -21,11 +21,11 @@ import (
|
||||
)
|
||||
|
||||
type FileManager struct {
|
||||
Network string
|
||||
Server string
|
||||
PrivateKey string
|
||||
UDPPort int
|
||||
ProgressBar func(total int64, desc string) ProgressBar
|
||||
Network string
|
||||
Server string
|
||||
PrivateKey string
|
||||
ListenUDPPort int
|
||||
ProgressBar func(total int64, desc string) ProgressBar
|
||||
|
||||
mutex sync.RWMutex
|
||||
index int
|
||||
@@ -36,12 +36,12 @@ type FileManager struct {
|
||||
|
||||
func (m *FileManager) ListenNetwork() (net.Listener, error) {
|
||||
pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey}
|
||||
packetConn, err := pnet.ListenPacket(m.UDPPort)
|
||||
packetConn, err := pnet.ListenPacket(m.ListenUDPPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen p2p packet failed: %w", err)
|
||||
}
|
||||
|
||||
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.UDPPort+100)))
|
||||
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.ListenUDPPort+100)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen rdt: %w", err)
|
||||
}
|
||||
@@ -154,12 +154,9 @@ func (m *FileManager) handleRequest(peerID string, conn net.Conn) {
|
||||
}
|
||||
io.CopyN(sha256Checksum, f, int64(partSize))
|
||||
if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) {
|
||||
if _, err = f.Seek(0, io.SeekStart); err != nil {
|
||||
conn.Write(buildErr(10))
|
||||
slog.Error("SeekToStart", "err", err)
|
||||
return
|
||||
}
|
||||
sha256Checksum = sha256.New()
|
||||
conn.Write(buildErr(5)) // not part of file
|
||||
slog.Error("Request not part of file", "file", f.Name())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user