203 lines
7.0 KiB
Python
203 lines
7.0 KiB
Python
import os
|
||
import csv
|
||
import pandas as pd
|
||
|
||
import concurrent.futures
|
||
from argparse import Namespace
|
||
from attack import attack_process
|
||
|
||
def create_datasets_csv(datasets_dir, output_csv="datasets_config.csv"):
|
||
"""
|
||
创建包含所有数据集及其参数的CSV文件
|
||
|
||
Args:
|
||
datasets_dir: 包含所有数据集文件夹的目录
|
||
output_csv: 输出CSV文件的路径
|
||
"""
|
||
# 检查目录是否存在
|
||
if not os.path.exists(datasets_dir):
|
||
raise FileNotFoundError(f"数据集目录 {datasets_dir} 不存在")
|
||
|
||
# 获取所有数据集文件夹
|
||
datasets = [d for d in os.listdir(datasets_dir) if os.path.isdir(os.path.join(datasets_dir, d))]
|
||
|
||
# 创建CSV文件
|
||
with open(output_csv, 'w', newline='') as csvfile:
|
||
fieldnames = [
|
||
'dataset_name',
|
||
'cuda',
|
||
'total_gpus',
|
||
'classes',
|
||
'target_class',
|
||
'pop_size',
|
||
'magnitude_factor',
|
||
'max_itr',
|
||
'run_tag',
|
||
'model',
|
||
'normalize',
|
||
'e',
|
||
'done'
|
||
]
|
||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||
writer.writeheader()
|
||
|
||
# 为每个数据集添加一行,使用默认参数
|
||
for dataset in datasets:
|
||
writer.writerow({
|
||
'dataset_name': dataset,
|
||
'cuda': 'store_true',
|
||
'total_gpus': 1,
|
||
'target_class': -1,
|
||
'classes': 0,
|
||
'pop_size': 1,
|
||
'magnitude_factor': 0.04,
|
||
'max_itr': 50,
|
||
'run_tag': dataset,
|
||
'model': 'r',
|
||
'normalize': 'store_true',
|
||
'e': 1499,
|
||
'done': '0' # 初始状态为未完成
|
||
})
|
||
|
||
print(f"已创建数据集配置文件: {output_csv}")
|
||
return output_csv
|
||
|
||
def validate_dataset(dataset_path):
|
||
"""验证数据集路径是否有效"""
|
||
if not os.path.exists(dataset_path):
|
||
return False
|
||
# 可以添加更多的验证逻辑,如检查必要的文件是否存在
|
||
return True
|
||
|
||
def process_dataset(row, datasets_base_dir):
|
||
"""处理单个数据集,利用多个GPU并行处理"""
|
||
dataset_name = row['dataset_name']
|
||
dataset_path = os.path.join(datasets_base_dir, dataset_name)
|
||
|
||
# 验证数据集
|
||
if not validate_dataset(dataset_path):
|
||
print(f"警告: 数据集 '{dataset_name}' 路径无效或格式不正确")
|
||
return False
|
||
|
||
# 获取总GPU数量
|
||
total_gpus = int(row['total_gpus'])
|
||
|
||
# 为每个GPU创建一个配置
|
||
configurations = []
|
||
for i in range(total_gpus):
|
||
config = {
|
||
'cuda': row['cuda'] == 'store_true',
|
||
'gpu': str(i), # 分配不同的GPU
|
||
'target_class': int(row['target_class']),
|
||
'classes':int(row['classes']),
|
||
'popsize': int(row['pop_size']),
|
||
'magnitude_factor': float(row['magnitude_factor']),
|
||
'maxitr': int(row['max_itr']),
|
||
'run_tag': row['run_tag'],
|
||
'model': row['model'],
|
||
'normalize': row['normalize'] == 'store_true',
|
||
'e': int(row['e']),
|
||
'dataset': dataset_name # 添加数据集名称
|
||
}
|
||
configurations.append(config)
|
||
|
||
# 并行处理同一个数据集的多个任务
|
||
success = True
|
||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||
futures = []
|
||
for config in configurations:
|
||
arg = Namespace(**config)
|
||
future = executor.submit(attack_process, arg)
|
||
futures.append(future)
|
||
|
||
# 等待所有进程完成
|
||
for future in concurrent.futures.as_completed(futures):
|
||
try:
|
||
result = future.result()
|
||
# 如果任何一个进程失败,标记整个数据集处理为失败
|
||
if not result:
|
||
success = False
|
||
except Exception as e:
|
||
import traceback
|
||
print(f"处理数据集 '{dataset_name}' 时出错: {str(e)}")
|
||
print(traceback.format_exc())
|
||
success = False
|
||
|
||
return success
|
||
|
||
def main(csv_file, datasets_base_dir, max_workers=4):
|
||
"""
|
||
主程序:顺序遍历CSV文件中的数据集并执行任务
|
||
每个数据集内部可能有并行处理(由attack_process内部实现)
|
||
Args:
|
||
csv_file: 包含数据集和参数的CSV文件
|
||
datasets_base_dir: 数据集根目录
|
||
max_workers: 数据集内部并行的最大工作进程数
|
||
"""
|
||
# 验证CSV文件
|
||
if not os.path.exists(csv_file):
|
||
raise FileNotFoundError(f"CSV文件 {csv_file} 不存在")
|
||
|
||
# 读取CSV文件
|
||
df = pd.read_csv(csv_file)
|
||
|
||
# 检查CSV格式是否正确
|
||
required_columns = [
|
||
'dataset_name',
|
||
'cuda',
|
||
'total_gpus',
|
||
'target_class',
|
||
'classes',
|
||
'pop_size',
|
||
'magnitude_factor',
|
||
'max_itr',
|
||
'run_tag',
|
||
'model',
|
||
'normalize',
|
||
'e',
|
||
'done'
|
||
]
|
||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||
if missing_columns:
|
||
raise ValueError(f"CSV文件缺少必要的列: {', '.join(missing_columns)}")
|
||
|
||
# 只处理未完成的数据集
|
||
pending_datasets = df[df['done'] == 0]
|
||
total_pending = len(pending_datasets)
|
||
print(f"发现 {total_pending} 个未处理的数据集")
|
||
|
||
# 顺序处理数据集
|
||
for index, row in pending_datasets.iterrows():
|
||
dataset_name = row['dataset_name']
|
||
print(f"开始处理数据集: {dataset_name} ({pending_datasets.index[pending_datasets.index == index].item() + 1}/{total_pending})")
|
||
|
||
try:
|
||
# 处理单个数据集(内部可能有并行)
|
||
success = process_dataset(row, datasets_base_dir)
|
||
|
||
if success:
|
||
# 更新CSV中的done状态
|
||
df.at[index, 'done'] = 1
|
||
df.to_csv(csv_file, index=False)
|
||
print(f"完成数据集: {dataset_name}")
|
||
else:
|
||
print(f"处理数据集失败: {dataset_name}")
|
||
except Exception as e:
|
||
print(f"处理数据集 '{dataset_name}' 时出现异常: {str(e)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='批量处理数据集')
|
||
parser.add_argument('--csv', type=str, default='datasets_config.csv', help='数据集配置CSV文件')
|
||
parser.add_argument('--datasets_dir',default='data' , type=str, help='数据集根目录')
|
||
parser.add_argument('--create_csv', action='store_true', help='是否创建新的CSV文件')
|
||
parser.add_argument('--max_workers', type=int, default=4, help='最大并行工作进程数')
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.create_csv:
|
||
create_datasets_csv(args.datasets_dir, args.csv)
|
||
|
||
main(args.csv, args.datasets_dir, args.max_workers) |