# -*- coding: utf-8 -*- import configparser def process_and_split_data(file_name): # 读取文件内容 data_path = 'data/' + file_name + '/' + file_name with open(data_path+'_TEST.txt', 'r') as file: lines = file.readlines() # 解析数据,并按类别排序 data = [line.strip().split() for line in lines] data.sort(key=lambda x: x[0]) # 假设类别在每行的第一列 # 按类别分组 from collections import defaultdict categorized_data = defaultdict(list) for row in data: categorized_data[row[0]].append(row) # 分别存储每个类别的前50%和剩余的数据 cp_file_name = data_path + '_cp.txt' attack_file_name = data_path + '_attack.txt' with open(cp_file_name, 'w') as cp_file, open(attack_file_name, 'w') as attack_file: for category, rows in categorized_data.items(): #如果数据量很大,找区间不用这么多条数据,cp文件里面应该少一些 if len(rows)>100: half_index = len(rows) // 10 #half_index = 6 #half = len(rows) // 10 #StarLightCurves这个数据集太长了,截取每一类的后百分之十 half = (len(rows) // 2) else: half_index = len(rows) // 2 # half = (len(rows) // 10)*9 half = (len(rows) // 2) # 存储前50%的数据 for row in rows[:half_index]: cp_file.write(' '.join(row) + '\n') # 存储剩余的数据 for row in rows[half:]: attack_file.write(' '.join(row) + '\n') # 假设文件名为'Car.txt' if __name__ == '__main__': # 临时添加配置文件类别种类 config = configparser.ConfigParser() config.read("run_parallel_config.ini") file_name = config['MODEL']['run_tag'] process_and_split_data(file_name)