api: implemented image download properly

This commit is contained in:
Tigor Hutasuhut 2024-04-27 15:16:14 +07:00
parent 19cd92004c
commit d99e28d997
15 changed files with 221 additions and 99 deletions

View file

@ -53,7 +53,7 @@ const downloadTopic = "subreddit_download"
var watermillLogger = &log.WatermillLogger{}
func New(deps Dependencies) *API {
ackDeadline := deps.Config.Duration("download.pubsub.ack.deadline")
ackDeadline := deps.Config.Duration("pubsub.ack.deadline")
subscriber, err := watermillSql.NewSubscriber(deps.PubsubDB, watermillSql.SubscriberConfig{
ConsumerGroup: "redmage",
AckDeadline: &ackDeadline,

View file

@ -15,6 +15,19 @@ type ImageMetadata struct {
type ImageKind int
func (im ImageKind) String() string {
switch im {
case KindThumbnail:
return "Thumbnail"
default:
return "Image"
}
}
func (im ImageKind) MarshalJSON() ([]byte, error) {
return []byte(`"` + im.String() + `"`), nil
}
const (
KindImage ImageKind = iota
KindThumbnail

View file

@ -5,14 +5,15 @@ import (
"errors"
"image/jpeg"
"io"
"math"
"net/http"
"net/url"
"os"
"path"
"strings"
"sync"
"time"
"github.com/aarondl/opt/omit"
"github.com/disintegration/imaging"
"github.com/tigorlazuardi/redmage/api/reddit"
"github.com/tigorlazuardi/redmage/models"
@ -21,7 +22,6 @@ import (
"github.com/tigorlazuardi/redmage/pkg/telemetry"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
)
type DownloadSubredditParams struct {
@ -35,22 +35,22 @@ var (
ErrDownloadDirNotSet = errors.New("api: download directory not set")
)
func (api *API) DownloadSubredditImages(ctx context.Context, subredditName string, params DownloadSubredditParams) error {
func (api *API) DownloadSubredditImages(ctx context.Context, subreddit *models.Subreddit, devices models.DeviceSlice) error {
downloadDir := api.config.String("download.directory")
if downloadDir == "" {
return errs.Wrapw(ErrDownloadDirNotSet, "download directory must be set before images can be downloaded").Code(http.StatusBadRequest)
}
if len(params.Devices) == 0 {
if len(devices) == 0 {
return errs.Wrapw(ErrNoDevices, "downloading images requires at least one device configured").Code(http.StatusBadRequest)
}
ctx, span := tracer.Start(ctx, "*API.DownloadSubredditImages", trace.WithAttributes(attribute.String("subreddit", subredditName)))
ctx, span := tracer.Start(ctx, "*API.DownloadSubredditImages", trace.WithAttributes(attribute.String("subreddit", subreddit.Name)))
defer span.End()
wg := sync.WaitGroup{}
countback := params.Countback
countback := int(subreddit.Countback)
var (
list reddit.Listing
@ -61,20 +61,20 @@ func (api *API) DownloadSubredditImages(ctx context.Context, subredditName strin
if limit > countback {
limit = countback
}
log.New(ctx).Info("getting posts", "subreddit_name", subredditName, "limit", limit, "countback", countback)
log.New(ctx).Debug("getting posts", "subreddit", subreddit, "current_countback", countback, "current_limit", limit)
list, err = api.reddit.GetPosts(ctx, reddit.GetPostsParam{
Subreddit: subredditName,
Subreddit: subreddit.Name,
Limit: limit,
After: list.GetLastAfter(),
SubredditType: params.SubredditType,
SubredditType: reddit.SubredditType(subreddit.Subtype),
})
if err != nil {
return errs.Wrapw(err, "failed to get posts", "subreddit_name", subredditName, "params", params)
return errs.Wrapw(err, "failed to get posts", "subreddit", subreddit)
}
wg.Add(1)
go func(ctx context.Context, posts reddit.Listing) {
defer wg.Done()
err := api.downloadSubredditListImage(ctx, list, params)
err := api.downloadSubredditListImage(ctx, list, subreddit, devices)
if err != nil {
log.New(ctx).Err(err).Error("failed to download image")
}
@ -90,7 +90,7 @@ func (api *API) DownloadSubredditImages(ctx context.Context, subredditName strin
return nil
}
func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.Listing, params DownloadSubredditParams) error {
func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.Listing, subreddit *models.Subreddit, devices models.DeviceSlice) error {
ctx, span := tracer.Start(ctx, "*API.downloadSubredditListImage")
defer span.End()
@ -100,10 +100,11 @@ func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.List
if !post.IsImagePost() {
continue
}
devices := api.getDevicesThatAcceptPost(ctx, post, params.Devices)
devices := api.getDevicesThatAcceptPost(ctx, post, devices)
if len(devices) == 0 {
continue
}
log.New(ctx).Debug("downloading image", "post_id", post.GetID(), "post_url", post.GetImageURL(), "devices", devices)
wg.Add(1)
api.imageSemaphore <- struct{}{}
go func(ctx context.Context, post reddit.Post) {
@ -112,7 +113,7 @@ func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.List
wg.Done()
}()
if err := api.downloadSubredditImage(ctx, post, devices); err != nil {
if err := api.downloadSubredditImage(ctx, post, subreddit, devices); err != nil {
log.New(ctx).Err(err).Error("failed to download subreddit image")
}
}(ctx, post)
@ -123,7 +124,7 @@ func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.List
return nil
}
func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, devices models.DeviceSlice) error {
func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, subreddit *models.Subreddit, devices models.DeviceSlice) error {
ctx, span := tracer.Start(ctx, "*API.downloadSubredditImage")
defer span.End()
@ -140,15 +141,6 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
}
defer tmpImageFile.Close()
w, close, err := api.createDeviceImageWriters(post, devices)
if err != nil {
return errs.Wrapw(err, "failed to create image files")
}
defer close()
_, err = io.Copy(w, tmpImageFile)
if err != nil {
return errs.Wrapw(err, "failed to save image files")
}
thumbnailPath := post.GetThumbnailTargetPath(api.config)
_, errStat := os.Stat(thumbnailPath)
if errStat == nil {
@ -163,7 +155,11 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
thumbnailSource, err := imaging.Open(tmpImageFile.filename)
if err != nil {
return errs.Wrapw(err, "failed to open temp thumbnail file", "filename", tmpImageFile.filename)
return errs.Wrapw(err, "failed to open temp thumbnail file",
"filename", tmpImageFile.filename,
"post_url", post.GetPermalink(),
"image_url", post.GetImageURL(),
)
}
thumbnail := imaging.Resize(thumbnailSource, 256, 0, imaging.Lanczos)
@ -178,7 +174,46 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
return errs.Wrapw(err, "failed to encode thumbnail file to jpeg", "filename", thumbnailPath)
}
// TODO: create entry to database
w, close, err := api.createDeviceImageWriters(post, devices)
if err != nil {
return errs.Wrapw(err, "failed to create image files")
}
log.New(ctx).Debug("saving image files", "post_id", post.GetID(), "post_url", post.GetImageURL(), "devices", devices)
defer close()
_, err = io.Copy(w, tmpImageFile)
if err != nil {
return errs.Wrapw(err, "failed to save image files")
}
var many []*models.ImageSetter
for _, device := range devices {
var nsfw int32
if post.IsNSFW() {
nsfw = 1
}
many = append(many, &models.ImageSetter{
SubredditID: omit.From(subreddit.ID),
DeviceID: omit.From(device.ID),
Title: omit.From(post.GetTitle()),
PostID: omit.From(post.GetID()),
PostURL: omit.From(post.GetImageURL()),
PostCreated: omit.From(post.GetCreated().Format(time.RFC3339)),
PostName: omit.From(post.GetName()),
Poster: omit.From(post.GetAuthor()),
PosterURL: omit.From(post.GetAuthorURL()),
ImageRelativePath: omit.From(post.GetImageRelativePath(device)),
ThumbnailRelativePath: omit.From(post.GetThumbnailRelativePath()),
ImageOriginalURL: omit.From(post.GetImageURL()),
ThumbnailOriginalURL: omit.From(post.GetThumbnailURL()),
NSFW: omit.From(nsfw),
})
}
log.New(ctx).Debug("inserting images to database", "images", many)
_, err = models.Images.InsertMany(ctx, api.db, many...)
if err != nil {
return errs.Wrapw(err, "failed to insert images to database", "params", many)
}
return nil
}
@ -226,22 +261,11 @@ func (api *API) createDeviceImageWriters(post reddit.Post, devices models.Device
}
func (api *API) getDevicesThatAcceptPost(ctx context.Context, post reddit.Post, devices models.DeviceSlice) (devs models.DeviceSlice) {
var mu sync.Mutex
errgrp, ctx := errgroup.WithContext(ctx)
for _, device := range devices {
if shouldDownloadPostForDevice(post, device) {
device := device
errgrp.Go(func() error {
if !api.isImageExists(ctx, post, device) {
mu.Lock()
defer mu.Unlock()
devs = append(devices, device)
}
return nil
})
if shouldDownloadPostForDevice(post, device) && !api.isImageExists(ctx, post, device) {
devs = append(devs, device)
}
}
_ = errgrp.Wait()
return devs
}
@ -249,27 +273,58 @@ func (api *API) isImageExists(ctx context.Context, post reddit.Post, device *mod
ctx, span := tracer.Start(ctx, "*API.IsImageExists")
defer span.End()
// Image does not exist in target image.
if _, err := os.Stat(post.GetImageTargetPath(api.config, device)); err != nil {
return false
}
_, err := models.Images.Query(ctx, api.db,
models.SelectWhere.Images.DeviceID.EQ(device.ID),
models.SelectWhere.Images.PostID.EQ(post.GetID()),
).One()
if err != nil {
if err.Error() == "sql: no rows in result set" {
return false
}
}
return err == nil
// Image does not exist in target path.
if _, err := os.Stat(post.GetImageTargetPath(api.config, device)); err != nil {
return false
}
return true
}
func shouldDownloadPostForDevice(post reddit.Post, device *models.Device) bool {
if post.IsNSFW() && device.NSFW == 0 {
return false
}
if math.Abs(deviceAspectRatio(device)-post.GetImageAspectRatio()) > device.AspectRatioTolerance { // outside of aspect ratio tolerance
devAspectRatio := deviceAspectRatio(device)
rangeStart := devAspectRatio - device.AspectRatioTolerance
rangeEnd := devAspectRatio + device.AspectRatioTolerance
imgAspectRatio := post.GetImageAspectRatio()
width, height := post.GetImageSize()
log.New(context.Background()).Debug("checking image aspect ratio",
"device", device.Slug,
"device_height", device.ResolutionY,
"device_width", device.ResolutionX,
"device_aspect_ratio", devAspectRatio,
"image_aspect_ratio", imgAspectRatio,
"range_start", rangeStart,
"range_end", rangeEnd,
"success_fulfill_download_range_start", (imgAspectRatio > rangeStart),
"success_fulfill_download_range_end", (imgAspectRatio < rangeEnd),
"url", post.GetImageURL(),
"image.width", width,
"image.height", height,
)
if imgAspectRatio < rangeStart {
return false
}
width, height := post.GetImageSize()
if imgAspectRatio > rangeEnd {
return false
}
if device.MaxX > 0 && width > int64(device.MaxX) {
return false
}

View file

@ -7,7 +7,6 @@ import (
"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/tigorlazuardi/redmage/api/reddit"
"github.com/tigorlazuardi/redmage/models"
"github.com/tigorlazuardi/redmage/pkg/errs"
"github.com/tigorlazuardi/redmage/pkg/log"
@ -18,11 +17,11 @@ import (
func (api *API) StartSubredditDownloadPubsub(messages <-chan *message.Message) {
for msg := range messages {
log.New(context.Background()).Info("received pubsub message",
log.New(context.Background()).Debug("received pubsub message",
"message", msg,
"len", len(api.subredditSemaphore),
"cap", cap(api.subredditSemaphore),
"download.concurrency.subreddts", api.config.Int("download.concurrency.subreddits"),
"download.concurrency.subreddits", api.config.Int("download.concurrency.subreddits"),
)
api.subredditSemaphore <- struct{}{}
go func(msg *message.Message) {
@ -50,11 +49,7 @@ func (api *API) StartSubredditDownloadPubsub(messages <-chan *message.Message) {
return
}
err = api.DownloadSubredditImages(ctx, subreddit.Name, DownloadSubredditParams{
Countback: int(subreddit.Countback),
Devices: devices,
SubredditType: reddit.SubredditType(subreddit.Subtype),
})
err = api.DownloadSubredditImages(ctx, subreddit, devices)
if err != nil {
log.New(ctx).Err(err).Error("failed to download subreddit images", "subreddit", subreddit)
return

View file

@ -2,8 +2,29 @@ package reddit
import (
"net/http"
"github.com/tigorlazuardi/redmage/config"
)
type Client interface {
Do(*http.Request) (*http.Response, error)
}
func NewRedditHTTPClient(cfg *config.Config) Client {
return &http.Client{
Transport: createRoundTripper(cfg),
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (ro roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return ro(req)
}
func createRoundTripper(cfg *config.Config) roundTripperFunc {
return func(r *http.Request) (*http.Response, error) {
r.Header.Set("User-Agent", cfg.String("download.useragent"))
return http.DefaultTransport.RoundTrip(r)
}
}

View file

@ -76,6 +76,12 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
if err != nil {
return nil, errs.Wrapw(err, "reddit: failed to execute request", "url", url)
}
if resp.StatusCode >= 400 {
return nil, errs.Fail("unexpected status code when trying to download images",
"url", url,
"status_code", resp.StatusCode,
)
}
idleSpeedStr := reddit.Config.String("download.timeout.idlespeed")
metricSpeed, _ := units.ParseMetricBytes(idleSpeedStr)
if metricSpeed == 0 {
@ -116,7 +122,7 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
ContentLength: units.MetricBytes(resp.ContentLength),
Downloaded: units.MetricBytes(downloaded),
Subreddit: post.GetSubreddit(),
PostURL: post.GetPermalink(),
PostURL: post.GetPostURL(),
PostID: post.GetID(),
Error: closeErr,
})
@ -138,7 +144,7 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
Kind: kind,
},
Subreddit: post.GetSubreddit(),
PostURL: post.GetPermalink(),
PostURL: post.GetPostURL(),
PostID: post.GetID(),
})
_, err := io.Copy(writer, resp.Body)

View file

@ -8,7 +8,6 @@ import (
"log/slog"
"net/http"
"strings"
"time"
"github.com/tigorlazuardi/redmage/pkg/errs"
)
@ -68,14 +67,18 @@ func (reddit *Reddit) GetPosts(ctx context.Context, params GetPostsParam) (posts
return posts, errs.Wrapw(err, "reddit: failed to create http request instance", "url", url, "params", params)
}
req.Header.Set("User-Agent", reddit.Config.String("download.useragent"))
res, err := reddit.Client.Do(req)
if err != nil {
return posts, errs.Wrapw(err, "reddit: failed to execute http request", "url", url, "params", params)
}
defer res.Body.Close()
if res.StatusCode == http.StatusTooManyRequests {
retryAfter, _ := time.ParseDuration(res.Header.Get("Retry-After"))
return posts, errs.Fail("reddit: too many requests", "retry_after", retryAfter.String())
return posts, errs.Fail("reddit: too many requests",
"retry_after", res.Header.Get("Retry-After"),
"url", url,
)
}
if res.StatusCode != http.StatusOK {

View file

@ -6,6 +6,7 @@ import (
"path"
"path/filepath"
"strings"
"time"
"github.com/tigorlazuardi/redmage/config"
"github.com/tigorlazuardi/redmage/models"
@ -145,7 +146,7 @@ type PostData struct {
IsCreatedFromAdsUI bool `json:"is_created_from_ads_ui"`
AuthorPremium bool `json:"author_premium"`
Thumbnail string `json:"thumbnail"`
Edited bool `json:"edited"`
Edited any `json:"edited"`
AuthorFlairCSSClass string `json:"author_flair_css_class"`
AuthorFlairRichtext []AuthorFlairRichtext `json:"author_flair_richtext"`
Gildings Gildings `json:"gildings"`
@ -231,6 +232,22 @@ func (post *Post) GetImageURL() string {
return post.Data.URL
}
func (post *Post) GetCreated() time.Time {
return time.Unix(int64(post.Data.Created), 0)
}
func (post *Post) GetAuthor() string {
return post.Data.Author
}
func (post *Post) GetTitle() string {
return post.Data.Title
}
func (post *Post) GetAuthorURL() string {
return fmt.Sprintf("https://www.reddit.com/user/%s", post.Data.Author)
}
func (post *Post) GetImageAspectRatio() float64 {
width, height := post.GetImageSize()
if height == 0 {
@ -336,6 +353,10 @@ func (post *Post) GetPermalink() string {
return post.Data.Permalink
}
func (post *Post) GetPostURL() string {
return fmt.Sprintf("https://reddit.com/%s", post.Data.Permalink)
}
func (post *Post) GetID() string {
return post.Data.ID
}

View file

@ -1,8 +1,6 @@
package cli
import (
"fmt"
"github.com/adrg/xdg"
"github.com/joho/godotenv"
"github.com/tigorlazuardi/redmage/config"
@ -29,9 +27,6 @@ func initConfig() {
LoadFlags(RootCmd.PersistentFlags()).
Build()
fmt.Println("download.concurrency.subreddits", cfg.Get("download.concurrency.subreddits"))
fmt.Println("download.concurrency.images", cfg.Get("download.concurrency.images"))
handler := log.NewHandler(cfg)
log.SetDefault(handler)
}

View file

@ -2,7 +2,6 @@ package cli
import (
"io/fs"
"net/http"
"os"
"github.com/spf13/cobra"
@ -30,23 +29,29 @@ var serveCmd = &cobra.Command{
database, err := db.Open(cfg)
if err != nil {
log.New(cmd.Context()).Err(err).Error("failed to connect database")
log.New(cmd.Context()).Err(err).Error("failed to open connection to database")
os.Exit(1)
}
pubsubDatabase, err := db.OpenSilent(cfg)
pubsubDb, err := db.OpenPubsub(cfg)
if err != nil {
log.New(cmd.Context()).Err(err).Error("failed to open connection to pubsub database")
os.Exit(1)
}
loggedDb := db.ApplyLogger(cfg, database)
if err != nil {
log.New(cmd.Context()).Err(err).Error("failed to connect database")
os.Exit(1)
}
red := &reddit.Reddit{
Client: http.DefaultClient,
Client: reddit.NewRedditHTTPClient(cfg),
Config: cfg,
}
api := api.New(api.Dependencies{
DB: database,
PubsubDB: pubsubDatabase,
DB: loggedDb,
PubsubDB: pubsubDb,
Config: cfg,
Reddit: red,
})

View file

@ -15,6 +15,10 @@ var DefaultConfig = map[string]any{
"db.string": "data.db",
"db.automigrate": true,
"pubsub.db.driver": "sqlite3",
"pubsub.db.string": "pubsub.db",
"pubsub.ack.deadline": "3h",
"download.concurrency.images": 5,
"download.concurrency.subreddits": 3,

View file

@ -19,12 +19,14 @@ var Migrations fs.FS
func Open(cfg *config.Config) (*sql.DB, error) {
driver := cfg.String("db.driver")
dsn := cfg.String("db.string")
db, err := OpenSilent(cfg)
db, err := otelsql.Open(driver, dsn, otelsql.WithAttributes(
semconv.DBSystemSqlite,
))
if err != nil {
return db, err
return db, errs.Wrapw(err, "failed to open database", "driver", driver)
}
if cfg.Bool("db.automigrate") {
goose.SetLogger(&gooseLogger{})
goose.SetLogger(goose.NopLogger())
goose.SetBaseFS(Migrations)
if err := goose.SetDialect(driver); err != nil {
@ -35,22 +37,22 @@ func Open(cfg *config.Config) (*sql.DB, error) {
return db, errs.Wrapw(err, "failed to migrate database", "dialect", driver)
}
}
db = sqldblogger.OpenDriver(dsn, db.Driver(), sqlLogger{},
sqldblogger.WithSQLQueryAsMessage(true),
)
return db, err
}
func OpenSilent(cfg *config.Config) (*sql.DB, error) {
driver := cfg.String("db.driver")
dsn := cfg.String("db.string")
db, err := otelsql.Open(driver, dsn, otelsql.WithAttributes(
semconv.DBSystemSqlite,
))
func OpenPubsub(cfg *config.Config) (*sql.DB, error) {
driver := cfg.String("pubsub.db.driver")
dsn := cfg.String("pubsub.db.string")
db, err := sql.Open(driver, dsn)
if err != nil {
return db, errs.Wrapw(err, "failed to open database", "driver", driver)
}
return db, err
}
func ApplyLogger(cfg *config.Config, db *sql.DB) *sql.DB {
dsn := cfg.String("db.string")
return sqldblogger.OpenDriver(dsn, db.Driver(), sqlLogger{},
sqldblogger.WithSQLQueryAsMessage(true),
)
}

View file

@ -1,14 +1,16 @@
POST http://localhost:8080/api/v1/devices HTTP/1.1
Host: localhost:8080
Content-Type: application/json
Content-Length: 178
Content-Length: 211
{
"name": "Sync Laptop",
"slug": "sync-l",
"resolution_x": 1920,
"resolution_y": 1080,
"name": "S20FE",
"slug": "s20fe",
"resolution_x": 1080,
"resolution_y": 2400,
"nsfw": 1,
"aspect_ratio_tolerance": 0.2,
"enable": 1
"enable": 1,
"min_x": 1080,
"min_y": 2400
}

View file

@ -1,10 +1,10 @@
POST http://localhost:8080/api/v1/subreddits HTTP/1.1
Host: localhost:8080
Content-Length: 91
Content-Length: 109
{
"name": "awoo",
"enable": 1,
"name": "animemidriff",
"enable_schedule": 1,
"schedule": "@daily",
"countback": 10
"countback": 300
}

View file

@ -1,8 +1,8 @@
POST http://localhost:8080/api/v1/subreddits/start HTTP/1.1
Host: localhost:8080
Content-Type: application/json
Content-Length: 29
Content-Length: 35
{
"subreddit": "awoo"
"subreddit": "fantasymoe"
}