Add file_clean tool for managing S3 files and database references; implement file key extraction and reporting
All checks were successful
CI - Build and Push / Build and Push Docker Image (push) Successful in 8s

This commit is contained in:
2025-11-01 08:39:58 +08:00
parent cef65deb23
commit 1d92cad560

View File

@@ -0,0 +1,357 @@
package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"sort"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
_ "github.com/lib/pq"
)
type postRef struct {
PostID int64
Source string
}
type fileIndex struct {
keyToID map[string]int64
idToKey map[int64]string
}
func main() {
var (
dsn = flag.String("pg", "", "PostgreSQL DSN")
s3Region = flag.String("s3-region", "", "S3 region")
s3Bucket = flag.String("s3-bucket", "", "S3 bucket name")
s3Endpoint = flag.String("s3-endpoint", "", "S3 endpoint (optional)")
s3Access = flag.String("s3-access-key", "", "S3 access key ID")
s3Secret = flag.String("s3-secret-key", "", "S3 secret access key")
deleteFlag = flag.Bool("delete", false, "Delete unused files (default: dry run)")
)
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExample:\n")
fmt.Fprintf(os.Stderr, " %s -pg 'postgres://user:pass@localhost/db?sslmode=disable' -s3-region us-east-1 -s3-bucket mybucket -s3-access-key AKIA... -s3-secret-key secret...\n", os.Args[0])
}
flag.Parse()
if *dsn == "" || *s3Region == "" || *s3Bucket == "" || *s3Access == "" || *s3Secret == "" {
flag.Usage()
log.Fatal("missing required flags")
}
ctx := context.Background()
db, err := sql.Open("postgres", *dsn)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
defer db.Close()
if err := db.PingContext(ctx); err != nil {
log.Fatalf("failed to ping database: %v", err)
}
s3Client, err := newS3Client(ctx, *s3Region, *s3Endpoint, *s3Access, *s3Secret)
if err != nil {
log.Fatalf("failed to create s3 client: %v", err)
}
filesIdx, err := loadFiles(ctx, db)
if err != nil {
log.Fatalf("failed to load files: %v", err)
}
postKeys, refs, err := collectPostFileKeys(ctx, db, filesIdx)
if err != nil {
log.Fatalf("failed to collect post file keys: %v", err)
}
bucketKeys, err := listBucketKeys(ctx, s3Client, *s3Bucket)
if err != nil {
log.Fatalf("failed to list bucket objects: %v", err)
}
report := compare(postKeys, bucketKeys)
printReport(report, refs, filesIdx)
if *deleteFlag {
if err := deleteUnused(ctx, s3Client, *s3Bucket, db, report.unusedInBucket, filesIdx); err != nil {
log.Fatalf("failed to delete unused files: %v", err)
}
} else if len(report.unusedInBucket) > 0 {
log.Printf("dry run: %d unused files not deleted. run with -delete to remove", len(report.unusedInBucket))
}
}
func newS3Client(ctx context.Context, region, endpoint, accessKey, secretKey string) (*s3.Client, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")),
)
if err != nil {
return nil, err
}
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
if endpoint != "" {
o.BaseEndpoint = aws.String(endpoint)
}
o.UsePathStyle = true
})
return client, nil
}
func loadFiles(ctx context.Context, db *sql.DB) (fileIndex, error) {
rows, err := db.QueryContext(ctx, `SELECT file_key, id FROM files`)
if err != nil {
return fileIndex{}, err
}
defer rows.Close()
idx := fileIndex{
keyToID: make(map[string]int64),
idToKey: make(map[int64]string),
}
for rows.Next() {
var (
key string
id int64
)
if err := rows.Scan(&key, &id); err != nil {
return fileIndex{}, err
}
idx.keyToID[key] = id
idx.idToKey[id] = key
}
return idx, rows.Err()
}
func collectPostFileKeys(ctx context.Context, db *sql.DB, files fileIndex) (map[string]struct{}, map[string][]postRef, error) {
rows, err := db.QueryContext(ctx, `SELECT id, content, cover_id FROM posts`)
if err != nil {
return nil, nil, err
}
defer rows.Close()
keys := make(map[string]struct{})
refs := make(map[string][]postRef)
for rows.Next() {
var (
postID int64
content string
coverID sql.NullInt64
)
if err := rows.Scan(&postID, &content, &coverID); err != nil {
return nil, nil, err
}
contentKeys, err := extractAllFileKeys(content)
if err != nil {
return nil, nil, fmt.Errorf("post %d: %w", postID, err)
}
addRefs(keys, refs, contentKeys, postID, "content")
if coverID.Valid {
if coverKey, ok := files.idToKey[coverID.Int64]; ok {
addRefs(keys, refs, []string{coverKey}, postID, "cover")
} else {
log.Printf("warn: cover id %d referenced by post %d missing in files table", coverID.Int64, postID)
}
}
}
return keys, refs, rows.Err()
}
func addRefs(keys map[string]struct{}, refs map[string][]postRef, list []string, postID int64, source string) {
seen := make(map[string]struct{})
for _, key := range list {
clean := strings.TrimSpace(key)
if clean == "" {
continue
}
if _, dup := seen[clean]; dup {
continue
}
seen[clean] = struct{}{}
keys[clean] = struct{}{}
refs[clean] = append(refs[clean], postRef{PostID: postID, Source: source})
}
}
func listBucketKeys(ctx context.Context, client *s3.Client, bucket string) (map[string]struct{}, error) {
result := make(map[string]struct{})
var token *string
for {
out, err := client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
ContinuationToken: token,
})
if err != nil {
return nil, err
}
for _, obj := range out.Contents {
if obj.Key == nil || *obj.Key == "" {
continue
}
result[*obj.Key] = struct{}{}
}
if aws.ToBool(out.IsTruncated) && out.NextContinuationToken != nil {
token = out.NextContinuationToken
continue
}
break
}
return result, nil
}
type reportData struct {
unusedInBucket []string
missingFromBucket []string
}
func compare(postKeys, bucketKeys map[string]struct{}) reportData {
var report reportData
for key := range postKeys {
if _, ok := bucketKeys[key]; !ok {
report.missingFromBucket = append(report.missingFromBucket, key)
}
}
for key := range bucketKeys {
if _, ok := postKeys[key]; !ok {
report.unusedInBucket = append(report.unusedInBucket, key)
}
}
sort.Strings(report.missingFromBucket)
sort.Strings(report.unusedInBucket)
return report
}
func printReport(report reportData, refs map[string][]postRef, files fileIndex) {
log.Printf("unique files referenced by posts: %d", len(refs))
log.Printf("files missing in bucket: %d", len(report.missingFromBucket))
log.Printf("files unused in bucket: %d", len(report.unusedInBucket))
if len(report.missingFromBucket) > 0 {
fmt.Println("\nFiles referenced by posts but missing in bucket:")
for _, key := range report.missingFromBucket {
fmt.Printf("- %s\n", key)
for _, ref := range refs[key] {
fmt.Printf(" * post %d (%s)\n", ref.PostID, ref.Source)
}
}
}
if len(report.unusedInBucket) > 0 {
fmt.Println("\nFiles present in bucket but not referenced by any post:")
for _, key := range report.unusedInBucket {
fmt.Printf("- %s", key)
if id, ok := files.keyToID[key]; ok {
fmt.Printf(" (files.id=%d)", id)
}
fmt.Println()
}
}
}
func deleteUnused(ctx context.Context, client *s3.Client, bucket string, db *sql.DB, keys []string, files fileIndex) error {
if len(keys) == 0 {
return nil
}
log.Printf("deleting %d objects", len(keys))
for _, key := range keys {
_, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
return fmt.Errorf("delete %s: %w", key, err)
}
if id, ok := files.keyToID[key]; ok {
if _, err := db.ExecContext(ctx, `DELETE FROM files WHERE id = $1`, id); err != nil {
return fmt.Errorf("delete files record for %s: %w", key, err)
}
delete(files.keyToID, key)
delete(files.idToKey, id)
}
log.Printf("deleted %s", key)
}
return nil
}
func extractAllFileKeys(content string) ([]string, error) {
if strings.TrimSpace(content) == "" {
return nil, nil
}
var data map[string]interface{}
if err := json.Unmarshal([]byte(content), &data); err != nil {
return nil, err
}
fileTypes := map[string]struct{}{
"image": {},
"archive": {},
"audio": {},
"video": {},
}
var keys []string
collectFileKeys(data, fileTypes, &keys)
return keys, nil
}
func collectFileKeys(node interface{}, typeSet map[string]struct{}, keys *[]string) {
switch v := node.(type) {
case map[string]interface{}:
if nodeType, ok := v["type"].(string); ok {
if _, allowed := typeSet[nodeType]; allowed {
if fileKey, ok := v["fileKey"].(string); ok && fileKey != "" {
*keys = append(*keys, fileKey)
}
}
}
for _, value := range v {
collectFileKeys(value, typeSet, keys)
}
case []interface{}:
for _, item := range v {
collectFileKeys(item, typeSet, keys)
}
}
}