package main import ( "context" "database/sql" "encoding/json" "flag" "fmt" "log" "os" "sort" "strings" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" _ "github.com/lib/pq" ) type postRef struct { PostID int64 Source string } type fileIndex struct { keyToID map[string]int64 idToKey map[int64]string } func main() { var ( dsn = flag.String("pg", "", "PostgreSQL DSN") s3Region = flag.String("s3-region", "", "S3 region") s3Bucket = flag.String("s3-bucket", "", "S3 bucket name") s3Endpoint = flag.String("s3-endpoint", "", "S3 endpoint (optional)") s3Access = flag.String("s3-access-key", "", "S3 access key ID") s3Secret = flag.String("s3-secret-key", "", "S3 secret access key") deleteFlag = flag.Bool("delete", false, "Delete unused files (default: dry run)") ) flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Options:\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nExample:\n") fmt.Fprintf(os.Stderr, " %s -pg 'postgres://user:pass@localhost/db?sslmode=disable' -s3-region us-east-1 -s3-bucket mybucket -s3-access-key AKIA... -s3-secret-key secret...\n", os.Args[0]) } flag.Parse() if *dsn == "" || *s3Region == "" || *s3Bucket == "" || *s3Access == "" || *s3Secret == "" { flag.Usage() log.Fatal("missing required flags") } ctx := context.Background() db, err := sql.Open("postgres", *dsn) if err != nil { log.Fatalf("failed to open database: %v", err) } defer db.Close() if err := db.PingContext(ctx); err != nil { log.Fatalf("failed to ping database: %v", err) } s3Client, err := newS3Client(ctx, *s3Region, *s3Endpoint, *s3Access, *s3Secret) if err != nil { log.Fatalf("failed to create s3 client: %v", err) } filesIdx, err := loadFiles(ctx, db) if err != nil { log.Fatalf("failed to load files: %v", err) } postKeys, refs, err := collectPostFileKeys(ctx, db, filesIdx) if err != nil { log.Fatalf("failed to collect post file keys: %v", err) } bucketKeys, err := listBucketKeys(ctx, s3Client, *s3Bucket) if err != nil { log.Fatalf("failed to list bucket objects: %v", err) } report := compare(postKeys, bucketKeys) printReport(report, refs, filesIdx) if *deleteFlag { if err := deleteUnused(ctx, s3Client, *s3Bucket, db, report.unusedInBucket, filesIdx); err != nil { log.Fatalf("failed to delete unused files: %v", err) } } else if len(report.unusedInBucket) > 0 { log.Printf("dry run: %d unused files not deleted. run with -delete to remove", len(report.unusedInBucket)) } } func newS3Client(ctx context.Context, region, endpoint, accessKey, secretKey string) (*s3.Client, error) { cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region), awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")), ) if err != nil { return nil, err } client := s3.NewFromConfig(cfg, func(o *s3.Options) { if endpoint != "" { o.BaseEndpoint = aws.String(endpoint) } o.UsePathStyle = true }) return client, nil } func loadFiles(ctx context.Context, db *sql.DB) (fileIndex, error) { rows, err := db.QueryContext(ctx, `SELECT file_key, id FROM files`) if err != nil { return fileIndex{}, err } defer rows.Close() idx := fileIndex{ keyToID: make(map[string]int64), idToKey: make(map[int64]string), } for rows.Next() { var ( key string id int64 ) if err := rows.Scan(&key, &id); err != nil { return fileIndex{}, err } idx.keyToID[key] = id idx.idToKey[id] = key } return idx, rows.Err() } func collectPostFileKeys(ctx context.Context, db *sql.DB, files fileIndex) (map[string]struct{}, map[string][]postRef, error) { rows, err := db.QueryContext(ctx, `SELECT id, content, cover_id FROM posts`) if err != nil { return nil, nil, err } defer rows.Close() keys := make(map[string]struct{}) refs := make(map[string][]postRef) for rows.Next() { var ( postID int64 content string coverID sql.NullInt64 ) if err := rows.Scan(&postID, &content, &coverID); err != nil { return nil, nil, err } contentKeys, err := extractAllFileKeys(content) if err != nil { return nil, nil, fmt.Errorf("post %d: %w", postID, err) } addRefs(keys, refs, contentKeys, postID, "content") if coverID.Valid { if coverKey, ok := files.idToKey[coverID.Int64]; ok { addRefs(keys, refs, []string{coverKey}, postID, "cover") } else { log.Printf("warn: cover id %d referenced by post %d missing in files table", coverID.Int64, postID) } } } return keys, refs, rows.Err() } func addRefs(keys map[string]struct{}, refs map[string][]postRef, list []string, postID int64, source string) { seen := make(map[string]struct{}) for _, key := range list { clean := strings.TrimSpace(key) if clean == "" { continue } if _, dup := seen[clean]; dup { continue } seen[clean] = struct{}{} keys[clean] = struct{}{} refs[clean] = append(refs[clean], postRef{PostID: postID, Source: source}) } } func listBucketKeys(ctx context.Context, client *s3.Client, bucket string) (map[string]struct{}, error) { result := make(map[string]struct{}) var token *string for { out, err := client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(bucket), ContinuationToken: token, }) if err != nil { return nil, err } for _, obj := range out.Contents { if obj.Key == nil || *obj.Key == "" { continue } result[*obj.Key] = struct{}{} } if aws.ToBool(out.IsTruncated) && out.NextContinuationToken != nil { token = out.NextContinuationToken continue } break } return result, nil } type reportData struct { unusedInBucket []string missingFromBucket []string } func compare(postKeys, bucketKeys map[string]struct{}) reportData { var report reportData for key := range postKeys { if _, ok := bucketKeys[key]; !ok { report.missingFromBucket = append(report.missingFromBucket, key) } } for key := range bucketKeys { if _, ok := postKeys[key]; !ok { report.unusedInBucket = append(report.unusedInBucket, key) } } sort.Strings(report.missingFromBucket) sort.Strings(report.unusedInBucket) return report } func printReport(report reportData, refs map[string][]postRef, files fileIndex) { log.Printf("unique files referenced by posts: %d", len(refs)) log.Printf("files missing in bucket: %d", len(report.missingFromBucket)) log.Printf("files unused in bucket: %d", len(report.unusedInBucket)) if len(report.missingFromBucket) > 0 { fmt.Println("\nFiles referenced by posts but missing in bucket:") for _, key := range report.missingFromBucket { fmt.Printf("- %s\n", key) for _, ref := range refs[key] { fmt.Printf(" * post %d (%s)\n", ref.PostID, ref.Source) } } } if len(report.unusedInBucket) > 0 { fmt.Println("\nFiles present in bucket but not referenced by any post:") for _, key := range report.unusedInBucket { fmt.Printf("- %s", key) if id, ok := files.keyToID[key]; ok { fmt.Printf(" (files.id=%d)", id) } fmt.Println() } } } func deleteUnused(ctx context.Context, client *s3.Client, bucket string, db *sql.DB, keys []string, files fileIndex) error { if len(keys) == 0 { return nil } log.Printf("deleting %d objects", len(keys)) for _, key := range keys { _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(bucket), Key: aws.String(key), }) if err != nil { return fmt.Errorf("delete %s: %w", key, err) } if id, ok := files.keyToID[key]; ok { if _, err := db.ExecContext(ctx, `DELETE FROM files WHERE id = $1`, id); err != nil { return fmt.Errorf("delete files record for %s: %w", key, err) } delete(files.keyToID, key) delete(files.idToKey, id) } log.Printf("deleted %s", key) } return nil } func extractAllFileKeys(content string) ([]string, error) { if strings.TrimSpace(content) == "" { return nil, nil } var data map[string]interface{} if err := json.Unmarshal([]byte(content), &data); err != nil { return nil, err } fileTypes := map[string]struct{}{ "image": {}, "archive": {}, "audio": {}, "video": {}, } var keys []string collectFileKeys(data, fileTypes, &keys) return keys, nil } func collectFileKeys(node interface{}, typeSet map[string]struct{}, keys *[]string) { switch v := node.(type) { case map[string]interface{}: if nodeType, ok := v["type"].(string); ok { if _, allowed := typeSet[nodeType]; allowed { if fileKey, ok := v["fileKey"].(string); ok && fileKey != "" { *keys = append(*keys, fileKey) } } } for _, value := range v { collectFileKeys(value, typeSet, keys) } case []interface{}: for _, item := range v { collectFileKeys(item, typeSet, keys) } } }