reddit: handled image downloads

This commit is contained in:
Tigor Hutasuhut 2024-04-26 22:13:04 +07:00
parent 6b35321e1e
commit 19cd92004c
23 changed files with 198 additions and 45 deletions

2
.gitignore vendored
View file

@ -31,3 +31,5 @@ db/queries/**/*.go
*.db
models/
/out

View file

@ -5,8 +5,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
"os"
"os/signal"
"github.com/robfig/cron/v3"
"github.com/stephenafamo/bob"
@ -73,9 +71,7 @@ func New(deps Dependencies) *API {
if err != nil {
panic(err)
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
ch, err := subscriber.Subscribe(ctx, downloadTopic)
ch, err := subscriber.Subscribe(context.Background(), downloadTopic)
if err != nil {
panic(err)
}
@ -92,7 +88,7 @@ func New(deps Dependencies) *API {
publisher: publisher,
}
if err := api.StartScheduler(ctx); err != nil {
if err := api.StartScheduler(context.Background()); err != nil {
panic(err)
}
go api.StartSubredditDownloadPubsub(ch)
@ -100,7 +96,7 @@ func New(deps Dependencies) *API {
}
func (api *API) StartScheduler(ctx context.Context) error {
subreddits, err := models.Subreddits.Query(ctx, api.db, models.SelectWhere.Subreddits.Enable.EQ(1)).All()
subreddits, err := models.Subreddits.Query(ctx, api.db, models.SelectWhere.Subreddits.EnableSchedule.EQ(1)).All()
if err != nil {
return errs.Wrapw(err, "failed to get all subreddits")
}

View file

@ -1,6 +1,8 @@
package bmessage
import (
"encoding/json"
"github.com/alecthomas/units"
)
@ -47,12 +49,32 @@ const (
)
type ImageDownloadMessage struct {
Event DownloadEvent
Metadata ImageMetadata
ContentLength units.MetricBytes
Downloaded units.MetricBytes
Subreddit string
PostURL string
PostID string
Error error
Event DownloadEvent `json:"event"`
Metadata ImageMetadata `json:"metadata"`
ContentLength units.MetricBytes `json:"content_length"`
Downloaded units.MetricBytes `json:"downloaded"`
Subreddit string `json:"subreddit"`
PostURL string `json:"post_url"`
PostID string `json:"post_id"`
Error error `json:"error"`
}
func (im ImageDownloadMessage) MarshalJSON() ([]byte, error) {
type Alias ImageDownloadMessage
type W struct {
Alias
Error string `json:"error"`
}
errMsg := ""
if im.Error != nil {
errMsg = im.Error.Error()
}
w := W{
Alias: Alias(im),
Error: errMsg,
}
return json.Marshal(w)
}

32
api/context.go Normal file
View file

@ -0,0 +1,32 @@
package api
import (
"context"
"time"
)
type noTimeoutContext struct {
inner context.Context
}
func (no noTimeoutContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
func (no noTimeoutContext) Done() <-chan struct{} {
return nil
}
func (no noTimeoutContext) Err() error {
return nil
}
func (no noTimeoutContext) Value(key any) any {
return no.inner.Value(key)
}
func noCancelContext(ctx context.Context) context.Context {
return noTimeoutContext{
inner: ctx,
}
}

View file

@ -18,10 +18,15 @@ type DevicesListParams struct {
Offset int64
OrderBy string
Sort string
Active bool
}
func (dlp DevicesListParams) Query() (expr []bob.Mod[*dialect.SelectQuery]) {
expr = append(expr, dlp.CountQuery()...)
if dlp.Active {
expr = append(expr, models.SelectWhere.Devices.Enable.EQ(1))
}
if dlp.All {
return expr
}
@ -47,6 +52,10 @@ func (dlp DevicesListParams) Query() (expr []bob.Mod[*dialect.SelectQuery]) {
}
func (dlp DevicesListParams) CountQuery() (expr []bob.Mod[*dialect.SelectQuery]) {
if dlp.Active {
expr = append(expr, models.SelectWhere.Devices.Enable.EQ(1))
}
if dlp.Q != "" {
arg := sqlite.Arg("%" + dlp.Q + "%")
expr = append(expr,

View file

@ -61,6 +61,7 @@ func (api *API) DownloadSubredditImages(ctx context.Context, subredditName strin
if limit > countback {
limit = countback
}
log.New(ctx).Info("getting posts", "subreddit_name", subredditName, "limit", limit, "countback", countback)
list, err = api.reddit.GetPosts(ctx, reddit.GetPostsParam{
Subreddit: subredditName,
Limit: limit,
@ -158,6 +159,8 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
return errs.Wrapw(err, "failed to check thumbnail existence", "path", thumbnailPath)
}
_ = os.MkdirAll(post.GetThumbnailTargetDir(api.config), 0o777)
thumbnailSource, err := imaging.Open(tmpImageFile.filename)
if err != nil {
return errs.Wrapw(err, "failed to open temp thumbnail file", "filename", tmpImageFile.filename)
@ -175,6 +178,8 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
return errs.Wrapw(err, "failed to encode thumbnail file to jpeg", "filename", thumbnailPath)
}
// TODO: create entry to database
return nil
}
@ -186,15 +191,24 @@ func (api *API) createDeviceImageWriters(post reddit.Post, devices models.Device
var filename string
if device.WindowsWallpaperMode == 1 {
filename = post.GetWindowsWallpaperImageTargetPath(api.config, device)
dir := post.GetWindowsWallpaperImageTargetDir(api.config, device)
_ = os.MkdirAll(dir, 0o777)
} else {
filename = post.GetImageTargetPath(api.config, device)
dir := post.GetImageTargetDir(api.config, device)
if err := os.MkdirAll(dir, 0o777); err != nil {
for _, f := range files {
_ = f.Close()
}
return nil, nil, errs.Wrapw(err, "failed to create target image dir")
}
}
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
for _, f := range files {
_ = f.Close()
}
return nil, nil, errs.Wrapw(err, "failed to open temp image file",
return nil, nil, errs.Wrapw(err, "failed to open target image file",
"device_name", device.Name,
"device_slug", device.Slug,
"filename", filename,
@ -304,10 +318,13 @@ func (api *API) copyImageToTempDir(ctx context.Context, img reddit.PostImage) (t
split := strings.Split(url.Path, "/")
imageFilename := split[len(split)-1]
tmpDirname := path.Join(os.TempDir(), "redmage")
_ = os.MkdirAll(tmpDirname, 0o644)
err = os.MkdirAll(tmpDirname, 0777)
if err != nil {
return nil, errs.Wrapw(err, "failed to create temporary dir", "dir_name", tmpDirname)
}
tmpFilename := path.Join(tmpDirname, imageFilename)
file, err := os.OpenFile(tmpFilename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
file, err := os.OpenFile(tmpFilename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o777)
if err != nil {
return nil, errs.Wrapw(err, "failed to open temp image file",
"temp_file_path", tmpFilename,
@ -315,9 +332,21 @@ func (api *API) copyImageToTempDir(ctx context.Context, img reddit.PostImage) (t
)
}
// File must be closed by end of function because kernel stuffs.
//
// A fresh fd must be used to properly get the new data.
defer file.Close()
_, err = io.Copy(file, img.File)
if err != nil {
_ = file.Close()
return nil, errs.Wrapw(err, "failed to download image to temp file",
"temp_file_path", tmpFilename,
"image_url", img.URL,
)
}
filew, err := os.OpenFile(tmpFilename, os.O_RDONLY, 0o777)
if err != nil {
return nil, errs.Wrapw(err, "failed to download image to temp file",
"temp_file_path", tmpFilename,
"image_url", img.URL,
@ -325,7 +354,7 @@ func (api *API) copyImageToTempDir(ctx context.Context, img reddit.PostImage) (t
}
return &tempFile{
file: file,
file: filew,
filename: tmpFilename,
}, err
}

View file

@ -18,6 +18,12 @@ import (
func (api *API) StartSubredditDownloadPubsub(messages <-chan *message.Message) {
for msg := range messages {
log.New(context.Background()).Info("received pubsub message",
"message", msg,
"len", len(api.subredditSemaphore),
"cap", cap(api.subredditSemaphore),
"download.concurrency.subreddts", api.config.Int("download.concurrency.subreddits"),
)
api.subredditSemaphore <- struct{}{}
go func(msg *message.Message) {
defer func() {

View file

@ -2,6 +2,7 @@ package reddit
import (
"context"
"errors"
"io"
"net/http"
@ -67,8 +68,6 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
url = post.GetThumbnailURL()
width, height = post.GetThumbnailSize()
}
ctx, cancel := context.WithTimeout(ctx, reddit.Config.Duration("download.timeout.headers"))
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, errs.Wrapw(err, "reddit: failed to create request", "url", url)
@ -91,6 +90,9 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
idr := &ImageDownloadReader{
OnProgress: func(downloaded int64, contentLength int64, err error) {
var event bmessage.DownloadEvent
if errors.Is(err, io.EOF) {
err = nil
}
if err != nil {
event = bmessage.DownloadError
} else {

View file

@ -8,6 +8,7 @@ import (
"log/slog"
"net/http"
"strings"
"time"
"github.com/tigorlazuardi/redmage/pkg/errs"
)
@ -72,6 +73,10 @@ func (reddit *Reddit) GetPosts(ctx context.Context, params GetPostsParam) (posts
return posts, errs.Wrapw(err, "reddit: failed to execute http request", "url", url, "params", params)
}
defer res.Body.Close()
if res.StatusCode == http.StatusTooManyRequests {
retryAfter, _ := time.ParseDuration(res.Header.Get("Retry-After"))
return posts, errs.Fail("reddit: too many requests", "retry_after", retryAfter.String())
}
if res.StatusCode != http.StatusOK {
body, _ := io.ReadAll(res.Body)

View file

@ -4,6 +4,7 @@ import (
"fmt"
"net/url"
"path"
"path/filepath"
"strings"
"github.com/tigorlazuardi/redmage/config"
@ -244,18 +245,45 @@ func (post *Post) GetName() string {
func (post *Post) GetImageTargetPath(cfg *config.Config, device *models.Device) string {
baseDownloadDir := cfg.String("download.directory")
return path.Join(baseDownloadDir, device.Name, post.GetSubreddit(), post.GetImageFilename())
p := path.Join(baseDownloadDir, device.Slug, post.GetSubreddit(), post.GetImageFilename())
abs, _ := filepath.Abs(p)
return abs
}
func (post *Post) GetImageTargetDir(cfg *config.Config, device *models.Device) string {
baseDownloadDir := cfg.String("download.directory")
p := path.Join(baseDownloadDir, device.Slug, post.GetSubreddit())
abs, _ := filepath.Abs(p)
return abs
}
func (post *Post) GetWindowsWallpaperImageTargetPath(cfg *config.Config, device *models.Device) string {
baseDownloadDir := cfg.String("download.directory")
filename := fmt.Sprintf("%s_%s", post.GetSubreddit(), post.GetImageFilename())
return path.Join(baseDownloadDir, device.Name, filename)
p := path.Join(baseDownloadDir, device.Slug, filename)
abs, _ := filepath.Abs(p)
return abs
}
func (post *Post) GetWindowsWallpaperImageTargetDir(cfg *config.Config, device *models.Device) string {
baseDownloadDir := cfg.String("download.directory")
p := path.Join(baseDownloadDir, device.Slug)
abs, _ := filepath.Abs(p)
return abs
}
func (post *Post) GetThumbnailTargetPath(cfg *config.Config) string {
baseDownloadDir := cfg.String("download.directory")
return path.Join(baseDownloadDir, "_thumbnails", post.GetSubreddit(), post.GetImageFilename())
p := path.Join(baseDownloadDir, "_thumbnails", post.GetSubreddit(), post.GetImageFilename())
abs, _ := filepath.Abs(p)
return abs
}
func (post *Post) GetThumbnailTargetDir(cfg *config.Config) string {
baseDownloadDir := cfg.String("download.directory")
p := path.Join(baseDownloadDir, "_thumbnails", post.GetSubreddit())
abs, _ := filepath.Abs(p)
return abs
}
func (post *Post) GetThumbnailRelativePath() string {

View file

@ -17,7 +17,7 @@ func (api *API) SubredditsCreate(ctx context.Context, params *models.Subreddit)
set := &models.SubredditSetter{
Name: omit.From(params.Name),
Enable: omit.From(params.Enable),
EnableSchedule: omit.From(params.EnableSchedule),
Subtype: omit.From(params.Subtype),
Schedule: omit.From(params.Schedule),
Countback: omit.From(params.Countback),

View file

@ -1,6 +1,8 @@
package cli
import (
"fmt"
"github.com/adrg/xdg"
"github.com/joho/godotenv"
"github.com/tigorlazuardi/redmage/config"
@ -27,6 +29,9 @@ func initConfig() {
LoadFlags(RootCmd.PersistentFlags()).
Build()
fmt.Println("download.concurrency.subreddits", cfg.Get("download.concurrency.subreddits"))
fmt.Println("download.concurrency.images", cfg.Get("download.concurrency.images"))
handler := log.NewHandler(cfg)
log.SetDefault(handler)
}

View file

@ -35,7 +35,10 @@ func (builder *ConfigBuilder) BuildHandle() (*Config, error) {
func (builder *ConfigBuilder) LoadDefault() *ConfigBuilder {
provider := confmap.Provider(DefaultConfig, ".")
_ = builder.koanf.Load(provider, nil)
err := builder.koanf.Load(provider, nil)
if err != nil {
panic(err)
}
return builder
}

View file

@ -17,6 +17,7 @@ var DefaultConfig = map[string]any{
"download.concurrency.images": 5,
"download.concurrency.subreddits": 3,
"download.directory": "",
"download.timeout.headers": "10s",
"download.timeout.idle": "5s",

View file

@ -3,7 +3,7 @@
CREATE TABLE subreddits (
id INTEGER PRIMARY KEY,
name VARCHAR(30) NOT NULL,
enable INT NOT NULL DEFAULT 1,
enable_schedule INT NOT NULL DEFAULT 1,
subtype INT NOT NULL DEFAULT 0,
schedule VARCHAR(20) NOT NULL DEFAULT '0 0 * * *',
countback INT NOT NULL DEFAULT 100,

View file

@ -2,6 +2,7 @@
-- +goose StatementBegin
CREATE TABLE devices(
id INTEGER PRIMARY KEY,
enable INTEGER NOT NULL DEFAULT 1,
slug VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
resolution_x DOUBLE NOT NULL,
@ -19,6 +20,8 @@ CREATE TABLE devices(
CREATE UNIQUE INDEX idx_devices_name ON devices(slug);
CREATE INDEX idx_devices_enable ON devices(enable);
CREATE TRIGGER update_devices_timestamp AFTER UPDATE ON devices FOR EACH ROW
BEGIN
UPDATE devices SET updated_at = CURRENT_TIMESTAMP WHERE id = old.id;

View file

@ -33,6 +33,11 @@ BEGIN
UPDATE images SET updated_at = CURRENT_TIMESTAMP WHERE id = old.id;
END;
CREATE TRIGGER update_subreddits_timestamp_on_insert AFTER INSERT ON images FOR EACH ROW
BEGIN
UPDATE subreddits SET updated_at = CURRENT_TIMESTAMP WHERE id = new.subreddit_id;
END;
CREATE INDEX idx_subreddit_id ON images(subreddit_id);
CREATE INDEX idx_nsfw ON images(nsfw);
-- +goose StatementEnd

View file

@ -1,7 +1,7 @@
POST http://localhost:8080/api/v1/devices HTTP/1.1
Host: localhost:8080
Content-Type: application/json
Content-Length: 160
Content-Length: 178
{
"name": "Sync Laptop",
@ -9,5 +9,6 @@ Content-Length: 160
"resolution_x": 1920,
"resolution_y": 1080,
"nsfw": 1,
"aspect_ratio_tolerance": 0.2
"aspect_ratio_tolerance": 0.2,
"enable": 1
}

View file

@ -1 +1,2 @@
GET http://localhost:8080/api/v1/devices HTTP/1.1
Host: localhost:8080

2
rest/events.http Normal file
View file

@ -0,0 +1,2 @@
GET http://localhost:8080/api/v1/events HTTP/1.1
Host: localhost:8080

View file

@ -1,8 +1,8 @@
POST http://localhost:8080/api/v1/subreddits/start HTTP/1.1
Host: localhost:8080
Content-Type: application/json
Content-Length: 30
Content-Length: 29
{
"subreddit": "awoo3"
"subreddit": "awoo"
}

View file

@ -46,6 +46,7 @@ func (routes *Routes) APIDeviceCreate(rw http.ResponseWriter, r *http.Request) {
MaxY: omit.From(body.MaxY),
NSFW: omit.From(body.NSFW),
WindowsWallpaperMode: omit.From(body.WindowsWallpaperMode),
Enable: omit.From(body.Enable),
})
if err != nil {
log.New(ctx).Err(err).Error("failed to create device", "body", body)

View file

@ -70,10 +70,10 @@ func validateSubredditsCreate(body *models.Subreddit) error {
if body.Name == "" {
return errors.New("name is required")
}
if body.Enable > 1 {
body.Enable = 1
} else if body.Enable < 0 {
body.Enable = 0
if body.EnableSchedule > 1 {
body.EnableSchedule = 1
} else if body.EnableSchedule < 0 {
body.EnableSchedule = 0
}
if body.Subtype > 1 {
body.Subtype = 1