CCPP/utils_pack/draw_line_compare.py
Qihang Zhang 1cd9841c1a feat(utils_pack): 新增时间序列对比绘图脚本和数据集绘图脚本
- 新增 draw_line_compare.py 支持原始与扰动时间序列对比及局部放大
- 支持空格和逗号分隔格式数据加载,命令行参数灵活配置
- 新增 draw_all_dataset.py 用于绘制Excel中CPadv与CPsoadv数据对比折线图
- 设置全局字体为 Times New Roman,保证图表美观一致
- 新增 line_compare_command.txt 提供绘图脚本示例命令
- 删除无用的 print_results_path.py 文件,清理项目冗余代码
2025-04-20 23:00:31 +08:00

153 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import argparse
from matplotlib.patches import Rectangle
import matplotlib.transforms as transforms
# 设置全局字体为Times New Roman
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['mathtext.fontset'] = 'cm' # 确保数学公式也使用合适的字体
def load_data(file_path):
"""加载数据,支持空格分隔和逗号分隔"""
try:
# 尝试以空格分隔符读取
df = pd.read_csv(file_path, header=None, sep=r'\s+')
print(f"已使用空格分隔符读取文件: {file_path}")
return df
except Exception as e:
print(f"以空格分隔符读取失败,尝试以逗号分隔符读取: {e}")
try:
# 尝试以逗号分隔符读取
df = pd.read_csv(file_path, header=None)
print(f"已使用逗号分隔符读取文件: {file_path}")
return df
except Exception as e:
print(f"读取文件失败: {e}")
raise
def plot_time_series(original_file, perturbed_file, sequence_id, beta, zoom_start=None, zoom_end=None, sub_offset=0.1):
"""绘制指定序号的原始和扰动时间序列对比图,并添加局部放大图
参数:
original_file -- 原始时间序列文件路径
perturbed_file -- 扰动时间序列文件路径
sequence_id -- 要绘制的序列ID
beta -- beta参数值
zoom_start -- 放大区域的起始索引如果为None则不添加放大图
zoom_end -- 放大区域的结束索引如果为None则不添加放大图
"""
# 加载数据
original_data = load_data(original_file)
perturbed_data = load_data(perturbed_file)
# 确保序号是整数
sequence_id = int(sequence_id)
# 在两个文件中查找指定序号的行
original_row = original_data[original_data[0] == sequence_id]
perturbed_row = perturbed_data[perturbed_data[0] == sequence_id]
print(f"数据中的唯一序号: {original_data[0].unique()}")
if original_row.empty or perturbed_row.empty:
print(f"序号 {sequence_id} 在一个或两个文件中不存在!")
return
# 提取时间序列数据从第3列开始
original_series = original_row.iloc[0, 2:].values
perturbed_series = perturbed_row.iloc[0, 2:].values
# 创建x轴数据点
x = np.arange(len(original_series))
# 设置图形大小和布局
fig = plt.figure(figsize=(8, 5))
# 创建主图
ax_main = plt.subplot(111)
ax_main.plot(x, original_series, label='origin', color='blue', linewidth=2)
ax_main.plot(x, perturbed_series, label='perturbed', color='red', linewidth=2)
ax_main.legend(prop={'family': 'Times New Roman', 'size': 16})
ax_main.set_xlabel(r'$\beta=$' + f'{beta}', fontsize=24)
ax_main.grid(True, linestyle='--', alpha=0.7)
# 增大主图坐标轴刻度字体
ax_main.tick_params(axis='both', labelsize=20)
# 如果提供了放大区域,添加子图和矩形框
if zoom_start is not None and zoom_end is not None:
# 确保缩放范围有效
zoom_start = max(0, min(zoom_start, len(original_series)-1))
zoom_end = max(zoom_start+1, min(zoom_end, len(original_series)))
# 计算放大区域的y轴范围
min_y = min(min(original_series[zoom_start:zoom_end]), min(perturbed_series[zoom_start:zoom_end]))
max_y = max(max(original_series[zoom_start:zoom_end]), max(perturbed_series[zoom_start:zoom_end]))
y_margin = (max_y - min_y) * 0.1 # 10% 的边距
# 添加矩形框表示放大区域
rect = Rectangle((zoom_start, min_y - y_margin),
zoom_end - zoom_start,
(max_y - min_y) + 2 * y_margin,
fill=False, edgecolor='gray', linestyle='--', linewidth=1.5)
ax_main.add_patch(rect)
# 创建放大子图 - 调整位置到上方,缩小尺寸
# 对主图对象进行调整
box = ax_main.get_position()
ax_main.set_position([box.x0, box.y0, box.width, box.height * 0.9])
# 放大图放在上方,位置和大小调整
left = sub_offset # 右侧位置
bottom = 0.65 # 上方位置
width = 0.35 # 缩小宽度
height = 0.30 # 缩小高度
# 创建子图
ax_zoom = fig.add_axes([left, bottom, width, height])
ax_zoom.plot(x[zoom_start:zoom_end], original_series[zoom_start:zoom_end], color='blue', linewidth=2)
ax_zoom.plot(x[zoom_start:zoom_end], perturbed_series[zoom_start:zoom_end], color='red', linewidth=2)
ax_zoom.grid(True, linestyle='--', alpha=0.7)
# 隐藏子图的坐标轴刻度数字
ax_zoom.set_xticklabels([])
ax_zoom.set_yticklabels([])
# 调整子图显示范围,添加一点边距
ax_zoom.set_xlim(zoom_start, zoom_end)
ax_zoom.set_ylim(min_y - y_margin, max_y + y_margin)
plt.tight_layout()
# 保存图片
output_filename = f'time_series_comparison_{beta}_{sequence_id}.pdf'
plt.savefig(output_filename)
print(f"图像已保存为 {output_filename}")
# 显示图像
plt.show()
# python draw_line_compare.py --original /Users/catb/Library/CloudStorage/CloudMounter-B40-4/home/BJTU/project/CPadv/CCPP实验结果/factor/ECG200/0_02/ori_time_series0.txt --perturbed /Users/catb/Library/CloudStorage/CloudMounter-B40-4/home/BJTU/project/CPadv/CCPP实验结果/factor/ECG200/0_02/attack_time_series0.txt --id 2 --beta 0.02 --zoom-start 60 --zoom-end 80
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='绘制时间序列对比图')
parser.add_argument('--original', type=str, required=True, help='原始时间序列文件路径')
parser.add_argument('--perturbed', type=str, required=True, help='攻击/扰动时间序列文件路径')
parser.add_argument('--id', type=int, required=True, help='要绘制的时间序列序号')
parser.add_argument('--beta', type=float, default=0.01, help='beta参数值')
parser.add_argument('--zoom-start', type=int, help='放大区域的起始索引')
parser.add_argument('--zoom-end', type=int, help='放大区域的结束索引')
parser.add_argument('--sub_offset', type=float, help='放大子图向右偏移量')
args = parser.parse_args()
plot_time_series(
args.original,
args.perturbed,
args.id,
args.beta,
args.zoom_start,
args.zoom_end,
args.sub_offset
)