diff --git a/example/highlevel/post.go b/example/highlevel/post.go index 338c035..5ec235d 100644 --- a/example/highlevel/post.go +++ b/example/highlevel/post.go @@ -29,7 +29,14 @@ func main() { log.Printf("Post 1337 has following tags: %s", post.Tags) postBuilder := client.GetPosts() - posts, err = client.GetNPosts(230, postBuilder) + posts1, err := client.GetNPosts(600, postBuilder) + if err != nil { + log.Panic(err) + } + log.Println(len(posts1)) + + postBuilder = client.GetPosts().Tags("how_to_dragon_your_train") + posts, err = client.GetAllPosts(postBuilder) if err != nil { log.Panic(err) } diff --git a/pkg/e621/builder/posts.go b/pkg/e621/builder/posts.go index 0ba5736..de739c6 100644 --- a/pkg/e621/builder/posts.go +++ b/pkg/e621/builder/posts.go @@ -11,7 +11,8 @@ import ( type PostsBuilder interface { Tags(tags string) PostsBuilder PageAfter(postID model.PostID) PostsBuilder - pageBefore(postID model.PostID) PostsBuilder + PageBefore(postID model.PostID) PostsBuilder + PageNumber(number int) PostsBuilder SetLimit(limitUser int) PostsBuilder Execute() ([]model.Post, error) } @@ -40,11 +41,16 @@ func (g *getPosts) PageAfter(postID model.PostID) PostsBuilder { return g } -func (g *getPosts) pageBefore(postID model.PostID) PostsBuilder { +func (g *getPosts) PageBefore(postID model.PostID) PostsBuilder { g.query["page"] = "b" + strconv.Itoa(int(postID)) return g } +func (g *getPosts) PageNumber(number int) PostsBuilder { + g.query["page"] = strconv.Itoa(number) + return g +} + func (g *getPosts) SetLimit(limitUser int) PostsBuilder { g.query["limit"] = strconv.Itoa(limitUser) return g diff --git a/pkg/e621/client.go b/pkg/e621/client.go index 64c589e..4732d19 100644 --- a/pkg/e621/client.go +++ b/pkg/e621/client.go @@ -7,6 +7,8 @@ import ( "git.dragse.it/anthrove/e621-sdk-go/pkg/e621/utils" _ "github.com/joho/godotenv/autoload" "golang.org/x/time/rate" + "log" + "math" "net/http" "strconv" ) @@ -79,13 +81,23 @@ func (c *Client) GetNPosts(n int, postBuilder builder.PostsBuilder) ([]model.Pos } for len(posts) < n { - postBuilder.PageAfter(posts[len(posts)-1].ID).SetLimit(n - len(posts)) + lastPostID := posts[len(posts)-1].ID + postBuilder.PageBefore(lastPostID) + postBuilder.SetLimit(n - len(posts)) newPosts, err := postBuilder.Execute() if err != nil { return nil, err } + if len(newPosts) == 0 { + break + } posts = append(posts, newPosts...) + log.Printf("Post ID: %d | Post Size: %d", lastPostID, len(posts)) } return posts, nil } + +func (c *Client) GetAllPosts(postBuilder builder.PostsBuilder) ([]model.Post, error) { + return c.GetNPosts(math.MaxInt, postBuilder) +}