diff --git a/api/api.go b/api/api.go index 91a5c63..18f2e9c 100644 --- a/api/api.go +++ b/api/api.go @@ -3,6 +3,7 @@ package api import ( "context" "database/sql" + "sync" "github.com/stephenafamo/bob" "github.com/teivah/broadcast" @@ -32,6 +33,8 @@ type API struct { subscriber message.Subscriber publisher message.Publisher + + mu *sync.Mutex } type Dependencies struct { @@ -59,6 +62,7 @@ func New(deps Dependencies) *API { reddit: deps.Reddit, subscriber: deps.Subscriber, publisher: deps.Publisher, + mu: &sync.Mutex{}, } api.scheduler = scheduler.New(api.scheduleRun) diff --git a/api/devices_create.go b/api/devices_create.go index bf8929c..6fb4c66 100644 --- a/api/devices_create.go +++ b/api/devices_create.go @@ -13,26 +13,28 @@ import ( type DeviceCreateParams = models.DeviceSetter -func (api *API) DevicesCreate(ctx context.Context, params *models.Device) (*models.Device, error) { +func (api *API) DevicesCreate(ctx context.Context, params *models.Device) (device *models.Device, err error) { ctx, span := tracer.Start(ctx, "*API.DevicesCreate") defer span.End() now := time.Now() - device, err := models.Devices.Insert(ctx, api.db, &models.DeviceSetter{ - Slug: omit.From(params.Slug), - Name: omit.From(params.Name), - ResolutionX: omit.From(params.ResolutionX), - ResolutionY: omit.From(params.ResolutionY), - AspectRatioTolerance: omit.From(params.AspectRatioTolerance), - MinX: omit.From(params.MinX), - MinY: omit.From(params.MinY), - MaxX: omit.From(params.MaxX), - MaxY: omit.From(params.MaxY), - NSFW: omit.From(params.NSFW), - WindowsWallpaperMode: omit.From(params.WindowsWallpaperMode), - Enable: omit.From(params.Enable), - CreatedAt: omit.From(now.Unix()), - UpdatedAt: omit.From(now.Unix()), + api.lockf(func() { + device, err = models.Devices.Insert(ctx, api.db, &models.DeviceSetter{ + Slug: omit.From(params.Slug), + Name: omit.From(params.Name), + ResolutionX: omit.From(params.ResolutionX), + ResolutionY: omit.From(params.ResolutionY), + AspectRatioTolerance: omit.From(params.AspectRatioTolerance), + MinX: omit.From(params.MinX), + MinY: omit.From(params.MinY), + MaxX: omit.From(params.MaxX), + MaxY: omit.From(params.MaxY), + NSFW: omit.From(params.NSFW), + WindowsWallpaperMode: omit.From(params.WindowsWallpaperMode), + Enable: omit.From(params.Enable), + CreatedAt: omit.From(now.Unix()), + UpdatedAt: omit.From(now.Unix()), + }) }) if err != nil { var sqliteErr sqlite3.Error diff --git a/api/devices_update.go b/api/devices_update.go index 9bc2850..5baa0d0 100644 --- a/api/devices_update.go +++ b/api/devices_update.go @@ -15,7 +15,9 @@ func (api *API) DevicesUpdate(ctx context.Context, slug string, update *models.D device = &models.Device{Slug: slug} - err = models.Devices.Update(ctx, api.db, update, device) + api.lockf(func() { + err = models.Devices.Update(ctx, api.db, update, device) + }) if err != nil { var sqliteErr sqlite3.Error if errors.As(err, &sqliteErr) { diff --git a/api/download_subreddit_images.go b/api/download_subreddit_images.go index f4c2a8b..8bddf12 100644 --- a/api/download_subreddit_images.go +++ b/api/download_subreddit_images.go @@ -246,7 +246,9 @@ func (api *API) saveImageToFSAndDatabase(ctx context.Context, image io.ReadClose } log.New(ctx).Debug("inserting images to database", "images", many) - _, err = models.Images.InsertMany(ctx, api.db, many...) + api.lockf(func() { + _, err = models.Images.InsertMany(ctx, api.db, many...) + }) if err != nil { return errs.Wrapw(err, "failed to insert images to database", "params", many) } diff --git a/api/lock.go b/api/lock.go new file mode 100644 index 0000000..d75846d --- /dev/null +++ b/api/lock.go @@ -0,0 +1,13 @@ +package api + +// lockf is a helper function to ensure to +// stop other goroutines from accessing the +// same resources at the same time. +// +// e.g. Use this function to wrap any write +// database calls to avoid `database locked error` +func (api *API) lockf(f func()) { + api.mu.Lock() + defer api.mu.Unlock() + f() +} diff --git a/api/schedule_history_insert.go b/api/schedule_history_insert.go index 0267acd..ee584c2 100644 --- a/api/schedule_history_insert.go +++ b/api/schedule_history_insert.go @@ -23,11 +23,13 @@ func (api *API) scheduleHistoryInsert(ctx context.Context, exec bob.Executor, pa now := time.Now() - history, err = models.ScheduleHistories.Insert(ctx, exec, &models.ScheduleHistorySetter{ - Subreddit: omit.FromCond(params.Subreddit, params.Subreddit != ""), - Status: omit.From(params.Status.Int8()), - ErrorMessage: omit.FromCond(params.ErrorMessage, params.Status == ScheduleStatusError), - CreatedAt: omit.From(now.Unix()), + api.lockf(func() { + history, err = models.ScheduleHistories.Insert(ctx, exec, &models.ScheduleHistorySetter{ + Subreddit: omit.FromCond(params.Subreddit, params.Subreddit != ""), + Status: omit.From(params.Status.Int8()), + ErrorMessage: omit.FromCond(params.ErrorMessage, params.Status == ScheduleStatusError), + CreatedAt: omit.From(now.Unix()), + }) }) if err != nil { return history, errs.Wrapw(err, "failed to insert schedule history", "params", params) diff --git a/api/subreddits_create.go b/api/subreddits_create.go index 2305139..bd2d1ad 100644 --- a/api/subreddits_create.go +++ b/api/subreddits_create.go @@ -27,7 +27,9 @@ func (api *API) SubredditsCreate(ctx context.Context, params *models.Subreddit) UpdatedAt: omit.From(now.Unix()), } - subreddit, err = models.Subreddits.Insert(ctx, api.db, set) + api.lockf(func() { + subreddit, err = models.Subreddits.Insert(ctx, api.db, set) + }) if err != nil { var sqliteErr sqlite3.Error if errors.As(err, &sqliteErr) { diff --git a/db/db.go b/db/db.go index 93aff1a..9fc4822 100644 --- a/db/db.go +++ b/db/db.go @@ -38,9 +38,6 @@ func Open(cfg *config.Config) (*sql.DB, error) { if err != nil { return db, errs.Wrapw(err, "failed to open database", "driver", driver, "db.string", dsn) } - if driver == "sqlite3" { - db.SetMaxOpenConns(1) // SQLITE is not thread safe. This is to prevent database is locked error. - } if cfg.Bool("db.automigrate") { goose.SetLogger(goose.NopLogger()) goose.SetBaseFS(Migrations)