diff --git a/tools/file_clean/file_clean.go b/tools/file_clean/file_clean.go new file mode 100644 index 0000000..4576bbd --- /dev/null +++ b/tools/file_clean/file_clean.go @@ -0,0 +1,357 @@ +package main + +import ( + "context" + "database/sql" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "sort" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + _ "github.com/lib/pq" +) + +type postRef struct { + PostID int64 + Source string +} + +type fileIndex struct { + keyToID map[string]int64 + idToKey map[int64]string +} + +func main() { + var ( + dsn = flag.String("pg", "", "PostgreSQL DSN") + s3Region = flag.String("s3-region", "", "S3 region") + s3Bucket = flag.String("s3-bucket", "", "S3 bucket name") + s3Endpoint = flag.String("s3-endpoint", "", "S3 endpoint (optional)") + s3Access = flag.String("s3-access-key", "", "S3 access key ID") + s3Secret = flag.String("s3-secret-key", "", "S3 secret access key") + deleteFlag = flag.Bool("delete", false, "Delete unused files (default: dry run)") + ) + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nExample:\n") + fmt.Fprintf(os.Stderr, " %s -pg 'postgres://user:pass@localhost/db?sslmode=disable' -s3-region us-east-1 -s3-bucket mybucket -s3-access-key AKIA... -s3-secret-key secret...\n", os.Args[0]) + } + + flag.Parse() + + if *dsn == "" || *s3Region == "" || *s3Bucket == "" || *s3Access == "" || *s3Secret == "" { + flag.Usage() + log.Fatal("missing required flags") + } + + ctx := context.Background() + + db, err := sql.Open("postgres", *dsn) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + if err := db.PingContext(ctx); err != nil { + log.Fatalf("failed to ping database: %v", err) + } + + s3Client, err := newS3Client(ctx, *s3Region, *s3Endpoint, *s3Access, *s3Secret) + if err != nil { + log.Fatalf("failed to create s3 client: %v", err) + } + + filesIdx, err := loadFiles(ctx, db) + if err != nil { + log.Fatalf("failed to load files: %v", err) + } + + postKeys, refs, err := collectPostFileKeys(ctx, db, filesIdx) + if err != nil { + log.Fatalf("failed to collect post file keys: %v", err) + } + + bucketKeys, err := listBucketKeys(ctx, s3Client, *s3Bucket) + if err != nil { + log.Fatalf("failed to list bucket objects: %v", err) + } + + report := compare(postKeys, bucketKeys) + printReport(report, refs, filesIdx) + + if *deleteFlag { + if err := deleteUnused(ctx, s3Client, *s3Bucket, db, report.unusedInBucket, filesIdx); err != nil { + log.Fatalf("failed to delete unused files: %v", err) + } + } else if len(report.unusedInBucket) > 0 { + log.Printf("dry run: %d unused files not deleted. run with -delete to remove", len(report.unusedInBucket)) + } +} + +func newS3Client(ctx context.Context, region, endpoint, accessKey, secretKey string) (*s3.Client, error) { + cfg, err := awsconfig.LoadDefaultConfig(ctx, + awsconfig.WithRegion(region), + awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")), + ) + if err != nil { + return nil, err + } + + client := s3.NewFromConfig(cfg, func(o *s3.Options) { + if endpoint != "" { + o.BaseEndpoint = aws.String(endpoint) + } + o.UsePathStyle = true + }) + + return client, nil +} + +func loadFiles(ctx context.Context, db *sql.DB) (fileIndex, error) { + rows, err := db.QueryContext(ctx, `SELECT file_key, id FROM files`) + if err != nil { + return fileIndex{}, err + } + defer rows.Close() + + idx := fileIndex{ + keyToID: make(map[string]int64), + idToKey: make(map[int64]string), + } + + for rows.Next() { + var ( + key string + id int64 + ) + if err := rows.Scan(&key, &id); err != nil { + return fileIndex{}, err + } + idx.keyToID[key] = id + idx.idToKey[id] = key + } + + return idx, rows.Err() +} + +func collectPostFileKeys(ctx context.Context, db *sql.DB, files fileIndex) (map[string]struct{}, map[string][]postRef, error) { + rows, err := db.QueryContext(ctx, `SELECT id, content, cover_id FROM posts`) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + keys := make(map[string]struct{}) + refs := make(map[string][]postRef) + + for rows.Next() { + var ( + postID int64 + content string + coverID sql.NullInt64 + ) + + if err := rows.Scan(&postID, &content, &coverID); err != nil { + return nil, nil, err + } + + contentKeys, err := extractAllFileKeys(content) + if err != nil { + return nil, nil, fmt.Errorf("post %d: %w", postID, err) + } + + addRefs(keys, refs, contentKeys, postID, "content") + + if coverID.Valid { + if coverKey, ok := files.idToKey[coverID.Int64]; ok { + addRefs(keys, refs, []string{coverKey}, postID, "cover") + } else { + log.Printf("warn: cover id %d referenced by post %d missing in files table", coverID.Int64, postID) + } + } + } + + return keys, refs, rows.Err() +} + +func addRefs(keys map[string]struct{}, refs map[string][]postRef, list []string, postID int64, source string) { + seen := make(map[string]struct{}) + for _, key := range list { + clean := strings.TrimSpace(key) + if clean == "" { + continue + } + if _, dup := seen[clean]; dup { + continue + } + seen[clean] = struct{}{} + keys[clean] = struct{}{} + refs[clean] = append(refs[clean], postRef{PostID: postID, Source: source}) + } +} + +func listBucketKeys(ctx context.Context, client *s3.Client, bucket string) (map[string]struct{}, error) { + result := make(map[string]struct{}) + var token *string + + for { + out, err := client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + ContinuationToken: token, + }) + if err != nil { + return nil, err + } + + for _, obj := range out.Contents { + if obj.Key == nil || *obj.Key == "" { + continue + } + result[*obj.Key] = struct{}{} + } + + if aws.ToBool(out.IsTruncated) && out.NextContinuationToken != nil { + token = out.NextContinuationToken + continue + } + + break + } + + return result, nil +} + +type reportData struct { + unusedInBucket []string + missingFromBucket []string +} + +func compare(postKeys, bucketKeys map[string]struct{}) reportData { + var report reportData + + for key := range postKeys { + if _, ok := bucketKeys[key]; !ok { + report.missingFromBucket = append(report.missingFromBucket, key) + } + } + + for key := range bucketKeys { + if _, ok := postKeys[key]; !ok { + report.unusedInBucket = append(report.unusedInBucket, key) + } + } + + sort.Strings(report.missingFromBucket) + sort.Strings(report.unusedInBucket) + + return report +} + +func printReport(report reportData, refs map[string][]postRef, files fileIndex) { + log.Printf("unique files referenced by posts: %d", len(refs)) + log.Printf("files missing in bucket: %d", len(report.missingFromBucket)) + log.Printf("files unused in bucket: %d", len(report.unusedInBucket)) + + if len(report.missingFromBucket) > 0 { + fmt.Println("\nFiles referenced by posts but missing in bucket:") + for _, key := range report.missingFromBucket { + fmt.Printf("- %s\n", key) + for _, ref := range refs[key] { + fmt.Printf(" * post %d (%s)\n", ref.PostID, ref.Source) + } + } + } + + if len(report.unusedInBucket) > 0 { + fmt.Println("\nFiles present in bucket but not referenced by any post:") + for _, key := range report.unusedInBucket { + fmt.Printf("- %s", key) + if id, ok := files.keyToID[key]; ok { + fmt.Printf(" (files.id=%d)", id) + } + fmt.Println() + } + } +} + +func deleteUnused(ctx context.Context, client *s3.Client, bucket string, db *sql.DB, keys []string, files fileIndex) error { + if len(keys) == 0 { + return nil + } + + log.Printf("deleting %d objects", len(keys)) + + for _, key := range keys { + _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return fmt.Errorf("delete %s: %w", key, err) + } + + if id, ok := files.keyToID[key]; ok { + if _, err := db.ExecContext(ctx, `DELETE FROM files WHERE id = $1`, id); err != nil { + return fmt.Errorf("delete files record for %s: %w", key, err) + } + delete(files.keyToID, key) + delete(files.idToKey, id) + } + + log.Printf("deleted %s", key) + } + + return nil +} + +func extractAllFileKeys(content string) ([]string, error) { + if strings.TrimSpace(content) == "" { + return nil, nil + } + + var data map[string]interface{} + if err := json.Unmarshal([]byte(content), &data); err != nil { + return nil, err + } + + fileTypes := map[string]struct{}{ + "image": {}, + "archive": {}, + "audio": {}, + "video": {}, + } + + var keys []string + collectFileKeys(data, fileTypes, &keys) + return keys, nil +} + +func collectFileKeys(node interface{}, typeSet map[string]struct{}, keys *[]string) { + switch v := node.(type) { + case map[string]interface{}: + if nodeType, ok := v["type"].(string); ok { + if _, allowed := typeSet[nodeType]; allowed { + if fileKey, ok := v["fileKey"].(string); ok && fileKey != "" { + *keys = append(*keys, fileKey) + } + } + } + for _, value := range v { + collectFileKeys(value, typeSet, keys) + } + case []interface{}: + for _, item := range v { + collectFileKeys(item, typeSet, keys) + } + } +}