package logic import ( "context" "database/sql" "time" "git.cialloo.com/CiallooWeb/Blog/app/internal/svc" "git.cialloo.com/CiallooWeb/Blog/app/internal/types" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/lib/pq" "github.com/zeromicro/go-zero/core/logx" ) type ListPostsLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } // Get a list of blog posts func NewListPostsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ListPostsLogic { return &ListPostsLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } func (l *ListPostsLogic) ListPosts(req *types.ListPostsReq) (resp *types.ListPostsResp, err error) { // Determine pagination parameters with sane defaults. page := req.Page if page < 1 { page = 1 } pageSize := req.PageSize if pageSize < 1 || pageSize > 100 { pageSize = 10 } offset := (page - 1) * pageSize var ( rows *sql.Rows totalCount int ) if len(req.TagIds) == 0 { countQuery := `SELECT COUNT(*) FROM posts` if err = l.svcCtx.DB.QueryRowContext(l.ctx, countQuery).Scan(&totalCount); err != nil { l.Errorf("Failed to get total count: %v", err) return nil, err } postsQuery := ` SELECT p.id, p.title, p.created_at, p.updated_at, p.cover_id FROM posts p ORDER BY p.created_at DESC LIMIT $1 OFFSET $2 ` rows, err = l.svcCtx.DB.QueryContext(l.ctx, postsQuery, pageSize, offset) if err != nil { l.Errorf("Failed to get posts: %v", err) return nil, err } } else { tagArray := pq.Array(req.TagIds) countQuery := ` SELECT COUNT(*) FROM ( SELECT p.id FROM posts p INNER JOIN post_hashtags ph ON p.id = ph.post_id WHERE ph.hashtag_id = ANY($1) GROUP BY p.id HAVING COUNT(DISTINCT ph.hashtag_id) = $2 ) filtered_posts ` if err = l.svcCtx.DB.QueryRowContext(l.ctx, countQuery, tagArray, len(req.TagIds)).Scan(&totalCount); err != nil { l.Errorf("Failed to get filtered total count: %v", err) return nil, err } postsQuery := ` SELECT p.id, p.title, p.created_at, p.updated_at, p.cover_id FROM posts p INNER JOIN post_hashtags ph ON p.id = ph.post_id WHERE ph.hashtag_id = ANY($1) GROUP BY p.id, p.title, p.created_at, p.updated_at, p.cover_id HAVING COUNT(DISTINCT ph.hashtag_id) = $2 ORDER BY p.created_at DESC LIMIT $3 OFFSET $4 ` rows, err = l.svcCtx.DB.QueryContext(l.ctx, postsQuery, tagArray, len(req.TagIds), pageSize, offset) if err != nil { l.Errorf("Failed to get filtered posts: %v", err) return nil, err } } defer rows.Close() posts := make([]types.ListPostsRespPosts, 0) for rows.Next() { var ( post types.ListPostsRespPosts coverID sql.NullInt64 createdAt time.Time updatedAt time.Time ) if scanErr := rows.Scan(&post.PostId, &post.Title, &createdAt, &updatedAt, &coverID); scanErr != nil { l.Errorf("Failed to scan post: %v", scanErr) continue } post.CreatedAt = createdAt.UnixMilli() post.UpdatedAt = updatedAt.UnixMilli() if coverID.Valid { const coverQuery = `SELECT file_key FROM files WHERE id = $1` var fileKey string if queryErr := l.svcCtx.DB.QueryRowContext(l.ctx, coverQuery, coverID.Int64).Scan(&fileKey); queryErr == nil { expiration := time.Duration(l.svcCtx.Config.S3.PresignedURLExpiration) * time.Second presignClient := s3.NewPresignClient(l.svcCtx.S3Client) getObjectInput := &s3.GetObjectInput{ Bucket: &l.svcCtx.Config.S3.Bucket, Key: &fileKey, } if presignedReq, presignErr := presignClient.PresignGetObject(l.ctx, getObjectInput, func(opts *s3.PresignOptions) { opts.Expires = expiration }); presignErr == nil { post.CoverImageUrl = presignedReq.URL } } } posts = append(posts, post) } if rowsErr := rows.Err(); rowsErr != nil { l.Errorf("Error while iterating posts: %v", rowsErr) return nil, rowsErr } return &types.ListPostsResp{ Posts: posts, TotalCount: totalCount, }, nil }