490 lines
14 KiB
Go
490 lines
14 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"image/jpeg"
|
|
"io"
|
|
"io/fs"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/aarondl/opt/omit"
|
|
"github.com/disintegration/imaging"
|
|
"github.com/tigorlazuardi/redmage/api/reddit"
|
|
"github.com/tigorlazuardi/redmage/models"
|
|
"github.com/tigorlazuardi/redmage/pkg/errs"
|
|
"github.com/tigorlazuardi/redmage/pkg/log"
|
|
"github.com/tigorlazuardi/redmage/pkg/telemetry"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
type DownloadSubredditParams struct {
|
|
Countback int
|
|
Devices models.DeviceSlice
|
|
SubredditType reddit.SubredditType
|
|
}
|
|
|
|
var (
|
|
ErrNoDevices = errors.New("api: no devices set")
|
|
ErrDownloadDirNotSet = errors.New("api: download directory not set")
|
|
)
|
|
|
|
func (api *API) DownloadSubredditImages(ctx context.Context, subreddit *models.Subreddit, devices models.DeviceSlice) error {
|
|
downloadDir := api.config.String("download.directory")
|
|
if downloadDir == "" {
|
|
return errs.Wrapw(ErrDownloadDirNotSet, "download directory must be set before images can be downloaded").Code(http.StatusBadRequest)
|
|
}
|
|
|
|
if len(devices) == 0 {
|
|
return errs.Wrapw(ErrNoDevices, "downloading images requires at least one device configured").Code(http.StatusBadRequest)
|
|
}
|
|
|
|
ctx, span := tracer.Start(ctx, "*API.DownloadSubredditImages", trace.WithAttributes(attribute.String("subreddit", subreddit.Name)))
|
|
defer span.End()
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
countback := int(subreddit.Countback)
|
|
|
|
var (
|
|
list reddit.Listing
|
|
err error
|
|
)
|
|
for countback > 0 {
|
|
limit := 100
|
|
if limit > countback {
|
|
limit = countback
|
|
}
|
|
log.New(ctx).Debug("getting posts", "subreddit", subreddit, "current_countback", countback, "current_limit", limit)
|
|
list, err = api.reddit.GetPosts(ctx, reddit.GetPostsParam{
|
|
Subreddit: subreddit.Name,
|
|
Limit: limit,
|
|
After: list.GetLastAfter(),
|
|
SubredditType: reddit.SubredditType(subreddit.Subtype),
|
|
})
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to get posts", "subreddit", subreddit)
|
|
}
|
|
wg.Add(1)
|
|
go func(ctx context.Context, posts reddit.Listing) {
|
|
defer wg.Done()
|
|
err := api.downloadSubredditListImage(ctx, list, subreddit, devices)
|
|
if err != nil {
|
|
log.New(ctx).Err(err).Error("failed to download image")
|
|
}
|
|
}(ctx, list)
|
|
if len(list.GetPosts()) == 0 {
|
|
break
|
|
}
|
|
countback -= len(list.GetPosts())
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.Listing, subreddit *models.Subreddit, devices models.DeviceSlice) error {
|
|
ctx, span := tracer.Start(ctx, "*API.downloadSubredditListImage")
|
|
defer span.End()
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
for _, post := range list.GetPosts() {
|
|
if !post.IsImagePost() {
|
|
continue
|
|
}
|
|
acceptedDevices := api.getDevicesThatAcceptPost(ctx, post, devices)
|
|
if len(acceptedDevices) == 0 {
|
|
continue
|
|
}
|
|
log.New(ctx).Debug("downloading image", "post_id", post.GetID(), "post_url", post.GetImageURL(), "devices", acceptedDevices)
|
|
wg.Add(1)
|
|
api.imageSemaphore <- struct{}{}
|
|
go func(ctx context.Context, post reddit.Post) {
|
|
defer func() {
|
|
<-api.imageSemaphore
|
|
wg.Done()
|
|
}()
|
|
|
|
if imageFile := api.findImageFileForDevices(ctx, post, devices); imageFile != nil {
|
|
err := api.saveImageToFSAndDatabase(ctx, imageFile, subreddit, post, acceptedDevices)
|
|
if err != nil {
|
|
log.New(ctx).Err(err).Error("failed to download subreddit image")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err := api.downloadSubredditImage(ctx, post, subreddit, acceptedDevices); err != nil {
|
|
log.New(ctx).Err(err).Error("failed to download subreddit image")
|
|
}
|
|
}(ctx, post)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, subreddit *models.Subreddit, devices models.DeviceSlice) error {
|
|
ctx, span := tracer.Start(ctx, "*API.downloadSubredditImage")
|
|
defer span.End()
|
|
|
|
imageHandler, err := api.reddit.DownloadImage(ctx, post, api.downloadBroadcast)
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to download image")
|
|
}
|
|
defer imageHandler.Close()
|
|
|
|
// copy to temp dir first to avoid copying incomplete files.
|
|
tmpImageFile, err := api.copyImageToTempDir(ctx, imageHandler)
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to download image to temp file")
|
|
}
|
|
defer tmpImageFile.Close()
|
|
|
|
thumbnailPath := post.GetThumbnailTargetPath(api.config)
|
|
_, errStat := os.Stat(thumbnailPath)
|
|
if errStat == nil {
|
|
// file exist
|
|
return nil
|
|
}
|
|
if !errors.Is(errStat, os.ErrNotExist) {
|
|
return errs.Wrapw(err, "failed to check thumbnail existence", "path", thumbnailPath)
|
|
}
|
|
|
|
_ = os.MkdirAll(post.GetThumbnailTargetDir(api.config), 0o777)
|
|
|
|
thumbnailSource, err := imaging.Open(tmpImageFile.filename)
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to open temp thumbnail file",
|
|
"filename", tmpImageFile.filename,
|
|
"post_url", post.GetPermalink(),
|
|
"image_url", post.GetImageURL(),
|
|
)
|
|
}
|
|
|
|
thumbnail := imaging.Resize(thumbnailSource, 256, 0, imaging.Lanczos)
|
|
thumbnailFile, err := os.Create(thumbnailPath)
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to create thumbnail file", "filename", thumbnailPath)
|
|
}
|
|
defer thumbnailFile.Close()
|
|
|
|
err = jpeg.Encode(thumbnailFile, thumbnail, nil)
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to encode thumbnail file to jpeg", "filename", thumbnailPath)
|
|
}
|
|
|
|
return api.saveImageToFSAndDatabase(ctx, tmpImageFile, subreddit, post, devices)
|
|
}
|
|
|
|
type stat interface {
|
|
Stat() (fs.FileInfo, error)
|
|
}
|
|
|
|
func (api *API) saveImageToFSAndDatabase(ctx context.Context, image io.ReadCloser, subreddit *models.Subreddit, post reddit.Post, devices models.DeviceSlice) (err error) {
|
|
ctx, span := tracer.Start(ctx, "*API.saveImageToFSAndDatabase")
|
|
defer span.End()
|
|
defer image.Close()
|
|
|
|
w, close, err := api.createDeviceImageWriters(post, devices)
|
|
defer close()
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to create image files")
|
|
}
|
|
log.New(ctx).Debug("saving image files", "post_id", post.GetID(), "post_url", post.GetImageURL(), "devices", devices)
|
|
size, err := io.Copy(w, image)
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to save image files")
|
|
}
|
|
if size == 0 {
|
|
if s, ok := image.(stat); ok {
|
|
if fi, err := s.Stat(); err == nil {
|
|
size = fi.Size()
|
|
}
|
|
}
|
|
}
|
|
|
|
var many []*models.ImageSetter
|
|
now := time.Now()
|
|
for _, device := range devices {
|
|
var nsfw int32
|
|
if post.IsNSFW() {
|
|
nsfw = 1
|
|
}
|
|
width, height := post.GetImageSize()
|
|
|
|
many = append(many, &models.ImageSetter{
|
|
Subreddit: omit.From(subreddit.Name),
|
|
Device: omit.From(device.Slug),
|
|
PostTitle: omit.From(post.GetTitle()),
|
|
PostURL: omit.From(post.GetPostURL()),
|
|
PostCreated: omit.From(post.GetCreated().Unix()),
|
|
PostName: omit.From(post.GetName()),
|
|
PostAuthor: omit.From(post.GetAuthor()),
|
|
PostAuthorURL: omit.From(post.GetAuthorURL()),
|
|
ImageWidth: omit.From(int32(width)),
|
|
ImageHeight: omit.From(int32(height)),
|
|
ImageSize: omit.From(size),
|
|
ImageRelativePath: omit.From(post.GetImageRelativePath(device)),
|
|
ThumbnailRelativePath: omit.From(post.GetThumbnailRelativePath()),
|
|
ImageOriginalURL: omit.From(post.GetImageURL()),
|
|
NSFW: omit.From(nsfw),
|
|
CreatedAt: omit.From(now.Unix()),
|
|
UpdatedAt: omit.From(now.Unix()),
|
|
})
|
|
}
|
|
|
|
log.New(ctx).Debug("inserting images to database", "images", many)
|
|
_, err = models.Images.InsertMany(ctx, api.db, many...)
|
|
if err != nil {
|
|
return errs.Wrapw(err, "failed to insert images to database", "params", many)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (api *API) createDeviceImageWriters(post reddit.Post, devices models.DeviceSlice) (writer io.Writer, close func(), err error) {
|
|
// open file for each device
|
|
var files []*os.File
|
|
var writers []io.Writer
|
|
for _, device := range devices {
|
|
var filename string
|
|
if device.WindowsWallpaperMode == 1 {
|
|
filename = post.GetWindowsWallpaperImageTargetPath(api.config, device)
|
|
dir := post.GetWindowsWallpaperImageTargetDir(api.config, device)
|
|
_ = os.MkdirAll(dir, 0o777)
|
|
} else {
|
|
filename = post.GetImageTargetPath(api.config, device)
|
|
dir := post.GetImageTargetDir(api.config, device)
|
|
if err := os.MkdirAll(dir, 0o777); err != nil {
|
|
for _, f := range files {
|
|
_ = f.Close()
|
|
}
|
|
return nil, nil, errs.Wrapw(err, "failed to create target image dir")
|
|
}
|
|
}
|
|
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
|
|
if err != nil {
|
|
for _, f := range files {
|
|
_ = f.Close()
|
|
}
|
|
return nil, nil, errs.Wrapw(err, "failed to open target image file",
|
|
"device_name", device.Name,
|
|
"device_slug", device.Slug,
|
|
"filename", filename,
|
|
)
|
|
}
|
|
files = append(files, file)
|
|
writers = append(writers, file)
|
|
}
|
|
|
|
return io.MultiWriter(writers...), func() {
|
|
for _, file := range files {
|
|
_ = file.Close()
|
|
}
|
|
}, nil
|
|
}
|
|
|
|
func (api *API) getDevicesThatAcceptPost(ctx context.Context, post reddit.Post, devices models.DeviceSlice) (devs models.DeviceSlice) {
|
|
for _, device := range devices {
|
|
if shouldDownloadPostForDevice(post, device) && !api.isImageEntryExists(ctx, post, device) {
|
|
devs = append(devs, device)
|
|
}
|
|
}
|
|
return devs
|
|
}
|
|
|
|
// isImageEntryExists checks if the image entry already exists in the database and
|
|
// the image file actually exists in the filesystem.
|
|
func (api *API) isImageEntryExists(ctx context.Context, post reddit.Post, device *models.Device) (found bool) {
|
|
ctx, span := tracer.Start(ctx, "*API.IsImageExists")
|
|
defer span.End()
|
|
|
|
exist, errQuery := models.Images.Query(ctx, api.db,
|
|
models.SelectWhere.Images.Device.EQ(device.Slug),
|
|
models.SelectWhere.Images.PostName.EQ(post.GetName()),
|
|
).Exists()
|
|
|
|
_, errStat := os.Stat(post.GetImageTargetPath(api.config, device))
|
|
|
|
return exist && errQuery == nil && errStat == nil
|
|
}
|
|
|
|
// findImageFileForDevice finds if any of the image file exists for given devices.
|
|
//
|
|
// This helps to avoid downloading the same image for different devices.
|
|
//
|
|
// Return nil if no image file exists for the devices.
|
|
//
|
|
// Ensure to close the file after use.
|
|
func (api *API) findImageFileForDevices(ctx context.Context, post reddit.Post, devices models.DeviceSlice) (oldImageFile *os.File) {
|
|
for _, device := range devices {
|
|
stat, err := os.Stat(post.GetImageTargetPath(api.config, device))
|
|
if err == nil {
|
|
oldImageFile, err = os.Open(post.GetImageTargetPath(api.config, device))
|
|
if err != nil {
|
|
log.New(ctx).Err(err).Error("failed to open image file", "filename", post.GetImageTargetPath(api.config, device))
|
|
return nil
|
|
}
|
|
defer oldImageFile.Close()
|
|
|
|
tempFilename := filepath.Join(os.TempDir(), "redmage", stat.Name())
|
|
|
|
tempFileWrite, err := os.Create(tempFilename)
|
|
if err != nil {
|
|
log.New(ctx).Err(err).Error("failed to create temp file", "filename", post.GetImageTargetPath(api.config, device))
|
|
return nil
|
|
}
|
|
defer tempFileWrite.Close()
|
|
|
|
_, err = io.Copy(tempFileWrite, oldImageFile)
|
|
if err != nil {
|
|
log.New(ctx).Err(err).Error("failed to copy image file", "filename", post.GetImageTargetPath(api.config, device))
|
|
return nil
|
|
}
|
|
|
|
rf, err := os.Open(tempFilename)
|
|
if err != nil {
|
|
log.New(ctx).Err(err).Error("failed to open temp file", "filename", tempFileWrite.Name())
|
|
return nil
|
|
}
|
|
|
|
return rf
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func shouldDownloadPostForDevice(post reddit.Post, device *models.Device) bool {
|
|
if post.IsNSFW() && device.NSFW == 0 {
|
|
return false
|
|
}
|
|
devAspectRatio := deviceAspectRatio(device)
|
|
rangeStart := devAspectRatio - device.AspectRatioTolerance
|
|
rangeEnd := devAspectRatio + device.AspectRatioTolerance
|
|
|
|
imgAspectRatio := post.GetImageAspectRatio()
|
|
|
|
width, height := post.GetImageSize()
|
|
log.New(context.Background()).Debug("checking image aspect ratio",
|
|
"device", device.Slug,
|
|
"device_height", device.ResolutionY,
|
|
"device_width", device.ResolutionX,
|
|
"device_aspect_ratio", devAspectRatio,
|
|
"image_aspect_ratio", imgAspectRatio,
|
|
"range_start", rangeStart,
|
|
"range_end", rangeEnd,
|
|
"success_fulfill_download_range_start", (imgAspectRatio > rangeStart),
|
|
"success_fulfill_download_range_end", (imgAspectRatio < rangeEnd),
|
|
"url", post.GetImageURL(),
|
|
"image.width", width,
|
|
"image.height", height,
|
|
)
|
|
|
|
if imgAspectRatio < rangeStart {
|
|
return false
|
|
}
|
|
|
|
if imgAspectRatio > rangeEnd {
|
|
return false
|
|
}
|
|
|
|
if device.MaxX > 0 && width > int64(device.MaxX) {
|
|
return false
|
|
}
|
|
if device.MaxY > 0 && height > int64(device.MaxY) {
|
|
return false
|
|
}
|
|
if device.MinX > 0 && width < int64(device.MinX) {
|
|
return false
|
|
}
|
|
if device.MinY > 0 && height < int64(device.MinY) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func deviceAspectRatio(device *models.Device) float64 {
|
|
return float64(device.ResolutionX) / float64(device.ResolutionY)
|
|
}
|
|
|
|
type tempFile struct {
|
|
filename string
|
|
file *os.File
|
|
}
|
|
|
|
func (te *tempFile) Read(p []byte) (n int, err error) {
|
|
return te.file.Read(p)
|
|
}
|
|
|
|
func (te *tempFile) Close() error {
|
|
return te.file.Close()
|
|
}
|
|
|
|
// copyImageToTempDir copies the image to a temporary directory and returns the file handle
|
|
//
|
|
// file must be closed by the caller after use.
|
|
//
|
|
// file is nil if an error occurred.
|
|
func (api *API) copyImageToTempDir(ctx context.Context, img reddit.PostImage) (tmp *tempFile, err error) {
|
|
_, span := tracer.Start(ctx, "*API.copyImageToTempDir")
|
|
defer func() { telemetry.EndWithStatus(span, err) }()
|
|
|
|
// ignore error because url is always valid if this
|
|
// function is called
|
|
url, _ := url.Parse(img.URL)
|
|
|
|
split := strings.Split(url.Path, "/")
|
|
imageFilename := split[len(split)-1]
|
|
tmpDirname := path.Join(os.TempDir(), "redmage")
|
|
err = os.MkdirAll(tmpDirname, 0o777)
|
|
if err != nil {
|
|
return nil, errs.Wrapw(err, "failed to create temporary dir", "dir_name", tmpDirname)
|
|
}
|
|
tmpFilename := path.Join(tmpDirname, imageFilename)
|
|
|
|
file, err := os.OpenFile(tmpFilename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o777)
|
|
if err != nil {
|
|
return nil, errs.Wrapw(err, "failed to open temp image file",
|
|
"temp_file_path", tmpFilename,
|
|
"image_url", img.URL,
|
|
)
|
|
}
|
|
|
|
// File must be closed by end of function because kernel stuffs.
|
|
//
|
|
// A fresh fd must be used to properly get the new data.
|
|
defer file.Close()
|
|
|
|
_, err = io.Copy(file, img.File)
|
|
if err != nil {
|
|
return nil, errs.Wrapw(err, "failed to download image to temp file",
|
|
"temp_file_path", tmpFilename,
|
|
"image_url", img.URL,
|
|
)
|
|
}
|
|
|
|
filew, err := os.OpenFile(tmpFilename, os.O_RDONLY, 0o777)
|
|
if err != nil {
|
|
return nil, errs.Wrapw(err, "failed to download image to temp file",
|
|
"temp_file_path", tmpFilename,
|
|
"image_url", img.URL,
|
|
)
|
|
}
|
|
|
|
return &tempFile{
|
|
file: filew,
|
|
filename: tmpFilename,
|
|
}, err
|
|
}
|