From 0ae4d1705bdb6c20e125114b594bcfde1fd7a2d7 Mon Sep 17 00:00:00 2001 From: cialloo Date: Sun, 26 Oct 2025 17:44:49 +0800 Subject: [PATCH] Refactor GetPost and ListPosts logic to use time.Time for timestamps; convert to Unix milliseconds and improve error handling --- app/internal/logic/getpostlogic.go | 15 ++- app/internal/logic/listpostslogic.go | 138 +++++++++++++-------------- 2 files changed, 79 insertions(+), 74 deletions(-) diff --git a/app/internal/logic/getpostlogic.go b/app/internal/logic/getpostlogic.go index e0feadf..719a232 100644 --- a/app/internal/logic/getpostlogic.go +++ b/app/internal/logic/getpostlogic.go @@ -31,6 +31,7 @@ func NewGetPostLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetPostLo func (l *GetPostLogic) GetPost(req *types.GetPostReq) (resp *types.GetPostResp, err error) { var post types.GetPostResp var coverID sql.NullInt64 + var createdAt, updatedAt time.Time // Query post with cover image query := ` @@ -39,7 +40,7 @@ func (l *GetPostLogic) GetPost(req *types.GetPostReq) (resp *types.GetPostResp, WHERE p.id = $1 ` err = l.svcCtx.DB.QueryRowContext(l.ctx, query, req.PostId).Scan( - &post.PostId, &post.Title, &post.Content, &post.CreatedAt, &post.UpdatedAt, &coverID, + &post.PostId, &post.Title, &post.Content, &createdAt, &updatedAt, &coverID, ) if err != nil { if err == sql.ErrNoRows { @@ -50,6 +51,14 @@ func (l *GetPostLogic) GetPost(req *types.GetPostReq) (resp *types.GetPostResp, return nil, err } + // Convert timestamps to Unix milliseconds + post.CreatedAt = createdAt.UnixMilli() + post.UpdatedAt = updatedAt.UnixMilli() + + // Convert timestamps to Unix milliseconds + post.CreatedAt = createdAt.UnixMilli() + post.UpdatedAt = updatedAt.UnixMilli() + // If cover image exists, get its URL if coverID.Valid { coverQuery := `SELECT file_key FROM files WHERE id = $1` @@ -72,9 +81,5 @@ func (l *GetPostLogic) GetPost(req *types.GetPostReq) (resp *types.GetPostResp, } } - // Convert timestamps to Unix - post.CreatedAt = post.CreatedAt * 1000 // Convert to milliseconds if needed - post.UpdatedAt = post.UpdatedAt * 1000 - return &post, nil } diff --git a/app/internal/logic/listpostslogic.go b/app/internal/logic/listpostslogic.go index 8a5a65c..4687299 100644 --- a/app/internal/logic/listpostslogic.go +++ b/app/internal/logic/listpostslogic.go @@ -9,6 +9,7 @@ import ( "git.cialloo.com/CiallooWeb/Blog/app/internal/types" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/lib/pq" "github.com/zeromicro/go-zero/core/logx" ) @@ -28,35 +29,65 @@ func NewListPostsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ListPos } func (l *ListPostsLogic) ListPosts(req *types.ListPostsReq) (resp *types.ListPostsResp, err error) { - // Set default values + // Determine pagination parameters with sane defaults. page := req.Page if page < 1 { page = 1 } + pageSize := req.PageSize if pageSize < 1 || pageSize > 100 { - pageSize = 10 // Default page size + pageSize = 10 } + offset := (page - 1) * pageSize - // Build query with optional tag filter - var countQuery string - var postsQuery string - var args []interface{} + var ( + rows *sql.Rows + totalCount int + ) - if len(req.TagIds) > 0 { - // Filter by tags - posts that have ALL specified tags - countQuery = ` - SELECT COUNT(DISTINCT p.id) + if len(req.TagIds) == 0 { + countQuery := `SELECT COUNT(*) FROM posts` + if err = l.svcCtx.DB.QueryRowContext(l.ctx, countQuery).Scan(&totalCount); err != nil { + l.Errorf("Failed to get total count: %v", err) + return nil, err + } + + postsQuery := ` + SELECT p.id, p.title, p.created_at, p.updated_at, p.cover_id FROM posts p - INNER JOIN post_hashtags ph ON p.id = ph.post_id - WHERE ph.hashtag_id = ANY($1) - GROUP BY p.id - HAVING COUNT(DISTINCT ph.hashtag_id) = $2 + ORDER BY p.created_at DESC + LIMIT $1 OFFSET $2 ` - postsQuery = ` - SELECT DISTINCT p.id, p.title, p.created_at, p.updated_at, p.cover_id + rows, err = l.svcCtx.DB.QueryContext(l.ctx, postsQuery, pageSize, offset) + if err != nil { + l.Errorf("Failed to get posts: %v", err) + return nil, err + } + } else { + tagArray := pq.Array(req.TagIds) + + countQuery := ` + SELECT COUNT(*) + FROM ( + SELECT p.id + FROM posts p + INNER JOIN post_hashtags ph ON p.id = ph.post_id + WHERE ph.hashtag_id = ANY($1) + GROUP BY p.id + HAVING COUNT(DISTINCT ph.hashtag_id) = $2 + ) filtered_posts + ` + + if err = l.svcCtx.DB.QueryRowContext(l.ctx, countQuery, tagArray, len(req.TagIds)).Scan(&totalCount); err != nil { + l.Errorf("Failed to get filtered total count: %v", err) + return nil, err + } + + postsQuery := ` + SELECT p.id, p.title, p.created_at, p.updated_at, p.cover_id FROM posts p INNER JOIN post_hashtags ph ON p.id = ph.post_id WHERE ph.hashtag_id = ANY($1) @@ -66,83 +97,47 @@ func (l *ListPostsLogic) ListPosts(req *types.ListPostsReq) (resp *types.ListPos LIMIT $3 OFFSET $4 ` - args = []interface{}{req.TagIds, len(req.TagIds), pageSize, offset} - } else { - // No tag filter - get all posts - countQuery = `SELECT COUNT(*) FROM posts` - postsQuery = ` - SELECT p.id, p.title, p.created_at, p.updated_at, p.cover_id - FROM posts p - ORDER BY p.created_at DESC - LIMIT $1 OFFSET $2 - ` - args = []interface{}{pageSize, offset} - } - - // Get total count - var totalCount int - if len(req.TagIds) > 0 { - // Count posts matching tag filter - rows, err := l.svcCtx.DB.QueryContext(l.ctx, countQuery, req.TagIds, len(req.TagIds)) + rows, err = l.svcCtx.DB.QueryContext(l.ctx, postsQuery, tagArray, len(req.TagIds), pageSize, offset) if err != nil { - l.Errorf("Failed to get total count: %v", err) - return nil, err - } - defer rows.Close() - - totalCount = 0 - for rows.Next() { - totalCount++ - } - } else { - err = l.svcCtx.DB.QueryRowContext(l.ctx, countQuery).Scan(&totalCount) - if err != nil { - l.Errorf("Failed to get total count: %v", err) + l.Errorf("Failed to get filtered posts: %v", err) return nil, err } } - // Get posts with pagination - rows, err := l.svcCtx.DB.QueryContext(l.ctx, postsQuery, args...) - if err != nil { - l.Errorf("Failed to get posts: %v", err) - return nil, err - } defer rows.Close() - posts := []types.ListPostsRespPosts{} + posts := make([]types.ListPostsRespPosts, 0) for rows.Next() { - var post types.ListPostsRespPosts - var coverID sql.NullInt64 - var createdAt, updatedAt time.Time + var ( + post types.ListPostsRespPosts + coverID sql.NullInt64 + createdAt time.Time + updatedAt time.Time + ) - err := rows.Scan(&post.PostId, &post.Title, &createdAt, &updatedAt, &coverID) - if err != nil { - l.Errorf("Failed to scan post: %v", err) + if scanErr := rows.Scan(&post.PostId, &post.Title, &createdAt, &updatedAt, &coverID); scanErr != nil { + l.Errorf("Failed to scan post: %v", scanErr) continue } - // Convert timestamps to Unix milliseconds post.CreatedAt = createdAt.UnixMilli() post.UpdatedAt = updatedAt.UnixMilli() - // If cover image exists, get its URL if coverID.Valid { - coverQuery := `SELECT file_key FROM files WHERE id = $1` + const coverQuery = `SELECT file_key FROM files WHERE id = $1` var fileKey string - err = l.svcCtx.DB.QueryRowContext(l.ctx, coverQuery, coverID.Int64).Scan(&fileKey) - if err == nil { - // Generate presigned URL for cover image + + if queryErr := l.svcCtx.DB.QueryRowContext(l.ctx, coverQuery, coverID.Int64).Scan(&fileKey); queryErr == nil { expiration := time.Duration(l.svcCtx.Config.S3.PresignedURLExpiration) * time.Second presignClient := s3.NewPresignClient(l.svcCtx.S3Client) getObjectInput := &s3.GetObjectInput{ Bucket: &l.svcCtx.Config.S3.Bucket, Key: &fileKey, } - presignedReq, err := presignClient.PresignGetObject(l.ctx, getObjectInput, func(opts *s3.PresignOptions) { + + if presignedReq, presignErr := presignClient.PresignGetObject(l.ctx, getObjectInput, func(opts *s3.PresignOptions) { opts.Expires = expiration - }) - if err == nil { + }); presignErr == nil { post.CoverImageUrl = presignedReq.URL } } @@ -151,6 +146,11 @@ func (l *ListPostsLogic) ListPosts(req *types.ListPostsReq) (resp *types.ListPos posts = append(posts, post) } + if rowsErr := rows.Err(); rowsErr != nil { + l.Errorf("Error while iterating posts: %v", rowsErr) + return nil, rowsErr + } + return &types.ListPostsResp{ Posts: posts, TotalCount: totalCount,