diff --git a/api/download_subreddit_images.go b/api/download_subreddit_images.go index cff0390..30556b3 100644 --- a/api/download_subreddit_images.go +++ b/api/download_subreddit_images.go @@ -116,6 +116,7 @@ func (api *API) downloadSubredditListImage(ctx context.Context, list reddit.List }() if imageFile := api.findImageFileForDevices(ctx, post, devices); imageFile != nil { + defer cleanup(imageFile) err := api.saveImageToFSAndDatabase(ctx, imageFile, subreddit, post, acceptedDevices) if err != nil { log.New(ctx).Err(err).Error("failed to download subreddit image") @@ -149,7 +150,7 @@ func (api *API) downloadSubredditImage(ctx context.Context, post reddit.Post, su if err != nil { return errs.Wrapw(err, "failed to download image to temp file") } - defer tmpImageFile.Close() + defer cleanup(tmpImageFile) thumbnailPath := post.GetThumbnailTargetPath(api.config) _, errStat := os.Stat(thumbnailPath) @@ -333,11 +334,12 @@ func (api *API) isImageEntryExists(ctx context.Context, post reddit.Post, device // Return nil if no image file exists for the devices. // // Ensure to close the file after use. -func (api *API) findImageFileForDevices(ctx context.Context, post reddit.Post, devices models.DeviceSlice) (oldImageFile *os.File) { +func (api *API) findImageFileForDevices(ctx context.Context, post reddit.Post, devices models.DeviceSlice) *os.File { for _, device := range devices { stat, err := os.Stat(post.GetImageTargetPath(api.config, device)) if err == nil { - oldImageFile, err = os.Open(post.GetImageTargetPath(api.config, device)) + var err error + oldImageFile, err := os.Open(post.GetImageTargetPath(api.config, device)) if err != nil { log.New(ctx).Err(err).Error("failed to open image file", "filename", post.GetImageTargetPath(api.config, device)) return nil @@ -438,6 +440,10 @@ func (te *tempFile) Close() error { return te.file.Close() } +func (te *tempFile) Name() string { + return te.filename +} + // copyImageToTempDir copies the image to a temporary directory and returns the file handle // // file must be closed by the caller after use. @@ -494,3 +500,13 @@ func (api *API) copyImageToTempDir(ctx context.Context, img reddit.PostImage) (t filename: tmpFilename, }, err } + +type removeableFile interface { + io.ReadCloser + Name() string +} + +func cleanup(file removeableFile) { + _ = file.Close() + _ = os.Remove(file.Name()) +}