56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
# -*- coding: utf-8 -*-
|
||
import configparser
|
||
|
||
|
||
def process_and_split_data(file_name):
|
||
# 读取文件内容
|
||
data_path = 'data/' + file_name + '/' + file_name
|
||
with open(data_path+'_TEST.txt', 'r') as file:
|
||
lines = file.readlines()
|
||
|
||
# 解析数据,并按类别排序
|
||
data = [line.strip().split() for line in lines]
|
||
data.sort(key=lambda x: x[0]) # 假设类别在每行的第一列
|
||
|
||
# 按类别分组
|
||
from collections import defaultdict
|
||
categorized_data = defaultdict(list)
|
||
for row in data:
|
||
categorized_data[row[0]].append(row)
|
||
|
||
# 分别存储每个类别的前50%和剩余的数据
|
||
cp_file_name = data_path + '_cp.txt'
|
||
attack_file_name = data_path + '_attack.txt'
|
||
|
||
with open(cp_file_name, 'w') as cp_file, open(attack_file_name, 'w') as attack_file:
|
||
for category, rows in categorized_data.items():
|
||
#如果数据量很大,找区间不用这么多条数据,cp文件里面应该少一些
|
||
if len(rows)>100:
|
||
half_index = len(rows) // 10
|
||
#half_index = 6
|
||
#half = len(rows) // 10
|
||
#StarLightCurves这个数据集太长了,截取每一类的后百分之十
|
||
half = (len(rows) // 2)
|
||
else:
|
||
half_index = len(rows) // 2
|
||
# half = (len(rows) // 10)*9
|
||
half = (len(rows) // 2)
|
||
# 存储前50%的数据
|
||
for row in rows[:half_index]:
|
||
cp_file.write(' '.join(row) + '\n')
|
||
# 存储剩余的数据
|
||
for row in rows[half:]:
|
||
attack_file.write(' '.join(row) + '\n')
|
||
|
||
|
||
|
||
|
||
# 假设文件名为'Car.txt'
|
||
if __name__ == '__main__':
|
||
# 临时添加配置文件类别种类
|
||
config = configparser.ConfigParser()
|
||
config.read("run_parallel_config.ini")
|
||
file_name = config['MODEL']['run_tag']
|
||
process_and_split_data(file_name)
|
||
|