diff --git a/api/reddit/check_subreddit.go b/api/reddit/check_subreddit.go new file mode 100644 index 0000000..7fbaa39 --- /dev/null +++ b/api/reddit/check_subreddit.go @@ -0,0 +1,96 @@ +package reddit + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/tigorlazuardi/redmage/pkg/errs" +) + +type CheckSubredditParams struct { + Subreddit string `json:"subreddit"` + SubredditType SubredditType `json:"subreddit_type"` +} + +// CheckSubreddit checks a subreddit existence and will return error if subreddit not found. +// +// The actual is the subreddit with proper capitalization if no error is returned. +func (reddit *Reddit) CheckSubreddit(ctx context.Context, params CheckSubredditParams) (actual string, err error) { + ctx, span := tracer.Start(ctx, "*Reddit.CheckSubreddit") + defer span.End() + + url := fmt.Sprintf("https://reddit.com/%s/%s.json?limit=1", params.SubredditType, params.Subreddit) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return actual, errs.Wrapw(err, "failed to create request", "url", url, "params", params) + } + req.Header.Set("User-Agent", reddit.Config.String("download.useragent")) + + resp, err := reddit.Client.Do(req) + if err != nil { + return actual, errs.Wrapw(err, "failed to execute request", "url", url, "params", params) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // This happens for user pages. + // For subreddits, they will be 200 or 301/302 status code and has to be specially handled below. + return actual, errs.Wrapw(err, "user not found", "url", url, "params", params).Code(http.StatusNotFound) + } + + if resp.StatusCode >= 400 { + msg := fmt.Sprintf("unexpected %d status code from reddit", resp.StatusCode) + return actual, errs. + Fail(msg, "url", url, "params", params, "response.status", resp.StatusCode). + Code(http.StatusFailedDependency) + } + + if resp.StatusCode == http.StatusTooManyRequests { + var msg string + dur, _ := time.ParseDuration(resp.Header.Get("Retry-After") + "s") + if dur > 0 { + msg = fmt.Sprintf("too many requests. Please retry after %s", dur) + } else { + msg = "too many requests. Please try again later" + } + return actual, errs.Fail(msg, + "params", params, + "url", url, + "response.location", resp.Request.URL.String(), + ).Code(http.StatusTooManyRequests) + } + if resp.StatusCode >= 400 { + msg := fmt.Sprintf("unexpected %d status code from reddit", resp.StatusCode) + return actual, errs.Fail(msg, + "params", params, + "url", url, + "response.location", resp.Request.URL.String(), + ).Code(http.StatusTooManyRequests) + } + + if resp.Request.URL.Path == "/subreddits/search.json" { + return actual, errs.Fail("subreddit not found", + "params", params, + "url", url, + "response.location", resp.Request.URL.String(), + ).Code(http.StatusNotFound) + } + + var body Listing + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return actual, errs.Wrapw(err, "failed to decode json body") + } + sub := body.GetSubreddit() + if sub == "" { + return actual, errs.Fail("subreddit not found", + "params", params, + "url", url, + "response.location", resp.Request.URL.String(), + ).Code(http.StatusNotFound) + } + + return sub, nil +} diff --git a/api/reddit/download_images.go b/api/reddit/download_images.go index 2a26edd..82d9f83 100644 --- a/api/reddit/download_images.go +++ b/api/reddit/download_images.go @@ -35,6 +35,8 @@ func (po *PostImage) Close() error { // // If downloading image or thumbnail fails func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) { + ctx, span := tracer.Start(ctx, "*Reddit.DownloadImage") + defer span.End() imageUrl := post.GetImageURL() image.URL = imageUrl @@ -43,6 +45,8 @@ func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster } func (reddit *Reddit) DownloadThumbnail(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) { + ctx, span := tracer.Start(ctx, "*Reddit.DownloadThumbnail") + defer span.End() imageUrl := post.GetThumbnailURL() image.URL = imageUrl diff --git a/api/reddit/get_posts.go b/api/reddit/get_posts.go index e2382e4..6484165 100644 --- a/api/reddit/get_posts.go +++ b/api/reddit/get_posts.go @@ -14,6 +14,24 @@ import ( type SubredditType int +func (su *SubredditType) UnmarshalJSON(b []byte) error { + switch string(b) { + case "null": + return nil + case `"user"`, `"u"`, "1": + *su = SubredditTypeUser + return nil + case `"r"`, `"subreddit"`, "0": + *su = SubredditTypeSub + return nil + } + return errs. + Fail("subreddit type not recognized. Valid values are 'user', 'u', 'r', 'subreddit', 0, 1, and null", + "got", string(b), + ). + Code(http.StatusBadRequest) +} + const ( SubredditTypeSub SubredditType = iota SubredditTypeUser @@ -28,6 +46,10 @@ func (s SubredditType) Code() string { } } +func (s SubredditType) String() string { + return s.Code() +} + type GetPostsParam struct { Subreddit string Limit int @@ -36,6 +58,9 @@ type GetPostsParam struct { } func (reddit *Reddit) GetPosts(ctx context.Context, params GetPostsParam) (posts Listing, err error) { + ctx, span := tracer.Start(ctx, "*Reddit.GetPosts") + defer span.End() + url := fmt.Sprintf("https://reddit.com/%s/%s.json?limit=%d&after=%s", params.SubredditType.Code(), params.Subreddit, params.Limit, params.After) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { diff --git a/api/reddit/post.go b/api/reddit/post.go index 9ea0149..37259c8 100644 --- a/api/reddit/post.go +++ b/api/reddit/post.go @@ -19,6 +19,14 @@ func (l *Listing) GetPosts() []Post { return l.Data.Children } +func (l *Listing) GetSubreddit() string { + length := len(l.Data.Children) + if length == 0 { + return "" + } + return l.Data.Children[length-1].Data.Subreddit +} + // GetLastAfter returns the last post namee for pagination. // // Returns empty string if there is no more posts to look up. @@ -143,7 +151,7 @@ type PostData struct { ContentCategories any `json:"content_categories"` IsSelf bool `json:"is_self"` SubredditType string `json:"subreddit_type"` - Created int `json:"created"` + Created float64 `json:"created"` LinkFlairType string `json:"link_flair_type"` Wls int `json:"wls"` RemovedByCategory any `json:"removed_by_category"` @@ -197,7 +205,7 @@ type PostData struct { Stickied bool `json:"stickied"` URL string `json:"url"` SubredditSubscribers int `json:"subreddit_subscribers"` - CreatedUtc int `json:"created_utc"` + CreatedUtc float64 `json:"created_utc"` NumCrossposts int `json:"num_crossposts"` Media any `json:"media"` IsVideo bool `json:"is_video"` diff --git a/api/reddit/tracer.go b/api/reddit/tracer.go new file mode 100644 index 0000000..5dfefb5 --- /dev/null +++ b/api/reddit/tracer.go @@ -0,0 +1,7 @@ +package reddit + +import ( + "go.opentelemetry.io/otel" +) + +var tracer = otel.Tracer("reddit") diff --git a/api/subreddits_check.go b/api/subreddits_check.go new file mode 100644 index 0000000..8fb0a98 --- /dev/null +++ b/api/subreddits_check.go @@ -0,0 +1,16 @@ +package api + +import ( + "context" + + "github.com/tigorlazuardi/redmage/api/reddit" +) + +type SubredditCheckParam = reddit.CheckSubredditParams + +func (api *API) SubredditCheck(ctx context.Context, params SubredditCheckParam) (actual string, err error) { + ctx, span := tracer.Start(ctx, "*API.SubredditCheck") + defer span.End() + + return api.reddit.CheckSubreddit(ctx, params) +} diff --git a/config/default.go b/config/default.go index 37f094a..035b9ec 100644 --- a/config/default.go +++ b/config/default.go @@ -21,6 +21,7 @@ var DefaultConfig = map[string]any{ "download.timeout.headers": "10s", "download.timeout.idle": "5s", "download.timeout.idlespeed": "10KB", + "download.useragent": "redmage", "download.pubsub.ack.deadline": "3h", diff --git a/pkg/errs/errs.go b/pkg/errs/errs.go index 2f630b5..fb4b7f9 100644 --- a/pkg/errs/errs.go +++ b/pkg/errs/errs.go @@ -71,33 +71,22 @@ func (er *Err) LogValue() slog.Value { } func (er *Err) Error() string { - var ( - s = strings.Builder{} - source = er.origin - msg = source.Error() - unwrap = errors.Unwrap(source) - ) - if unwrap == nil { - if er.message != "" { - s.WriteString(er.message) - s.WriteString(": ") - } - s.WriteString(msg) - return s.String() + s := strings.Builder{} + if er.message != "" { + s.WriteString(er.message) } - for unwrap := errors.Unwrap(source); unwrap != nil; source = unwrap { - originMsg := unwrap.Error() - var write string - if cut, found := strings.CutSuffix(msg, originMsg); found { - write = cut - } else { - write = msg + for unwrap := errors.Unwrap(er); unwrap != nil; { + if e, ok := unwrap.(Error); ok && e.GetMessage() != "" { + s.WriteString(e.GetMessage()) + s.WriteString(": ") + continue } - msg = originMsg - if write != "" { - s.WriteString(write) + s.WriteString(unwrap.Error()) + next := errors.Unwrap(unwrap) + if next != nil { s.WriteString(": ") } + unwrap = next } return s.String() } diff --git a/pkg/errs/query.go b/pkg/errs/query.go index 358cb94..5572e6d 100644 --- a/pkg/errs/query.go +++ b/pkg/errs/query.go @@ -1,26 +1,39 @@ package errs -import "errors" +import ( + "errors" +) func FindCodeOrDefault(err error, def int) int { - unwrap := errors.Unwrap(err) - for unwrap != nil { - if coder, ok := err.(interface{ GetCode() int }); ok { + if coder, ok := err.(interface{ GetCode() int }); ok { + code := coder.GetCode() + if code != 0 { + return code + } + } + + for unwrap := errors.Unwrap(err); unwrap != nil; unwrap = errors.Unwrap(unwrap) { + if coder, ok := unwrap.(interface{ GetCode() int }); ok { code := coder.GetCode() if code != 0 { return code } } - unwrap = errors.Unwrap(unwrap) } return def } func FindMessage(err error) string { - unwrap := errors.Unwrap(err) - for unwrap != nil { - if messager, ok := err.(interface{ GetMessage() string }); ok { + if messager, ok := err.(interface{ GetMessage() string }); ok { + message := messager.GetMessage() + if message != "" { + return message + } + } + + for unwrap := errors.Unwrap(err); unwrap != nil; unwrap = errors.Unwrap(unwrap) { + if messager, ok := unwrap.(interface{ GetMessage() string }); ok { message := messager.GetMessage() if message != "" { return message @@ -36,5 +49,6 @@ func HTTPMessage(err error) (code int, message string) { if code >= 500 { return code, err.Error() } - return code, FindMessage(err) + message = FindMessage(err) + return code, message } diff --git a/rest/subreddits/check.http b/rest/subreddits/check.http new file mode 100644 index 0000000..4af0ad3 --- /dev/null +++ b/rest/subreddits/check.http @@ -0,0 +1,8 @@ +POST http://localhost:8080/api/v1/subreddits/check HTTP/1.1 +Host: localhost:8080 +Content-Type: application/json +Content-Length: 37 + +{ + "subreddit": "Wallpapers" +} diff --git a/rest/subreddits/create.http b/rest/subreddits/create.http index ac8645c..1f3b89c 100644 --- a/rest/subreddits/create.http +++ b/rest/subreddits/create.http @@ -1,6 +1,6 @@ POST http://localhost:8080/api/v1/subreddits HTTP/1.1 Host: localhost:8080 -Content-Length: 69 +Content-Length: 91 { "name": "awoo", diff --git a/server/routes/routes.go b/server/routes/routes.go index d3c07a1..0afe43e 100644 --- a/server/routes/routes.go +++ b/server/routes/routes.go @@ -39,6 +39,7 @@ func (routes *Routes) registerV1APIRoutes(router chi.Router) { router.Post("/subreddits/start", routes.SubredditStartDownloadAPI) router.Get("/subreddits", routes.SubredditsListAPI) router.Post("/subreddits", routes.SubredditsCreateAPI) + router.Post("/subreddits/check", routes.SubredditsCheckAPI) router.Get("/devices", routes.APIDeviceList) router.Post("/devices", routes.APIDeviceCreate) diff --git a/server/routes/subreddit_check.go b/server/routes/subreddit_check.go new file mode 100644 index 0000000..7dcaa94 --- /dev/null +++ b/server/routes/subreddit_check.go @@ -0,0 +1,54 @@ +package routes + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/tigorlazuardi/redmage/api" + "github.com/tigorlazuardi/redmage/pkg/errs" + "github.com/tigorlazuardi/redmage/pkg/log" +) + +func (routes *Routes) SubredditsCheckAPI(rw http.ResponseWriter, r *http.Request) { + ctx, span := tracer.Start(r.Context(), "*Routes.SubredditsCheck") + defer span.End() + + var ( + enc = json.NewEncoder(rw) + dec = json.NewDecoder(r.Body) + ) + + var body api.SubredditCheckParam + if err := dec.Decode(&body); err != nil { + rw.WriteHeader(http.StatusBadRequest) + _ = enc.Encode(map[string]string{"error": fmt.Sprintf("failed to decode json body: %s", err)}) + return + } + + if err := validateSubredditCheckParam(body); err != nil { + rw.WriteHeader(http.StatusBadRequest) + _ = enc.Encode(map[string]string{"error": err.Error()}) + return + } + + actual, err := routes.API.SubredditCheck(ctx, body) + if err != nil { + log.New(ctx).Err(err).Error("failed to check subreddit") + code, message := errs.HTTPMessage(err) + rw.WriteHeader(code) + _ = enc.Encode(map[string]string{"error": message}) + return + } + + _ = enc.Encode(map[string]string{"subreddit": actual}) +} + +func validateSubredditCheckParam(body api.SubredditCheckParam) error { + if body.Subreddit == "" { + return errors.New("subreddit name is required") + } + + return nil +}