Skip to content

Commit e518038

Browse files
authored
self upgrade multi thread (#4)
* Refactory with multiple thread download * Remove commented code lines
1 parent 4452afd commit e518038

File tree

3 files changed

+168
-160
lines changed

3 files changed

+168
-160
lines changed

cmd/root.go

Lines changed: 3 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,8 @@ import (
55
extver "github.com/linuxsuren/cobra-extension/version"
66
"github.com/linuxsuren/http-downloader/pkg"
77
"github.com/spf13/cobra"
8-
"io/ioutil"
9-
"net/http"
10-
"os"
118
"runtime"
12-
"strconv"
139
"strings"
14-
"sync"
1510
)
1611

1712
// NewRoot returns the root command
@@ -90,7 +85,7 @@ func (o *downloadOption) providerURLParse(path string) (url string, err error) {
9085

9186
if len(addr) == 3 {
9287
name = addr[2]
93-
} else {
88+
} else if len(addr) > 3 {
9489
err = fmt.Errorf("only support format xx/xx or xx/xx/xx")
9590
}
9691

@@ -141,121 +136,9 @@ func (o *downloadOption) preRunE(cmd *cobra.Command, args []string) (err error)
141136

142137
func (o *downloadOption) runE(cmd *cobra.Command, args []string) (err error) {
143138
if o.Thread <= 1 {
144-
err = o.download(o.Output, o.ContinueAt, 0)
139+
err = pkg.DownloadWithContinue(o.URL, o.Output, o.ContinueAt, 0, o.ShowProgress)
145140
} else {
146-
// get the total size of the target file
147-
var total int64
148-
var rangeSupport bool
149-
if total, rangeSupport, err = o.detectSize(o.Output); err != nil {
150-
return
151-
}
152-
153-
if rangeSupport {
154-
unit := total / int64(o.Thread)
155-
offset := total - unit*int64(o.Thread)
156-
var wg sync.WaitGroup
157-
158-
cmd.Printf("start to download with %d threads, size: %d, unit: %d\n", o.Thread, total, unit)
159-
for i := 0; i < o.Thread; i++ {
160-
wg.Add(1)
161-
go func(index int, wg *sync.WaitGroup) {
162-
defer wg.Done()
163-
164-
end := unit*int64(index+1) - 1
165-
if index == o.Thread-1 {
166-
// this is the last part
167-
end += offset
168-
}
169-
start := unit * int64(index)
170-
171-
if downloadErr := o.download(fmt.Sprintf("%s-%d", o.Output, index), start, end); downloadErr != nil {
172-
cmd.PrintErrln(downloadErr)
173-
}
174-
}(i, &wg)
175-
}
176-
177-
wg.Wait()
178-
179-
// concat all these partial files
180-
var f *os.File
181-
if f, err = os.OpenFile(o.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
182-
defer func() {
183-
_ = f.Close()
184-
}()
185-
186-
for i := 0; i < o.Thread; i++ {
187-
partFile := fmt.Sprintf("%s-%d", o.Output, i)
188-
if data, ferr := ioutil.ReadFile(partFile); ferr == nil {
189-
if _, err = f.Write(data); err != nil {
190-
err = fmt.Errorf("failed to write file: '%s'", partFile)
191-
break
192-
} else {
193-
_ = os.RemoveAll(partFile)
194-
}
195-
} else {
196-
err = fmt.Errorf("failed to read file: '%s'", partFile)
197-
break
198-
}
199-
}
200-
}
201-
} else {
202-
cmd.Println("cannot download it using multiple threads, failed to one")
203-
err = o.download(o.Output, o.ContinueAt, 0)
204-
}
205-
}
206-
return
207-
}
208-
209-
func (o *downloadOption) detectSize(output string) (total int64, rangeSupport bool, err error) {
210-
downloader := pkg.HTTPDownloader{
211-
TargetFilePath: output,
212-
URL: o.URL,
213-
ShowProgress: o.ShowProgress,
214-
}
215-
216-
var detectOffset int64
217-
var lenErr error
218-
219-
detectOffset = 2
220-
downloader.Header = make(map[string]string, 1)
221-
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", detectOffset)
222-
223-
downloader.PreStart = func(resp *http.Response) bool {
224-
rangeSupport = resp.StatusCode == http.StatusPartialContent
225-
contentLen := resp.Header.Get("Content-Length")
226-
if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil {
227-
total += detectOffset
228-
}
229-
// always return false because we just want to get the header from response
230-
return false
231-
}
232-
233-
if err = downloader.DownloadFile(); err != nil || lenErr != nil {
234-
err = fmt.Errorf("cannot download from %s, response error: %v, content length error: %v", o.URL, err, lenErr)
235-
}
236-
return
237-
}
238-
239-
func (o *downloadOption) download(output string, continueAt, end int64) (err error) {
240-
downloader := pkg.HTTPDownloader{
241-
TargetFilePath: output,
242-
URL: o.URL,
243-
ShowProgress: o.ShowProgress,
244-
}
245-
246-
if continueAt >= 0 {
247-
downloader.Header = make(map[string]string, 1)
248-
249-
//fmt.Println("range", continueAt, end)
250-
if end > continueAt {
251-
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-%d", continueAt, end)
252-
} else {
253-
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", continueAt)
254-
}
255-
}
256-
257-
if err = downloader.DownloadFile(); err != nil {
258-
err = fmt.Errorf("cannot download from %s, error: %v", o.URL, err)
141+
err = pkg.DownloadFileWithMultipleThread(o.URL, o.Output, o.Thread, o.ShowProgress)
259142
}
260143
return
261144
}

pkg/http.go

Lines changed: 108 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import (
55
"encoding/base64"
66
"fmt"
77
"io"
8+
"io/ioutil"
89
"net/http"
910
"net/url"
1011
"os"
1112
"path"
1213
"strconv"
13-
14-
"github.com/gosuri/uiprogress"
14+
"sync"
1515
)
1616

1717
const (
@@ -39,6 +39,8 @@ type HTTPDownloader struct {
3939
// PreStart returns false will don't continue
4040
PreStart func(*http.Response) bool
4141

42+
Thread int
43+
4244
Debug bool
4345
RoundTripper http.RoundTripper
4446
}
@@ -85,9 +87,6 @@ func (h *HTTPDownloader) fetchProxyFromEnv(scheme string) {
8587
}
8688
}
8789

88-
//Range: bytes=10-
89-
//HTTP/1.1 206 Partial Content
90-
9190
// DownloadFile download a file with the progress
9291
func (h *HTTPDownloader) DownloadFile() error {
9392
filepath, downloadURL, showProgress := h.TargetFilePath, h.URL, h.ShowProgress
@@ -175,52 +174,121 @@ func (h *HTTPDownloader) DownloadFile() error {
175174
return err
176175
}
177176

178-
// ProgressIndicator hold the progress of io operation
179-
type ProgressIndicator struct {
180-
Writer io.Writer
181-
Reader io.Reader
182-
Title string
177+
// DownloadFileWithMultipleThread downloads the files with multiple threads
178+
func DownloadFileWithMultipleThread(targetURL, targetFilePath string, thread int, showProgress bool) (err error) {
179+
// get the total size of the target file
180+
var total int64
181+
var rangeSupport bool
182+
if total, rangeSupport, err = DetectSize(targetURL, targetFilePath, true); err != nil {
183+
return
184+
}
185+
186+
if rangeSupport {
187+
unit := total / int64(thread)
188+
offset := total - unit*int64(thread)
189+
var wg sync.WaitGroup
190+
191+
fmt.Printf("start to download with %d threads, size: %d, unit: %d\n", thread, total, unit)
192+
for i := 0; i < thread; i++ {
193+
wg.Add(1)
194+
go func(index int, wg *sync.WaitGroup) {
195+
defer wg.Done()
196+
197+
end := unit*int64(index+1) - 1
198+
if index == thread-1 {
199+
// this is the last part
200+
end += offset
201+
}
202+
start := unit * int64(index)
203+
204+
if downloadErr := DownloadWithContinue(targetURL, fmt.Sprintf("%s-%d", targetFilePath, index), start, end, showProgress); downloadErr != nil {
205+
fmt.Println(downloadErr)
206+
}
207+
}(i, &wg)
208+
}
183209

184-
// bytes.Buffer
185-
Total float64
186-
count float64
187-
bar *uiprogress.Bar
210+
wg.Wait()
211+
212+
// concat all these partial files
213+
var f *os.File
214+
if f, err = os.OpenFile(targetFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil {
215+
defer func() {
216+
_ = f.Close()
217+
}()
218+
219+
for i := 0; i < thread; i++ {
220+
partFile := fmt.Sprintf("%s-%d", targetFilePath, i)
221+
if data, ferr := ioutil.ReadFile(partFile); ferr == nil {
222+
if _, err = f.Write(data); err != nil {
223+
err = fmt.Errorf("failed to write file: '%s'", partFile)
224+
break
225+
} else {
226+
_ = os.RemoveAll(partFile)
227+
}
228+
} else {
229+
err = fmt.Errorf("failed to read file: '%s'", partFile)
230+
break
231+
}
232+
}
233+
}
234+
} else {
235+
fmt.Println("cannot download it using multiple threads, failed to one")
236+
err = DownloadWithContinue(targetURL, targetFilePath, 0, 0, true)
237+
}
238+
return
188239
}
189240

190-
// Init set the default value for progress indicator
191-
func (i *ProgressIndicator) Init() {
192-
uiprogress.Start() // start rendering
193-
i.bar = uiprogress.AddBar(100) // Add a new bar
241+
// DownloadWithContinue downloads the files continuously
242+
func DownloadWithContinue(targetURL, output string, continueAt, end int64, showProgress bool) (err error) {
243+
downloader := HTTPDownloader{
244+
TargetFilePath: output,
245+
URL: targetURL,
246+
ShowProgress: showProgress,
247+
}
194248

195-
// optionally, append and prepend completion and elapsed time
196-
i.bar.AppendCompleted()
197-
// i.bar.PrependElapsed()
249+
if continueAt >= 0 {
250+
downloader.Header = make(map[string]string, 1)
198251

199-
if i.Title != "" {
200-
i.bar.PrependFunc(func(_ *uiprogress.Bar) string {
201-
return fmt.Sprintf("%s: ", i.Title)
202-
})
252+
if end > continueAt {
253+
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-%d", continueAt, end)
254+
} else {
255+
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", continueAt)
256+
}
203257
}
204-
}
205258

206-
// Write writes the progress
207-
func (i *ProgressIndicator) Write(p []byte) (n int, err error) {
208-
n, err = i.Writer.Write(p)
209-
i.setBar(n)
259+
if err = downloader.DownloadFile(); err != nil {
260+
err = fmt.Errorf("cannot download from %s, error: %v", targetURL, err)
261+
}
210262
return
211263
}
212264

213-
// Read reads the progress
214-
func (i *ProgressIndicator) Read(p []byte) (n int, err error) {
215-
n, err = i.Reader.Read(p)
216-
i.setBar(n)
217-
return
218-
}
265+
// DetectSize returns the size of target resource
266+
func DetectSize(targetURL, output string, showProgress bool) (total int64, rangeSupport bool, err error) {
267+
downloader := HTTPDownloader{
268+
TargetFilePath: output,
269+
URL: targetURL,
270+
ShowProgress: showProgress,
271+
}
272+
273+
var detectOffset int64
274+
var lenErr error
275+
276+
detectOffset = 2
277+
downloader.Header = make(map[string]string, 1)
278+
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", detectOffset)
219279

220-
func (i *ProgressIndicator) setBar(n int) {
221-
i.count += float64(n)
280+
downloader.PreStart = func(resp *http.Response) bool {
281+
rangeSupport = resp.StatusCode == http.StatusPartialContent
282+
contentLen := resp.Header.Get("Content-Length")
283+
if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil {
284+
total += detectOffset
285+
}
286+
// always return false because we just want to get the header from response
287+
return false
288+
}
222289

223-
if i.bar != nil {
224-
i.bar.Set((int)(i.count * 100 / i.Total))
290+
if err = downloader.DownloadFile(); err != nil || lenErr != nil {
291+
err = fmt.Errorf("cannot download from %s, response error: %v, content length error: %v", targetURL, err, lenErr)
225292
}
293+
return
226294
}

pkg/progress.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package pkg
2+
3+
import (
4+
"fmt"
5+
"github.com/gosuri/uiprogress"
6+
"io"
7+
)
8+
9+
// ProgressIndicator hold the progress of io operation
10+
type ProgressIndicator struct {
11+
Writer io.Writer
12+
Reader io.Reader
13+
Title string
14+
15+
// bytes.Buffer
16+
Total float64
17+
count float64
18+
bar *uiprogress.Bar
19+
}
20+
21+
// Init set the default value for progress indicator
22+
func (i *ProgressIndicator) Init() {
23+
uiprogress.Start() // start rendering
24+
i.bar = uiprogress.AddBar(100) // Add a new bar
25+
26+
// optionally, append and prepend completion and elapsed time
27+
i.bar.AppendCompleted()
28+
// i.bar.PrependElapsed()
29+
30+
if i.Title != "" {
31+
i.bar.PrependFunc(func(_ *uiprogress.Bar) string {
32+
return fmt.Sprintf("%s: ", i.Title)
33+
})
34+
}
35+
}
36+
37+
// Write writes the progress
38+
func (i *ProgressIndicator) Write(p []byte) (n int, err error) {
39+
n, err = i.Writer.Write(p)
40+
i.setBar(n)
41+
return
42+
}
43+
44+
// Read reads the progress
45+
func (i *ProgressIndicator) Read(p []byte) (n int, err error) {
46+
n, err = i.Reader.Read(p)
47+
i.setBar(n)
48+
return
49+
}
50+
51+
func (i *ProgressIndicator) setBar(n int) {
52+
i.count += float64(n)
53+
54+
if i.bar != nil {
55+
i.bar.Set((int)(i.count * 100 / i.Total))
56+
}
57+
}

0 commit comments

Comments
 (0)