CCPP/utils.py
2025-04-20 20:55:06 +08:00

71 lines
1.9 KiB
Python

import configparser
import torch
import numpy as np
from torch.utils.data import Dataset
from query_probability import load_ucr
import warnings
warnings.filterwarnings('ignore')
class UcrDataset(Dataset):
def __init__(self, txt_file, channel_last, normalize):
'''
:param txt_file: path of file
:param channel_last
'''
# self.data = np.loadtxt(txt_file)
self.data = load_ucr(txt_file, normalize)
self.channel_last = channel_last
if self.channel_last:
self.data = np.reshape(self.data, [self.data.shape[0], self.data.shape[1], 1])
else:
self.data = np.reshape(self.data, [self.data.shape[0], 1, self.data.shape[1]])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if not self.channel_last:
return self.data[idx, :, 1:], self.data[idx, :, 0]
else:
return self.data[idx, 1:, :], self.data[idx, 0, :]
def get_seq_len(self):
if self.channel_last:
return self.data.shape[1] - 1
else:
return self.data.shape[2] - 1
class AdvDataset(Dataset):
def __init__(self, txt_file):
self.data = np.loadtxt(txt_file)
self.data = np.reshape(self.data, [self.data.shape[0], self.data.shape[1], 1])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx, 2:, :], self.data[idx, 1, :]
def get_seq_len(self):
return self.data.shape[1] - 2
def UCR_dataloader(dataset, batch_size):
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
return data_loader
def read_config(file_name):
# 创建一个配置解析器
config = configparser.ConfigParser()
# 从INI文件读取配置
config.read(file_name)
return config