api: refactor db instance

This commit is contained in:
Tigor Hutasuhut 2024-04-25 13:28:35 +07:00
parent 847b9c1ab9
commit e5bfe7a29a
7 changed files with 12 additions and 14 deletions

View file

@ -23,8 +23,7 @@ import (
) )
type API struct { type API struct {
db *sql.DB db bob.Executor
exec bob.Executor
scheduler *cron.Cron scheduler *cron.Cron
scheduleMap map[cron.EntryID]*models.Subreddit scheduleMap map[cron.EntryID]*models.Subreddit
@ -78,8 +77,7 @@ func New(deps Dependencies) *API {
panic(err) panic(err)
} }
api := &API{ api := &API{
db: deps.DB, db: bob.New(deps.DB),
exec: bob.New(deps.DB),
scheduler: cron.New(), scheduler: cron.New(),
scheduleMap: make(map[cron.EntryID]*models.Subreddit, 8), scheduleMap: make(map[cron.EntryID]*models.Subreddit, 8),
downloadBroadcast: broadcast.NewRelay[bmessage.ImageDownloadMessage](), downloadBroadcast: broadcast.NewRelay[bmessage.ImageDownloadMessage](),
@ -96,7 +94,7 @@ func New(deps Dependencies) *API {
} }
func (api *API) StartScheduler(ctx context.Context) error { func (api *API) StartScheduler(ctx context.Context) error {
subreddits, err := models.Subreddits.Query(ctx, api.exec, nil).All() subreddits, err := models.Subreddits.Query(ctx, api.db, nil).All()
if err != nil { if err != nil {
return errs.Wrapw(err, "failed to get all subreddits") return errs.Wrapw(err, "failed to get all subreddits")
} }

View file

@ -15,7 +15,7 @@ func (api *API) DevicesCreate(ctx context.Context, params *DeviceCreateParams) (
ctx, span := tracer.Start(ctx, "*API.DevicesCreate") ctx, span := tracer.Start(ctx, "*API.DevicesCreate")
defer span.End() defer span.End()
device, err := models.Devices.Insert(ctx, api.exec, params) device, err := models.Devices.Insert(ctx, api.db, params)
if err != nil { if err != nil {
var sqliteErr sqlite3.Error var sqliteErr sqlite3.Error
if errors.As(err, &sqliteErr) { if errors.As(err, &sqliteErr) {

View file

@ -68,12 +68,12 @@ func (api *API) DevicesList(ctx context.Context, params DevicesListParams) (resu
ctx, span := tracer.Start(ctx, "*API.DevicesList") ctx, span := tracer.Start(ctx, "*API.DevicesList")
defer span.End() defer span.End()
result.Devices, err = models.Devices.Query(ctx, api.exec, params.Query()...).All() result.Devices, err = models.Devices.Query(ctx, api.db, params.Query()...).All()
if err != nil { if err != nil {
return result, errs.Wrapw(err, "failed to query devices", "params", params) return result, errs.Wrapw(err, "failed to query devices", "params", params)
} }
result.Total, err = models.Devices.Query(ctx, api.exec, params.CountQuery()...).Count() result.Total, err = models.Devices.Query(ctx, api.db, params.CountQuery()...).Count()
if err != nil { if err != nil {
return result, errs.Wrapw(err, "failed to count devices", "params", params) return result, errs.Wrapw(err, "failed to count devices", "params", params)
} }

View file

@ -13,7 +13,7 @@ func (api *API) DevicesUpdate(ctx context.Context, id int, update *models.Device
ctx, span := tracer.Start(ctx, "*API.DevicesUpdate") ctx, span := tracer.Start(ctx, "*API.DevicesUpdate")
defer span.End() defer span.End()
err = models.Devices.Update(ctx, api.exec, update, &models.Device{ID: int32(id)}) err = models.Devices.Update(ctx, api.db, update, &models.Device{ID: int32(id)})
if err != nil { if err != nil {
var sqliteErr sqlite3.Error var sqliteErr sqlite3.Error
if errors.As(err, &sqliteErr) { if errors.As(err, &sqliteErr) {

View file

@ -240,7 +240,7 @@ func (api *API) isImageExists(ctx context.Context, post reddit.Post, device *mod
return false return false
} }
_, err := models.Images.Query(ctx, api.exec, _, err := models.Images.Query(ctx, api.db,
models.SelectWhere.Images.DeviceID.EQ(device.ID), models.SelectWhere.Images.DeviceID.EQ(device.ID),
models.SelectWhere.Images.PostID.EQ(post.GetID()), models.SelectWhere.Images.PostID.EQ(post.GetID()),
).One() ).One()

View file

@ -26,13 +26,13 @@ func (api *API) startSubredditDownloadPubsub(messages <-chan *message.Message) {
subredditName := string(msg.Payload) subredditName := string(msg.Payload)
span.SetAttributes(attribute.String("subreddit", subredditName)) span.SetAttributes(attribute.String("subreddit", subredditName))
subreddit, err := models.Subreddits.Query(ctx, api.exec, models.SelectWhere.Subreddits.Name.EQ(subredditName)).One() subreddit, err := models.Subreddits.Query(ctx, api.db, models.SelectWhere.Subreddits.Name.EQ(subredditName)).One()
if err != nil { if err != nil {
log.New(ctx).Err(err).Error("failed to find subreddit", "subreddit", subredditName) log.New(ctx).Err(err).Error("failed to find subreddit", "subreddit", subredditName)
return return
} }
devices, err := models.Devices.Query(ctx, api.exec).All() devices, err := models.Devices.Query(ctx, api.db).All()
if err != nil { if err != nil {
log.New(ctx).Err(err).Error("failed to query devices") log.New(ctx).Err(err).Error("failed to query devices")
return return

View file

@ -56,12 +56,12 @@ func (api *API) ListSubreddits(ctx context.Context, arg ListSubredditsParams) (r
ctx, span := tracer.Start(ctx, "api.ListSubreddits") ctx, span := tracer.Start(ctx, "api.ListSubreddits")
defer span.End() defer span.End()
result.Data, err = models.Subreddits.Query(ctx, api.exec, arg.Query()...).All() result.Data, err = models.Subreddits.Query(ctx, api.db, arg.Query()...).All()
if err != nil { if err != nil {
return result, errs.Wrapw(err, "failed to list subreddits", "query", arg) return result, errs.Wrapw(err, "failed to list subreddits", "query", arg)
} }
result.Total, err = models.Subreddits.Query(ctx, api.exec, arg.CountQuery()...).Count() result.Total, err = models.Subreddits.Query(ctx, api.db, arg.CountQuery()...).Count()
if err != nil { if err != nil {
return result, errs.Wrapw(err, "failed to count subreddits", "query", arg) return result, errs.Wrapw(err, "failed to count subreddits", "query", arg)
} }