CCPP/run_all_datasets.py
2025-04-20 20:55:06 +08:00

203 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import csv
import pandas as pd
import concurrent.futures
from argparse import Namespace
from attack import attack_process
def create_datasets_csv(datasets_dir, output_csv="datasets_config.csv"):
"""
创建包含所有数据集及其参数的CSV文件
Args:
datasets_dir: 包含所有数据集文件夹的目录
output_csv: 输出CSV文件的路径
"""
# 检查目录是否存在
if not os.path.exists(datasets_dir):
raise FileNotFoundError(f"数据集目录 {datasets_dir} 不存在")
# 获取所有数据集文件夹
datasets = [d for d in os.listdir(datasets_dir) if os.path.isdir(os.path.join(datasets_dir, d))]
# 创建CSV文件
with open(output_csv, 'w', newline='') as csvfile:
fieldnames = [
'dataset_name',
'cuda',
'total_gpus',
'classes',
'target_class',
'pop_size',
'magnitude_factor',
'max_itr',
'run_tag',
'model',
'normalize',
'e',
'done'
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
# 为每个数据集添加一行,使用默认参数
for dataset in datasets:
writer.writerow({
'dataset_name': dataset,
'cuda': 'store_true',
'total_gpus': 1,
'target_class': -1,
'classes': 0,
'pop_size': 1,
'magnitude_factor': 0.04,
'max_itr': 50,
'run_tag': dataset,
'model': 'r',
'normalize': 'store_true',
'e': 1499,
'done': '0' # 初始状态为未完成
})
print(f"已创建数据集配置文件: {output_csv}")
return output_csv
def validate_dataset(dataset_path):
"""验证数据集路径是否有效"""
if not os.path.exists(dataset_path):
return False
# 可以添加更多的验证逻辑,如检查必要的文件是否存在
return True
def process_dataset(row, datasets_base_dir):
"""处理单个数据集利用多个GPU并行处理"""
dataset_name = row['dataset_name']
dataset_path = os.path.join(datasets_base_dir, dataset_name)
# 验证数据集
if not validate_dataset(dataset_path):
print(f"警告: 数据集 '{dataset_name}' 路径无效或格式不正确")
return False
# 获取总GPU数量
total_gpus = int(row['total_gpus'])
# 为每个GPU创建一个配置
configurations = []
for i in range(total_gpus):
config = {
'cuda': row['cuda'] == 'store_true',
'gpu': str(i), # 分配不同的GPU
'target_class': int(row['target_class']),
'classes':int(row['classes']),
'popsize': int(row['pop_size']),
'magnitude_factor': float(row['magnitude_factor']),
'maxitr': int(row['max_itr']),
'run_tag': row['run_tag'],
'model': row['model'],
'normalize': row['normalize'] == 'store_true',
'e': int(row['e']),
'dataset': dataset_name # 添加数据集名称
}
configurations.append(config)
# 并行处理同一个数据集的多个任务
success = True
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = []
for config in configurations:
arg = Namespace(**config)
future = executor.submit(attack_process, arg)
futures.append(future)
# 等待所有进程完成
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
# 如果任何一个进程失败,标记整个数据集处理为失败
if not result:
success = False
except Exception as e:
import traceback
print(f"处理数据集 '{dataset_name}' 时出错: {str(e)}")
print(traceback.format_exc())
success = False
return success
def main(csv_file, datasets_base_dir, max_workers=4):
"""
主程序顺序遍历CSV文件中的数据集并执行任务
每个数据集内部可能有并行处理由attack_process内部实现
Args:
csv_file: 包含数据集和参数的CSV文件
datasets_base_dir: 数据集根目录
max_workers: 数据集内部并行的最大工作进程数
"""
# 验证CSV文件
if not os.path.exists(csv_file):
raise FileNotFoundError(f"CSV文件 {csv_file} 不存在")
# 读取CSV文件
df = pd.read_csv(csv_file)
# 检查CSV格式是否正确
required_columns = [
'dataset_name',
'cuda',
'total_gpus',
'target_class',
'classes',
'pop_size',
'magnitude_factor',
'max_itr',
'run_tag',
'model',
'normalize',
'e',
'done'
]
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise ValueError(f"CSV文件缺少必要的列: {', '.join(missing_columns)}")
# 只处理未完成的数据集
pending_datasets = df[df['done'] == 0]
total_pending = len(pending_datasets)
print(f"发现 {total_pending} 个未处理的数据集")
# 顺序处理数据集
for index, row in pending_datasets.iterrows():
dataset_name = row['dataset_name']
print(f"开始处理数据集: {dataset_name} ({pending_datasets.index[pending_datasets.index == index].item() + 1}/{total_pending})")
try:
# 处理单个数据集(内部可能有并行)
success = process_dataset(row, datasets_base_dir)
if success:
# 更新CSV中的done状态
df.at[index, 'done'] = 1
df.to_csv(csv_file, index=False)
print(f"完成数据集: {dataset_name}")
else:
print(f"处理数据集失败: {dataset_name}")
except Exception as e:
print(f"处理数据集 '{dataset_name}' 时出现异常: {str(e)}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='批量处理数据集')
parser.add_argument('--csv', type=str, default='datasets_config.csv', help='数据集配置CSV文件')
parser.add_argument('--datasets_dir',default='data' , type=str, help='数据集根目录')
parser.add_argument('--create_csv', action='store_true', help='是否创建新的CSV文件')
parser.add_argument('--max_workers', type=int, default=4, help='最大并行工作进程数')
args = parser.parse_args()
if args.create_csv:
create_datasets_csv(args.datasets_dir, args.csv)
main(args.csv, args.datasets_dir, args.max_workers)