import os import csv import pandas as pd import concurrent.futures from argparse import Namespace from attack import attack_process def create_datasets_csv(datasets_dir, output_csv="datasets_config.csv"): """ 创建包含所有数据集及其参数的CSV文件 Args: datasets_dir: 包含所有数据集文件夹的目录 output_csv: 输出CSV文件的路径 """ # 检查目录是否存在 if not os.path.exists(datasets_dir): raise FileNotFoundError(f"数据集目录 {datasets_dir} 不存在") # 获取所有数据集文件夹 datasets = [d for d in os.listdir(datasets_dir) if os.path.isdir(os.path.join(datasets_dir, d))] # 创建CSV文件 with open(output_csv, 'w', newline='') as csvfile: fieldnames = [ 'dataset_name', 'cuda', 'total_gpus', 'classes', 'target_class', 'pop_size', 'magnitude_factor', 'max_itr', 'run_tag', 'model', 'normalize', 'e', 'done' ] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() # 为每个数据集添加一行,使用默认参数 for dataset in datasets: writer.writerow({ 'dataset_name': dataset, 'cuda': 'store_true', 'total_gpus': 1, 'target_class': -1, 'classes': 0, 'pop_size': 1, 'magnitude_factor': 0.04, 'max_itr': 50, 'run_tag': dataset, 'model': 'r', 'normalize': 'store_true', 'e': 1499, 'done': '0' # 初始状态为未完成 }) print(f"已创建数据集配置文件: {output_csv}") return output_csv def validate_dataset(dataset_path): """验证数据集路径是否有效""" if not os.path.exists(dataset_path): return False # 可以添加更多的验证逻辑,如检查必要的文件是否存在 return True def process_dataset(row, datasets_base_dir): """处理单个数据集,利用多个GPU并行处理""" dataset_name = row['dataset_name'] dataset_path = os.path.join(datasets_base_dir, dataset_name) # 验证数据集 if not validate_dataset(dataset_path): print(f"警告: 数据集 '{dataset_name}' 路径无效或格式不正确") return False # 获取总GPU数量 total_gpus = int(row['total_gpus']) # 为每个GPU创建一个配置 configurations = [] for i in range(total_gpus): config = { 'cuda': row['cuda'] == 'store_true', 'gpu': str(i), # 分配不同的GPU 'target_class': int(row['target_class']), 'classes':int(row['classes']), 'popsize': int(row['pop_size']), 'magnitude_factor': float(row['magnitude_factor']), 'maxitr': int(row['max_itr']), 'run_tag': row['run_tag'], 'model': row['model'], 'normalize': row['normalize'] == 'store_true', 'e': int(row['e']), 'dataset': dataset_name # 添加数据集名称 } configurations.append(config) # 并行处理同一个数据集的多个任务 success = True with concurrent.futures.ProcessPoolExecutor() as executor: futures = [] for config in configurations: arg = Namespace(**config) future = executor.submit(attack_process, arg) futures.append(future) # 等待所有进程完成 for future in concurrent.futures.as_completed(futures): try: result = future.result() # 如果任何一个进程失败,标记整个数据集处理为失败 if not result: success = False except Exception as e: import traceback print(f"处理数据集 '{dataset_name}' 时出错: {str(e)}") print(traceback.format_exc()) success = False return success def main(csv_file, datasets_base_dir, max_workers=4): """ 主程序:顺序遍历CSV文件中的数据集并执行任务 每个数据集内部可能有并行处理(由attack_process内部实现) Args: csv_file: 包含数据集和参数的CSV文件 datasets_base_dir: 数据集根目录 max_workers: 数据集内部并行的最大工作进程数 """ # 验证CSV文件 if not os.path.exists(csv_file): raise FileNotFoundError(f"CSV文件 {csv_file} 不存在") # 读取CSV文件 df = pd.read_csv(csv_file) # 检查CSV格式是否正确 required_columns = [ 'dataset_name', 'cuda', 'total_gpus', 'target_class', 'classes', 'pop_size', 'magnitude_factor', 'max_itr', 'run_tag', 'model', 'normalize', 'e', 'done' ] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: raise ValueError(f"CSV文件缺少必要的列: {', '.join(missing_columns)}") # 只处理未完成的数据集 pending_datasets = df[df['done'] == 0] total_pending = len(pending_datasets) print(f"发现 {total_pending} 个未处理的数据集") # 顺序处理数据集 for index, row in pending_datasets.iterrows(): dataset_name = row['dataset_name'] print(f"开始处理数据集: {dataset_name} ({pending_datasets.index[pending_datasets.index == index].item() + 1}/{total_pending})") try: # 处理单个数据集(内部可能有并行) success = process_dataset(row, datasets_base_dir) if success: # 更新CSV中的done状态 df.at[index, 'done'] = 1 df.to_csv(csv_file, index=False) print(f"完成数据集: {dataset_name}") else: print(f"处理数据集失败: {dataset_name}") except Exception as e: print(f"处理数据集 '{dataset_name}' 时出现异常: {str(e)}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='批量处理数据集') parser.add_argument('--csv', type=str, default='datasets_config.csv', help='数据集配置CSV文件') parser.add_argument('--datasets_dir',default='data' , type=str, help='数据集根目录') parser.add_argument('--create_csv', action='store_true', help='是否创建新的CSV文件') parser.add_argument('--max_workers', type=int, default=4, help='最大并行工作进程数') args = parser.parse_args() if args.create_csv: create_datasets_csv(args.datasets_dir, args.csv) main(args.csv, args.datasets_dir, args.max_workers)