84 lines
1.9 KiB
Go
84 lines
1.9 KiB
Go
package service
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
)
|
|
|
|
// Counter 接口定义了进度追踪器应实现的方法
|
|
type Counter interface {
|
|
Write(p []byte) (int, error)
|
|
SetTotal(total int64)
|
|
}
|
|
|
|
// WriteCounter 实现了 Counter 接口
|
|
type WriteCounter struct {
|
|
Total int64
|
|
Current int64
|
|
Rate int64
|
|
Package string
|
|
ProgressCb ProgressCallback
|
|
}
|
|
|
|
// ProgressCallback 是用于报告下载进度的函数类型
|
|
type ProgressCallback func(bytesDownloaded, totalBytes int64, rate int64, packageName string)
|
|
|
|
func (w *WriteCounter) Write(p []byte) (int, error) {
|
|
n := len(p)
|
|
w.Current += int64(n)
|
|
w.Rate = w.Current * 100 / w.Total
|
|
if w.ProgressCb != nil {
|
|
w.ProgressCb(w.Current, w.Total, w.Rate, w.Package)
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func (w *WriteCounter) SetTotal(total int64) {
|
|
w.Total = total
|
|
}
|
|
|
|
// EnhancedDownload 下载文件并报告进度
|
|
func EnhancedDownload(url string, path string, overwrite bool, progressCb ProgressCallback) error {
|
|
// 如果路径为空,使用URL中的文件名
|
|
if path == "" {
|
|
path = filepath.Base(url)
|
|
}
|
|
|
|
// 检查文件是否已存在
|
|
if _, err := os.Stat(path); err == nil && !overwrite {
|
|
return fmt.Errorf("文件 %s 已存在且未设置覆盖", path)
|
|
}
|
|
|
|
// 创建HTTP请求
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
return fmt.Errorf("request %s error: %v", url, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 创建输出文件
|
|
out, err := os.Create(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer out.Close()
|
|
|
|
// 创建一个WriteCounter来跟踪进度
|
|
counter := &WriteCounter{
|
|
Total: resp.ContentLength,
|
|
Package: filepath.Base(path),
|
|
ProgressCb: progressCb,
|
|
}
|
|
|
|
// 使用TeeReader来同时写入文件和更新进度
|
|
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
|
|
if err != nil {
|
|
return fmt.Errorf("write error: %v", err)
|
|
}
|
|
|
|
fmt.Printf("\ndownload [ %v ] -> [ %s ] success\n", url, path)
|
|
return nil
|
|
} |