58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
import configparser
|
||
|
||
|
||
def read_and_group_data_by_class(input_file):
|
||
data_by_class = {}
|
||
with open(input_file + '_attack.txt', 'r') as file:
|
||
for line in file:
|
||
parts = line.strip().split()
|
||
cls = parts[0]
|
||
if cls not in data_by_class:
|
||
data_by_class[cls] = []
|
||
data_by_class[cls].append(parts)
|
||
return data_by_class
|
||
|
||
def split_data(data, n):
|
||
# 确保即使数据不能均匀分配也能处理
|
||
total_length = len(data)
|
||
split_length = max(total_length // n, 1)
|
||
remainder = total_length % n
|
||
splits = []
|
||
start_idx = 0
|
||
for i in range(n):
|
||
if i < remainder:
|
||
end_idx = start_idx + split_length + 1
|
||
else:
|
||
end_idx = start_idx + split_length
|
||
splits.append(data[start_idx:end_idx])
|
||
start_idx = end_idx
|
||
return splits
|
||
|
||
def write_split_data(split_data_by_class, n, data_path):
|
||
for i in range(n):
|
||
output_file = f'{data_path}_attack{i}.txt'
|
||
with open(output_file, 'w') as f:
|
||
for cls, data_splits in split_data_by_class.items():
|
||
if i < len(data_splits): # 检查索引以防止越界
|
||
for line in data_splits[i]:
|
||
f.write(" ".join(line) + "\n")
|
||
|
||
def split_and_save_data(data_path, n):
|
||
# Step 1: Read and group data by class
|
||
data_by_class = read_and_group_data_by_class(data_path)
|
||
|
||
# Step 2: Split data for each class
|
||
split_data_by_class = {cls: split_data(data, n) for cls, data in data_by_class.items()}
|
||
|
||
# Step 3: Write split data into files
|
||
write_split_data(split_data_by_class, n, data_path)
|
||
|
||
if __name__ == "__main__":
|
||
# 临时添加配置文件类别种类
|
||
config = configparser.ConfigParser()
|
||
config.read("run_parallel_config.ini")
|
||
file_name = config['MODEL']['run_tag']
|
||
data_path = 'data/' + file_name + '/' + file_name
|
||
n = 5 # 假设我们想将数据均分到7个文件中,最多7个
|
||
split_and_save_data(data_path, n)
|