♻️ refactor(attacker): 添加随机变点测试代码
- 新增随机变点生成逻辑用于测试 - 通过断言确保随机变点数量与原算法一致 - 方便后续调试和对比不同变点策略效果
This commit is contained in:
parent
1cd9841c1a
commit
bf480237e0
25
attacker.py
25
attacker.py
@ -349,6 +349,31 @@ class Attacker:
|
|||||||
bounds.append((-1 * perturbed_magnitude, perturbed_magnitude))
|
bounds.append((-1 * perturbed_magnitude, perturbed_magnitude))
|
||||||
|
|
||||||
print('The length of shapelet interval', steps_count)
|
print('The length of shapelet interval', steps_count)
|
||||||
|
|
||||||
|
if False: # 需要时手动修改为True,用于测试随机变点
|
||||||
|
# 生成随机变点
|
||||||
|
# 假设ori_ts的长度可以从shape[0]获取
|
||||||
|
length = ori_ts.shape[0]
|
||||||
|
|
||||||
|
# 获取与steps_count相同数量的随机变点位置
|
||||||
|
random_indices = np.random.choice(length, size=int(steps_count), replace=False)
|
||||||
|
|
||||||
|
random_pos = np.zeros(length, dtype=int)
|
||||||
|
random_pos[random_indices] = 1
|
||||||
|
|
||||||
|
# 用于存储随机变点的扰动边界
|
||||||
|
random_bounds = []
|
||||||
|
for i in range(len(random_pos)):
|
||||||
|
if random_pos[i] == 1:
|
||||||
|
random_bounds.append((-1 * perturbed_magnitude, perturbed_magnitude))
|
||||||
|
|
||||||
|
print('The length of random interval', random_pos.sum())
|
||||||
|
|
||||||
|
# 验证随机变点数量与原算法相同
|
||||||
|
assert random_pos.sum() == steps_count
|
||||||
|
# 覆盖原有的变点
|
||||||
|
attack_pos = random_pos
|
||||||
|
|
||||||
#print('The length of bounds', len(bounds))
|
#print('The length of bounds', len(bounds))
|
||||||
popmul = max(1, popsize // len(bounds))
|
popmul = max(1, popsize // len(bounds))
|
||||||
# Record of the number of iterations
|
# Record of the number of iterations
|
||||||
|
Loading…
Reference in New Issue
Block a user