subreddit: added check endpoint

This commit is contained in:
Tigor Hutasuhut 2024-04-26 10:51:09 +07:00
parent bd6092fdee
commit c6d84e3de2
13 changed files with 258 additions and 35 deletions

View file

@ -0,0 +1,96 @@
package reddit
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/tigorlazuardi/redmage/pkg/errs"
)
type CheckSubredditParams struct {
Subreddit string `json:"subreddit"`
SubredditType SubredditType `json:"subreddit_type"`
}
// CheckSubreddit checks a subreddit existence and will return error if subreddit not found.
//
// The actual is the subreddit with proper capitalization if no error is returned.
func (reddit *Reddit) CheckSubreddit(ctx context.Context, params CheckSubredditParams) (actual string, err error) {
ctx, span := tracer.Start(ctx, "*Reddit.CheckSubreddit")
defer span.End()
url := fmt.Sprintf("https://reddit.com/%s/%s.json?limit=1", params.SubredditType, params.Subreddit)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return actual, errs.Wrapw(err, "failed to create request", "url", url, "params", params)
}
req.Header.Set("User-Agent", reddit.Config.String("download.useragent"))
resp, err := reddit.Client.Do(req)
if err != nil {
return actual, errs.Wrapw(err, "failed to execute request", "url", url, "params", params)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// This happens for user pages.
// For subreddits, they will be 200 or 301/302 status code and has to be specially handled below.
return actual, errs.Wrapw(err, "user not found", "url", url, "params", params).Code(http.StatusNotFound)
}
if resp.StatusCode >= 400 {
msg := fmt.Sprintf("unexpected %d status code from reddit", resp.StatusCode)
return actual, errs.
Fail(msg, "url", url, "params", params, "response.status", resp.StatusCode).
Code(http.StatusFailedDependency)
}
if resp.StatusCode == http.StatusTooManyRequests {
var msg string
dur, _ := time.ParseDuration(resp.Header.Get("Retry-After") + "s")
if dur > 0 {
msg = fmt.Sprintf("too many requests. Please retry after %s", dur)
} else {
msg = "too many requests. Please try again later"
}
return actual, errs.Fail(msg,
"params", params,
"url", url,
"response.location", resp.Request.URL.String(),
).Code(http.StatusTooManyRequests)
}
if resp.StatusCode >= 400 {
msg := fmt.Sprintf("unexpected %d status code from reddit", resp.StatusCode)
return actual, errs.Fail(msg,
"params", params,
"url", url,
"response.location", resp.Request.URL.String(),
).Code(http.StatusTooManyRequests)
}
if resp.Request.URL.Path == "/subreddits/search.json" {
return actual, errs.Fail("subreddit not found",
"params", params,
"url", url,
"response.location", resp.Request.URL.String(),
).Code(http.StatusNotFound)
}
var body Listing
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return actual, errs.Wrapw(err, "failed to decode json body")
}
sub := body.GetSubreddit()
if sub == "" {
return actual, errs.Fail("subreddit not found",
"params", params,
"url", url,
"response.location", resp.Request.URL.String(),
).Code(http.StatusNotFound)
}
return sub, nil
}

View file

