Skip to content

Commit ec90fa5

Browse files
authored
feat: support to download as stream with bearer token auth (#488)
* feat: add bearer token support * support to download as stream * fix the code errors * upgrade ubuntu 20.04 to 22.04 * add comment for the exported functions --------- Co-authored-by: Rick <linuxsuren@users.noreply.github.com>
1 parent d375e85 commit ec90fa5

File tree

7 files changed

+230
-23
lines changed

7 files changed

+230
-23
lines changed

.github/workflows/backup.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
BackupGit:
11-
runs-on: ubuntu-20.04
11+
runs-on: ubuntu-22.04
1212
steps:
1313
- uses: actions/checkout@v3.6.0
1414
- name: backup

.github/workflows/coverage-report.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77

88
jobs:
99
TestAndReport:
10-
runs-on: ubuntu-20.04
10+
runs-on: ubuntu-22.04
1111
steps:
1212
- name: Set up Go
1313
uses: actions/setup-go@v4

.github/workflows/pull-request.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ env:
1010
jobs:
1111
build:
1212
name: Build
13-
runs-on: ubuntu-20.04
13+
runs-on: ubuntu-22.04
1414
steps:
1515
- name: Set up Go
1616
uses: actions/setup-go@v4
@@ -51,7 +51,7 @@ jobs:
5151

5252
GoLint:
5353
name: Lint
54-
runs-on: ubuntu-20.04
54+
runs-on: ubuntu-22.04
5555
steps:
5656
- name: Set up Go
5757
uses: actions/setup-go@v4
@@ -66,7 +66,7 @@ jobs:
6666
golint-path: ./...
6767
Security:
6868
name: Security
69-
runs-on: ubuntu-20.04
69+
runs-on: ubuntu-22.04
7070
env:
7171
GO111MODULE: on
7272
steps:
@@ -78,7 +78,7 @@ jobs:
7878
args: '-exclude=G402,G204,G304,G110,G306,G107 ./...'
7979
CodeQL:
8080
name: CodeQL
81-
runs-on: ubuntu-20.04
81+
runs-on: ubuntu-22.04
8282
env:
8383
GO111MODULE: on
8484
steps:
@@ -92,15 +92,15 @@ jobs:
9292
uses: github/codeql-action/analyze@v1
9393
MarkdownLinkCheck:
9494
name: MarkdownLinkCheck
95-
runs-on: ubuntu-20.04
95+
runs-on: ubuntu-22.04
9696
steps:
9797
- uses: actions/checkout@v3.6.0
9898
- uses: gaurav-nelson/github-action-markdown-link-check@1.0.13
9999
with:
100100
use-verbose-mode: 'yes'
101101

102102
image:
103-
runs-on: ubuntu-20.04
103+
runs-on: ubuntu-22.04
104104
steps:
105105
- name: Checkout
106106
uses: actions/checkout@v4

.github/workflows/release-drafter.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77

88
jobs:
99
UpdateReleaseDraft:
10-
runs-on: ubuntu-20.04
10+
runs-on: ubuntu-22.04
1111
steps:
1212
- uses: release-drafter/release-drafter@v5
1313
env:

.github/workflows/release.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ env:
99

1010
jobs:
1111
goreleaser:
12-
runs-on: ubuntu-20.04
12+
runs-on: ubuntu-22.04
1313
steps:
1414
- name: Checkout
1515
uses: actions/checkout@v3.6.0
@@ -43,7 +43,7 @@ jobs:
4343
oras push ${{ env.REGISTRY }}/linuxsuren/hd:$TAG release
4444
4545
image:
46-
runs-on: ubuntu-20.04
46+
runs-on: ubuntu-22.04
4747
steps:
4848
- name: Checkout
4949
uses: actions/checkout@v4

pkg/net/http.go

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ func (h *HTTPDownloader) fetchProxyFromEnv(scheme string) {
9797
}
9898
}
9999

