MacFastLookup/service/download.go

84 lines
1.9 KiB
Go
Raw Normal View History

2024-07-23 16:32:49 +08:00
package service
import (
2024-07-25 10:06:09 +08:00
"fmt"
"io"
"net/http"
"os"
"path/filepath"
2024-07-23 16:32:49 +08:00
)
2024-07-25 10:06:09 +08:00
// Counter 接口定义了进度追踪器应实现的方法
type Counter interface {
Write(p []byte) (int, error)
SetTotal(total int64)
2024-07-23 16:32:49 +08:00
}
2024-07-25 10:06:09 +08:00
// WriteCounter 实现了 Counter 接口
type WriteCounter struct {
Total int64
Current int64
Rate int64
Package string
ProgressCb ProgressCallback
2024-07-23 16:32:49 +08:00
}
2024-07-25 10:06:09 +08:00
// ProgressCallback 是用于报告下载进度的函数类型
type ProgressCallback func(bytesDownloaded, totalBytes int64, rate int64, packageName string)
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
func (w *WriteCounter) Write(p []byte) (int, error) {
n := len(p)
w.Current += int64(n)
w.Rate = w.Current * 100 / w.Total
if w.ProgressCb != nil {
w.ProgressCb(w.Current, w.Total, w.Rate, w.Package)
}
return n, nil
}
func (w *WriteCounter) SetTotal(total int64) {
w.Total = total
}
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
// EnhancedDownload 下载文件并报告进度
func EnhancedDownload(url string, path string, overwrite bool, progressCb ProgressCallback) error {
// 如果路径为空,使用URL中的文件名
if path == "" {
path = filepath.Base(url)
}
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
// 检查文件是否已存在
if _, err := os.Stat(path); err == nil && !overwrite {
return fmt.Errorf("文件 %s 已存在且未设置覆盖", path)
}
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
// 创建HTTP请求
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("request %s error: %v", url, err)
}
defer resp.Body.Close()
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
// 创建输出文件
out, err := os.Create(path)
if err != nil {
return err
}
defer out.Close()
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
// 创建一个WriteCounter来跟踪进度
counter := &WriteCounter{
Total: resp.ContentLength,
Package: filepath.Base(path),
ProgressCb: progressCb,
}
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
// 使用TeeReader来同时写入文件和更新进度
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
if err != nil {
return fmt.Errorf("write error: %v", err)
}
2024-07-23 16:32:49 +08:00
2024-07-25 10:06:09 +08:00
fmt.Printf("\ndownload [ %v ] -> [ %s ] success\n", url, path)
return nil
}