mirror of
https://github.com/sigcn/pg.git
synced 2025-10-05 15:26:52 +08:00
158 lines
3.7 KiB
Go
158 lines
3.7 KiB
Go
package fileshare
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/rkonfj/peerguard/peer"
|
|
"github.com/rkonfj/peerguard/rdt"
|
|
)
|
|
|
|
type Downloader struct {
|
|
Network string
|
|
Server string
|
|
PrivateKey string
|
|
UDPPort int
|
|
ProgressBar func(total int64, desc string) ProgressBar
|
|
}
|
|
|
|
func (d *Downloader) Download(ctx context.Context, shareURL string) error {
|
|
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
|
|
packetConn, err := pnet.ListenPacket(d.UDPPort)
|
|
if err != nil {
|
|
return fmt.Errorf("listen p2p packet failed: %w", err)
|
|
}
|
|
|
|
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.UDPPort+100)))
|
|
if err != nil {
|
|
return fmt.Errorf("listen rdt: %w", err)
|
|
}
|
|
|
|
resourceURL, err := url.Parse(shareURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
|
|
dir, filename := path.Split(resourceURL.Path)
|
|
index, err := strconv.ParseInt(strings.Trim(dir, "/"), 10, 16)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
|
|
fn, err := url.QueryUnescape(filename)
|
|
if err != nil {
|
|
fn = filename
|
|
}
|
|
|
|
conn, err := listener.OpenStream(peer.ID(resourceURL.Host))
|
|
if err != nil {
|
|
return fmt.Errorf("dial server failed: %w", err)
|
|
}
|
|
defer conn.Close()
|
|
go func() { // watch exit program event
|
|
<-ctx.Done()
|
|
conn.Write(buildClose())
|
|
conn.Close()
|
|
}()
|
|
return d.download(conn, uint16(index), fn)
|
|
}
|
|
|
|
func (d *Downloader) download(conn net.Conn, index uint16, filename string) error {
|
|
f, err := os.OpenFile(filename, os.O_RDWR, 0666)
|
|
if err != nil {
|
|
f, err = os.Create(filename)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
defer f.Close()
|
|
stat, err := f.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partSize := stat.Size()
|
|
|
|
sha256Checksum := sha256.New()
|
|
if partSize > 0 {
|
|
fmt.Println("download resuming")
|
|
}
|
|
|
|
if _, err = io.CopyN(sha256Checksum, f, partSize); err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Write(buildGet(uint16(index), uint32(stat.Size()), sha256Checksum.Sum(nil)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
header := make([]byte, 5)
|
|
_, err = io.ReadFull(conn, header)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
switch header[0] {
|
|
case 0:
|
|
case 20:
|
|
case 1:
|
|
return errors.New("bad request. maybe the version is lower than peer")
|
|
case 2:
|
|
return errors.New("file not found")
|
|
case 4:
|
|
return errors.New("download file size is less than local file")
|
|
default:
|
|
return errors.New("invalid protocol header")
|
|
}
|
|
|
|
fileSize := binary.BigEndian.Uint32(header[1:])
|
|
var bar ProgressBar = NopProgress{}
|
|
if d.ProgressBar != nil {
|
|
bar = d.ProgressBar(int64(fileSize), filename)
|
|
}
|
|
bar.Add(int(stat.Size()))
|
|
defer conn.Write(buildClose())
|
|
|
|
_, err = io.CopyN(io.MultiWriter(f, bar, sha256Checksum), conn, int64(fileSize-uint32(partSize)))
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
return fmt.Errorf("download file falied: %w", err)
|
|
}
|
|
checksum := make([]byte, 32)
|
|
if _, err = io.ReadFull(conn, checksum); err != nil {
|
|
return fmt.Errorf("read checksum failed: %w", err)
|
|
}
|
|
recvSum := sha256Checksum.Sum(nil)
|
|
slog.Debug("Checksum", "recv", recvSum, "send", checksum)
|
|
if !bytes.Equal(checksum, recvSum) {
|
|
return fmt.Errorf("download file failed: checksum mismatched")
|
|
}
|
|
fmt.Printf("sha256: %x\n", checksum)
|
|
return nil
|
|
}
|
|
|
|
func buildGet(index uint16, partSize uint32, checksum []byte) []byte {
|
|
header := []byte{0, 0}
|
|
if partSize > 0 {
|
|
header[1] = 36
|
|
}
|
|
header = append(header, binary.BigEndian.AppendUint16(nil, index)...)
|
|
if partSize > 0 {
|
|
header = append(header, binary.BigEndian.AppendUint32(nil, partSize)...)
|
|
header = append(header, checksum...)
|
|
}
|
|
return header
|
|
}
|
|
|
|
func buildClose() []byte {
|
|
return []byte{1}
|
|
}
|