71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
|
import configparser
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from torch.utils.data import Dataset
|
||
|
from query_probability import load_ucr
|
||
|
import warnings
|
||
|
|
||
|
warnings.filterwarnings('ignore')
|
||
|
|
||
|
|
||
|
class UcrDataset(Dataset):
|
||
|
def __init__(self, txt_file, channel_last, normalize):
|
||
|
'''
|
||
|
:param txt_file: path of file
|
||
|
:param channel_last
|
||
|
'''
|
||
|
# self.data = np.loadtxt(txt_file)
|
||
|
self.data = load_ucr(txt_file, normalize)
|
||
|
self.channel_last = channel_last
|
||
|
if self.channel_last:
|
||
|
self.data = np.reshape(self.data, [self.data.shape[0], self.data.shape[1], 1])
|
||
|
else:
|
||
|
self.data = np.reshape(self.data, [self.data.shape[0], 1, self.data.shape[1]])
|
||
|
|
||
|
def __len__(self):
|
||
|
|
||
|
return len(self.data)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
|
||
|
if not self.channel_last:
|
||
|
return self.data[idx, :, 1:], self.data[idx, :, 0]
|
||
|
else:
|
||
|
return self.data[idx, 1:, :], self.data[idx, 0, :]
|
||
|
|
||
|
def get_seq_len(self):
|
||
|
if self.channel_last:
|
||
|
return self.data.shape[1] - 1
|
||
|
else:
|
||
|
return self.data.shape[2] - 1
|
||
|
|
||
|
|
||
|
class AdvDataset(Dataset):
|
||
|
def __init__(self, txt_file):
|
||
|
self.data = np.loadtxt(txt_file)
|
||
|
self.data = np.reshape(self.data, [self.data.shape[0], self.data.shape[1], 1])
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.data)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
return self.data[idx, 2:, :], self.data[idx, 1, :]
|
||
|
|
||
|
def get_seq_len(self):
|
||
|
return self.data.shape[1] - 2
|
||
|
|
||
|
|
||
|
def UCR_dataloader(dataset, batch_size):
|
||
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
|
||
|
|
||
|
return data_loader
|
||
|
|
||
|
|
||
|
def read_config(file_name):
|
||
|
# 创建一个配置解析器
|
||
|
config = configparser.ConfigParser()
|
||
|
# 从INI文件读取配置
|
||
|
config.read(file_name)
|
||
|
return config
|