- 新增 draw_line_compare.py 支持原始与扰动时间序列对比及局部放大 - 支持空格和逗号分隔格式数据加载,命令行参数灵活配置 - 新增 draw_all_dataset.py 用于绘制Excel中CPadv与CPsoadv数据对比折线图 - 设置全局字体为 Times New Roman,保证图表美观一致 - 新增 line_compare_command.txt 提供绘图脚本示例命令 - 删除无用的 print_results_path.py 文件,清理项目冗余代码
153 lines
6.5 KiB
Python
153 lines
6.5 KiB
Python
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import argparse
|
||
from matplotlib.patches import Rectangle
|
||
import matplotlib.transforms as transforms
|
||
|
||
# 设置全局字体为Times New Roman
|
||
plt.rcParams['font.family'] = 'Times New Roman'
|
||
plt.rcParams['mathtext.fontset'] = 'cm' # 确保数学公式也使用合适的字体
|
||
|
||
def load_data(file_path):
|
||
"""加载数据,支持空格分隔和逗号分隔"""
|
||
try:
|
||
# 尝试以空格分隔符读取
|
||
df = pd.read_csv(file_path, header=None, sep=r'\s+')
|
||
print(f"已使用空格分隔符读取文件: {file_path}")
|
||
return df
|
||
except Exception as e:
|
||
print(f"以空格分隔符读取失败,尝试以逗号分隔符读取: {e}")
|
||
try:
|
||
# 尝试以逗号分隔符读取
|
||
df = pd.read_csv(file_path, header=None)
|
||
print(f"已使用逗号分隔符读取文件: {file_path}")
|
||
return df
|
||
except Exception as e:
|
||
print(f"读取文件失败: {e}")
|
||
raise
|
||
|
||
def plot_time_series(original_file, perturbed_file, sequence_id, beta, zoom_start=None, zoom_end=None, sub_offset=0.1):
|
||
"""绘制指定序号的原始和扰动时间序列对比图,并添加局部放大图
|
||
|
||
参数:
|
||
original_file -- 原始时间序列文件路径
|
||
perturbed_file -- 扰动时间序列文件路径
|
||
sequence_id -- 要绘制的序列ID
|
||
beta -- beta参数值
|
||
zoom_start -- 放大区域的起始索引(如果为None则不添加放大图)
|
||
zoom_end -- 放大区域的结束索引(如果为None则不添加放大图)
|
||
"""
|
||
# 加载数据
|
||
original_data = load_data(original_file)
|
||
perturbed_data = load_data(perturbed_file)
|
||
|
||
# 确保序号是整数
|
||
sequence_id = int(sequence_id)
|
||
|
||
# 在两个文件中查找指定序号的行
|
||
original_row = original_data[original_data[0] == sequence_id]
|
||
perturbed_row = perturbed_data[perturbed_data[0] == sequence_id]
|
||
|
||
print(f"数据中的唯一序号: {original_data[0].unique()}")
|
||
|
||
if original_row.empty or perturbed_row.empty:
|
||
print(f"序号 {sequence_id} 在一个或两个文件中不存在!")
|
||
return
|
||
|
||
# 提取时间序列数据(从第3列开始)
|
||
original_series = original_row.iloc[0, 2:].values
|
||
perturbed_series = perturbed_row.iloc[0, 2:].values
|
||
|
||
# 创建x轴数据点
|
||
x = np.arange(len(original_series))
|
||
|
||
# 设置图形大小和布局
|
||
fig = plt.figure(figsize=(8, 5))
|
||
|
||
# 创建主图
|
||
ax_main = plt.subplot(111)
|
||
ax_main.plot(x, original_series, label='origin', color='blue', linewidth=2)
|
||
ax_main.plot(x, perturbed_series, label='perturbed', color='red', linewidth=2)
|
||
ax_main.legend(prop={'family': 'Times New Roman', 'size': 16})
|
||
ax_main.set_xlabel(r'$\beta=$' + f'{beta}', fontsize=24)
|
||
ax_main.grid(True, linestyle='--', alpha=0.7)
|
||
|
||
# 增大主图坐标轴刻度字体
|
||
ax_main.tick_params(axis='both', labelsize=20)
|
||
|
||
# 如果提供了放大区域,添加子图和矩形框
|
||
if zoom_start is not None and zoom_end is not None:
|
||
# 确保缩放范围有效
|
||
zoom_start = max(0, min(zoom_start, len(original_series)-1))
|
||
zoom_end = max(zoom_start+1, min(zoom_end, len(original_series)))
|
||
|
||
# 计算放大区域的y轴范围
|
||
min_y = min(min(original_series[zoom_start:zoom_end]), min(perturbed_series[zoom_start:zoom_end]))
|
||
max_y = max(max(original_series[zoom_start:zoom_end]), max(perturbed_series[zoom_start:zoom_end]))
|
||
y_margin = (max_y - min_y) * 0.1 # 10% 的边距
|
||
|
||
# 添加矩形框表示放大区域
|
||
rect = Rectangle((zoom_start, min_y - y_margin),
|
||
zoom_end - zoom_start,
|
||
(max_y - min_y) + 2 * y_margin,
|
||
fill=False, edgecolor='gray', linestyle='--', linewidth=1.5)
|
||
ax_main.add_patch(rect)
|
||
|
||
# 创建放大子图 - 调整位置到上方,缩小尺寸
|
||
# 对主图对象进行调整
|
||
box = ax_main.get_position()
|
||
ax_main.set_position([box.x0, box.y0, box.width, box.height * 0.9])
|
||
|
||
# 放大图放在上方,位置和大小调整
|
||
left = sub_offset # 右侧位置
|
||
bottom = 0.65 # 上方位置
|
||
width = 0.35 # 缩小宽度
|
||
height = 0.30 # 缩小高度
|
||
|
||
# 创建子图
|
||
ax_zoom = fig.add_axes([left, bottom, width, height])
|
||
ax_zoom.plot(x[zoom_start:zoom_end], original_series[zoom_start:zoom_end], color='blue', linewidth=2)
|
||
ax_zoom.plot(x[zoom_start:zoom_end], perturbed_series[zoom_start:zoom_end], color='red', linewidth=2)
|
||
ax_zoom.grid(True, linestyle='--', alpha=0.7)
|
||
|
||
# 隐藏子图的坐标轴刻度数字
|
||
ax_zoom.set_xticklabels([])
|
||
ax_zoom.set_yticklabels([])
|
||
|
||
# 调整子图显示范围,添加一点边距
|
||
ax_zoom.set_xlim(zoom_start, zoom_end)
|
||
ax_zoom.set_ylim(min_y - y_margin, max_y + y_margin)
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存图片
|
||
output_filename = f'time_series_comparison_{beta}_{sequence_id}.pdf'
|
||
plt.savefig(output_filename)
|
||
print(f"图像已保存为 {output_filename}")
|
||
|
||
# 显示图像
|
||
plt.show()
|
||
|
||
# python draw_line_compare.py --original /Users/catb/Library/CloudStorage/CloudMounter-B40-4/home/BJTU/project/CPadv/CCPP实验结果/factor/ECG200/0_02/ori_time_series0.txt --perturbed /Users/catb/Library/CloudStorage/CloudMounter-B40-4/home/BJTU/project/CPadv/CCPP实验结果/factor/ECG200/0_02/attack_time_series0.txt --id 2 --beta 0.02 --zoom-start 60 --zoom-end 80
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description='绘制时间序列对比图')
|
||
parser.add_argument('--original', type=str, required=True, help='原始时间序列文件路径')
|
||
parser.add_argument('--perturbed', type=str, required=True, help='攻击/扰动时间序列文件路径')
|
||
parser.add_argument('--id', type=int, required=True, help='要绘制的时间序列序号')
|
||
parser.add_argument('--beta', type=float, default=0.01, help='beta参数值')
|
||
parser.add_argument('--zoom-start', type=int, help='放大区域的起始索引')
|
||
parser.add_argument('--zoom-end', type=int, help='放大区域的结束索引')
|
||
parser.add_argument('--sub_offset', type=float, help='放大子图向右偏移量')
|
||
|
||
args = parser.parse_args()
|
||
|
||
plot_time_series(
|
||
args.original,
|
||
args.perturbed,
|
||
args.id,
|
||
args.beta,
|
||
args.zoom_start,
|
||
args.zoom_end,
|
||
args.sub_offset
|
||
) |