100-
// DownloadFile download a file with the progress
101-
func (h *HTTPDownloader) DownloadFile() error {
100+
// DownloadAsStream downloads the file as stream
101+
func (h *HTTPDownloader) DownloadAsStream(writer io.Writer) (err error) {
102102
filepath, downloadURL, showProgress := h.TargetFilePath, h.URL, h.ShowProgress
103103
// Get the data
104104
if h.Context == nil {
@@ -115,7 +115,10 @@ func (h *HTTPDownloader) DownloadFile() error {
115115

116116
if h.UserName != "" && h.Password != "" {
117117
req.SetBasicAuth(h.UserName, h.Password)
118+
} else if h.Password != "" {
119+
req.Header.Set("Authorization", "Bearer "+h.Password)
118120
}
121+
119122
var tr http.RoundTripper
120123
if h.RoundTripper != nil {
121124
tr = h.RoundTripper
@@ -178,22 +181,32 @@ func (h *HTTPDownloader) DownloadFile() error {
178181
}
179182
}
180183

181-
if err := os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil {
182-
return err
184+
h.progressIndicator.Writer = writer
185+
h.progressIndicator.Init()
186+
187+
// Write the body to file
188+
_, err = io.Copy(h.progressIndicator, resp.Body)
189+
return
190+
}
191+
192+
// DownloadFile download a file with the progress
193+
func (h *HTTPDownloader) DownloadFile() (err error) {
194+
filepath := h.TargetFilePath
195+
if err = os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil {
196+
return
183197
}
184198

185199
// Create the file
186-
out, err := os.Create(filepath)
200+
var out io.WriteCloser
201+
out, err = os.Create(filepath)
187202
if err != nil {
188-
_ = out.Close()
189-
return err
203+
return
190204
}
205+
defer func() {
206+
_ = out.Close()
207+
}()
191208

192-
h.progressIndicator.Writer = out
193-
h.progressIndicator.Init()
194-
195-
// Write the body to file
196-
_, err = io.Copy(h.progressIndicator, resp.Body)
209+
err = h.DownloadAsStream(out)
197210
return err
198211
}
199212

@@ -269,6 +282,39 @@ func (c *ContinueDownloader) WithBasicAuth(username, password string) *ContinueD
269282
return c
270283
}
271284

285+
// DownloadWithContinueAsStream downloads the files continuously
286+
func (c *ContinueDownloader) DownloadWithContinueAsStream(targetURL string, output io.Writer, index, continueAt, end int64, showProgress bool) (err error) {
287+
c.downloader = &HTTPDownloader{
288+
URL: targetURL,
289+
ShowProgress: showProgress,
290+
NoProxy: c.noProxy,
291+
RoundTripper: c.roundTripper,
292+
InsecureSkipVerify: c.insecureSkipVerify,
293+
UserName: c.UserName,
294+
Password: c.Password,
295+
Context: c.Context,
296+
Timeout: c.Timeout,
297+
}
298+
if index >= 0 {
299+
c.downloader.Title = fmt.Sprintf("Downloading part %d", index)
300+
}
301+
302+
if continueAt >= 0 {
303+
c.downloader.Header = make(map[string]string, 1)
304+
305+
if end > continueAt {
306+
c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-%d", continueAt, end)
307+
} else {
308+
c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", continueAt)
309+
}
310+
}
311+
312+
if err = c.downloader.DownloadAsStream(output); err != nil {
313+
err = fmt.Errorf("cannot download from %s, error: %v", targetURL, err)
314+
}
315+
return
316+
}
317+
272318
// DownloadWithContinue downloads the files continuously
273319
func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, index, continueAt, end int64, showProgress bool) (err error) {
274320
c.downloader = &HTTPDownloader{
@@ -303,6 +349,45 @@ func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, inde
303349
return
304350
}
305351

352+
// DetectSizeWithRoundTripperAndAuthStream returns the size of target resource
353+
func DetectSizeWithRoundTripperAndAuthStream(targetURL string, output io.Writer, showProgress, noProxy, insecureSkipVerify bool,
354+
roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) {
355+
downloader := HTTPDownloader{
356+
URL: targetURL,
357+
ShowProgress: showProgress,
358+
RoundTripper: roundTripper,
359+
NoProxy: false, // below HTTP request does not need proxy
360+
InsecureSkipVerify: insecureSkipVerify,
361+
UserName: username,
362+
Password: password,
363+
Timeout: timeout,
364+
}
365+
366+
var detectOffset int64
367+
var lenErr error
368+
369+
detectOffset = 2
370+
downloader.Header = make(map[string]string, 1)
371+
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", detectOffset)
372+
373+
downloader.PreStart = func(resp *http.Response) bool {
374+
rangeSupport = resp.StatusCode == http.StatusPartialContent
375+
contentLen := resp.Header.Get("Content-Length")
376+
if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil {
377+
total += detectOffset
378+
} else {
379+
rangeSupport = false
380+
}
381+
// always return false because we just want to get the header from response
382+
return false
383+
}
384+
385+
if err = downloader.DownloadAsStream(output); err != nil || lenErr != nil {
386+
err = fmt.Errorf("cannot download from %s, response error: %v, content length error: %v", targetURL, err, lenErr)
387+
}
388+
return
389+
}
390+
306391
// DetectSizeWithRoundTripperAndAuth returns the size of target resource
307392
func DetectSizeWithRoundTripperAndAuth(targetURL, output string, showProgress, noProxy, insecureSkipVerify bool,
308393
roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) {

pkg/net/multi_thread.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package net
33
import (
44
"context"
55
"fmt"
6+
"io"
67
"net/http"
78
"os"
89
"os/signal"
@@ -70,6 +71,127 @@ func (d *MultiThreadDownloader) WithBasicAuth(username, password string) *MultiT
7071
return d
7172
}
7273

74+
// WithBearerToken sets the bearer token
75+
func (d *MultiThreadDownloader) WithBearerToken(bearerToken string) *MultiThreadDownloader {
76+
d.password = bearerToken
77+
return d
78+
}
79+
80+
// DownloadWithContext starts to download the target URL with context
81+
func (d *MultiThreadDownloader) DownloadWithContext(ctx context.Context, targetURL string, outputWriter io.Writer, thread int) (err error) {
82+
// get the total size of the target file
83+
var total int64
84+
var rangeSupport bool
85+
if total, rangeSupport, err = DetectSizeWithRoundTripperAndAuthStream(targetURL, outputWriter, d.showProgress,
86+
d.noProxy, d.insecureSkipVerify, d.roundTripper, d.username, d.password, d.timeout); rangeSupport && err != nil {
87+
return
88+
}
89+
90+
if rangeSupport {
91+
unit := total / int64(thread)
92+
offset := total - unit*int64(thread)
93+
var wg sync.WaitGroup
94+
var m sync.Mutex
95+
partItems := make(map[int]string)
96+
97+
defer func() {
98+
// remove all partial files
99+
for _, part := range partItems {
100+
_ = os.RemoveAll(part)
101+
}
102+
}()
103+
104+
c := make(chan os.Signal, 1)
105+
signal.Notify(c, os.Interrupt)
106+
ctx, cancel := context.WithCancel(context.Background())
107+
var canceled bool
108+
109+
go func() {
110+
<-c
111+
canceled = true
112+
cancel()
113+
}()
114+
115+
fmt.Printf("start to download with %d threads, size: %d, unit: %d", thread, total, unit)
116+
for i := 0; i < thread; i++ {
117+
fmt.Println() // TODO take position, should take over by progerss bars
118+
wg.Add(1)
119+
go func(index int, wg *sync.WaitGroup, ctx context.Context) {
120+
defer wg.Done()
121+
outputFile, err := os.CreateTemp(os.TempDir(), fmt.Sprintf("part-%d", index))
122+
if err != nil {
123+
fmt.Println("failed to create template file", err)
124+
}
125+
outputFile.Close()
126+
127+
m.Lock()
128+
partItems[index] = outputFile.Name()
129+
m.Unlock()
130+
131+
end := unit*int64(index+1) - 1
132+
if index == thread-1 {
133+
// this is the last part
134+
end += offset
135+
}
136+
start := unit * int64(index)
137+
138+
downloader := &ContinueDownloader{}
139+
downloader.WithoutProxy(d.noProxy).
140+
WithRoundTripper(d.roundTripper).
141+
WithInsecureSkipVerify(d.insecureSkipVerify).
142+
WithBasicAuth(d.username, d.password).
143+
WithContext(ctx).WithTimeout(d.timeout)
144+
if downloadErr := downloader.DownloadWithContinue(targetURL, outputFile.Name(),
145+
int64(index), start, end, d.showProgress); downloadErr != nil {
146+
fmt.Println(downloadErr)
147+
}
148+
}(i, &wg, ctx)
149+
}
150+
151+
wg.Wait()
152+
// ProgressIndicator{}.Close()
153+
if canceled {
154+
err = fmt.Errorf("download process canceled")
155+
return
156+
}
157+
158+
// make the cursor right
159+
// TODO the progress component should take over it
160+
if thread > 1 {
161+
// line := GetCurrentLine()
162+
time.Sleep(time.Second)
163+
fmt.Printf("\033[%dE\n", thread) // move to the target line
164+
time.Sleep(time.Second * 5)
165+
}
166+
167+
for i := 0; i < thread; i++ {
168+
partFile := partItems[i]
169+
if data, ferr := os.ReadFile(partFile); ferr == nil {
170+
if _, err = outputWriter.Write(data); err != nil {
171+
err = fmt.Errorf("failed to write file: '%s'", partFile)
172+
break
173+
} else if !d.keepParts {
174+
_ = os.RemoveAll(partFile)
175+
}
176+
} else {
177+
err = fmt.Errorf("failed to read file: '%s'", partFile)
178+
break
179+
}
180+
}
181+
} else {
182+
fmt.Println("cannot download it using multiple threads, failed to one")
183+
downloader := &ContinueDownloader{}
184+
downloader.WithoutProxy(d.noProxy)
185+
downloader.WithRoundTripper(d.roundTripper)
186+
downloader.WithInsecureSkipVerify(d.insecureSkipVerify)
187+
downloader.WithTimeout(d.timeout)
188+
downloader.WithBasicAuth(d.username, d.password)
189+
err = downloader.DownloadWithContinueAsStream(targetURL, outputWriter, -1, 0, 0, true)
190+
d.suggestedFilename = downloader.GetSuggestedFilename()
191+
}
192+
return
193+
}
194+
73195
// Download starts to download the target URL
74196
func (d *MultiThreadDownloader) Download(targetURL, targetFilePath string, thread int) (err error) {
75197
// get the total size of the target file

0 commit comments

Comments
 (0)