CCPP/drawing_tool.py
2025-04-20 20:55:06 +08:00

566 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib.font_manager import FontProperties
from matplotlib.ticker import FuncFormatter
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.patheffects as pe
# 设置字体为 Times New Roman并调整字体大小
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 22 # 设置默认字体大小
plt.rcParams['axes.titlesize'] = 26 # 设置标题字体大小
plt.rcParams['axes.labelsize'] = 26 # 设置轴标签字体大小
plt.rcParams['xtick.labelsize'] = 26 # 设置x轴刻度字体大小
plt.rcParams['ytick.labelsize'] = 26 # 设置y轴刻度字体大小
plt.rcParams['legend.fontsize'] = 20 # 设置图例字体大小
import matplotlib
from matplotlib.patches import FancyBboxPatch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from analyzeTool import generate_heat_maps_array, count_classes_samples
matplotlib.use('MacOSX') # 使用MacOSX后端以兼容Mac
# 绘制原始和扰动时间序列的函数
def draw_ori_perturbed(ori_data, pert_data, dashed_lists=None, title=None, filename='comp_ori_pert_result.pdf',
loc_sample='upper left', loc_sub='upper left', inset_range=None, xy=(0, 0), xytext=(0, 0)):
color_1 = '#4B0082'
color_2 = '#FF8C00'
# 将字符串数据转换为浮点数列表
ori_ts = list(map(float, ori_data.split()))
perturbed_ts = list(map(float, pert_data.split()))
length = len(ori_ts)
x = np.arange(0, length, 1) # 生成从0到数据长度的x值
# 绘制数据
plt.figure(figsize=(10, 5))
plt.plot(x, ori_ts, label='Original', color=color_1, linewidth=5)
plt.plot(x, perturbed_ts, label='Perturbed', color=color_2, linewidth=5)
# 添加虚线框
if dashed_lists is None:
dashed_lists = []
if len(dashed_lists) > 0:
for start, end in dashed_lists:
max_len = max(max(ori_ts[start:end]), max(perturbed_ts[start:end]))
min_len = min(min(ori_ts[start:end]), min(perturbed_ts[start:end]))
plt.plot(
[start, start, end, end, start],
[min_len, max_len, max_len, min_len, min_len],
linestyle='--', # 使用linestyle来指定虚线样式
linewidth=4, # 线宽
color='black'
)
plt.xlim(-10, length + 10) # 设置x轴范围
plt.legend(loc=loc_sample)
plt.grid(True)
# 获取当前轴并设置边框宽度
ax = plt.gca()
for spine in ax.spines.values():
spine.set_linewidth(2) # 设置边框线宽
# 添加局部放大
if inset_range is not None:
plt.annotate('', xy=xy, xytext=xytext,
arrowprops=dict(arrowstyle='->', color='black', lw=3))
inset_start, inset_end = inset_range
ax_inset = inset_axes(plt.gca(), width="30%", height="30%", loc=loc_sub, borderpad=1) # 调整宽度和高度
ax_inset.plot(x[inset_start:inset_end], ori_ts[inset_start:inset_end], label='Original', color=color_1,
linewidth=4)
ax_inset.plot(x[inset_start:inset_end], perturbed_ts[inset_start:inset_end], label='Perturbed', color=color_2,
linewidth=4)
ax_inset.set_xlim(inset_start, inset_end)
ax_inset.grid(True)
# 移除插入图的坐标轴刻度
ax_inset.set_xticks([])
ax_inset.set_yticks([])
for spine in ax_inset.spines.values():
spine.set_linewidth(2) # 设置边框线宽
if title is not None:
# 使用 tight_layout 和 subplots_adjust 调整布局
plt.tight_layout()
plt.subplots_adjust(bottom=0.15) # 调整底部边距
# 在下方添加标题
plt.figtext(0.5, 0.02, title, ha='center', fontsize=12)
plt.savefig(filename)
plt.show()
# 绘制原始和扰动时间序列的函数,带有放大视图
def draw_ori_perturbed_with_inset(ori_data, pert_data, inset_start, inset_end):
# 将字符串数据转换为浮点数列表
ori_ts = list(map(float, ori_data.split()))
perturbed_ts = list(map(float, pert_data.split()))
length = len(ori_ts)
x = np.arange(0, length, 1) # 生成从0到数据长度的x值
# 创建主图
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(x, ori_ts, label='Original: Class 0', color='royalblue', linewidth=2)
ax.plot(x, perturbed_ts, label='Perturbed: Class 1', color='tomato', linewidth=2)
ax.set_xlim(-20, length + 20) # 设置x轴范围
ax.grid(True)
# 创建放大视图
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
# 定义放大区域
inset_axes_params = {
'width': '40%', # 放大视图的宽度
'height': '40%', # 放大视图的高度
'bbox_to_anchor': (-0.2, -0.1, 0.8, 0.8), # 偏移,尺寸
'bbox_transform': ax.transAxes # 使用轴的变换
}
# 创建放大视图轴
ax_inset = inset_axes(ax, **inset_axes_params)
# 设置放大视图的范围
ax_inset.plot(x, ori_ts, color='royalblue', linewidth=2)
ax_inset.plot(x, perturbed_ts, color='tomato', linewidth=2)
ax_inset.set_xlim(inset_start, inset_end)
ax_inset.set_ylim(min(ori_ts[inset_start:inset_end]) - 0.5, max(ori_ts[inset_start:inset_end]) + 0.5)
# 移除放大视图中的刻度
ax_inset.set_xticks([])
ax_inset.set_yticks([])
# 标记放大区域并设置边框粗细和zorder(层次位置,越大越靠前)避免线条重叠在数据线上loc1和loc2代表引导线的起始角
mark_inset(ax, ax_inset, loc1=3, loc2=1, fc="none", ec="0.5", lw=2, zorder=10)
# 添加阴影效果到原图的标记框
bbox_shadow = FancyBboxPatch((inset_start, min(ori_ts[inset_start:inset_end]) - 0.5),
inset_end - inset_start,
max(ori_ts[inset_start:inset_end]) - min(ori_ts[inset_start:inset_end]) + 1,
boxstyle="round,pad=0.1",
ec="none", fc="black", alpha=0.2,
transform=ax.transData, zorder=2)
ax.add_patch(bbox_shadow)
# 保存并展示图像
plt.savefig('comp_ori_pert_inset_result.pdf')
plt.show()
def draw_times_diff(data1, data2, data3, data1_label, data2_label, data3_label):
x = np.arange(1, 6)
# 创建图表
plt.figure(figsize=(10, 5))
plt.plot(x, data1, label=data1_label, color='royalblue', marker='o')
plt.plot(x, data2, label=data2_label, color='tomato', marker='o')
plt.plot(x, data3, label=data3_label, color='sandybrown', marker='o')
# 添加标签和图例
plt.xlabel('Times of population size')
plt.ylabel('MSE')
plt.legend()
plt.grid(True)
# 调整布局以防止标签被裁剪
plt.tight_layout()
plt.legend(loc='center right')
# 保存为PDF格式矢量图
plt.savefig('line_chart_result.pdf', format='pdf')
plt.show()
def draw_times_diff_bar(data1, data2, data3, data4,
line_data1, line_data2, line_data3, line_data4,
data1_label, data2_label, data3_label, data4_label,
filename='bar_chart_with_lines_and_legend_above.pdf', leftLabel='MSE (10^-3)', rightLabel='ANI',
xlabel='Times of population size', x_value=None, times=1000):
if x_value is None:
x_value = ['1', '2', '3', '4', '5']
x = np.arange(1, 6) # 假设有5个数据点
# 设置柱状图的宽度
bar_width = 0.15
# 创建图表
fig, ax1 = plt.subplots(figsize=(8, 5))
color1 = '#05B9E2'
color2 = '#54B345'
color3 = '#F27970'
color4 = '#8983BF'
# 为每组数据创建柱状图,并添加黑色边框
bars1 = ax1.bar(x - bar_width * 1.5, data1, width=bar_width, label=data1_label, color=color1, alpha=0.6,
edgecolor='black')
bars2 = ax1.bar(x - bar_width * 0.5, data2, width=bar_width, label=data2_label, color=color2, alpha=0.6,
edgecolor='black')
bars3 = ax1.bar(x + bar_width * 0.5, data3, width=bar_width, label=data3_label, color=color3, alpha=0.6,
edgecolor='black')
bars4 = ax1.bar(x + bar_width * 1.5, data4, width=bar_width, label=data4_label, color=color4, alpha=0.6,
edgecolor='black')
# 设置左侧Y轴标签和格式
ax1.set_xlabel(xlabel)
ax1.set_ylabel(leftLabel)
ax1.set_xticks(x)
ax1.set_xticklabels(x_value)
ax1.tick_params(axis='y')
ax1.grid(axis='y') # 只显示y轴的网格线
for spine in ax1.spines.values():
spine.set_linewidth(2) # 设置边框线宽
# 格式化y轴刻度为指数形式
def to_exponential(x, _):
return f'{x * times:.0f}'
ax1.yaxis.set_major_formatter(FuncFormatter(to_exponential))
# 创建右侧Y轴
ax2 = ax1.twinx()
ax2.plot(x, line_data1, color=color1, marker='o', linestyle='-', linewidth=2, markersize=8)
ax2.plot(x, line_data2, color=color2, marker='o', linestyle='-', linewidth=2, markersize=8)
ax2.plot(x, line_data3, color=color3, marker='o', linestyle='-', linewidth=2, markersize=8)
ax2.plot(x, line_data4, color=color4, marker='o', linestyle='-', linewidth=2, markersize=8)
ax2.set_ylabel(rightLabel)
ax2.tick_params(axis='y')
# 将柱状图的图例放在图的上方,并设置字体大小
bars_legend = ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.2), ncol=4, frameon=False)
# 添加图例到图中
ax1.add_artist(bars_legend)
# 调整布局以防止标签被裁剪
plt.tight_layout(rect=[0, 0, 1, 0.95])
# 保存为PDF格式矢量图
plt.savefig(filename, format='pdf')
plt.show()
def draw_diff_pert_info(data1, data2, data5, data6, labels, marker=None, linewidth=2, filename='line_chart_result.pdf'):
x = np.arange(len(labels))
color1 = '#F27970'
color2 = '#54B345'
color3 = '#05B9E2'
color4 = '#8983BF'
# 创建图表
plt.figure(figsize=(7, 5))
plt.plot(x, data1, label='Car', color=color1, marker=marker, linewidth=linewidth, markersize=8)
plt.plot(x, data5, label='ECG200', color=color2, marker=marker, linewidth=linewidth, markersize=8)
plt.plot(x, data2, label='Italy', color=color3, marker=marker, linewidth=linewidth, markersize=8)
plt.plot(x, data6, label='Ligning7', color=color4, marker=marker, linewidth=linewidth, markersize=8)
# 添加标签和图例
plt.xlabel('$\\beta$')
plt.ylabel('Success Rate')
plt.xticks(x, labels, rotation=0, ha='center')
# plt.legend(loc='center right') # 图例位置和字体大小
plt.grid(True) # 网格线
# 调整子图布局以防止重叠
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
# 获取当前坐标轴对象
ax = plt.gca()
for spine in ax.spines.values():
spine.set_linewidth(2) # 设置边框线宽
# 保存为PDF格式矢量图
plt.savefig(filename, format='pdf')
plt.show()
# 画大表
def draw_all_data_set_info(data1, data2, data3, data4, data5, labels):
def replace_neg_ones_with_nan(data):
"""
将数据中的所有-1替换为np.nan。
参数:
data (array-like): 输入数据数组。
返回:
numpy.ndarray: 将-1替换为np.nan后的数据数组。
"""
return np.array([np.nan if x == -1 else x for x in data])
def move_neg_ones_to_end(arrays):
for i in range(len(arrays)):
# 跳过所有非数字数组
if not isinstance(arrays[i][0], (int, float)):
continue
# 记录每个数组中的元素和对应的索引
elements = [(val, idx) for idx, val in enumerate(arrays[i])]
# 将-1移动到末尾
sorted_elements = sorted(elements, key=lambda x: (x[0] == -1, 0))
# 重新排列每个数组
for array in arrays:
sorted_array = [array[elem[1]] for elem in sorted_elements]
array[:] = sorted_array
# 修正数组
move_neg_ones_to_end([labels, data1, data2, data3, data4, data5])
# 将数据集中的-1替换为NaN
data1 = replace_neg_ones_with_nan(data1)
data2 = replace_neg_ones_with_nan(data2)
data3 = replace_neg_ones_with_nan(data3)
data4 = replace_neg_ones_with_nan(data4)
data5 = replace_neg_ones_with_nan(data5)
# 创建x轴的值
x = np.arange(len(labels))
# 创建折线图
plt.figure(figsize=(10, 5))
plt.plot(x, data1, label='GATN MSE', color='black')
plt.plot(x, data2, label='advGAN MSE', color='royalblue')
plt.plot(x, data3, label='TSadv MSE $\\beta=0.04$', color='tomato')
plt.plot(x, data4, label='CPadv MSE $\\beta=0.04$', color='sandybrown')
plt.plot(x, data5, label='CPadv MSE $\\beta=0.02$', color='firebrick')
# 添加标签和图例
plt.xlabel('')
plt.ylabel('MSE')
plt.xticks(x, labels, rotation=30, ha='right', fontsize=5, fontweight='bold')
plt.legend(loc='upper right') # 标签位置
plt.grid(False) # 网格线
# 调整布局以防止重叠
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
# 添加标题
plt.figtext(0.5, 0.01, 'Fig. 9. MSE results of TSadv over different $\\beta$ and GATN on 42 UCR datasets',
ha='center', fontsize=12)
# 保存为PDF文件
plt.savefig('line_chart_result.pdf', format='pdf')
plt.show()
def heat_maps_of_classes(ori_data1, ori_data2, title, labels=None, cbar_ticks1=None, cbar_ticks2=None,
cmap_color='Blues', filename='heatmap.pdf'):
"""
生成两个数据集的并排热图并将其保存为PDF文件。
参数:
ori_data1 (list of int/float): 第一个数据集,将其可视化为热图。
ori_data2 (list of int/float): 第二个数据集,将其可视化为热图。
title (str): 热图的标题。
labels (list of tuple of str, optional): 两个热图的x轴和y轴标签列表。默认值为[('', ''), ('', '')]。
cbar_ticks1 (list of float, optional): 自定义color bar1的刻度列表。默认值为None。
cbar_ticks2 (list of float, optional): 自定义color bar2的刻度列表。默认值为None。
cmap_color (str): 热力图颜色默认Blues
"""
# 设置字体
# 设置颜色条的字体
font = FontProperties()
font.set_family('serif')
font.set_name('Times New Roman')
font.set_size(20) # 设置字体大小
# 数据
if labels is None:
labels = [('', ''), ('', '')]
data1 = np.array(ori_data1)
data2 = np.array(ori_data2)
fig, axs = plt.subplots(1, 2, figsize=(8, 3), gridspec_kw={'wspace': 0.4})
# 第一个子图
cax1 = axs[0].matshow(data1, cmap=cmap_color, alpha=0.8, vmin=min(cbar_ticks1), vmax=max(cbar_ticks1))
# 添加网格
axs[0].set_xticks(np.arange(data1.shape[1] + 1) - 0.5, minor=True)
axs[0].set_yticks(np.arange(data1.shape[0] + 1) - 0.5, minor=True)
axs[0].grid(which='minor', color='gray', linestyle='-', linewidth=0.5)
# 创建第一个子图的颜色条
divider1 = make_axes_locatable(axs[0])
cax1_cb = divider1.append_axes("right", size="5%", pad=0.1)
cb1 = fig.colorbar(cax1, cax=cax1_cb)
if cbar_ticks1:
cb1.set_ticks(cbar_ticks1)
# 应用字体设置到颜色条的标签
cb1.ax.yaxis.set_tick_params(labelsize=12)
for label in cb1.ax.get_yticklabels():
label.set_fontproperties(font)
# 第二个子图
cax2 = axs[1].matshow(data2, cmap=cmap_color, alpha=0.8, vmin=min(cbar_ticks2), vmax=max(cbar_ticks2))
# 添加网格
axs[1].set_xticks(np.arange(data2.shape[1] + 1) - 0.5, minor=True)
axs[1].set_yticks(np.arange(data2.shape[0] + 1) - 0.5, minor=True)
axs[1].grid(which='minor', color='gray', linestyle='-', linewidth=0.5)
# 创建第二个子图的颜色条
divider2 = make_axes_locatable(axs[1])
cax2_cb = divider2.append_axes("right", size="5%", pad=0.1)
cb2 = fig.colorbar(cax2, cax=cax2_cb)
if cbar_ticks2:
cb2.set_ticks(cbar_ticks2)
cb2.ax.yaxis.set_tick_params(labelsize=12)
for label in cb2.ax.get_yticklabels():
label.set_fontproperties(font)
for i in range(2):
axs[i].tick_params(axis='both', which='major', labelsize=12)
axs[i].xaxis.set_ticks_position('bottom')
axs[i].xaxis.set_label_position('bottom')
axs[i].invert_yaxis()
for label in (axs[i].get_xticklabels() + axs[i].get_yticklabels()):
label.set_fontproperties(font)
# 设置xy轴刻度间隔为1
axs[0].set_xticks(np.arange(data1.shape[1]))
axs[0].set_yticks(np.arange(data1.shape[0]))
axs[1].set_xticks(np.arange(data2.shape[1]))
axs[1].set_yticks(np.arange(data2.shape[0]))
# 添加描述文字
plt.figtext(0.5, 0.01, title, ha='center', va='center', fontsize=10, fontname='Times New Roman')
# 保存为矢量图PDF格式
plt.savefig(filename, format='pdf', bbox_inches='tight')
# 显示图像
plt.show()
if __name__ == '__main__':
# 攻击样本对比
# original_data = "0.95110175 0.03860069 0.001077176 0.001077176 0.001077176 -0.028107779 0.0052464556 -0.036446338 -0.040615617 -0.036446338 -0.028107779 -0.078139131 -0.08647769 -0.044784896 -0.078139131 -0.057292734 -0.057292734 -0.094816249 -0.08647769 -0.069800573 -0.069800573 -0.08647769 -0.1240012 -0.11566264 -0.078139131 -0.094816249 -0.094816249 -0.14067832 -0.14067832 -0.090646969 -0.14901688 -0.16986328 -0.098985528 -0.15735544 -0.1240012 -0.061462014 -0.0072613824 -0.0072613824 0.071954926 0.096970602 -0.032277059 -0.13233976 -0.19070967 -0.27826454 -0.21572535 -0.27409526 -0.26992598 -0.28660309 -0.25741814 -0.28243382 -0.24907958 -0.24907958 -0.24907958 -0.2240639 -0.098985528 1.0600742 1.7771901 -0.82027086 -3.8930297 -5.3647854 -4.0431238 -2.2294872 -1.0287348 -0.44920499 -0.23240247 -0.1240012 -0.08647769 -0.015599941 0.017754294 0.03860069 0.063616366 0.10530916 0.10530916 0.18869474 0.18035619 0.28458817 0.40549727 0.46803646 0.60979196 0.75988602 0.93499575 1.2018296 1.5312027 1.8522372 2.1565946 2.4943062 2.7402937 2.9445884 2.9154034 2.5443376 2.0148391 1.3811086 0.85161016 0.45135934 0.21371042 0.071954926 -0.0239385 -0.078139131 -0.10315481 -0.08647769 -0.13233976 -0.094816249 -0.090646969 -0.044784896 -0.053123455 -0.044784896 -0.057292734 -0.10315481 -0.08647769 -0.08647769 -0.11566264 -0.15735544 -0.165694 -0.20321751 -0.21572535 -0.28243382 -0.28243382 -0.26992598 -0.29077238 -0.3074495 -0.34914228 -0.34080373 -0.33246517 -0.34914228 -0.32829589 -0.34080373 -0.3658194 -0.3658194 -0.37832724 -0.38249653 -0.39083508 -0.39917364 -0.39500436 -0.39083508 -0.43252788 -0.40334292"
# perturbed_data = "0.9546 0.0387 0.0011 0.0011 0.0011 -0.0282 0.0053 -0.0366 -0.0408 -0.0366 -0.0282 -0.0784 -0.0868 -0.0450 -0.0784 -0.0575 -0.0575 -0.0952 -0.0868 -0.0701 -0.0701 -0.0868 -0.1245 -0.1161 -0.0784 -0.0952 -0.0952 -0.1412 -0.1412 -0.0910 -0.1496 -0.1705 -0.0994 -0.1579 -0.1245 -0.0617 -0.0073 -0.0073 0.0722 0.0973 -0.0324 -0.1328 -0.1914 -0.2793 -0.2165 -0.2751 -0.2709 -0.2877 -0.2584 -0.1592 -0.2146 0.0321 -0.0189 0.0241 0.1806 1.3527 1.6821 -1.0405 -3.7122 -5.2477 -4.2234 -2.4362 -1.3062 -0.6095 -0.2333 -0.1245 -0.0868 -0.0157 0.0178 0.0387 0.0639 0.1057 0.1057 0.1894 0.1810 0.4762 0.2598 0.1544 0.2773 0.8633 0.8632 1.0042 1.3689 1.9531 2.3896 2.7358 2.7239 2.8305 2.7696 2.5007 2.0223 1.3862 0.8548 0.4530 0.2145 0.0722 -0.0240 -0.0784 -0.1035 -0.0868 -0.1328 -0.0952 -0.0910 -0.0450 -0.0533 -0.0450 -0.0575 -0.1035 -0.0868 -0.0868 -0.1161 -0.1579 -0.1663 -0.2040 -0.2165 -0.2835 -0.2835 -0.2709 -0.2918 -0.3086 -0.3504 -0.3421 -0.3337 -0.3504 -0.3295 -0.3421 -0.3672 -0.3672 -0.3797 -0.3839 -0.3923 -0.4006 -0.3965 -0.3923 -0.4341 -0.4048"
# # # 样本4成功
# # -0.4898 -0.4707 -0.4611 -0.4323 -0.4898 -0.4515 -0.4036 -0.4228 -0.4898 -0.4707 -0.3461 -0.4132 -0.4898 -0.3173 -0.4707 -0.2982 -0.3653 -0.2886 -0.3461 -0.2694 -0.3844 -0.3077 -0.3365 -0.3269 -0.2119 -0.2598 -0.2694 -0.2311 -0.3269 -0.2215 -0.2694 -0.1640 -0.1256 0.0660 0.0564 -0.0202 -0.1927 -0.4803 -0.5569 -0.3077 -0.3557 -0.2119 -0.2502 -0.2694 -0.2311 -0.2790 -0.0969 -0.0011 -0.0969 0.1653 0.2051 0.2683 0.8980 -0.2241 -4.2003 -6.0178 -4.5483 -3.0059 -1.3216 -0.5447 -0.2506 0.1809 -0.1301 0.4711 0.2481 0.2673 0.2385 0.2865 0.2673 0.3631 0.3823 0.3056 0.3823 0.4877 0.5261 0.9649 0.6908 0.7372 0.6360 0.8578 1.0912 1.5034 1.6566 1.7832 2.1861 2.5592 2.3825 1.6257 1.4009 0.9878 0.4590 0.3248 0.2194 0.2290 0.1906 0.1906 0.1715 0.1906 0.1906 0.1906 0.3344 0.1715 0.2290 0.2098 0.2481 0.2098 0.2194 0.2385 0.1715 0.2673 0.2290 0.2098 0.2098 0.2769 0.1715 0.2002 0.1810 0.1715 0.1906 0.1810 0.2098 0.1715 0.1906 0.1140 0.2290 0.1906 0.1619 0.2290 0.2481 0.2290 0.2481 0.2290 0.1715 0.1906 0.2002 0.2481
# draw_ori_perturbed(original_data, perturbed_data,
# [(49, 64), (75, 90)],
# '') # 可以添加多个虚线框
# draw_ori_perturbed_with_inset(original_data, perturbed_data, 220, 230)
# draw_ori_perturbed(original_data, perturbed_data)
# 种群倍数
# car_data = [1, 1, 1, 1, 1]
# ecg200_data = [10.2, 9.2, 8.3, 10, 8.7]
# italy_data = [6.7, 5.8, 4.2, 6.2, 6.2]
# car_data = [0.0031, 0.003, 0.00315, 0.0032, 0.00325]
# ecg200_data = [0.0066, 0.00655, 0.00652, 0.0068, 0.0067]
# italy_data = [0.0024, 0.00245, 0.0024, 0.0026, 0.0027]
# draw_times_diff_bar(car_data, ecg200_data, italy_data, 'Car', 'ECG200', 'ItalyPowerDemand')
# 大表MSE
# labels = [
# "Car", "Chlorine", "CinCECGTorso", "Earthquakes", "ECG200",
# "ECG5000", "ECGFiveDays", "FordA", "FordB", "InsectWngSnd",
# "ItalyPowerDemand", "Lightning2", "Lightning7", "MoteStrain", "NonIFECGTho1",
# "NonIFECGTho2", "Phoneme", "Plane", "SonyAIBOSurf1", "SonyAIBOSurf2",
# "StarLightCurves", "Trace", "TwoLeadECG", "Wafer", "AllGestureWiimoteX",
# "AllGestureWiimoteY", "AllGestureWiimoteZ", "FreezerRegularTrain", "FreezerSmallTrain", "PickupGestureZ",
# "PigAirwayPressure", "PigArtPressure", "PigCVP", "ShakeGestureZ", "Fungi",
# "GesturePebbleZ1", "GesturePebbleZ2", "DodgerLoopDay", "DodgerLoopWeekend", "DodgerLoopGame",
# "EOGHorizontalSignal", "EOGVerticalSignal"
# ]
#
# GATN = [
# 0.1130, 0.2690, 0.0490, 0.1220, 0.1380,
# 0.1530, 0.0830, 0.1130, 0.1190, 0.1580,
# 0.0800, 0.1330, 0.1320, 0.1430, 0.0840,
# 0.0870, 0.0750, 0.2730, 0.1680, 0.1790,
# 0.0640, 0.0780, 0.1080, 0.1330, 0.1200,
# 0.1220, 0.1090, 0.1490, 0.1310, 0.1350,
# 0.0400, 0.0380, 0.0410, 0.1220, 0.2050,
# 0.1290, 0.1300, 0.0950, 0.1370, 0.0260,
# 0.0580, 0.0640
# ]
#
# advGAN = [
# 0.246, 0.1223, 0.138, -1, 0.1257,
# 0.1796, 0.0331, 0.2281, 0.2466, 0.1306,
# 0.2473, 0.2102, 0.1856, 0.2194, 0.1587,
# 0.2434, 0.1778, 0.1774, 0.1924, 0.2033,
# 0.1479, 0.0406, 0.1169, 0.0743, 0.2397,
# 0.1447, 0.1936, 0.0533, 0.0624, 0.1253,
# -1, -1, -1, 0.1421, 0.1388,
# 0.124, 0.1386, -1, -1, -1,
# -1, -1
# ]
#
# TS_adv4 = [
# 0.0013, 0.0111, 0.017, 0.0052, 0.0079,
# 0.0062, 0.0139, 0.0067, 0.0079, 0.0063,
# 0.0036, 0.0383, 0.0138, 0.0048, 0.0085,
# 0.0049, 0.0048, 0.0068, 0.0063, 0.0073,
# 0.0022, 0.0024, 0.0034, 0.0031, 0.0201,
# 0.0096, 0.0019, 0.0012, 0.001, 0.006,
# 0.0022, 0.0014, 0.0052, 0.0087, 0.0086,
# 0.0345, 0.0204, 0.003, 0.003, 0.0044,
# 0.0023, 0.0043
# ]
#
# CP_adv4 = [
# 0.0029, 0.0042, 0.0179, 0.0042, 0.0065,
# 0.0085, 0.0092, 0.0091, 0.0094, 0.0064,
# 0.0042, 0.0208, 0.0143, 0.004, 0.0112,
# 0.0097, 0.0186, 0.0055, 0.0051, 0.0069,
# 0.0032, 0.0013, 0.0039, 0.0032, 0.0203,
# 0.0126, 0.0057, 0.0011, 0.0017, 0.005,
# 0.0045, 0.0044, 0.0058, 0.0097, 0.0066,
# 0.0281, 0.0226, 0.0036, 0.0038, 0.0057,
# 0.0043, 0.0051
# ]
#
# CP_adv2 = [
# 0.0008, 0.0011, 0.0046, 0.0011, 0.0018,
# 0.0022, 0.0027, 0.0023, 0.0024, 0.0016,
# 0.0008, 0.006, 0.0038, 0.0011, 0.0029,
# 0.0026, 0.0046, 0.0006, 0.0014, 0.0016,
# 0.0008, -1, 0.0012, 0.0008, 0.0052,
# 0.0031, 0.0014, 0.0003, 0.0004, 0.0013,
# 0.0011, 0.0011, 0.0013, 0.0023, 0.0022,
# 0.0067, 0.0053, 0.0009, 0.001, 0.0014,
# 0.0011, 0.0013
# ]
#
# draw_all_data_set_info(GATN, advGAN, TS_adv4, CP_adv4, CP_adv2, labels)
# 画不同扰乱因子下mse/ani对比图
# 扰动因子和成功率
# data1 = [52, 81, 81, 81, 81, 81] # Car
# # data2 = [1, 8, 8, 19, 18, 29] # ItalyPowerDemand
# data3 = [3, 12, 17, 42, 59, 77] # SonyAIBORobotSurface2
# data4 = [0, 4, 65, 78, 78, 78] # Plane
# data5 = [17, 22, 82, 87, 88, 88] # ECG200
# data6 = [23, 43, 81.5, 100, 100, 100] # Ligning7
# labels = ['0.01', '0.02', '0.04', '0.06', '0.08', '0.1'] # x轴扰乱因子标记
#
# draw_diff_pert_info(data1, data3, data4, data5, data6, labels,marker='o',linewidth=3)
# 热力图
# data1 = generate_heat_maps_array(count_classes_samples('result_0.08_6_f/Lightning7', select_colum='Success'),
# classes_size=7)
#
# data2 = generate_heat_maps_array(count_classes_samples('result_0.08_8_f/Plane', select_colum='Success'),
# classes_size=7)
#
# heat_maps_of_classes(data1, data2, '', cbar_ticks1=[2,4,6,8,10], cbar_ticks2=[3,6,9,12,15,18])
pass