api: implemented image download properly

This commit is contained in:
Tigor Hutasuhut 2024-04-27 15:16:14 +07:00
parent 19cd92004c
commit d99e28d997
15 changed files with 221 additions and 99 deletions

View file

@ -53,7 +53,7 @@ const downloadTopic = "subreddit_download"
var watermillLogger = &log.WatermillLogger{} var watermillLogger = &log.WatermillLogger{}
func New(deps Dependencies) *API { func New(deps Dependencies) *API {
ackDeadline := deps.Config.Duration("download.pubsub.ack.deadline") ackDeadline := deps.Config.Duration("pubsub.ack.deadline")
subscriber, err := watermillSql.NewSubscriber(deps.PubsubDB, watermillSql.SubscriberConfig{ subscriber, err := watermillSql.NewSubscriber(deps.PubsubDB, watermillSql.SubscriberConfig{
ConsumerGroup: "redmage", ConsumerGroup: "redmage",
AckDeadline: &ackDeadline, AckDeadline: &ackDeadline,

View file

@ -15,6 +15,19 @@ type ImageMetadata struct {
type ImageKind int type ImageKind int
func (im ImageKind) String() string {
switch im {
case KindThumbnail:
return "Thumbnail"
default:
return "Image"
}
}
func (im ImageKind) MarshalJSON() ([]byte, error) {
return []byte(`"` + im.String() + `"`), nil
}
const ( const (
KindImage ImageKind = iota KindImage ImageKind = iota
KindThumbnail KindThumbnail

View file

@ -5,14 +5,15 @@ import (
"errors" "errors"
"image/jpeg" "image/jpeg"
"io" "io"
"math"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path" "path"
"strings" "strings"
"sync" "sync"
"time"
"github.com/aarondl/opt/omit"
"github.com/disintegration/imaging" "github.com/disintegration/imaging"
"github.com/tigorlazuardi/redmage/api/reddit" "github.com/tigorlazuardi/redmage/api/reddit"
"github.com/tigorlazuardi/redmage/models" "github.com/tigorlazuardi/redmage/models"
@ -21,7 +22,6 @@ import (
"github.com/tigorlazuardi/redmage/pkg/telemetry" "github.com/tigorlazuardi/redmage/pkg/telemetry"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
) )
type DownloadSubredditParams struct { type DownloadSubredditParams struct {
@ -35,22 +35,22 @@ var (
ErrDownloadDirNotSet = errors.New("api: download directory not set") ErrDownloadDirNotSet = errors.New("api: download directory not set")
) )
func (api *API) DownloadSubredditImages(ctx context.Context, subredditName string, params DownloadSubredditParams) error { func (api *API) DownloadSubredditImages(ctx context.Context, subreddit *models.Subreddit, devices models.DeviceSlice) error {
downloadDir := api.config.String("download.directory") downloadDir := api.config.String("download.directory")
if downloadDir == "" { if downloadDir == "" {
return errs.Wrapw(ErrDownloadDirNotSet, "download directory must be set before images can be downloaded").Code(http.StatusBadRequest) return errs.Wrapw(ErrDownloadDirNotSet, "download directory must be set before images can be downloaded").Code(http.StatusBadRequest)
} }
if len(params.Devices) == 0 { if len(devices) == 0 {
return errs.Wrapw(ErrNoDevices, "downloading images requires at least one device configured").Code(http.StatusBadRequest) return errs.Wrapw(ErrNoDevices, "downloading images requires at least one device configured").Code(http.StatusBadRequest)
} }
ctx, span := tracer.Start(ctx, "*API.DownloadSubredditImages", trace.WithAttributes(attribute.String("subreddit", subredditName))) ctx, span := tracer.Start(ctx, "*API.DownloadSubredditImages", trace.WithAttributes(attribute.String("subreddit", subreddit.Name)))
defer span.End() defer span.End()
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
countback := params.Countback countback := int(subreddit.Countback)
var ( var (
list reddit.Listing list reddit.Listing
@ -61,20 +61,20 @@ func (api *API) DownloadSubredditImages(ctx context.Context, subredditName strin
if limit > countback { if limit > countback {
limit = countback limit = countback
} }
log.New(ctx).Info("getting posts", "subreddit_name", subredditName, "limit", limit, "countback", countback) log.New(ctx).Debug("getting posts", "subreddit", subreddit, "current_countback", countback, "current_limit", limit)
list, err = api.reddit.GetPosts(ctx, reddit.GetPostsParam{ list, err = api.reddit.GetPosts(ctx, reddit.GetPostsParam{
Subreddit: subredditName, Subreddit: subreddit.Name,
Limit: limit, Limit: limit,
After: list.GetLastAfter(), After: list.GetLastAfter(),
SubredditType: params.SubredditType, SubredditType: reddit.SubredditType(subreddit.Subtype),
}) })
if err != nil { if err != nil {
return errs.Wrapw(err, "failed to get posts", "subreddit_name", subredditName, "params", params) return errs.Wrapw(err, "failed to get posts", "subreddit", subreddit)
} }
wg.Add(1) wg.Add(1)
go func(ctx context.Context, posts reddit.Listing) { go func(ctx context.Context, posts reddit.Listing) {
defer wg.Done() defer wg.Done()
err := api.downloadSubredditListImage(ctx, list, params) err := api.downloadSubredditListImage(ctx, list, subreddit, devices)
if err != nil { if err != nil {
log.New(ctx).Err(err).Error("failed to download image") log.New(ctx).Err(err).Error("failed to download image")
} }
@ -90,7 +90,7 @@ func (api *API) DownloadSubredditImages(ctx context.Context, subredditName strin
return nil return nil
} }
func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.Listing, params DownloadSubredditParams) error { func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.Listing, subreddit *models.Subreddit, devices models.DeviceSlice) error {
ctx, span := tracer.Start(ctx, "*API.downloadSubredditListImage") ctx, span := tracer.Start(ctx, "*API.downloadSubredditListImage")
defer span.End() defer span.End()
@ -100,10 +100,11 @@ func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.List
if !post.IsImagePost() { if !post.IsImagePost() {
continue continue
} }
devices := api.getDevicesThatAcceptPost(ctx, post, params.Devices) devices := api.getDevicesThatAcceptPost(ctx, post, devices)
if len(devices) == 0 { if len(devices) == 0 {
continue continue
} }
log.New(ctx).Debug("downloading image", "post_id", post.GetID(), "post_url", post.GetImageURL(), "devices", devices)
wg.Add(1) wg.Add(1)
api.imageSemaphore <- struct{}{} api.imageSemaphore <- struct{}{}
go func(ctx context.Context, post reddit.Post) { go func(ctx context.Context, post reddit.Post) {
@ -112,7 +113,7 @@ func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.List
wg.Done() wg.Done()
}() }()
if err := api.downloadSubredditImage(ctx, post, devices); err != nil { if err := api.downloadSubredditImage(ctx, post, subreddit, devices); err != nil {
log.New(ctx).Err(err).Error("failed to download subreddit image") log.New(ctx).Err(err).Error("failed to download subreddit image")
} }
}(ctx, post) }(ctx, post)
@ -123,7 +124,7 @@ func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.List
return nil return nil
} }
func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, devices models.DeviceSlice) error { func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, subreddit *models.Subreddit, devices models.DeviceSlice) error {
ctx, span := tracer.Start(ctx, "*API.downloadSubredditImage") ctx, span := tracer.Start(ctx, "*API.downloadSubredditImage")
defer span.End() defer span.End()
@ -140,15 +141,6 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
} }
defer tmpImageFile.Close() defer tmpImageFile.Close()
w, close, err := api.createDeviceImageWriters(post, devices)
if err != nil {
return errs.Wrapw(err, "failed to create image files")
}
defer close()
_, err = io.Copy(w, tmpImageFile)
if err != nil {
return errs.Wrapw(err, "failed to save image files")
}
thumbnailPath := post.GetThumbnailTargetPath(api.config) thumbnailPath := post.GetThumbnailTargetPath(api.config)
_, errStat := os.Stat(thumbnailPath) _, errStat := os.Stat(thumbnailPath)
if errStat == nil { if errStat == nil {
@ -163,7 +155,11 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
thumbnailSource, err := imaging.Open(tmpImageFile.filename) thumbnailSource, err := imaging.Open(tmpImageFile.filename)
if err != nil { if err != nil {
return errs.Wrapw(err, "failed to open temp thumbnail file", "filename", tmpImageFile.filename) return errs.Wrapw(err, "failed to open temp thumbnail file",
"filename", tmpImageFile.filename,
"post_url", post.GetPermalink(),
"image_url", post.GetImageURL(),
)
} }
thumbnail := imaging.Resize(thumbnailSource, 256, 0, imaging.Lanczos) thumbnail := imaging.Resize(thumbnailSource, 256, 0, imaging.Lanczos)
@ -178,7 +174,46 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, de
return errs.Wrapw(err, "failed to encode thumbnail file to jpeg", "filename", thumbnailPath) return errs.Wrapw(err, "failed to encode thumbnail file to jpeg", "filename", thumbnailPath)
} }
// TODO: create entry to database w, close, err := api.createDeviceImageWriters(post, devices)
if err != nil {
return errs.Wrapw(err, "failed to create image files")
}
log.New(ctx).Debug("saving image files", "post_id", post.GetID(), "post_url", post.GetImageURL(), "devices", devices)
defer close()
_, err = io.Copy(w, tmpImageFile)
if err != nil {
return errs.Wrapw(err, "failed to save image files")
}
var many []*models.ImageSetter
for _, device := range devices {
var nsfw int32
if post.IsNSFW() {
nsfw = 1
}
many = append(many, &models.ImageSetter{
SubredditID: omit.From(subreddit.ID),
DeviceID: omit.From(device.ID),
Title: omit.From(post.GetTitle()),
PostID: omit.From(post.GetID()),
PostURL: omit.From(post.GetImageURL()),
PostCreated: omit.From(post.GetCreated().Format(time.RFC3339)),
PostName: omit.From(post.GetName()),
Poster: omit.From(post.GetAuthor()),
PosterURL: omit.From(post.GetAuthorURL()),
ImageRelativePath: omit.From(post.GetImageRelativePath(device)),
ThumbnailRelativePath: omit.From(post.GetThumbnailRelativePath()),
ImageOriginalURL: omit.From(post.GetImageURL()),
ThumbnailOriginalURL: omit.From(post.GetThumbnailURL()),
NSFW: omit.From(nsfw),
})
}
log.New(ctx).Debug("inserting images to database", "images", many)
_, err = models.Images.InsertMany(ctx, api.db, many...)
if err != nil {
return errs.Wrapw(err, "failed to insert images to database", "params", many)
}
return nil return nil
} }
@ -226,22 +261,11 @@ func (api *API) createDeviceImageWriters(post reddit.Post, devices models.Device
} }
func (api *API) getDevicesThatAcceptPost(ctx context.Context, post reddit.Post, devices models.DeviceSlice) (devs models.DeviceSlice) { func (api *API) getDevicesThatAcceptPost(ctx context.Context, post reddit.Post, devices models.DeviceSlice) (devs models.DeviceSlice) {
var mu sync.Mutex
errgrp, ctx := errgroup.WithContext(ctx)
for _, device := range devices { for _, device := range devices {
if shouldDownloadPostForDevice(post, device) { if shouldDownloadPostForDevice(post, device) && !api.isImageExists(ctx, post, device) {
device := device devs = append(devs, device)
errgrp.Go(func() error {
if !api.isImageExists(ctx, post, device) {
mu.Lock()
defer mu.Unlock()
devs = append(devices, device)
}
return nil
})
} }
} }
_ = errgrp.Wait()
return devs return devs
} }
@ -249,27 +273,58 @@ func (api *API) isImageExists(ctx context.Context, post reddit.Post, device *mod
ctx, span := tracer.Start(ctx, "*API.IsImageExists") ctx, span := tracer.Start(ctx, "*API.IsImageExists")
defer span.End() defer span.End()
// Image does not exist in target image.
if _, err := os.Stat(post.GetImageTargetPath(api.config, device)); err != nil {
return false
}
_, err := models.Images.Query(ctx, api.db, _, err := models.Images.Query(ctx, api.db,
models.SelectWhere.Images.DeviceID.EQ(device.ID), models.SelectWhere.Images.DeviceID.EQ(device.ID),
models.SelectWhere.Images.PostID.EQ(post.GetID()), models.SelectWhere.Images.PostID.EQ(post.GetID()),
).One() ).One()
if err != nil {
if err.Error() == "sql: no rows in result set" {
return false
}
}
return err == nil // Image does not exist in target path.
if _, err := os.Stat(post.GetImageTargetPath(api.config, device)); err != nil {
return false
}
return true
} }
func shouldDownloadPostForDevice(post reddit.Post, device *models.Device) bool { func shouldDownloadPostForDevice(post reddit.Post, device *models.Device) bool {
if post.IsNSFW() && device.NSFW == 0 { if post.IsNSFW() && device.NSFW == 0 {
return false return false
} }
if math.Abs(deviceAspectRatio(device)-post.GetImageAspectRatio()) > device.AspectRatioTolerance { // outside of aspect ratio tolerance devAspectRatio := deviceAspectRatio(device)
rangeStart := devAspectRatio - device.AspectRatioTolerance
rangeEnd := devAspectRatio + device.AspectRatioTolerance
imgAspectRatio := post.GetImageAspectRatio()
width, height := post.GetImageSize()
log.New(context.Background()).Debug("checking image aspect ratio",
"device", device.Slug,
"device_height", device.ResolutionY,
"device_width", device.ResolutionX,
"device_aspect_ratio", devAspectRatio,
"image_aspect_ratio", imgAspectRatio,
"range_start", rangeStart,
"range_end", rangeEnd,
"success_fulfill_download_range_start", (imgAspectRatio > rangeStart),
"success_fulfill_download_range_end", (imgAspectRatio < rangeEnd),
"url", post.GetImageURL(),
"image.width", width,
"image.height", height,
)
if imgAspectRatio < rangeStart {
return false return false
} }
width, height := post.GetImageSize()
if imgAspectRatio > rangeEnd {
return false
}
if device.MaxX > 0 && width > int64(device.MaxX) { if device.MaxX > 0 && width > int64(device.MaxX) {
return false return false
} }

View file

@ -7,7 +7,6 @@ import (
"github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message"
"github.com/tigorlazuardi/redmage/api/reddit"
"github.com/tigorlazuardi/redmage/models" "github.com/tigorlazuardi/redmage/models"
"github.com/tigorlazuardi/redmage/pkg/errs" "github.com/tigorlazuardi/redmage/pkg/errs"
"github.com/tigorlazuardi/redmage/pkg/log" "github.com/tigorlazuardi/redmage/pkg/log"
@ -18,11 +17,11 @@ import (
func (api *API) StartSubredditDownloadPubsub(messages <-chan *message.Message) { func (api *API) StartSubredditDownloadPubsub(messages <-chan *message.Message) {
for msg := range messages { for msg := range messages {
log.New(context.Background()).Info("received pubsub message", log.New(context.Background()).Debug("received pubsub message",
"message", msg, "message", msg,
"len", len(api.subredditSemaphore), "len", len(api.subredditSemaphore),
"cap", cap(api.subredditSemaphore), "cap", cap(api.subredditSemaphore),
"download.concurrency.subreddts", api.config.Int("download.concurrency.subreddits"), "download.concurrency.subreddits", api.config.Int("download.concurrency.subreddits"),
) )
api.subredditSemaphore <- struct{}{} api.subredditSemaphore <- struct{}{}
go func(msg *message.Message) { go func(msg *message.Message) {
@ -50,11 +49,7 @@ func (api *API) StartSubredditDownloadPubsub(messages <-chan *message.Message) {
return return
} }
err = api.DownloadSubredditImages(ctx, subreddit.Name, DownloadSubredditParams{ err = api.DownloadSubredditImages(ctx, subreddit, devices)
Countback: int(subreddit.Countback),
Devices: devices,
SubredditType: reddit.SubredditType(subreddit.Subtype),
})
if err != nil { if err != nil {
log.New(ctx).Err(err).Error("failed to download subreddit images", "subreddit", subreddit) log.New(ctx).Err(err).Error("failed to download subreddit images", "subreddit", subreddit)
return return

View file

@ -2,8 +2,29 @@ package reddit
import ( import (
"net/http" "net/http"
"github.com/tigorlazuardi/redmage/config"
) )
type Client interface { type Client interface {
Do(*http.Request) (*http.Response, error) Do(*http.Request) (*http.Response, error)
} }
func NewRedditHTTPClient(cfg *config.Config) Client {
return &http.Client{
Transport: createRoundTripper(cfg),
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (ro roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return ro(req)
}
func createRoundTripper(cfg *config.Config) roundTripperFunc {
return func(r *http.Request) (*http.Response, error) {
r.Header.Set("User-Agent", cfg.String("download.useragent"))
return http.DefaultTransport.RoundTrip(r)
}
}

View file

@ -76,6 +76,12 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
if err != nil { if err != nil {
return nil, errs.Wrapw(err, "reddit: failed to execute request", "url", url) return nil, errs.Wrapw(err, "reddit: failed to execute request", "url", url)
} }
if resp.StatusCode >= 400 {
return nil, errs.Fail("unexpected status code when trying to download images",
"url", url,
"status_code", resp.StatusCode,
)
}
idleSpeedStr := reddit.Config.String("download.timeout.idlespeed") idleSpeedStr := reddit.Config.String("download.timeout.idlespeed")
metricSpeed, _ := units.ParseMetricBytes(idleSpeedStr) metricSpeed, _ := units.ParseMetricBytes(idleSpeedStr)
if metricSpeed == 0 { if metricSpeed == 0 {
@ -116,7 +122,7 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
ContentLength: units.MetricBytes(resp.ContentLength), ContentLength: units.MetricBytes(resp.ContentLength),
Downloaded: units.MetricBytes(downloaded), Downloaded: units.MetricBytes(downloaded),
Subreddit: post.GetSubreddit(), Subreddit: post.GetSubreddit(),
PostURL: post.GetPermalink(), PostURL: post.GetPostURL(),
PostID: post.GetID(), PostID: post.GetID(),
Error: closeErr, Error: closeErr,
}) })
@ -138,7 +144,7 @@ func (reddit *Reddit) downloadImage(ctx context.Context, post Post, kind bmessag
Kind: kind, Kind: kind,
}, },
Subreddit: post.GetSubreddit(), Subreddit: post.GetSubreddit(),
PostURL: post.GetPermalink(), PostURL: post.GetPostURL(),
PostID: post.GetID(), PostID: post.GetID(),
}) })
_, err := io.Copy(writer, resp.Body) _, err := io.Copy(writer, resp.Body)

View file

@ -8,7 +8,6 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/tigorlazuardi/redmage/pkg/errs" "github.com/tigorlazuardi/redmage/pkg/errs"
) )
@ -68,14 +67,18 @@ func (reddit *Reddit) GetPosts(ctx context.Context, params GetPostsParam) (posts
return posts, errs.Wrapw(err, "reddit: failed to create http request instance", "url", url, "params", params) return posts, errs.Wrapw(err, "reddit: failed to create http request instance", "url", url, "params", params)
} }
req.Header.Set("User-Agent", reddit.Config.String("download.useragent"))
res, err := reddit.Client.Do(req) res, err := reddit.Client.Do(req)
if err != nil { if err != nil {
return posts, errs.Wrapw(err, "reddit: failed to execute http request", "url", url, "params", params) return posts, errs.Wrapw(err, "reddit: failed to execute http request", "url", url, "params", params)
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode == http.StatusTooManyRequests { if res.StatusCode == http.StatusTooManyRequests {
retryAfter, _ := time.ParseDuration(res.Header.Get("Retry-After")) return posts, errs.Fail("reddit: too many requests",
return posts, errs.Fail("reddit: too many requests", "retry_after", retryAfter.String()) "retry_after", res.Header.Get("Retry-After"),
"url", url,
)
} }
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {

View file

@ -6,6 +6,7 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"github.com/tigorlazuardi/redmage/config" "github.com/tigorlazuardi/redmage/config"
"github.com/tigorlazuardi/redmage/models" "github.com/tigorlazuardi/redmage/models"
@ -145,7 +146,7 @@ type PostData struct {
IsCreatedFromAdsUI bool `json:"is_created_from_ads_ui"` IsCreatedFromAdsUI bool `json:"is_created_from_ads_ui"`
AuthorPremium bool `json:"author_premium"` AuthorPremium bool `json:"author_premium"`
Thumbnail string `json:"thumbnail"` Thumbnail string `json:"thumbnail"`
Edited bool `json:"edited"` Edited any `json:"edited"`
AuthorFlairCSSClass string `json:"author_flair_css_class"` AuthorFlairCSSClass string `json:"author_flair_css_class"`
AuthorFlairRichtext []AuthorFlairRichtext `json:"author_flair_richtext"` AuthorFlairRichtext []AuthorFlairRichtext `json:"author_flair_richtext"`
Gildings Gildings `json:"gildings"` Gildings Gildings `json:"gildings"`
@ -231,6 +232,22 @@ func (post *Post) GetImageURL() string {
return post.Data.URL return post.Data.URL
} }
func (post *Post) GetCreated() time.Time {
return time.Unix(int64(post.Data.Created), 0)
}
func (post *Post) GetAuthor() string {
return post.Data.Author
}
func (post *Post) GetTitle() string {
return post.Data.Title
}
func (post *Post) GetAuthorURL() string {
return fmt.Sprintf("https://www.reddit.com/user/%s", post.Data.Author)
}
func (post *Post) GetImageAspectRatio() float64 { func (post *Post) GetImageAspectRatio() float64 {
width, height := post.GetImageSize() width, height := post.GetImageSize()
if height == 0 { if height == 0 {
@ -336,6 +353,10 @@ func (post *Post) GetPermalink() string {
return post.Data.Permalink return post.Data.Permalink
} }
func (post *Post) GetPostURL() string {
return fmt.Sprintf("https://reddit.com/%s", post.Data.Permalink)
}
func (post *Post) GetID() string { func (post *Post) GetID() string {
return post.Data.ID return post.Data.ID
} }

View file

@ -1,8 +1,6 @@
package cli package cli
import ( import (
"fmt"
"github.com/adrg/xdg" "github.com/adrg/xdg"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/tigorlazuardi/redmage/config" "github.com/tigorlazuardi/redmage/config"
@ -29,9 +27,6 @@ func initConfig() {
LoadFlags(RootCmd.PersistentFlags()). LoadFlags(RootCmd.PersistentFlags()).
Build() Build()
fmt.Println("download.concurrency.subreddits", cfg.Get("download.concurrency.subreddits"))
fmt.Println("download.concurrency.images", cfg.Get("download.concurrency.images"))
handler := log.NewHandler(cfg) handler := log.NewHandler(cfg)
log.SetDefault(handler) log.SetDefault(handler)
} }

View file

@ -2,7 +2,6 @@ package cli
import ( import (
"io/fs" "io/fs"
"net/http"
"os" "os"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -30,23 +29,29 @@ var serveCmd = &cobra.Command{
database, err := db.Open(cfg) database, err := db.Open(cfg)
if err != nil { if err != nil {
log.New(cmd.Context()).Err(err).Error("failed to connect database") log.New(cmd.Context()).Err(err).Error("failed to open connection to database")
os.Exit(1) os.Exit(1)
} }
pubsubDatabase, err := db.OpenSilent(cfg) pubsubDb, err := db.OpenPubsub(cfg)
if err != nil {
log.New(cmd.Context()).Err(err).Error("failed to open connection to pubsub database")
os.Exit(1)
}
loggedDb := db.ApplyLogger(cfg, database)
if err != nil { if err != nil {
log.New(cmd.Context()).Err(err).Error("failed to connect database") log.New(cmd.Context()).Err(err).Error("failed to connect database")
os.Exit(1) os.Exit(1)
} }
red := &reddit.Reddit{ red := &reddit.Reddit{
Client: http.DefaultClient, Client: reddit.NewRedditHTTPClient(cfg),
Config: cfg, Config: cfg,
} }
api := api.New(api.Dependencies{ api := api.New(api.Dependencies{
DB: database, DB: loggedDb,
PubsubDB: pubsubDatabase, PubsubDB: pubsubDb,
Config: cfg, Config: cfg,
Reddit: red, Reddit: red,
}) })

View file

@ -15,6 +15,10 @@ var DefaultConfig = map[string]any{
"db.string": "data.db", "db.string": "data.db",
"db.automigrate": true, "db.automigrate": true,
"pubsub.db.driver": "sqlite3",
"pubsub.db.string": "pubsub.db",
"pubsub.ack.deadline": "3h",
"download.concurrency.images": 5, "download.concurrency.images": 5,
"download.concurrency.subreddits": 3, "download.concurrency.subreddits": 3,

View file

@ -19,12 +19,14 @@ var Migrations fs.FS
func Open(cfg *config.Config) (*sql.DB, error) { func Open(cfg *config.Config) (*sql.DB, error) {
driver := cfg.String("db.driver") driver := cfg.String("db.driver")
dsn := cfg.String("db.string") dsn := cfg.String("db.string")
db, err := OpenSilent(cfg) db, err := otelsql.Open(driver, dsn, otelsql.WithAttributes(
semconv.DBSystemSqlite,
))
if err != nil { if err != nil {
return db, err return db, errs.Wrapw(err, "failed to open database", "driver", driver)
} }
if cfg.Bool("db.automigrate") { if cfg.Bool("db.automigrate") {
goose.SetLogger(&gooseLogger{}) goose.SetLogger(goose.NopLogger())
goose.SetBaseFS(Migrations) goose.SetBaseFS(Migrations)
if err := goose.SetDialect(driver); err != nil { if err := goose.SetDialect(driver); err != nil {
@ -35,22 +37,22 @@ func Open(cfg *config.Config) (*sql.DB, error) {
return db, errs.Wrapw(err, "failed to migrate database", "dialect", driver) return db, errs.Wrapw(err, "failed to migrate database", "dialect", driver)
} }
} }
db = sqldblogger.OpenDriver(dsn, db.Driver(), sqlLogger{},
sqldblogger.WithSQLQueryAsMessage(true),
)
return db, err return db, err
} }
func OpenSilent(cfg *config.Config) (*sql.DB, error) { func OpenPubsub(cfg *config.Config) (*sql.DB, error) {
driver := cfg.String("db.driver") driver := cfg.String("pubsub.db.driver")
dsn := cfg.String("db.string") dsn := cfg.String("pubsub.db.string")
db, err := otelsql.Open(driver, dsn, otelsql.WithAttributes( db, err := sql.Open(driver, dsn)
semconv.DBSystemSqlite,
))
if err != nil { if err != nil {
return db, errs.Wrapw(err, "failed to open database", "driver", driver) return db, errs.Wrapw(err, "failed to open database", "driver", driver)
} }
return db, err return db, err
} }
func ApplyLogger(cfg *config.Config, db *sql.DB) *sql.DB {
dsn := cfg.String("db.string")
return sqldblogger.OpenDriver(dsn, db.Driver(), sqlLogger{},
sqldblogger.WithSQLQueryAsMessage(true),
)
}

View file

@ -1,14 +1,16 @@
POST http://localhost:8080/api/v1/devices HTTP/1.1 POST http://localhost:8080/api/v1/devices HTTP/1.1
Host: localhost:8080 Host: localhost:8080
Content-Type: application/json Content-Type: application/json
Content-Length: 178 Content-Length: 211
{ {
"name": "Sync Laptop", "name": "S20FE",
"slug": "sync-l", "slug": "s20fe",
"resolution_x": 1920, "resolution_x": 1080,
"resolution_y": 1080, "resolution_y": 2400,
"nsfw": 1, "nsfw": 1,
"aspect_ratio_tolerance": 0.2, "aspect_ratio_tolerance": 0.2,
"enable": 1 "enable": 1,
"min_x": 1080,
"min_y": 2400
} }

View file

@ -1,10 +1,10 @@
POST http://localhost:8080/api/v1/subreddits HTTP/1.1 POST http://localhost:8080/api/v1/subreddits HTTP/1.1
Host: localhost:8080 Host: localhost:8080
Content-Length: 91 Content-Length: 109
{ {
"name": "awoo", "name": "animemidriff",
"enable": 1, "enable_schedule": 1,
"schedule": "@daily", "schedule": "@daily",
"countback": 10 "countback": 300
} }

View file

@ -1,8 +1,8 @@
POST http://localhost:8080/api/v1/subreddits/start HTTP/1.1 POST http://localhost:8080/api/v1/subreddits/start HTTP/1.1
Host: localhost:8080 Host: localhost:8080
Content-Type: application/json Content-Type: application/json
Content-Length: 29 Content-Length: 35
{ {
"subreddit": "awoo" "subreddit": "fantasymoe"
} }