CCPP/utils_pack/draw_line_compare.py

153 lines
6.5 KiB
Python
Raw Normal View History

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import argparse
from matplotlib.patches import Rectangle
import matplotlib.transforms as transforms
# 设置全局字体为Times New Roman
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['mathtext.fontset'] = 'cm' # 确保数学公式也使用合适的字体
def load_data(file_path):
"""加载数据,支持空格分隔和逗号分隔"""
try:
# 尝试以空格分隔符读取
df = pd.read_csv(file_path, header=None, sep=r'\s+')
print(f"已使用空格分隔符读取文件: {file_path}")
return df
except Exception as e:
print(f"以空格分隔符读取失败,尝试以逗号分隔符读取: {e}")
try:
# 尝试以逗号分隔符读取
df = pd.read_csv(file_path, header=None)
print(f"已使用逗号分隔符读取文件: {file_path}")
return df
except Exception as e:
print(f"读取文件失败: {e}")
raise
def plot_time_series(original_file, perturbed_file, sequence_id, beta, zoom_start=None, zoom_end=None, sub_offset=0.1):
"""绘制指定序号的原始和扰动时间序列对比图,并添加局部放大图
参数:
original_file -- 原始时间序列文件路径
perturbed_file -- 扰动时间序列文件路径
sequence_id -- 要绘制的序列ID
beta -- beta参数值
zoom_start -- 放大区域的起始索引如果为None则不添加放大图
zoom_end -- 放大区域的结束索引如果为None则不添加放大图
"""
# 加载数据
original_data = load_data(original_file)
perturbed_data = load_data(perturbed_file)
# 确保序号是整数
sequence_id = int(sequence_id)
# 在两个文件中查找指定序号的行
original_row = original_data[original_data[0] == sequence_id]
perturbed_row = perturbed_data[perturbed_data[0] == sequence_id]
print(f"数据中的唯一序号: {original_data[0].unique()}")
if original_row.empty or perturbed_row.empty:
print(f"序号 {sequence_id} 在一个或两个文件中不存在!")
return
# 提取时间序列数据从第3列开始
original_series = original_row.iloc[0, 2:].values
perturbed_series = perturbed_row.iloc[0, 2:].values
# 创建x轴数据点
x = np.arange(len(original_series))
# 设置图形大小和布局
fig = plt.figure(figsize=(8, 5))
# 创建主图
ax_main = plt.subplot(111)
ax_main.plot(x, original_series, label='origin', color='blue', linewidth=2)
ax_main.plot(x, perturbed_series, label='perturbed', color='red', linewidth=2)
ax_main.legend(prop={'family': 'Times New Roman', 'size': 16})
ax_main.set_xlabel(r'$\beta=$' + f'{beta}', fontsize=24)
ax_main.grid(True, linestyle='--', alpha=0.7)
# 增大主图坐标轴刻度字体
ax_main.tick_params(axis='both', labelsize=20)
# 如果提供了放大区域,添加子图和矩形框
if zoom_start is not None and zoom_end is not None:
# 确保缩放范围有效
zoom_start = max(0, min(zoom_start, len(original_series)-1))
zoom_end = max(zoom_start+1, min(zoom_end, len(original_series)))
# 计算放大区域的y轴范围
min_y = min(min(original_series[zoom_start:zoom_end]), min(perturbed_series[zoom_start:zoom_end]))
max_y = max(max(original_series[zoom_start:zoom_end]), max(perturbed_series[zoom_start:zoom_end]))
y_margin = (max_y - min_y) * 0.1 # 10% 的边距
# 添加矩形框表示放大区域
rect = Rectangle((zoom_start, min_y - y_margin),
zoom_end - zoom_start,
(max_y - min_y) + 2 * y_margin,
fill=False, edgecolor='gray', linestyle='--', linewidth=1.5)
ax_main.add_patch(rect)
# 创建放大子图 - 调整位置到上方,缩小尺寸
# 对主图对象进行调整
box = ax_main.get_position()
ax_main.set_position([box.x0, box.y0, box.width, box.height * 0.9])
# 放大图放在上方,位置和大小调整
left = sub_offset # 右侧位置
bottom = 0.65 # 上方位置
width = 0.35 # 缩小宽度
height = 0.30 # 缩小高度
# 创建子图
ax_zoom = fig.add_axes([left, bottom, width, height])
ax_zoom.plot(x[zoom_start:zoom_end], original_series[zoom_start:zoom_end], color='blue', linewidth=2)
ax_zoom.plot(x[zoom_start:zoom_end], perturbed_series[zoom_start:zoom_end], color='red', linewidth=2)
ax_zoom.grid(True, linestyle='--', alpha=0.7)
# 隐藏子图的坐标轴刻度数字
ax_zoom.set_xticklabels([])
ax_zoom.set_yticklabels([])
# 调整子图显示范围,添加一点边距
ax_zoom.set_xlim(zoom_start, zoom_end)
ax_zoom.set_ylim(min_y - y_margin, max_y + y_margin)
plt.tight_layout()
# 保存图片
output_filename = f'time_series_comparison_{beta}_{sequence_id}.pdf'
plt.savefig(output_filename)
print(f"图像已保存为 {output_filename}")
# 显示图像
plt.show()
# python draw_line_compare.py --original /Users/catb/Library/CloudStorage/CloudMounter-B40-4/home/BJTU/project/CPadv/CCPP实验结果/factor/ECG200/0_02/ori_time_series0.txt --perturbed /Users/catb/Library/CloudStorage/CloudMounter-B40-4/home/BJTU/project/CPadv/CCPP实验结果/factor/ECG200/0_02/attack_time_series0.txt --id 2 --beta 0.02 --zoom-start 60 --zoom-end 80
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='绘制时间序列对比图')
parser.add_argument('--original', type=str, required=True, help='原始时间序列文件路径')
parser.add_argument('--perturbed', type=str, required=True, help='攻击/扰动时间序列文件路径')
parser.add_argument('--id', type=int, required=True, help='要绘制的时间序列序号')
parser.add_argument('--beta', type=float, default=0.01, help='beta参数值')
parser.add_argument('--zoom-start', type=int, help='放大区域的起始索引')
parser.add_argument('--zoom-end', type=int, help='放大区域的结束索引')
parser.add_argument('--sub_offset', type=float, help='放大子图向右偏移量')
args = parser.parse_args()
plot_time_series(
args.original,
args.perturbed,
args.id,
args.beta,
args.zoom_start,
args.zoom_end,
args.sub_offset
)