mirror of
https://github.com/sigcn/pg.git
synced 2025-10-01 00:22:07 +08:00
fileshare: a better downloader api
This commit is contained in:
@@ -1,9 +1,12 @@
|
|||||||
package download
|
package download
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -36,7 +39,7 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
slog.SetLogLoggerLevel(slog.Level(verbose))
|
slog.SetLogLoggerLevel(slog.Level(verbose))
|
||||||
|
|
||||||
downloader := fileshare.Downloader{UDPPort: 29879, ProgressBar: createBar}
|
downloader := fileshare.Downloader{ListenUDPPort: 29879}
|
||||||
|
|
||||||
downloader.Server, err = cmd.Flags().GetString("server")
|
downloader.Server, err = cmd.Flags().GetString("server")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -55,8 +58,51 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
return downloader.Request(ctx, args[0], readFile)
|
||||||
|
}
|
||||||
|
|
||||||
return downloader.Download(ctx, args[0])
|
func readFile(fh *fileshare.FileHandle) error {
|
||||||
|
f, err := os.OpenFile(fh.Filename, os.O_RDWR, 0666)
|
||||||
|
if err != nil {
|
||||||
|
f, err = os.Create(fh.Filename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
stat, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
partSize := stat.Size()
|
||||||
|
|
||||||
|
sha256Checksum := sha256.New()
|
||||||
|
if partSize > 0 {
|
||||||
|
fmt.Println("download resuming")
|
||||||
|
}
|
||||||
|
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fh.Handshake(uint32(partSize), sha256Checksum.Sum(nil)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r, fileSize, _ := fh.File()
|
||||||
|
bar := createBar(int64(fileSize), fh.Filename)
|
||||||
|
bar.Add(int(partSize))
|
||||||
|
if _, err = io.Copy(io.MultiWriter(f, bar, sha256Checksum), r); err != nil {
|
||||||
|
return fmt.Errorf("download file falied: %w", err)
|
||||||
|
}
|
||||||
|
checksum, err := fh.Sha256()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
recvSum := sha256Checksum.Sum(nil)
|
||||||
|
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
|
||||||
|
if !bytes.Equal(checksum, recvSum) {
|
||||||
|
return fmt.Errorf("download file failed: checksum mismatched")
|
||||||
|
}
|
||||||
|
fmt.Printf("sha256: %x\n", checksum)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createBar(total int64, desc string) fileshare.ProgressBar {
|
func createBar(total int64, desc string) fileshare.ProgressBar {
|
||||||
|
@@ -37,7 +37,7 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
slog.SetLogLoggerLevel(slog.Level(verbose))
|
slog.SetLogLoggerLevel(slog.Level(verbose))
|
||||||
|
|
||||||
fileManager := fileshare.FileManager{UDPPort: 29878, ProgressBar: createBar}
|
fileManager := fileshare.FileManager{ListenUDPPort: 29878, ProgressBar: createBar}
|
||||||
|
|
||||||
if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil {
|
if fileManager.Server, err = cmd.Flags().GetString("server"); err != nil {
|
||||||
return err
|
return err
|
||||||
|
49
fileshare/README.md
Normal file
49
fileshare/README.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# fileshare
|
||||||
|
|
||||||
|
A p2p file sharing library
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
#### download
|
||||||
|
```go
|
||||||
|
downloader := &fileshare.Downloader{
|
||||||
|
Server: "wss://synf.in/pg",
|
||||||
|
ListenUDPPort: 29999,
|
||||||
|
}
|
||||||
|
|
||||||
|
read := func(fh *fileshare.FileHandle) error {
|
||||||
|
// handshake (can set the offset to facilitate breakpoint resuming)
|
||||||
|
if err := fh.Handshake(0, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reader, fileSize, err := fh.File()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f, err := os.Create(fh.Filename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = io.Copy(io.MultiWriter(f, sum), reader); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerSum, err := fh.Sha256() // file checksum from peer
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(sum.Sum(nil), peerSum) { // assert that local and remote are consistent
|
||||||
|
return errors.New("transfer error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||||
|
defer cancel()
|
||||||
|
err := downloader.Request(ctx, "pg://DJX2csRurJ3DvKeh63JebVHFDqVhnFjckdVhToAAiPYf/0/my-show.pptx", read)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
```
|
@@ -1,6 +1,7 @@
|
|||||||
package fileshare
|
package fileshare
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -22,15 +23,11 @@ func (pn *PublicNetwork) ListenPacket(udpPort int) (net.PacketConn, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid peermap URL: %w", err)
|
return nil, fmt.Errorf("invalid peermap URL: %w", err)
|
||||||
}
|
}
|
||||||
|
network := cmp.Or(pn.Name, "pubnet")
|
||||||
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{
|
pmap, err := peermap.New(pmapURL, &peer.NetworkSecret{Network: network, Secret: network})
|
||||||
Network: pn.Name,
|
|
||||||
Secret: pn.Name,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create peermap failed: %w", err)
|
return nil, fmt.Errorf("create peermap failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort))
|
return p2p.ListenPacket(pmap, pn.secureOption(), p2p.ListenUDPPort(udpPort))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,17 +1,13 @@
|
|||||||
package fileshare
|
package fileshare
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -20,22 +16,79 @@ import (
|
|||||||
"github.com/rkonfj/peerguard/rdt"
|
"github.com/rkonfj/peerguard/rdt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type FileHandle struct {
|
||||||
|
Filename string
|
||||||
|
|
||||||
|
c net.Conn
|
||||||
|
index uint16
|
||||||
|
fSize uint32
|
||||||
|
f io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FileHandle) Handshake(offset uint32, sha256Checksum []byte) error {
|
||||||
|
_, err := h.c.Write(buildGet(h.index, offset, sha256Checksum))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
header := make([]byte, 5)
|
||||||
|
_, err = io.ReadFull(h.c, header)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch header[0] {
|
||||||
|
case 0:
|
||||||
|
case 20:
|
||||||
|
case 1:
|
||||||
|
return errors.New("bad request. maybe the version is lower than peer")
|
||||||
|
case 2:
|
||||||
|
return errors.New("file not found")
|
||||||
|
case 4:
|
||||||
|
return errors.New("download file size is less than local file")
|
||||||
|
case 5:
|
||||||
|
return errors.New("local file is not part of the file to be downloaded")
|
||||||
|
default:
|
||||||
|
return errors.New("invalid protocol header")
|
||||||
|
}
|
||||||
|
if offset > 0 && header[0] != 20 {
|
||||||
|
return errors.New("sha256 checksum non matched for [0, offset)")
|
||||||
|
}
|
||||||
|
h.fSize = binary.BigEndian.Uint32(header[1:])
|
||||||
|
h.f = io.LimitReader(h.c, int64(h.fSize-offset))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FileHandle) File() (io.Reader, uint32, error) {
|
||||||
|
if h.f == nil {
|
||||||
|
return nil, 0, errors.New("handshake first")
|
||||||
|
}
|
||||||
|
return h.f, h.fSize, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *FileHandle) Sha256() ([]byte, error) {
|
||||||
|
checksum := make([]byte, 32)
|
||||||
|
if _, err := io.ReadFull(h.c, checksum); err != nil {
|
||||||
|
return nil, fmt.Errorf("read checksum failed: %w", err)
|
||||||
|
}
|
||||||
|
return checksum, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Read func(f *FileHandle) error
|
||||||
|
|
||||||
type Downloader struct {
|
type Downloader struct {
|
||||||
Network string
|
Network string
|
||||||
Server string
|
Server string
|
||||||
PrivateKey string
|
PrivateKey string
|
||||||
UDPPort int
|
ListenUDPPort int
|
||||||
ProgressBar func(total int64, desc string) ProgressBar
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Downloader) Download(ctx context.Context, shareURL string) error {
|
func (d *Downloader) Request(ctx context.Context, shareURL string, read Read) error {
|
||||||
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
|
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
|
||||||
packetConn, err := pnet.ListenPacket(d.UDPPort)
|
packetConn, err := pnet.ListenPacket(d.ListenUDPPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("listen p2p packet failed: %w", err)
|
return fmt.Errorf("listen p2p packet failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.UDPPort+100)))
|
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.ListenUDPPort+100)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("listen rdt: %w", err)
|
return fmt.Errorf("listen rdt: %w", err)
|
||||||
}
|
}
|
||||||
@@ -66,77 +119,12 @@ func (d *Downloader) Download(ctx context.Context, shareURL string) error {
|
|||||||
conn.Write(buildClose())
|
conn.Write(buildClose())
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}()
|
}()
|
||||||
return d.download(conn, uint16(index), fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Downloader) download(conn net.Conn, index uint16, filename string) error {
|
|
||||||
f, err := os.OpenFile(filename, os.O_RDWR, 0666)
|
|
||||||
if err != nil {
|
|
||||||
f, err = os.Create(filename)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
stat, err := f.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
partSize := stat.Size()
|
|
||||||
|
|
||||||
sha256Checksum := sha256.New()
|
|
||||||
if partSize > 0 {
|
|
||||||
fmt.Println("download resuming")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = conn.Write(buildGet(uint16(index), uint32(stat.Size()), sha256Checksum.Sum(nil)))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
header := make([]byte, 5)
|
|
||||||
_, err = io.ReadFull(conn, header)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
switch header[0] {
|
|
||||||
case 0:
|
|
||||||
case 20:
|
|
||||||
case 1:
|
|
||||||
return errors.New("bad request. maybe the version is lower than peer")
|
|
||||||
case 2:
|
|
||||||
return errors.New("file not found")
|
|
||||||
case 4:
|
|
||||||
return errors.New("download file size is less than local file")
|
|
||||||
default:
|
|
||||||
return errors.New("invalid protocol header")
|
|
||||||
}
|
|
||||||
|
|
||||||
fileSize := binary.BigEndian.Uint32(header[1:])
|
|
||||||
var bar ProgressBar = NopProgress{}
|
|
||||||
if d.ProgressBar != nil {
|
|
||||||
bar = d.ProgressBar(int64(fileSize), filename)
|
|
||||||
}
|
|
||||||
bar.Add(int(stat.Size()))
|
|
||||||
defer conn.Write(buildClose())
|
defer conn.Write(buildClose())
|
||||||
|
return read(&FileHandle{
|
||||||
_, err = io.CopyN(io.MultiWriter(f, bar, sha256Checksum), conn, int64(fileSize-uint32(partSize)))
|
Filename: fn,
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
c: conn,
|
||||||
return fmt.Errorf("download file falied: %w", err)
|
index: uint16(index),
|
||||||
}
|
})
|
||||||
checksum := make([]byte, 32)
|
|
||||||
if _, err = io.ReadFull(conn, checksum); err != nil {
|
|
||||||
return fmt.Errorf("read checksum failed: %w", err)
|
|
||||||
}
|
|
||||||
recvSum := sha256Checksum.Sum(nil)
|
|
||||||
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
|
|
||||||
if !bytes.Equal(checksum, recvSum) {
|
|
||||||
return fmt.Errorf("download file failed: checksum mismatched")
|
|
||||||
}
|
|
||||||
fmt.Printf("sha256: %x\n", checksum)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildGet(index uint16, partSize uint32, checksum []byte) []byte {
|
func buildGet(index uint16, partSize uint32, checksum []byte) []byte {
|
||||||
|
@@ -24,7 +24,7 @@ type FileManager struct {
|
|||||||
Network string
|
Network string
|
||||||
Server string
|
Server string
|
||||||
PrivateKey string
|
PrivateKey string
|
||||||
UDPPort int
|
ListenUDPPort int
|
||||||
ProgressBar func(total int64, desc string) ProgressBar
|
ProgressBar func(total int64, desc string) ProgressBar
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
@@ -36,12 +36,12 @@ type FileManager struct {
|
|||||||
|
|
||||||
func (m *FileManager) ListenNetwork() (net.Listener, error) {
|
func (m *FileManager) ListenNetwork() (net.Listener, error) {
|
||||||
pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey}
|
pnet := PublicNetwork{Name: m.Network, Server: m.Server, PrivateKey: m.PrivateKey}
|
||||||
packetConn, err := pnet.ListenPacket(m.UDPPort)
|
packetConn, err := pnet.ListenPacket(m.ListenUDPPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("listen p2p packet failed: %w", err)
|
return nil, fmt.Errorf("listen p2p packet failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.UDPPort+100)))
|
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", m.ListenUDPPort+100)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("listen rdt: %w", err)
|
return nil, fmt.Errorf("listen rdt: %w", err)
|
||||||
}
|
}
|
||||||
@@ -154,13 +154,10 @@ func (m *FileManager) handleRequest(peerID string, conn net.Conn) {
|
|||||||
}
|
}
|
||||||
io.CopyN(sha256Checksum, f, int64(partSize))
|
io.CopyN(sha256Checksum, f, int64(partSize))
|
||||||
if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) {
|
if !bytes.Equal(sha256Checksum.Sum(nil), partChecksum) {
|
||||||
if _, err = f.Seek(0, io.SeekStart); err != nil {
|
conn.Write(buildErr(5)) // not part of file
|
||||||
conn.Write(buildErr(10))
|
slog.Error("Request not part of file", "file", f.Name())
|
||||||
slog.Error("SeekToStart", "err", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sha256Checksum = sha256.New()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pos, err := f.Seek(0, io.SeekCurrent)
|
pos, err := f.Seek(0, io.SeekCurrent)
|
||||||
|
Reference in New Issue
Block a user