Redmage/api/download_subreddit_images.go

293 lines
7.7 KiB
Go
Raw Normal View History

2024-04-09 22:37:26 +07:00
package api
import (
"context"
"errors"
2024-04-14 00:32:55 +07:00
"image/jpeg"
"io"
"math"
2024-04-10 17:13:07 +07:00
"net/http"
2024-04-14 00:32:55 +07:00
"net/url"
"os"
"path"
"strings"
"sync"
2024-04-14 00:32:55 +07:00
"github.com/disintegration/imaging"
"github.com/tigorlazuardi/redmage/api/reddit"
"github.com/tigorlazuardi/redmage/db/queries"
"github.com/tigorlazuardi/redmage/pkg/errs"
2024-04-14 00:32:55 +07:00
"github.com/tigorlazuardi/redmage/pkg/log"
"github.com/tigorlazuardi/redmage/pkg/telemetry"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
2024-04-09 22:37:26 +07:00
type DownloadSubredditParams struct {
2024-04-14 00:32:55 +07:00
Countback int
Devices []queries.Device
SubredditType reddit.SubredditType
2024-04-09 22:37:26 +07:00
}
var (
2024-04-10 17:13:07 +07:00
ErrNoDevices = errors.New("api: no devices set")
ErrDownloadDirNotSet = errors.New("api: download directory not set")
)
2024-04-09 22:37:26 +07:00
func (api *API) DownloadSubredditImages(ctx context.Context, subredditName string, params DownloadSubredditParams) error {
2024-04-10 17:13:07 +07:00
downloadDir := api.config.String("download.directory")
if downloadDir == "" {
return errs.Wrapw(ErrDownloadDirNotSet, "download directory must be set before images can be downloaded").Code(http.StatusBadRequest)
}
if len(params.Devices) == 0 {
2024-04-10 17:13:07 +07:00
return errs.Wrapw(ErrNoDevices, "downloading images requires at least one device configured").Code(http.StatusBadRequest)
}
2024-04-10 17:13:07 +07:00
2024-04-14 00:32:55 +07:00
ctx, span := tracer.Start(ctx, "*API.DownloadSubredditImages", trace.WithAttributes(attribute.String("subreddit", subredditName)))
defer span.End()
wg := sync.WaitGroup{}
countback := params.Countback
for page := 1; countback > 0; page += 1 {
limit := countback
if limit > 100 {
limit = 100
}
list, err := api.reddit.GetPosts(ctx, reddit.GetPostsParam{
Subreddit: subredditName,
Limit: limit,
Page: page,
SubredditType: params.SubredditType,
})
if err != nil {
return errs.Wrapw(err, "failed to get posts", "subreddit_name", subredditName, "params", params)
}
wg.Add(1)
go func(ctx context.Context, posts reddit.Listing) {
defer wg.Done()
err := api.downloadSubredditListImage(ctx, list, params)
if err != nil {
log.New(ctx).Err(err).Error("failed to download image")
}
}(ctx, list)
countback -= len(list.GetPosts())
}
wg.Wait()
return nil
}
func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.Listing, params DownloadSubredditParams) error {
ctx, span := tracer.Start(ctx, "*API.downloadSubredditImage")
defer span.End()
wg := sync.WaitGroup{}
for _, post := range list.GetPosts() {
if !post.IsImagePost() {
continue
}
devices := getDevicesThatAcceptPost(post, params.Devices)
if len(devices) == 0 {
continue
}
wg.Add(1)
api.imageSemaphore <- struct{}{}
go func(ctx context.Context, post reddit.Post) {
defer func() {
<-api.imageSemaphore
wg.Done()
}()
imageHandler, err := api.reddit.DownloadImage(ctx, post, api.downloadBroadcast)
if err != nil {
log.New(ctx).Err(err).Error("failed to download image")
return
}
defer imageHandler.Close()
// copy to temp dir first to avoid copying incomplete files.
tmpImageFile, err := api.copyImageToTempDir(ctx, imageHandler)
if err != nil {
log.New(ctx).Err(err).Error("failed to download image to temp file")
return
}
defer tmpImageFile.Close()
w, close, err := api.createDeviceImageWriters(post, devices)
if err != nil {
log.New(ctx).Err(err).Error("failed to create image files")
return
}
defer close()
_, err = io.Copy(w, tmpImageFile)
if err != nil {
log.New(ctx).Err(err).Error("failed to create save image files")
return
}
thumbnailPath := post.GetThumbnailTargetPath(api.config)
_, errStat := os.Stat(thumbnailPath)
if errStat == nil {
// file exist
return
}
if !errors.Is(errStat, os.ErrNotExist) {
log.New(ctx).Err(err).Error("failed to check thumbail existence", "path", thumbnailPath)
return
}
thumbnailSource, err := imaging.Open(tmpImageFile.filename)
if err != nil {
log.New(ctx).Err(err).Error("failed to open temp thumbnail file", "filename", tmpImageFile.filename)
return
}
thumbnail := imaging.Resize(thumbnailSource, 256, 0, imaging.Lanczos)
thumbnailFile, err := os.Create(thumbnailPath)
if err != nil {
log.New(ctx).Err(err).Error("failed to create thumbnail file", "filename", thumbnailPath)
return
}
defer thumbnailFile.Close()
err = jpeg.Encode(thumbnailFile, thumbnail, nil)
if err != nil {
log.New(ctx).Err(err).Error("failed to encode thumbnail file to jpeg", "filename", thumbnailPath)
return
}
}(ctx, post)
}
wg.Wait()
2024-04-09 22:37:26 +07:00
return nil
}
2024-04-14 00:32:55 +07:00
func (api *API) createDeviceImageWriters(post reddit.Post, devices []queries.Device) (writer io.Writer, close func(), err error) {
// open file for each device
var files []*os.File
var writers []io.Writer
for _, device := range devices {
var filename string
if device.WindowsWallpaperMode == 1 {
filename = post.GetWindowsWallpaperImageTargetPath(api.config, device)
} else {
filename = post.GetImageTargetPath(api.config, device)
}
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
for _, f := range files {
_ = f.Close()
}
return nil, nil, errs.Wrapw(err, "failed to open temp image file",
"device_name", device.Name,
"filename", filename,
)
}
files = append(files, file)
writers = append(writers, file)
}
return io.MultiWriter(writers...), func() {
for _, file := range files {
_ = file.Close()
}
}, nil
}
func getDevicesThatAcceptPost(post reddit.Post, devices []queries.Device) []queries.Device {
var devs []queries.Device
for _, device := range devices {
if shouldDownloadPostForDevice(post, device) {
devs = append(devices, device)
}
}
return devs
}
func shouldDownloadPostForDevice(post reddit.Post, device queries.Device) bool {
if post.IsNSFW() && device.Nsfw == 0 {
return false
}
if math.Abs(deviceAspectRatio(device)-post.GetImageAspectRatio()) > device.AspectRatioTolerance { // outside of aspect ratio tolerance
return false
}
width, height := post.GetImageSize()
if device.MaxX > 0 && width > device.MaxX {
return false
}
if device.MaxY > 0 && height > device.MaxY {
return false
}
if device.MinX > 0 && width < device.MinX {
return false
}
if device.MinY > 0 && height < device.MinY {
return false
}
return true
}
func deviceAspectRatio(device queries.Device) float64 {
return float64(device.ResolutionX) / float64(device.ResolutionY)
}
type tempFile struct {
filename string
file *os.File
}
func (te *tempFile) Read(p []byte) (n int, err error) {
return te.file.Read(p)
}
func (te *tempFile) Close() error {
return te.file.Close()
}
// copyImageToTempDir copies the image to a temporary directory and returns the file handle
//
// file must be closed by the caller after use.
//
// file is nil if an error occurred.
func (api *API) copyImageToTempDir(ctx context.Context, img reddit.PostImage) (tmp *tempFile, err error) {
_, span := tracer.Start(ctx, "*API.copyImageToTempDir")
defer func() { telemetry.EndWithStatus(span, err) }()
// ignore error because url is always valid if this
// function is called
url, _ := url.Parse(img.URL)
split := strings.Split(url.Path, "/")
imageFilename := split[len(split)-1]
tmpDirname := path.Join(os.TempDir(), "redmage")
_ = os.MkdirAll(tmpDirname, 0644)
tmpFilename := path.Join(tmpDirname, imageFilename)
file, err := os.OpenFile(tmpFilename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return nil, errs.Wrapw(err, "failed to open temp image file",
"temp_file_path", tmpFilename,
"image_url", img.URL,
)
}
_, err = io.Copy(file, img.File)
if err != nil {
_ = file.Close()
return nil, errs.Wrapw(err, "failed to download image to temp file",
"temp_file_path", tmpFilename,
"image_url", img.URL,
)
}
return &tempFile{
file: file,
filename: tmpFilename,
}, err
}