♻️ refactor(attacker): 添加随机变点测试代码
- 新增随机变点生成逻辑用于测试 - 通过断言确保随机变点数量与原算法一致 - 方便后续调试和对比不同变点策略效果
This commit is contained in:
parent
1cd9841c1a
commit
bf480237e0
25
attacker.py
25
attacker.py
@ -349,6 +349,31 @@ class Attacker:
|
||||
bounds.append((-1 * perturbed_magnitude, perturbed_magnitude))
|
||||
|
||||
print('The length of shapelet interval', steps_count)
|
||||
|
||||
if False: # 需要时手动修改为True,用于测试随机变点
|
||||
# 生成随机变点
|
||||
# 假设ori_ts的长度可以从shape[0]获取
|
||||
length = ori_ts.shape[0]
|
||||
|
||||
# 获取与steps_count相同数量的随机变点位置
|
||||
random_indices = np.random.choice(length, size=int(steps_count), replace=False)
|
||||
|
||||
random_pos = np.zeros(length, dtype=int)
|
||||
random_pos[random_indices] = 1
|
||||
|
||||
# 用于存储随机变点的扰动边界
|
||||
random_bounds = []
|
||||
for i in range(len(random_pos)):
|
||||
if random_pos[i] == 1:
|
||||
random_bounds.append((-1 * perturbed_magnitude, perturbed_magnitude))
|
||||
|
||||
print('The length of random interval', random_pos.sum())
|
||||
|
||||
# 验证随机变点数量与原算法相同
|
||||
assert random_pos.sum() == steps_count
|
||||
# 覆盖原有的变点
|
||||
attack_pos = random_pos
|
||||
|
||||
#print('The length of bounds', len(bounds))
|
||||
popmul = max(1, popsize // len(bounds))
|
||||
# Record of the number of iterations
|
||||
|
Loading…
Reference in New Issue
Block a user