subreddit: added check endpoint
This commit is contained in:
parent
bd6092fdee
commit
c6d84e3de2
96
api/reddit/check_subreddit.go
Normal file
96
api/reddit/check_subreddit.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package reddit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/tigorlazuardi/redmage/pkg/errs"
|
||||
)
|
||||
|
||||
type CheckSubredditParams struct {
|
||||
Subreddit string `json:"subreddit"`
|
||||
SubredditType SubredditType `json:"subreddit_type"`
|
||||
}
|
||||
|
||||
// CheckSubreddit checks a subreddit existence and will return error if subreddit not found.
|
||||
//
|
||||
// The actual is the subreddit with proper capitalization if no error is returned.
|
||||
func (reddit *Reddit) CheckSubreddit(ctx context.Context, params CheckSubredditParams) (actual string, err error) {
|
||||
ctx, span := tracer.Start(ctx, "*Reddit.CheckSubreddit")
|
||||
defer span.End()
|
||||
|
||||
url := fmt.Sprintf("https://reddit.com/%s/%s.json?limit=1", params.SubredditType, params.Subreddit)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
if err != nil {
|
||||
return actual, errs.Wrapw(err, "failed to create request", "url", url, "params", params)
|
||||
}
|
||||
req.Header.Set("User-Agent", reddit.Config.String("download.useragent"))
|
||||
|
||||
resp, err := reddit.Client.Do(req)
|
||||
if err != nil {
|
||||
return actual, errs.Wrapw(err, "failed to execute request", "url", url, "params", params)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// This happens for user pages.
|
||||
// For subreddits, they will be 200 or 301/302 status code and has to be specially handled below.
|
||||
return actual, errs.Wrapw(err, "user not found", "url", url, "params", params).Code(http.StatusNotFound)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
msg := fmt.Sprintf("unexpected %d status code from reddit", resp.StatusCode)
|
||||
return actual, errs.
|
||||
Fail(msg, "url", url, "params", params, "response.status", resp.StatusCode).
|
||||
Code(http.StatusFailedDependency)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
var msg string
|
||||
dur, _ := time.ParseDuration(resp.Header.Get("Retry-After") + "s")
|
||||
if dur > 0 {
|
||||
msg = fmt.Sprintf("too many requests. Please retry after %s", dur)
|
||||
} else {
|
||||
msg = "too many requests. Please try again later"
|
||||
}
|
||||
return actual, errs.Fail(msg,
|
||||
"params", params,
|
||||
"url", url,
|
||||
"response.location", resp.Request.URL.String(),
|
||||
).Code(http.StatusTooManyRequests)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
msg := fmt.Sprintf("unexpected %d status code from reddit", resp.StatusCode)
|
||||
return actual, errs.Fail(msg,
|
||||
"params", params,
|
||||
"url", url,
|
||||
"response.location", resp.Request.URL.String(),
|
||||
).Code(http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
if resp.Request.URL.Path == "/subreddits/search.json" {
|
||||
return actual, errs.Fail("subreddit not found",
|
||||
"params", params,
|
||||
"url", url,
|
||||
"response.location", resp.Request.URL.String(),
|
||||
).Code(http.StatusNotFound)
|
||||
}
|
||||
|
||||
var body Listing
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return actual, errs.Wrapw(err, "failed to decode json body")
|
||||
}
|
||||
sub := body.GetSubreddit()
|
||||
if sub == "" {
|
||||
return actual, errs.Fail("subreddit not found",
|
||||
"params", params,
|
||||
"url", url,
|
||||
"response.location", resp.Request.URL.String(),
|
||||
).Code(http.StatusNotFound)
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
|
@ -35,6 +35,8 @@ func (po *PostImage) Close() error {
|
|||
//
|
||||
// If downloading image or thumbnail fails
|
||||
func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) {
|
||||
ctx, span := tracer.Start(ctx, "*Reddit.DownloadImage")
|
||||
defer span.End()
|
||||
imageUrl := post.GetImageURL()
|
||||
image.URL = imageUrl
|
||||
|
||||
|
@ -43,6 +45,8 @@ func (reddit *Reddit) DownloadImage(ctx context.Context, post Post, broadcaster
|
|||
}
|
||||
|
||||
func (reddit *Reddit) DownloadThumbnail(ctx context.Context, post Post, broadcaster DownloadStatusBroadcaster) (image PostImage, err error) {
|
||||
ctx, span := tracer.Start(ctx, "*Reddit.DownloadThumbnail")
|
||||
defer span.End()
|
||||
imageUrl := post.GetThumbnailURL()
|
||||
image.URL = imageUrl
|
||||
|
||||
|
|
|
@ -14,6 +14,24 @@ import (
|
|||
|
||||
type SubredditType int
|
||||
|
||||
func (su *SubredditType) UnmarshalJSON(b []byte) error {
|
||||
switch string(b) {
|
||||
case "null":
|
||||
return nil
|
||||
case `"user"`, `"u"`, "1":
|
||||
*su = SubredditTypeUser
|
||||
return nil
|
||||
case `"r"`, `"subreddit"`, "0":
|
||||
*su = SubredditTypeSub
|
||||
return nil
|
||||
}
|
||||
return errs.
|
||||
Fail("subreddit type not recognized. Valid values are 'user', 'u', 'r', 'subreddit', 0, 1, and null",
|
||||
"got", string(b),
|
||||
).
|
||||
Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
const (
|
||||
SubredditTypeSub SubredditType = iota
|
||||
SubredditTypeUser
|
||||
|
@ -28,6 +46,10 @@ func (s SubredditType) Code() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (s SubredditType) String() string {
|
||||
return s.Code()
|
||||
}
|
||||
|
||||
type GetPostsParam struct {
|
||||
Subreddit string
|
||||
Limit int
|
||||
|
@ -36,6 +58,9 @@ type GetPostsParam struct {
|
|||
}
|
||||
|
||||
func (reddit *Reddit) GetPosts(ctx context.Context, params GetPostsParam) (posts Listing, err error) {
|
||||
ctx, span := tracer.Start(ctx, "*Reddit.GetPosts")
|
||||
defer span.End()
|
||||
|
||||
url := fmt.Sprintf("https://reddit.com/%s/%s.json?limit=%d&after=%s", params.SubredditType.Code(), params.Subreddit, params.Limit, params.After)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
if err != nil {
|
||||
|
|
|
@ -19,6 +19,14 @@ func (l *Listing) GetPosts() []Post {
|
|||
return l.Data.Children
|
||||
}
|
||||
|
||||
func (l *Listing) GetSubreddit() string {
|
||||
length := len(l.Data.Children)
|
||||
if length == 0 {
|
||||
return ""
|
||||
}
|
||||
return l.Data.Children[length-1].Data.Subreddit
|
||||
}
|
||||
|
||||
// GetLastAfter returns the last post namee for pagination.
|
||||
//
|
||||
// Returns empty string if there is no more posts to look up.
|
||||
|
@ -143,7 +151,7 @@ type PostData struct {
|
|||
ContentCategories any `json:"content_categories"`
|
||||
IsSelf bool `json:"is_self"`
|
||||
SubredditType string `json:"subreddit_type"`
|
||||
Created int `json:"created"`
|
||||
Created float64 `json:"created"`
|
||||
LinkFlairType string `json:"link_flair_type"`
|
||||
Wls int `json:"wls"`
|
||||
RemovedByCategory any `json:"removed_by_category"`
|
||||
|
@ -197,7 +205,7 @@ type PostData struct {
|
|||
Stickied bool `json:"stickied"`
|
||||
URL string `json:"url"`
|
||||
SubredditSubscribers int `json:"subreddit_subscribers"`
|
||||
CreatedUtc int `json:"created_utc"`
|
||||
CreatedUtc float64 `json:"created_utc"`
|
||||
NumCrossposts int `json:"num_crossposts"`
|
||||
Media any `json:"media"`
|
||||
IsVideo bool `json:"is_video"`
|
||||
|
|
7
api/reddit/tracer.go
Normal file
7
api/reddit/tracer.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package reddit
|
||||
|
||||
import (
|
||||
"go.opentelemetry.io/otel"
|
||||
)
|
||||
|
||||
var tracer = otel.Tracer("reddit")
|
16
api/subreddits_check.go
Normal file
16
api/subreddits_check.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tigorlazuardi/redmage/api/reddit"
|
||||
)
|
||||
|
||||
type SubredditCheckParam = reddit.CheckSubredditParams
|
||||
|
||||
func (api *API) SubredditCheck(ctx context.Context, params SubredditCheckParam) (actual string, err error) {
|
||||
ctx, span := tracer.Start(ctx, "*API.SubredditCheck")
|
||||
defer span.End()
|
||||
|
||||
return api.reddit.CheckSubreddit(ctx, params)
|
||||
}
|
|
@ -21,6 +21,7 @@ var DefaultConfig = map[string]any{
|
|||
"download.timeout.headers": "10s",
|
||||
"download.timeout.idle": "5s",
|
||||
"download.timeout.idlespeed": "10KB",
|
||||
"download.useragent": "redmage",
|
||||
|
||||
"download.pubsub.ack.deadline": "3h",
|
||||
|
||||
|
|
|
@ -71,33 +71,22 @@ func (er *Err) LogValue() slog.Value {
|
|||
}
|
||||
|
||||
func (er *Err) Error() string {
|
||||
var (
|
||||
s = strings.Builder{}
|
||||
source = er.origin
|
||||
msg = source.Error()
|
||||
unwrap = errors.Unwrap(source)
|
||||
)
|
||||
if unwrap == nil {
|
||||
if er.message != "" {
|
||||
s.WriteString(er.message)
|
||||
s.WriteString(": ")
|
||||
}
|
||||
s.WriteString(msg)
|
||||
return s.String()
|
||||
s := strings.Builder{}
|
||||
if er.message != "" {
|
||||
s.WriteString(er.message)
|
||||
}
|
||||
for unwrap := errors.Unwrap(source); unwrap != nil; source = unwrap {
|
||||
originMsg := unwrap.Error()
|
||||
var write string
|
||||
if cut, found := strings.CutSuffix(msg, originMsg); found {
|
||||
write = cut
|
||||
} else {
|
||||
write = msg
|
||||
for unwrap := errors.Unwrap(er); unwrap != nil; {
|
||||
if e, ok := unwrap.(Error); ok && e.GetMessage() != "" {
|
||||
s.WriteString(e.GetMessage())
|
||||
s.WriteString(": ")
|
||||
continue
|
||||
}
|
||||
msg = originMsg
|
||||
if write != "" {
|
||||
s.WriteString(write)
|
||||
s.WriteString(unwrap.Error())
|
||||
next := errors.Unwrap(unwrap)
|
||||
if next != nil {
|
||||
s.WriteString(": ")
|
||||
}
|
||||
unwrap = next
|
||||
}
|
||||
return s.String()
|
||||
}
|
||||
|
|
|
@ -1,26 +1,39 @@
|
|||
package errs
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func FindCodeOrDefault(err error, def int) int {
|
||||
unwrap := errors.Unwrap(err)
|
||||
for unwrap != nil {
|
||||
if coder, ok := err.(interface{ GetCode() int }); ok {
|
||||
if coder, ok := err.(interface{ GetCode() int }); ok {
|
||||
code := coder.GetCode()
|
||||
if code != 0 {
|
||||
return code
|
||||
}
|
||||
}
|
||||
|
||||
for unwrap := errors.Unwrap(err); unwrap != nil; unwrap = errors.Unwrap(unwrap) {
|
||||
if coder, ok := unwrap.(interface{ GetCode() int }); ok {
|
||||
code := coder.GetCode()
|
||||
if code != 0 {
|
||||
return code
|
||||
}
|
||||
}
|
||||
unwrap = errors.Unwrap(unwrap)
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
func FindMessage(err error) string {
|
||||
unwrap := errors.Unwrap(err)
|
||||
for unwrap != nil {
|
||||
if messager, ok := err.(interface{ GetMessage() string }); ok {
|
||||
if messager, ok := err.(interface{ GetMessage() string }); ok {
|
||||
message := messager.GetMessage()
|
||||
if message != "" {
|
||||
return message
|
||||
}
|
||||
}
|
||||
|
||||
for unwrap := errors.Unwrap(err); unwrap != nil; unwrap = errors.Unwrap(unwrap) {
|
||||
if messager, ok := unwrap.(interface{ GetMessage() string }); ok {
|
||||
message := messager.GetMessage()
|
||||
if message != "" {
|
||||
return message
|
||||
|
@ -36,5 +49,6 @@ func HTTPMessage(err error) (code int, message string) {
|
|||
if code >= 500 {
|
||||
return code, err.Error()
|
||||
}
|
||||
return code, FindMessage(err)
|
||||
message = FindMessage(err)
|
||||
return code, message
|
||||
}
|
||||
|
|
8
rest/subreddits/check.http
Normal file
8
rest/subreddits/check.http
Normal file
|
@ -0,0 +1,8 @@
|
|||
POST http://localhost:8080/api/v1/subreddits/check HTTP/1.1
|
||||
Host: localhost:8080
|
||||
Content-Type: application/json
|
||||
Content-Length: 37
|
||||
|
||||
{
|
||||
"subreddit": "Wallpapers"
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
POST http://localhost:8080/api/v1/subreddits HTTP/1.1
|
||||
Host: localhost:8080
|
||||
Content-Length: 69
|
||||
Content-Length: 91
|
||||
|
||||
{
|
||||
"name": "awoo",
|
||||
|
|
|
@ -39,6 +39,7 @@ func (routes *Routes) registerV1APIRoutes(router chi.Router) {
|
|||
router.Post("/subreddits/start", routes.SubredditStartDownloadAPI)
|
||||
router.Get("/subreddits", routes.SubredditsListAPI)
|
||||
router.Post("/subreddits", routes.SubredditsCreateAPI)
|
||||
router.Post("/subreddits/check", routes.SubredditsCheckAPI)
|
||||
|
||||
router.Get("/devices", routes.APIDeviceList)
|
||||
router.Post("/devices", routes.APIDeviceCreate)
|
||||
|
|
54
server/routes/subreddit_check.go
Normal file
54
server/routes/subreddit_check.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/tigorlazuardi/redmage/api"
|
||||
"github.com/tigorlazuardi/redmage/pkg/errs"
|
||||
"github.com/tigorlazuardi/redmage/pkg/log"
|
||||
)
|
||||
|
||||
func (routes *Routes) SubredditsCheckAPI(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := tracer.Start(r.Context(), "*Routes.SubredditsCheck")
|
||||
defer span.End()
|
||||
|
||||
var (
|
||||
enc = json.NewEncoder(rw)
|
||||
dec = json.NewDecoder(r.Body)
|
||||
)
|
||||
|
||||
var body api.SubredditCheckParam
|
||||
if err := dec.Decode(&body); err != nil {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
_ = enc.Encode(map[string]string{"error": fmt.Sprintf("failed to decode json body: %s", err)})
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateSubredditCheckParam(body); err != nil {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
_ = enc.Encode(map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
actual, err := routes.API.SubredditCheck(ctx, body)
|
||||
if err != nil {
|
||||
log.New(ctx).Err(err).Error("failed to check subreddit")
|
||||
code, message := errs.HTTPMessage(err)
|
||||
rw.WriteHeader(code)
|
||||
_ = enc.Encode(map[string]string{"error": message})
|
||||
return
|
||||
}
|
||||
|
||||
_ = enc.Encode(map[string]string{"subreddit": actual})
|
||||
}
|
||||
|
||||
func validateSubredditCheckParam(body api.SubredditCheckParam) error {
|
||||
if body.Subreddit == "" {
|
||||
return errors.New("subreddit name is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in a new issue