108 lines
3 KiB
Go
108 lines
3 KiB
Go
package reddit
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
|
|
"github.com/alecthomas/units"
|
|
"github.com/tigorlazuardi/redmage/api/bmessage"
|
|
"github.com/tigorlazuardi/redmage/pkg/errs"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
type DownloadStatusBroadcaster interface {
|
|
Broadcast(bmessage.ImageDownloadMessage)
|
|
}
|
|
|
|
type NullDownloadStatusBroadcaster struct{}
|
|
|
|
func (NullDownloadStatusBroadcaster) Broadcast(bmessage.ImageDownloadMessage) {}
|
|
|
|
type PostImage struct {
|
|
ImageURL string
|
|
ImageFile io.Reader
|
|
ThumbnailURL string
|
|
ThumbnailFile io.Reader
|
|
}
|
|
|
|
func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) {
|
|
imageUrl, thumbnailUrl := post.GetImageURL(), post.GetThumbnailURL()
|
|
image.ImageURL = imageUrl
|
|
image.ThumbnailURL = thumbnailUrl
|
|
|
|
group, groupCtx := errgroup.WithContext(ctx)
|
|
group.Go(func() error {
|
|
var err error
|
|
image.ImageFile, err = reddit.downloadImage(groupCtx, post, bmessage.KindImage, broadcaster)
|
|
return err
|
|
})
|
|
group.Go(func() error {
|
|
var err error
|
|
image.ThumbnailFile, err = reddit.downloadImage(groupCtx, post, bmessage.KindThumbnail, broadcaster)
|
|
return err
|
|
})
|
|
|
|
err = group.Wait()
|
|
return image, err
|
|
}
|
|
|
|
func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessage.ImageKind, broadcaster DownloadStatusBroadcaster) (io.Reader, error) {
|
|
var (
|
|
url string
|
|
height int
|
|
width int
|
|
)
|
|
if kind == bmessage.KindImage {
|
|
url = post.GetImageURL()
|
|
width, height = post.GetImageSize()
|
|
} else {
|
|
url = post.GetThumbnailURL()
|
|
width, height = post.GetThumbnailSize()
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, reddit.Config.Duration("download.timeout.headers"))
|
|
defer cancel()
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return nil, errs.Wrapw(err, "reddit: failed to create request", "url", url)
|
|
}
|
|
resp, err := reddit.Client.Do(req)
|
|
if err != nil {
|
|
return nil, errs.Wrapw(err, "reddit: failed to execute request", "url", url)
|
|
}
|
|
idleSpeedStr := reddit.Config.String("download.timeout.idlespeed")
|
|
metricSpeed, _ := units.ParseMetricBytes(idleSpeedStr)
|
|
if metricSpeed == 0 {
|
|
metricSpeed = 10 * units.KB
|
|
}
|
|
idr := &ImageDownloadReader{
|
|
OnProgress: func(downloaded int64, contentLength int64, err error) {
|
|
broadcaster.Broadcast(bmessage.ImageDownloadMessage{
|
|
Metadata: bmessage.ImageMetadata{
|
|
URL: url,
|
|
Height: height,
|
|
Width: width,
|
|
Kind: kind,
|
|
},
|
|
ContentLength: units.MetricBytes(resp.ContentLength),
|
|
Downloaded: units.MetricBytes(downloaded),
|
|
Subreddit: post.GetSubreddit(),
|
|
PostURL: post.GetPermalink(),
|
|
PostID: post.GetID(),
|
|
Error: err,
|
|
})
|
|
},
|
|
IdleTimeout: reddit.Config.Duration("download.timeout.idle"),
|
|
IdleSpeedThreshold: metricSpeed,
|
|
}
|
|
|
|
resp = idr.WrapHTTPResponse(resp)
|
|
reader, writer := io.Pipe()
|
|
go func() {
|
|
defer resp.Body.Close()
|
|
_, err := io.Copy(writer, resp.Body)
|
|
_ = writer.CloseWithError(err)
|
|
}()
|
|
return reader, nil
|
|
}
|