@ -35,6 +35,8 @@ func (po *PostImage) Close() error {
//
// If downloading image or thumbnail fails
func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) {
ctx, span := tracer.Start(ctx, "*Reddit.DownloadImage")
defer span.End()
imageUrl := post.GetImageURL()
image.URL = imageUrl
@ -43,6 +45,8 @@ func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster
}
func (reddit *Reddit) DownloadThumbnail(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) {
ctx, span := tracer.Start(ctx, "*Reddit.DownloadThumbnail")
defer span.End()
imageUrl := post.GetThumbnailURL()
image.URL = imageUrl

View file

@ -14,6 +14,24 @@ import (
type SubredditType int
func (su *SubredditType) UnmarshalJSON(b []byte) error {
switch string(b) {
case "null":
return nil
case `"user"`, `"u"`, "1":
*su = SubredditTypeUser
return nil
case `"r"`, `"subreddit"`, "0":
*su = SubredditTypeSub
return nil
}
return errs.
Fail("subreddit type not recognized. Valid values are 'user', 'u', 'r', 'subreddit', 0, 1, and null",
"got", string(b),
).
Code(http.StatusBadRequest)
}
const (
SubredditTypeSub SubredditType = iota
SubredditTypeUser
@ -28,6 +46,10 @@ func (s SubredditType) Code() string {
}
}
func (s SubredditType) String() string {
return s.Code()
}
type GetPostsParam struct {
Subreddit string
Limit int
@ -36,6 +58,9 @@ type GetPostsParam struct {
}
func (reddit *Reddit) GetPosts(ctx context.Context, params GetPostsParam) (posts Listing, err error) {
ctx, span := tracer.Start(ctx, "*Reddit.GetPosts")
defer span.End()
url := fmt.Sprintf("https://reddit.com/%s/%s.json?limit=%d&after=%s", params.SubredditType.Code(), params.Subreddit, params.Limit, params.After)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {

View file

@ -19,6 +19,14 @@ func (l *Listing) GetPosts() []Post {
return l.Data.Children
}
func (l *Listing) GetSubreddit() string {
length := len(l.Data.Children)
if length == 0 {
return ""
}
return l.Data.Children[length-1].Data.Subreddit
}
// GetLastAfter returns the last post namee for pagination.
//
// Returns empty string if there is no more posts to look up.
@ -143,7 +151,7 @@ type PostData struct {
ContentCategories any `json:"content_categories"`
IsSelf bool `json:"is_self"`
SubredditType string `json:"subreddit_type"`
Created int `json:"created"`
Created float64 `json:"created"`
LinkFlairType string `json:"link_flair_type"`
Wls int `json:"wls"`
RemovedByCategory any `json:"removed_by_category"`
@ -197,7 +205,7 @@ type PostData struct {
Stickied bool `json:"stickied"`
URL string `json:"url"`
SubredditSubscribers int `json:"subreddit_subscribers"`
CreatedUtc int `json:"created_utc"`
CreatedUtc float64 `json:"created_utc"`
NumCrossposts int `json:"num_crossposts"`
Media any `json:"media"`
IsVideo bool `json:"is_video"`

7
api/reddit/tracer.go Normal file
View file

@ -0,0 +1,7 @@
package reddit
import (
"go.opentelemetry.io/otel"
)
var tracer = otel.Tracer("reddit")

16
api/subreddits_check.go Normal file
View file

@ -0,0 +1,16 @@
package api
import (
"context"
"github.com/tigorlazuardi/redmage/api/reddit"
)
type SubredditCheckParam = reddit.CheckSubredditParams
func (api *API) SubredditCheck(ctx context.Context, params SubredditCheckParam) (actual string, err error) {
ctx, span := tracer.Start(ctx, "*API.SubredditCheck")
defer span.End()
return api.reddit.CheckSubreddit(ctx, params)
}

View file

@ -21,6 +21,7 @@ var DefaultConfig = map[string]any{
"download.timeout.headers": "10s",
"download.timeout.idle": "5s",
"download.timeout.idlespeed": "10KB",
"download.useragent": "redmage",
"download.pubsub.ack.deadline": "3h",

View file

@ -71,33 +71,22 @@ func (er *Err) LogValue() slog.Value {
}
func (er *Err) Error() string {
var (
s = strings.Builder{}
source = er.origin
msg = source.Error()
unwrap = errors.Unwrap(source)
)
if unwrap == nil {
s := strings.Builder{}
if er.message != "" {
s.WriteString(er.message)
}
for unwrap := errors.Unwrap(er); unwrap != nil; {
if e, ok := unwrap.(Error); ok && e.GetMessage() != "" {
s.WriteString(e.GetMessage())
s.WriteString(": ")
continue
}
s.WriteString(unwrap.Error())
next := errors.Unwrap(unwrap)
if next != nil {
s.WriteString(": ")
}
s.WriteString(msg)
return s.String()
}
for unwrap := errors.Unwrap(source); unwrap != nil; source = unwrap {
originMsg := unwrap.Error()
var write string
if cut, found := strings.CutSuffix(msg, originMsg); found {
write = cut
} else {
write = msg
}
msg = originMsg
if write != "" {
s.WriteString(write)
s.WriteString(": ")
}
unwrap = next
}
return s.String()
}

View file

@ -1,31 +1,44 @@
package errs
import "errors"
import (
"errors"
)
func FindCodeOrDefault(err error, def int) int {
unwrap := errors.Unwrap(err)
for unwrap != nil {
if coder, ok := err.(interface{ GetCode() int }); ok {
code := coder.GetCode()
if code != 0 {
return code
}
}
unwrap = errors.Unwrap(unwrap)
for unwrap := errors.Unwrap(err); unwrap != nil; unwrap = errors.Unwrap(unwrap) {
if coder, ok := unwrap.(interface{ GetCode() int }); ok {
code := coder.GetCode()
if code != 0 {
return code
}
}
}
return def
}
func FindMessage(err error) string {
unwrap := errors.Unwrap(err)
for unwrap != nil {
if messager, ok := err.(interface{ GetMessage() string }); ok {
message := messager.GetMessage()
if message != "" {
return message
}
}
for unwrap := errors.Unwrap(err); unwrap != nil; unwrap = errors.Unwrap(unwrap) {
if messager, ok := unwrap.(interface{ GetMessage() string }); ok {
message := messager.GetMessage()
if message != "" {
return message
}
}
}
return err.Error()
@ -36,5 +49,6 @@ func HTTPMessage(err error) (code int, message string) {
if code >= 500 {
return code, err.Error()
}
return code, FindMessage(err)
message = FindMessage(err)
return code, message
}

View file

@ -0,0 +1,8 @@
POST http://localhost:8080/api/v1/subreddits/check HTTP/1.1
Host: localhost:8080
Content-Type: application/json
Content-Length: 37
{
"subreddit": "Wallpapers"
}

View file

@ -1,6 +1,6 @@
POST http://localhost:8080/api/v1/subreddits HTTP/1.1
Host: localhost:8080
Content-Length: 69
Content-Length: 91
{
"name": "awoo",

View file

@ -39,6 +39,7 @@ func (routes *Routes) registerV1APIRoutes(router chi.Router) {
router.Post("/subreddits/start", routes.SubredditStartDownloadAPI)
router.Get("/subreddits", routes.SubredditsListAPI)
router.Post("/subreddits", routes.SubredditsCreateAPI)
router.Post("/subreddits/check", routes.SubredditsCheckAPI)
router.Get("/devices", routes.APIDeviceList)
router.Post("/devices", routes.APIDeviceCreate)

View file

@ -0,0 +1,54 @@
package routes
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/tigorlazuardi/redmage/api"
"github.com/tigorlazuardi/redmage/pkg/errs"
"github.com/tigorlazuardi/redmage/pkg/log"
)
func (routes *Routes) SubredditsCheckAPI(rw http.ResponseWriter, r *http.Request) {
ctx, span := tracer.Start(r.Context(), "*Routes.SubredditsCheck")
defer span.End()
var (
enc = json.NewEncoder(rw)
dec = json.NewDecoder(r.Body)
)
var body api.SubredditCheckParam
if err := dec.Decode(&body); err != nil {
rw.WriteHeader(http.StatusBadRequest)
_ = enc.Encode(map[string]string{"error": fmt.Sprintf("failed to decode json body: %s", err)})
return
}
if err := validateSubredditCheckParam(body); err != nil {
rw.WriteHeader(http.StatusBadRequest)
_ = enc.Encode(map[string]string{"error": err.Error()})
return
}
actual, err := routes.API.SubredditCheck(ctx, body)
if err != nil {
log.New(ctx).Err(err).Error("failed to check subreddit")
code, message := errs.HTTPMessage(err)
rw.WriteHeader(code)
_ = enc.Encode(map[string]string{"error": message})
return
}
_ = enc.Encode(map[string]string{"subreddit": actual})
}
func validateSubredditCheckParam(body api.SubredditCheckParam) error {
if body.Subreddit == "" {
return errors.New("subreddit name is required")
}
return nil
}