diff --git a/apis/favourite/api.go b/apis/favourite/api.go index 2bbaa27..dc772d8 100644 --- a/apis/favourite/api.go +++ b/apis/favourite/api.go @@ -198,7 +198,7 @@ func DeleteFavorite(c *fiber.Ctx) error { } var data []int - err = DB.Transaction(func(tx *gorm.DB) error { + err = DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { // delete favorite err = DeleteUserFavorite(tx, userID, body.HoleID, body.FavoriteGroupID) if err != nil { @@ -294,7 +294,7 @@ func AddFavoriteGroup(c *fiber.Ctx) error { } var data FavoriteGroups - err = DB.Transaction(func(tx *gorm.DB) error { + err = DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { // add favorite group err = AddUserFavoriteGroup(tx, userID, body.Name) if err != nil { diff --git a/apis/subscription/api.go b/apis/subscription/api.go index 7735d34..c826760 100644 --- a/apis/subscription/api.go +++ b/apis/subscription/api.go @@ -118,18 +118,18 @@ func DeleteSubscription(c *fiber.Ctx) error { if err != nil { return err } + var data []int - // delete subscriptions - err = DB.Delete(UserSubscription{UserID: userID, HoleID: body.HoleID}).Error - if err != nil { - return err - } + err = DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { + // 删除订阅并更新计数 + if err := RemoveUserSubscription(tx, userID, body.HoleID); err != nil { + return err + } - // create response - data, err := UserGetSubscriptionData(DB, userID) - if err != nil { + var err error + data, err = UserGetSubscriptionData(tx, userID) return err - } + }) return c.JSON(&Response{ Message: "删除成功", diff --git a/models/hole.go b/models/hole.go index 45c2b91..0edac91 100644 --- a/models/hole.go +++ b/models/hole.go @@ -62,6 +62,9 @@ type Hole struct { // UserSubscriptionHoles UserSubscription Users `json:"-" gorm:"many2many:user_subscription;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` + FavoriteCount int `json:"favorite_count" gorm:"not null;default:0"` + SubscriptionCount int `json:"subscription_count" gorm:"not null;default:0"` + /// generated field // 兼容旧版 id @@ -513,3 +516,24 @@ func (hole *Hole) HoleHook() { } } } + +func (hole *Hole) RecalculateStats() { + var favoriteCount int64 + var subscriptionCount int64 + + // Recalculate the stats for the hole + DB.Model(&UserFavorite{}).Where("hole_id = ?", hole.ID).Count(&favoriteCount) + DB.Model(&UserSubscription{}).Where("hole_id = ?", hole.ID).Count(&subscriptionCount) + + hole.FavoriteCount = int(favoriteCount) + hole.SubscriptionCount = int(subscriptionCount) + + // Update the hole in the database + DB.Save(hole) + + // Update the cache + err := utils.SetCache(hole.CacheName(), hole, HoleCacheExpire) + if err != nil { + return + } +} diff --git a/models/user_favorite.go b/models/user_favorite.go index 460c179..31951ba 100644 --- a/models/user_favorite.go +++ b/models/user_favorite.go @@ -114,8 +114,14 @@ func AddUserFavorite(tx *gorm.DB, userID int, holeID int, favoriteGroupID int) e if err != nil { return err } - return tx.Clauses(dbresolver.Write).Model(&FavoriteGroup{}). + err = tx.Clauses(dbresolver.Write).Model(&FavoriteGroup{}). Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", gorm.Expr("count + 1")).Error + if err != nil { + return err + } + return tx.Model(&Hole{}).Where("id = ?", holeID). + UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error + } // UserGetFavoriteData get all favorite data of a user @@ -147,13 +153,39 @@ func DeleteUserFavorite(tx *gorm.DB, userID int, holeID int, favoriteGroupID int if !IsHolesExist(tx, []int{holeID}) { return common.NotFound("帖子不存在") } - return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { - err := tx.Delete(&UserFavorite{UserID: userID, HoleID: holeID, FavoriteGroupID: favoriteGroupID}).Error - if err != nil { + + // 检查记录是否存在 + var count int64 + if err := tx.Model(&UserFavorite{}). + Where("user_id = ? AND hole_id = ? AND favorite_group_id = ?", + userID, holeID, favoriteGroupID). + Count(&count).Error; err != nil { + return err + } + + if count > 0 { + // 删除收藏记录 + if err := tx.Where("user_id = ? AND hole_id = ? AND favorite_group_id = ?", + userID, holeID, favoriteGroupID). + Delete(&UserFavorite{}).Error; err != nil { return err } - return tx.Clauses(dbresolver.Write).Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", gorm.Expr("count - 1")).Error - }) + + // 更新收藏夹计数 + if err := tx.Model(&FavoriteGroup{}). + Where("id = ? AND user_id = ?", favoriteGroupID, userID). + UpdateColumn("count", gorm.Expr("count - ?", 1)).Error; err != nil { + return err + } + + // 更新帖子收藏计数 + if err := tx.Model(&Hole{}).Where("id = ?", holeID). + UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error; err != nil { + return err + } + } + + return nil } // MoveUserFavorite move holes that are really in the fromFavoriteGroup diff --git a/models/user_subscription.go b/models/user_subscription.go index 6305d78..a940cb0 100644 --- a/models/user_subscription.go +++ b/models/user_subscription.go @@ -27,9 +27,50 @@ func UserGetSubscriptionData(tx *gorm.DB, userID int) ([]int, error) { } func AddUserSubscription(tx *gorm.DB, userID int, holeID int) error { - return tx.Clauses(clause.OnConflict{ - DoUpdates: clause.Assignments(Map{"created_at": time.Now()}), + + // 检查是否已存在订阅关系 + var exists int64 + if err := tx.Model(&UserSubscription{}).Where("user_id = ? AND hole_id = ?", userID, holeID).Count(&exists).Error; err != nil { + return err + } + + err := tx.Clauses(clause.OnConflict{ + DoUpdates: clause.Assignments(map[string]interface{}{"created_at": time.Now()}), }).Create(&UserSubscription{ UserID: userID, - HoleID: holeID}).Error + HoleID: holeID, + }).Error + + if err != nil { + return err + } + + if exists == 0 { + return tx.Model(&Hole{}).Where("id = ?", holeID). + UpdateColumn("subscription_count", gorm.Expr("subscription_count + ?", 1)).Error + } + + return nil +} + +func RemoveUserSubscription(tx *gorm.DB, userID int, holeID int) error { + // 检查记录是否存在 + var exists int64 + if err := tx.Model(&UserSubscription{}).Where("user_id = ? AND hole_id = ?", userID, holeID).Count(&exists).Error; err != nil { + return err + } + + // 只有当记录存在时才执行删除和计数更新 + if exists > 0 { + // 删除订阅 + if err := tx.Where("user_id = ? AND hole_id = ?", userID, holeID).Delete(&UserSubscription{}).Error; err != nil { + return err + } + + // 更新订阅计数 + return tx.Model(&Hole{}).Where("id = ?", holeID). + UpdateColumn("subscription_count", gorm.Expr("subscription_count - ?", 1)).Error + } + + return nil } diff --git a/tests/hole_test.go b/tests/hole_test.go index 62b4a78..5713e11 100644 --- a/tests/hole_test.go +++ b/tests/hole_test.go @@ -129,3 +129,14 @@ func TestDeleteHole(t *testing.T) { DB.Where("id = ?", 10).Find(&hole) assert.Equal(t, true, hole.Hidden) } + +func TestHoleStats(t *testing.T) { + + for i := 1; i <= 10; i++ { + var hole Hole + testAPIModel(t, "get", "/api/holes/"+strconv.Itoa(i), 200, &hole) + hole.RecalculateStats() + assert.Equal(t, 1, hole.FavoriteCount) + } + +} diff --git a/tests/init.go b/tests/init.go index 99d86e5..1143272 100644 --- a/tests/init.go +++ b/tests/init.go @@ -1,11 +1,10 @@ package tests import ( + "github.com/rs/zerolog/log" "strconv" "strings" - "github.com/rs/zerolog/log" - "treehole_next/config" . "treehole_next/models" )