diff --git a/internal/postgres/post_test.go b/internal/postgres/post_test.go index fba733a..735f6b5 100644 --- a/internal/postgres/post_test.go +++ b/internal/postgres/post_test.go @@ -176,6 +176,7 @@ func TestGetPostBySourceURL(t *testing.T) { BaseModel: models.BaseModel[models.AnthrovePostID]{ ID: models.AnthrovePostID(fmt.Sprintf("%025s", "1")), }, + Rating: "safe", } @@ -198,7 +199,7 @@ func TestGetPostBySourceURL(t *testing.T) { t.Fatal("Could not create source", err) } - err = CreateReferenceBetweenPostAndSource(ctx, gormDB, post.ID, models.AnthroveSourceDomain(source.Domain)) + err = CreateReferenceBetweenPostAndSource(ctx, gormDB, post.ID, models.AnthroveSourceDomain(source.Domain), "http://test.org") if err != nil { t.Fatal("Could not create source reference", err) } @@ -220,7 +221,7 @@ func TestGetPostBySourceURL(t *testing.T) { args: args{ ctx: ctx, db: gormDB, - sourceURL: source.Domain, + sourceURL: "http://test.org", }, want: post, wantErr: false, @@ -297,7 +298,7 @@ func TestGetPostBySourceID(t *testing.T) { t.Fatal("Could not create source", err) } - err = CreateReferenceBetweenPostAndSource(ctx, gormDB, post.ID, models.AnthroveSourceDomain(source.Domain)) + err = CreateReferenceBetweenPostAndSource(ctx, gormDB, post.ID, models.AnthroveSourceDomain(source.Domain), "http://test.otg") if err != nil { t.Fatal("Could not create source reference", err) } diff --git a/internal/postgres/user.go b/internal/postgres/user.go index 404c32c..aa5d5a5 100644 --- a/internal/postgres/user.go +++ b/internal/postgres/user.go @@ -283,8 +283,15 @@ func GetUserFavoriteWithPagination(ctx context.Context, db *gorm.DB, anthroveUse return &models.FavoriteList{Posts: favoritePosts}, nil } +// Workaround, should be changed later maybe, but its not that bad right now +type selectFrequencyTag struct { + tagName string `gorm:"tag_name"` + count int64 `gorm:"count"` + tagType models.TagType `gorm:"tag_type"` +} + func GetUserTagWitRelationToFavedPosts(ctx context.Context, db *gorm.DB, anthroveUserID models.AnthroveUserID) ([]models.TagsWithFrequency, error) { - var userFavorites []models.UserFavorites + var queryUserFavorites []selectFrequencyTag if anthroveUserID == "" { return nil, &otterError.EntityValidationFailed{Reason: otterError.AnthroveUserIDIsEmpty} @@ -294,7 +301,11 @@ func GetUserTagWitRelationToFavedPosts(ctx context.Context, db *gorm.DB, anthrov return nil, &otterError.EntityValidationFailed{Reason: otterError.AnthroveUserIDToShort} } - result := db.WithContext(ctx).Where("user_id = ?", string(anthroveUserID)).Find(&userFavorites) + result := db.WithContext(ctx).Raw( + `WITH user_posts AS ( + SELECT post_id FROM "UserFavorites" WHERE user_id = $1 + ) + SELECT post_tags.tag_name AS tag_name, count(*) AS count, (SELECT tag_type FROM "Tag" WHERE "Tag".name = post_tags.tag_name LIMIT 1) AS tag_type FROM post_tags, user_posts WHERE post_tags.post_id IN (user_posts.post_id) GROUP BY post_tags.tag_name`, anthroveUserID).Scan(&queryUserFavorites) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, &otterError.NoDataFound{} @@ -302,47 +313,17 @@ func GetUserTagWitRelationToFavedPosts(ctx context.Context, db *gorm.DB, anthrov return nil, result.Error } - tagFrequency := make(map[struct { - name string - typeName string - }]int) - - for _, userFavorite := range userFavorites { - var post models.Post - result = db.WithContext(ctx).Preload("Tags", func(db *gorm.DB) *gorm.DB { - return db.Order("tag_type ASC") - }).First(&post, "id = ?", userFavorite.PostID) - - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, &otterError.NoDataFound{} - } - return nil, result.Error - } - - for _, tag := range post.Tags { - tagFrequency[struct { - name string - typeName string - }{name: tag.Name, typeName: string(tag.Type)}]++ - } - } - - var tagsWithFrequency []models.TagsWithFrequency - for data, frequency := range tagFrequency { - tagsWithFrequency = append(tagsWithFrequency, models.TagsWithFrequency{ - Frequency: int64(frequency), - Tags: models.Tag{ - Name: data.name, - Type: models.TagType(data.typeName), - }, - }) + var userFavoritesFrequency = make([]models.TagsWithFrequency, len(queryUserFavorites)) + for i, query := range queryUserFavorites { + userFavoritesFrequency[i].Frequency = query.count + userFavoritesFrequency[i].Tags.Name = query.tagName + userFavoritesFrequency[i].Tags.Type = query.tagType } log.WithFields(log.Fields{ "anthrove_user_id": anthroveUserID, - "tag_amount": len(tagsWithFrequency), + "tag_amount": len(queryUserFavorites), }).Trace("database: got user tag node with relation to faved posts") - return tagsWithFrequency, nil + return userFavoritesFrequency, nil } diff --git a/internal/postgres/user_test.go b/internal/postgres/user_test.go index dd0932b..168829a 100644 --- a/internal/postgres/user_test.go +++ b/internal/postgres/user_test.go @@ -162,7 +162,7 @@ func TestCreateUserNodeWithSourceRelation(t *testing.T) { userID: "e1", username: "marius", }, - wantErr: true, + wantErr: false, }, { name: "Test 5: no userID", @@ -270,30 +270,33 @@ func TestGetUserSourceBySourceID(t *testing.T) { validUserID := models.AnthroveUserID(fmt.Sprintf("%025s", "User1")) invalidUserID := models.AnthroveUserID("XXX") + validSourceID := models.AnthroveSourceID(fmt.Sprintf("%025s", "Source1")) + + source := &models.Source{ + BaseModel: models.BaseModel[models.AnthroveSourceID]{ + ID: validSourceID, + }, + DisplayName: "e621", + Domain: "e621.net", + } + expectedResult := make(map[string]models.UserSource) expectedResult["e621"] = models.UserSource{ UserID: "e1", AccountUsername: "euser", Source: models.Source{ - DisplayName: "e621", - Domain: "e621.net", + DisplayName: source.DisplayName, + Domain: source.Domain, + Icon: source.Icon, }, } - source := &models.Source{ - BaseModel: models.BaseModel[models.AnthroveSourceID]{ - ID: expectedResult["e621"].Source.ID, - }, - DisplayName: expectedResult["e621"].Source.DisplayName, - Domain: expectedResult["e621"].Source.Domain, - } - err = CreateSource(ctx, gormDB, source) if err != nil { t.Fatal(err) } - err = CreateUserWithRelationToSource(ctx, gormDB, validUserID, models.AnthroveSourceID(expectedResult["e621"].SourceID), expectedResult["e621"].UserID, expectedResult["e621"].AccountUsername) + err = CreateUserWithRelationToSource(ctx, gormDB, validUserID, validSourceID, expectedResult["e621"].UserID, expectedResult["e621"].AccountUsername) if err != nil { t.Fatal(err) } @@ -389,6 +392,7 @@ func TestGetUserSourceBySourceID(t *testing.T) { wantErr: true, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := GetUserSourceBySourceID(tt.args.ctx, tt.args.db, tt.args.anthroveUserID, tt.args.sourceID) diff --git a/pkg/database/database.go b/pkg/database/database.go index f8be8ef..0060744 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -35,7 +35,7 @@ type OtterSpace interface { GetPostByAnthroveID(ctx context.Context, anthrovePostID models.AnthrovePostID) (*models.Post, error) // GetPostByURL retrieves a post by its source URL. - GetPostByURL(ctx context.Context, sourceUrl string) (*models.Post, error) + GetPostByURL(ctx context.Context, postURL string) (*models.Post, error) // GetPostBySourceID retrieves a post by its source ID. GetPostBySourceID(ctx context.Context, sourceID models.AnthroveSourceID) (*models.Post, error)