Redmage/api/reddit/download_images.go

108 lines
3 KiB
Go

package reddit
import (
"context"
"io"
"net/http"
"github.com/alecthomas/units"
"github.com/tigorlazuardi/redmage/api/bmessage"
"github.com/tigorlazuardi/redmage/pkg/errs"
"golang.org/x/sync/errgroup"
)
type DownloadStatusBroadcaster interface {
Broadcast(bmessage.ImageDownloadMessage)
}
type NullDownloadStatusBroadcaster struct{}
func (NullDownloadStatusBroadcaster) Broadcast(bmessage.ImageDownloadMessage) {}
type PostImage struct {
ImageURL string
ImageFile io.Reader
ThumbnailURL string
ThumbnailFile io.Reader
}
func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) {
imageUrl, thumbnailUrl := post.GetImageURL(), post.GetThumbnailURL()
image.ImageURL = imageUrl
image.ThumbnailURL = thumbnailUrl
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
var err error
image.ImageFile, err = reddit.downloadImage(groupCtx, post, bmessage.KindImage, broadcaster)
return err
})
group.Go(func() error {
var err error
image.ThumbnailFile, err = reddit.downloadImage(groupCtx, post, bmessage.KindThumbnail, broadcaster)
return err
})
err = group.Wait()
return image, err
}
func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessage.ImageKind, broadcaster DownloadStatusBroadcaster) (io.Reader, error) {
var (
url string
height int
width int
)
if kind == bmessage.KindImage {
url = post.GetImageURL()
width, height = post.GetImageSize()
} else {
url = post.GetThumbnailURL()
width, height = post.GetThumbnailSize()
}
ctx, cancel := context.WithTimeout(ctx, reddit.Config.Duration("download.timeout.headers"))
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, errs.Wrapw(err, "reddit: failed to create request", "url", url)
}
resp, err := reddit.Client.Do(req)
if err != nil {
return nil, errs.Wrapw(err, "reddit: failed to execute request", "url", url)
}
idleSpeedStr := reddit.Config.String("download.timeout.idlespeed")
metricSpeed, _ := units.ParseMetricBytes(idleSpeedStr)
if metricSpeed == 0 {
metricSpeed = 10 * units.KB
}
idr := &ImageDownloadReader{
OnProgress: func(downloaded int64, contentLength int64, err error) {
broadcaster.Broadcast(bmessage.ImageDownloadMessage{
Metadata: bmessage.ImageMetadata{
URL: url,
Height: height,
Width: width,
Kind: kind,
},
ContentLength: units.MetricBytes(resp.ContentLength),
Downloaded: units.MetricBytes(downloaded),
Subreddit: post.GetSubreddit(),
PostURL: post.GetPermalink(),
PostID: post.GetID(),
Error: err,
})
},
IdleTimeout: reddit.Config.Duration("download.timeout.idle"),
IdleSpeedThreshold: metricSpeed,
}
resp = idr.WrapHTTPResponse(resp)
reader, writer := io.Pipe()
go func() {
defer resp.Body.Close()
_, err := io.Copy(writer, resp.Body)
_ = writer.CloseWithError(err)
}()
return reader, nil
}