58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
|
import configparser
|
|||
|
|
|||
|
|
|||
|
def read_and_group_data_by_class(input_file):
|
|||
|
data_by_class = {}
|
|||
|
with open(input_file + '_attack.txt', 'r') as file:
|
|||
|
for line in file:
|
|||
|
parts = line.strip().split()
|
|||
|
cls = parts[0]
|
|||
|
if cls not in data_by_class:
|
|||
|
data_by_class[cls] = []
|
|||
|
data_by_class[cls].append(parts)
|
|||
|
return data_by_class
|
|||
|
|
|||
|
def split_data(data, n):
|
|||
|
# 确保即使数据不能均匀分配也能处理
|
|||
|
total_length = len(data)
|
|||
|
split_length = max(total_length // n, 1)
|
|||
|
remainder = total_length % n
|
|||
|
splits = []
|
|||
|
start_idx = 0
|
|||
|
for i in range(n):
|
|||
|
if i < remainder:
|
|||
|
end_idx = start_idx + split_length + 1
|
|||
|
else:
|
|||
|
end_idx = start_idx + split_length
|
|||
|
splits.append(data[start_idx:end_idx])
|
|||
|
start_idx = end_idx
|
|||
|
return splits
|
|||
|
|
|||
|
def write_split_data(split_data_by_class, n, data_path):
|
|||
|
for i in range(n):
|
|||
|
output_file = f'{data_path}_attack{i}.txt'
|
|||
|
with open(output_file, 'w') as f:
|
|||
|
for cls, data_splits in split_data_by_class.items():
|
|||
|
if i < len(data_splits): # 检查索引以防止越界
|
|||
|
for line in data_splits[i]:
|
|||
|
f.write(" ".join(line) + "\n")
|
|||
|
|
|||
|
def split_and_save_data(data_path, n):
|
|||
|
# Step 1: Read and group data by class
|
|||
|
data_by_class = read_and_group_data_by_class(data_path)
|
|||
|
|
|||
|
# Step 2: Split data for each class
|
|||
|
split_data_by_class = {cls: split_data(data, n) for cls, data in data_by_class.items()}
|
|||
|
|
|||
|
# Step 3: Write split data into files
|
|||
|
write_split_data(split_data_by_class, n, data_path)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 临时添加配置文件类别种类
|
|||
|
config = configparser.ConfigParser()
|
|||
|
config.read("run_parallel_config.ini")
|
|||
|
file_name = config['MODEL']['run_tag']
|
|||
|
data_path = 'data/' + file_name + '/' + file_name
|
|||
|
n = 5 # 假设我们想将数据均分到7个文件中,最多7个
|
|||
|
split_and_save_data(data_path, n)
|