56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
|
# -*- coding: utf-8 -*-
|
|||
|
import configparser
|
|||
|
|
|||
|
|
|||
|
def process_and_split_data(file_name):
|
|||
|
# 读取文件内容
|
|||
|
data_path = 'data/' + file_name + '/' + file_name
|
|||
|
with open(data_path+'_TEST.txt', 'r') as file:
|
|||
|
lines = file.readlines()
|
|||
|
|
|||
|
# 解析数据,并按类别排序
|
|||
|
data = [line.strip().split() for line in lines]
|
|||
|
data.sort(key=lambda x: x[0]) # 假设类别在每行的第一列
|
|||
|
|
|||
|
# 按类别分组
|
|||
|
from collections import defaultdict
|
|||
|
categorized_data = defaultdict(list)
|
|||
|
for row in data:
|
|||
|
categorized_data[row[0]].append(row)
|
|||
|
|
|||
|
# 分别存储每个类别的前50%和剩余的数据
|
|||
|
cp_file_name = data_path + '_cp.txt'
|
|||
|
attack_file_name = data_path + '_attack.txt'
|
|||
|
|
|||
|
with open(cp_file_name, 'w') as cp_file, open(attack_file_name, 'w') as attack_file:
|
|||
|
for category, rows in categorized_data.items():
|
|||
|
#如果数据量很大,找区间不用这么多条数据,cp文件里面应该少一些
|
|||
|
if len(rows)>100:
|
|||
|
half_index = len(rows) // 10
|
|||
|
#half_index = 6
|
|||
|
#half = len(rows) // 10
|
|||
|
#StarLightCurves这个数据集太长了,截取每一类的后百分之十
|
|||
|
half = (len(rows) // 2)
|
|||
|
else:
|
|||
|
half_index = len(rows) // 2
|
|||
|
# half = (len(rows) // 10)*9
|
|||
|
half = (len(rows) // 2)
|
|||
|
# 存储前50%的数据
|
|||
|
for row in rows[:half_index]:
|
|||
|
cp_file.write(' '.join(row) + '\n')
|
|||
|
# 存储剩余的数据
|
|||
|
for row in rows[half:]:
|
|||
|
attack_file.write(' '.join(row) + '\n')
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# 假设文件名为'Car.txt'
|
|||
|
if __name__ == '__main__':
|
|||
|
# 临时添加配置文件类别种类
|
|||
|
config = configparser.ConfigParser()
|
|||
|
config.read("run_parallel_config.ini")
|
|||
|
file_name = config['MODEL']['run_tag']
|
|||
|
process_and_split_data(file_name)
|
|||
|
|