CCPP/dataPre.py
2025-04-20 20:55:06 +08:00

56 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import configparser
def process_and_split_data(file_name):
# 读取文件内容
data_path = 'data/' + file_name + '/' + file_name
with open(data_path+'_TEST.txt', 'r') as file:
lines = file.readlines()
# 解析数据,并按类别排序
data = [line.strip().split() for line in lines]
data.sort(key=lambda x: x[0]) # 假设类别在每行的第一列
# 按类别分组
from collections import defaultdict
categorized_data = defaultdict(list)
for row in data:
categorized_data[row[0]].append(row)
# 分别存储每个类别的前50%和剩余的数据
cp_file_name = data_path + '_cp.txt'
attack_file_name = data_path + '_attack.txt'
with open(cp_file_name, 'w') as cp_file, open(attack_file_name, 'w') as attack_file:
for category, rows in categorized_data.items():
#如果数据量很大找区间不用这么多条数据cp文件里面应该少一些
if len(rows)>100:
half_index = len(rows) // 10
#half_index = 6
#half = len(rows) // 10
#StarLightCurves这个数据集太长了截取每一类的后百分之十
half = (len(rows) // 2)
else:
half_index = len(rows) // 2
# half = (len(rows) // 10)*9
half = (len(rows) // 2)
# 存储前50%的数据
for row in rows[:half_index]:
cp_file.write(' '.join(row) + '\n')
# 存储剩余的数据
for row in rows[half:]:
attack_file.write(' '.join(row) + '\n')
# 假设文件名为'Car.txt'
if __name__ == '__main__':
# 临时添加配置文件类别种类
config = configparser.ConfigParser()
config.read("run_parallel_config.ini")
file_name = config['MODEL']['run_tag']
process_and_split_data(file_name)