🗑️ remove(get_kpl_list.py): 删除涨停板数据获取旧脚本
✨ feat(main_force_strategy.py): 新增主力资金流向策略分析功能 🔧 refactor(utils.py): 增加数据库操作支持和辅助函数
This commit is contained in:
parent
e7d2dc3b64
commit
b652322061
343
get_kpl_list.py
343
get_kpl_list.py
@ -1,343 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import openpyxl
|
|
||||||
import pandas as pd
|
|
||||||
from openpyxl.styles import PatternFill
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from utils import load_config, get_trade_cal
|
|
||||||
|
|
||||||
# 加载配置并初始化tushare
|
|
||||||
config = load_config()
|
|
||||||
import tushare as ts
|
|
||||||
|
|
||||||
ts.set_token(config['tushare_token'])
|
|
||||||
pro = ts.pro_api()
|
|
||||||
|
|
||||||
|
|
||||||
def get_kpl_data(start_date=None, end_date=None):
|
|
||||||
"""
|
|
||||||
获取指定时间段内的打板数据
|
|
||||||
参数:
|
|
||||||
start_date (str): 开始日期,格式'YYYYMMDD'
|
|
||||||
end_date (str): 结束日期,格式'YYYYMMDD'
|
|
||||||
返回:
|
|
||||||
pandas.DataFrame: 所有打板数据
|
|
||||||
"""
|
|
||||||
# 获取目标交易日历
|
|
||||||
all_trade_dates = get_trade_cal(start_date, end_date)
|
|
||||||
|
|
||||||
# 检查是否已有现有数据
|
|
||||||
existing_data = pd.DataFrame()
|
|
||||||
existing_dates = set()
|
|
||||||
output_file = 'ori_kpl_list.xlsx'
|
|
||||||
|
|
||||||
if os.path.exists(output_file):
|
|
||||||
try:
|
|
||||||
print(f"检测到已有数据文件: {output_file}")
|
|
||||||
existing_data = pd.read_excel(output_file)
|
|
||||||
if not existing_data.empty and 'trade_date' in existing_data.columns:
|
|
||||||
# 确保trade_date是字符串类型
|
|
||||||
existing_data['trade_date'] = existing_data['trade_date'].astype(str)
|
|
||||||
# 提取已有数据的交易日期
|
|
||||||
existing_dates = set(existing_data['trade_date'].astype(str).unique())
|
|
||||||
print(f"已有数据包含 {len(existing_dates)} 个交易日")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"读取现有数据时出错: {e}")
|
|
||||||
|
|
||||||
# 确定需要获取的日期
|
|
||||||
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
|
|
||||||
|
|
||||||
if not dates_to_fetch:
|
|
||||||
print("所有数据均已存在,无需更新")
|
|
||||||
return existing_data
|
|
||||||
|
|
||||||
print(f"需要获取 {len(dates_to_fetch)} 个新交易日的数据")
|
|
||||||
|
|
||||||
# 获取新的打板数据
|
|
||||||
new_data = []
|
|
||||||
for trade_date in tqdm(dates_to_fetch):
|
|
||||||
try:
|
|
||||||
# 不指定字段参数,获取所有返回的字段
|
|
||||||
df = pro.kpl_list(trade_date=trade_date, tag='涨停')
|
|
||||||
if not df.empty:
|
|
||||||
# 确保新数据的trade_date也是字符串类型
|
|
||||||
if 'trade_date' in df.columns:
|
|
||||||
df['trade_date'] = df['trade_date'].astype(str)
|
|
||||||
new_data.append(df)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取 {trade_date} 数据时出错: {e}")
|
|
||||||
time.sleep(1) # 出错时稍微多等待一下
|
|
||||||
|
|
||||||
# 合并所有数据
|
|
||||||
if new_data:
|
|
||||||
new_result = pd.concat(new_data, ignore_index=True)
|
|
||||||
print(f"成功获取 {len(new_data)} 个交易日的新数据,共 {len(new_result)} 条记录")
|
|
||||||
|
|
||||||
# 合并新旧数据
|
|
||||||
if not existing_data.empty:
|
|
||||||
result = pd.concat([existing_data, new_result], ignore_index=True)
|
|
||||||
print(f"合并后共有 {len(result)} 条记录")
|
|
||||||
else:
|
|
||||||
result = new_result
|
|
||||||
|
|
||||||
# 进行一次去重操作,以防万一
|
|
||||||
if 'ts_code' in result.columns and 'trade_date' in result.columns:
|
|
||||||
result = result.drop_duplicates(subset=['ts_code', 'trade_date'], keep='last')
|
|
||||||
print(f"去重后共有 {len(result)} 条记录")
|
|
||||||
|
|
||||||
# 确保trade_date是字符串类型后再排序
|
|
||||||
if 'trade_date' in result.columns:
|
|
||||||
result['trade_date'] = result['trade_date'].astype(str)
|
|
||||||
result = result.sort_values(by='trade_date', ascending=False) # 降序排列,最新的数据在前
|
|
||||||
print("数据已按交易日期排序")
|
|
||||||
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
print("未获取到任何新数据")
|
|
||||||
return existing_data
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_kpl_data():
|
|
||||||
"""
|
|
||||||
分析涨停板数据,统计每日各板块涨停数量,并使用热力图风格展示
|
|
||||||
按照最近100个工作日的涨停总数对板块进行排序
|
|
||||||
删除100个工作日内没有涨停记录的板块
|
|
||||||
"""
|
|
||||||
print("开始分析涨停板数据...")
|
|
||||||
|
|
||||||
# 1. 从原始文件读取数据
|
|
||||||
try:
|
|
||||||
ori_data = pd.read_excel('ori_kpl_list.xlsx')
|
|
||||||
print(f"成功读取原始数据,共 {len(ori_data)} 条记录")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"读取数据失败: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 确保日期字段是字符串类型
|
|
||||||
ori_data['trade_date'] = ori_data['trade_date'].astype(str)
|
|
||||||
|
|
||||||
# 获取所有唯一的交易日期和板块
|
|
||||||
all_dates = sorted(ori_data['trade_date'].unique(), reverse=True) # 降序排列日期
|
|
||||||
|
|
||||||
if 'lu_desc' not in ori_data.columns:
|
|
||||||
print("错误: 原始数据中没有板块信息字段 'lu_desc'")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取所有唯一的板块
|
|
||||||
all_sectors = ori_data['lu_desc'].dropna().unique()
|
|
||||||
print(f"数据包含 {len(all_dates)} 个交易日和 {len(all_sectors)} 个板块")
|
|
||||||
|
|
||||||
# 创建一个包含所有日期和板块的DataFrame用于统计
|
|
||||||
temp_result = pd.DataFrame(0, index=all_dates, columns=all_sectors)
|
|
||||||
|
|
||||||
# 按日期分组统计
|
|
||||||
for date in all_dates:
|
|
||||||
# 获取当日数据
|
|
||||||
daily_data = ori_data[ori_data['trade_date'] == date]
|
|
||||||
|
|
||||||
# 统计各板块涨停数量
|
|
||||||
sector_counts = daily_data.groupby('lu_desc').size()
|
|
||||||
|
|
||||||
# 更新临时结果DataFrame
|
|
||||||
for sector, count in sector_counts.items():
|
|
||||||
if sector in temp_result.columns:
|
|
||||||
temp_result.loc[date, sector] = count
|
|
||||||
|
|
||||||
# 计算最近100个工作日(或所有可用天数)的各板块涨停总数
|
|
||||||
recent_days = min(100, len(all_dates))
|
|
||||||
recent_dates = all_dates[:recent_days]
|
|
||||||
|
|
||||||
# 计算这些日期内每个板块的涨停总数
|
|
||||||
sector_totals = temp_result.loc[recent_dates].sum()
|
|
||||||
|
|
||||||
# 筛选出在100个工作日内有涨停记录的板块
|
|
||||||
active_sectors = sector_totals[sector_totals > 0].index.tolist()
|
|
||||||
|
|
||||||
# 按照涨停总数对活跃板块进行排序
|
|
||||||
sorted_sectors = sector_totals[active_sectors].sort_values(ascending=False).index.tolist()
|
|
||||||
|
|
||||||
print(f"已按最近{recent_days}个工作日的涨停总数对板块排序")
|
|
||||||
print(
|
|
||||||
f"共保留了{len(sorted_sectors)}个有涨停记录的板块,删除了{len(all_sectors) - len(sorted_sectors)}个无涨停记录的板块")
|
|
||||||
|
|
||||||
if sorted_sectors:
|
|
||||||
print("涨停数量前10的板块:")
|
|
||||||
for i, sector in enumerate(sorted_sectors[:min(10, len(sorted_sectors))], 1):
|
|
||||||
print(f"{i}. {sector}: {sector_totals[sector]}只")
|
|
||||||
|
|
||||||
# 如果没有活跃板块,提前返回
|
|
||||||
if not sorted_sectors:
|
|
||||||
print("警告: 在指定时间段内没有板块有涨停记录")
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 创建最终结果DataFrame,只使用有涨停记录的排序后的板块
|
|
||||||
result = pd.DataFrame("", index=all_dates, columns=sorted_sectors)
|
|
||||||
|
|
||||||
# 填充数据,只填入非零值
|
|
||||||
for date in all_dates:
|
|
||||||
for sector in sorted_sectors:
|
|
||||||
count = temp_result.loc[date, sector]
|
|
||||||
if count > 0:
|
|
||||||
result.loc[date, sector] = count
|
|
||||||
|
|
||||||
# 保存结果到新的Excel文件
|
|
||||||
output_file = 'sector_limit_up_analysis.xlsx'
|
|
||||||
result.to_excel(output_file)
|
|
||||||
|
|
||||||
# 创建热力图色阶函数:从浅红色到深红色(FFFF0000)
|
|
||||||
def get_heatmap_color(value):
|
|
||||||
try:
|
|
||||||
value = int(value)
|
|
||||||
# 将值限制在0-20范围内
|
|
||||||
value = min(max(value, 0), 20)
|
|
||||||
|
|
||||||
# 计算颜色深度 - 值越大颜色越深
|
|
||||||
# 红色固定为FF,绿色和蓝色从FF(浅)递减到00(深)
|
|
||||||
intensity = int(255 - (value / 20 * 255))
|
|
||||||
intensity_hex = format(intensity, '02X')
|
|
||||||
|
|
||||||
# 构建颜色代码: 红色固定为FF,绿色和蓝色根据值变化
|
|
||||||
color_code = f"FF{intensity_hex}{intensity_hex}"
|
|
||||||
|
|
||||||
return color_code
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 使用openpyxl添加热力图风格
|
|
||||||
print("正在添加热力图样式...")
|
|
||||||
workbook = openpyxl.load_workbook(output_file)
|
|
||||||
worksheet = workbook.active
|
|
||||||
|
|
||||||
# 遍历所有数据单元格
|
|
||||||
for row in range(2, worksheet.max_row + 1): # 跳过标题行
|
|
||||||
for col in range(2, worksheet.max_column + 1): # 跳过索引列
|
|
||||||
cell = worksheet.cell(row=row, column=col)
|
|
||||||
if cell.value and str(cell.value).strip(): # 只处理非空单元格
|
|
||||||
# 获取相应的热力图颜色
|
|
||||||
color_code = get_heatmap_color(cell.value)
|
|
||||||
if color_code:
|
|
||||||
# 应用背景色
|
|
||||||
cell.fill = PatternFill(start_color=color_code, end_color=color_code, fill_type="solid")
|
|
||||||
|
|
||||||
# 保存格式化后的Excel
|
|
||||||
workbook.save(output_file)
|
|
||||||
|
|
||||||
print(f"分析完成,结果已保存到 {output_file}")
|
|
||||||
print(f"统计了 {len(result.columns)} 个活跃板块的涨停数据")
|
|
||||||
print(f"已使用红色热力图标记涨停数量:0-20对应从浅红到深红")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_sector_moneyflow_data(start_date=None, end_date=None):
|
|
||||||
"""
|
|
||||||
获取指定时间段内的板块资金流向数据
|
|
||||||
|
|
||||||
参数:
|
|
||||||
start_date (str): 开始日期,格式'YYYYMMDD'
|
|
||||||
end_date (str): 结束日期,格式'YYYYMMDD'
|
|
||||||
|
|
||||||
返回:
|
|
||||||
pandas.DataFrame: 所有板块资金流向数据
|
|
||||||
"""
|
|
||||||
# 获取目标交易日历
|
|
||||||
all_trade_dates = get_trade_cal(start_date, end_date)
|
|
||||||
|
|
||||||
# 检查是否已有现有数据
|
|
||||||
existing_data = pd.DataFrame()
|
|
||||||
existing_dates = set()
|
|
||||||
output_file = 'ori_sector_moneyflow.xlsx'
|
|
||||||
|
|
||||||
if os.path.exists(output_file):
|
|
||||||
try:
|
|
||||||
print(f"检测到已有数据文件: {output_file}")
|
|
||||||
existing_data = pd.read_excel(output_file)
|
|
||||||
if not existing_data.empty and 'trade_date' in existing_data.columns:
|
|
||||||
# 确保trade_date是字符串类型
|
|
||||||
existing_data['trade_date'] = existing_data['trade_date'].astype(str)
|
|
||||||
# 提取已有数据的交易日期
|
|
||||||
existing_dates = set(existing_data['trade_date'].astype(str).unique())
|
|
||||||
print(f"已有数据包含 {len(existing_dates)} 个交易日")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"读取现有数据时出错: {e}")
|
|
||||||
|
|
||||||
# 确定需要获取的日期
|
|
||||||
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
|
|
||||||
if not dates_to_fetch:
|
|
||||||
print("所有数据均已存在,无需更新")
|
|
||||||
return existing_data
|
|
||||||
|
|
||||||
print(f"需要获取 {len(dates_to_fetch)} 个新交易日的数据")
|
|
||||||
|
|
||||||
# 获取新的板块资金流向数据
|
|
||||||
new_data = []
|
|
||||||
|
|
||||||
for trade_date in tqdm(dates_to_fetch):
|
|
||||||
try:
|
|
||||||
# 获取当日所有板块资金流向
|
|
||||||
df = pro.moneyflow_ind_dc(trade_date=trade_date)
|
|
||||||
if not df.empty:
|
|
||||||
# 确保新数据的trade_date也是字符串类型
|
|
||||||
if 'trade_date' in df.columns:
|
|
||||||
df['trade_date'] = df['trade_date'].astype(str)
|
|
||||||
new_data.append(df)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取 {trade_date} 数据时出错: {e}")
|
|
||||||
time.sleep(1) # 出错时稍微多等待一下
|
|
||||||
|
|
||||||
# 合并所有数据
|
|
||||||
if new_data:
|
|
||||||
new_result = pd.concat(new_data, ignore_index=True)
|
|
||||||
print(f"成功获取 {len(new_data)} 个交易日的新数据,共 {len(new_result)} 条记录")
|
|
||||||
|
|
||||||
# 合并新旧数据
|
|
||||||
if not existing_data.empty:
|
|
||||||
result = pd.concat([existing_data, new_result], ignore_index=True)
|
|
||||||
print(f"合并后共有 {len(result)} 条记录")
|
|
||||||
else:
|
|
||||||
result = new_result
|
|
||||||
|
|
||||||
# 进行一次去重操作,以防万一
|
|
||||||
if 'name' in result.columns and 'trade_date' in result.columns:
|
|
||||||
result = result.drop_duplicates(subset=['name', 'trade_date'], keep='last')
|
|
||||||
print(f"去重后共有 {len(result)} 条记录")
|
|
||||||
|
|
||||||
# 确保trade_date是字符串类型后再排序
|
|
||||||
if 'trade_date' in result.columns:
|
|
||||||
result['trade_date'] = result['trade_date'].astype(str)
|
|
||||||
result = result.sort_values(by='trade_date', ascending=False) # 降序排列,最新的数据在前
|
|
||||||
print("数据已按交易日期排序")
|
|
||||||
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
print("未获取到任何新数据")
|
|
||||||
return existing_data
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 指定日期范围
|
|
||||||
start_date = '20250101'
|
|
||||||
end_date = None
|
|
||||||
|
|
||||||
# 获取打板数据
|
|
||||||
kpl_data = get_kpl_data(start_date, end_date)
|
|
||||||
# 保存到Excel
|
|
||||||
if not kpl_data.empty:
|
|
||||||
kpl_data.to_excel('ori_kpl_list.xlsx', index=False)
|
|
||||||
print(f"数据已保存到 ori_kpl_list.xlsx, 共 {len(kpl_data)} 条记录")
|
|
||||||
else:
|
|
||||||
print("没有数据可保存")
|
|
||||||
|
|
||||||
# 获取板块资金流向数据
|
|
||||||
sector_moneyflow_data = get_sector_moneyflow_data(start_date, end_date)
|
|
||||||
# 保存到Excel
|
|
||||||
if not sector_moneyflow_data.empty:
|
|
||||||
sector_moneyflow_data.to_excel('ori_sector_moneyflow.xlsx', index=False)
|
|
||||||
print(f"资金流向数据已保存到 ori_sector_moneyflow.xlsx, 共 {len(sector_moneyflow_data)} 条记录")
|
|
||||||
else:
|
|
||||||
print("没有资金流向数据可保存")
|
|
||||||
|
|
||||||
# 执行分析
|
|
||||||
# analyze_kpl_data()
|
|
313
main_force_strategy.py
Normal file
313
main_force_strategy.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from utils import load_config, get_trade_cal
|
||||||
|
from utils import save_df_to_db, load_df_from_db, get_existing_trade_dates
|
||||||
|
|
||||||
|
# 加载配置并初始化tushare
|
||||||
|
config = load_config()
|
||||||
|
import tushare as ts
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
ts.set_token(config['tushare_token'])
|
||||||
|
pro = ts.pro_api()
|
||||||
|
|
||||||
|
|
||||||
|
def get_sector_moneyflow_data(start_date=None, end_date=None):
|
||||||
|
"""
|
||||||
|
获取指定时间段内的板块资金流向数据,使用数据库缓存
|
||||||
|
|
||||||
|
参数:
|
||||||
|
start_date (str): 开始日期,格式'YYYYMMDD'
|
||||||
|
end_date (str): 结束日期,格式'YYYYMMDD'
|
||||||
|
|
||||||
|
返回:
|
||||||
|
pandas.DataFrame: 所有板块资金流向数据
|
||||||
|
"""
|
||||||
|
# 获取目标交易日历
|
||||||
|
all_trade_dates = get_trade_cal(start_date, end_date)
|
||||||
|
|
||||||
|
# 从数据库获取已有的交易日期
|
||||||
|
existing_dates = get_existing_trade_dates('sector_fund_flow')
|
||||||
|
|
||||||
|
# 筛选出需要新获取的日期
|
||||||
|
new_dates = [date for date in all_trade_dates if date not in existing_dates]
|
||||||
|
|
||||||
|
if not new_dates:
|
||||||
|
print("所有数据已在数据库中,无需更新")
|
||||||
|
return load_df_from_db('sector_fund_flow')
|
||||||
|
|
||||||
|
print(f"需要获取 {len(new_dates)} 个新交易日的数据")
|
||||||
|
|
||||||
|
# 获取新日期的数据
|
||||||
|
all_new_data = []
|
||||||
|
|
||||||
|
# 使用tqdm显示进度
|
||||||
|
for trade_date in tqdm(new_dates):
|
||||||
|
try:
|
||||||
|
# 从tushare获取当日板块资金流向数据
|
||||||
|
df = pro.moneyflow_ind_dc(trade_date=trade_date)
|
||||||
|
|
||||||
|
# 如果有数据,添加到列表
|
||||||
|
if not df.empty:
|
||||||
|
# 计算主力资金 = 超大单买入 + 大单买入
|
||||||
|
df['main_force_amount'] = df['buy_elg_amount'] + df['buy_lg_amount']
|
||||||
|
all_new_data.append(df)
|
||||||
|
else:
|
||||||
|
print(f"日期 {trade_date} 无数据")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取 {trade_date} 的数据时出错: {e}")
|
||||||
|
|
||||||
|
# 如果有新数据,合并并保存到数据库
|
||||||
|
if all_new_data:
|
||||||
|
# 将所有新数据合并为一个DataFrame
|
||||||
|
new_df = pd.concat(all_new_data, ignore_index=True)
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
save_df_to_db(new_df, table_name='sector_fund_flow', if_exists='append')
|
||||||
|
|
||||||
|
print(f"已将 {len(new_df)} 条新记录保存到数据库")
|
||||||
|
else:
|
||||||
|
print("未获取到任何新数据")
|
||||||
|
|
||||||
|
return load_df_from_db('sector_fund_flow')
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_money_flow():
|
||||||
|
"""
|
||||||
|
分析各类资金流向指标对行业在随后1-10天表现的影响
|
||||||
|
包括期望收益分析和特定交易策略验证
|
||||||
|
"""
|
||||||
|
# 读取资金流数据
|
||||||
|
try:
|
||||||
|
df = load_df_from_db('sector_fund_flow')
|
||||||
|
print(f"成功从数据库加载资金流数据,共计{len(df)}条记录")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"从数据库读取数据失败:{e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 将日期格式转换为datetime - 如果存储在数据库中的是字符串格式
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d', errors='coerce')
|
||||||
|
df = df[~df['trade_date'].isna()]
|
||||||
|
# 按日期排序
|
||||||
|
df = df.sort_values('trade_date')
|
||||||
|
# 获取所有交易日期
|
||||||
|
all_dates = df['trade_date'].unique()
|
||||||
|
|
||||||
|
# 定义要分析的资金流指标
|
||||||
|
# 格式: (指标名, 排序方向, 关联性)
|
||||||
|
# 关联性: 正相关=1, 负相关=-1 (用于确定是取最高还是最低)
|
||||||
|
flow_indicators = [
|
||||||
|
('main_force_amount', 1, '主力净额')
|
||||||
|
]
|
||||||
|
|
||||||
|
# 确保结果目录存在
|
||||||
|
os.makedirs('result', exist_ok=True)
|
||||||
|
|
||||||
|
# 为每个指标进行分析
|
||||||
|
for indicator, correlation, indicator_name in flow_indicators:
|
||||||
|
print(f"\n\n分析 {indicator_name} 与未来指数关系...")
|
||||||
|
|
||||||
|
# 创建结果数据结构
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# 遍历每个交易日期(除了最后10天)
|
||||||
|
for i in range(len(all_dates) - 10):
|
||||||
|
current_date = all_dates[i]
|
||||||
|
# 获取当前日期的数据
|
||||||
|
current_day_data = df[df['trade_date'] == current_date]
|
||||||
|
|
||||||
|
# 确定排序方向和选择逻辑
|
||||||
|
sort_ascending = correlation < 0 # 负相关时升序(最小值), 正相关时降序(最大值)
|
||||||
|
|
||||||
|
# 找出该指标排名靠前的行业
|
||||||
|
if correlation > 0:
|
||||||
|
# 正相关,找最高值
|
||||||
|
top_sectors = current_day_data.sort_values(indicator, ascending=False).head(1)['name'].tolist()
|
||||||
|
else:
|
||||||
|
# 负相关,找最低值
|
||||||
|
top_sectors = current_day_data.sort_values(indicator, ascending=True).head(1)['name'].tolist()
|
||||||
|
|
||||||
|
# 分析每个行业在随后1-10天的表现
|
||||||
|
for sector in top_sectors:
|
||||||
|
# 获取该行业当天的指数变化和指标值
|
||||||
|
sector_current = current_day_data[current_day_data['name'] == sector]
|
||||||
|
if sector_current.empty:
|
||||||
|
continue
|
||||||
|
current_pct_change = sector_current['pct_change'].values[0]
|
||||||
|
current_indicator_value = sector_current[indicator].values[0]
|
||||||
|
|
||||||
|
# 分析随后1-10天的表现
|
||||||
|
future_changes = []
|
||||||
|
for day_offset in range(1, 11):
|
||||||
|
if i + day_offset < len(all_dates):
|
||||||
|
future_date = all_dates[i + day_offset]
|
||||||
|
future_data = df[(df['trade_date'] == future_date) & (df['name'] == sector)]
|
||||||
|
if not future_data.empty:
|
||||||
|
future_changes.append(future_data['pct_change'].values[0])
|
||||||
|
else:
|
||||||
|
future_changes.append(None)
|
||||||
|
else:
|
||||||
|
future_changes.append(None)
|
||||||
|
|
||||||
|
# 如果至少有一个未来日期有数据
|
||||||
|
if any(x is not None for x in future_changes):
|
||||||
|
result_entry = {
|
||||||
|
'date': current_date.strftime('%Y%m%d'), # 将日期格式化为YYYYMMDD字符串
|
||||||
|
'sector': sector,
|
||||||
|
f'{indicator}': current_indicator_value,
|
||||||
|
'current_pct_change': current_pct_change,
|
||||||
|
}
|
||||||
|
# 添加1-10天的变化
|
||||||
|
for day in range(1, 11):
|
||||||
|
result_entry[f'day{day}_change'] = future_changes[day - 1]
|
||||||
|
# 计算平均变化
|
||||||
|
result_entry['avg_10day_change'] = np.nanmean([x for x in future_changes if x is not None])
|
||||||
|
results.append(result_entry)
|
||||||
|
|
||||||
|
# 转换为DataFrame
|
||||||
|
results_df = pd.DataFrame(results)
|
||||||
|
if results_df.empty:
|
||||||
|
print(f"没有足够的数据来分析{indicator_name}与后续表现的关系")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
output_file = f'result/{indicator}_performance.xlsx'
|
||||||
|
results_df.to_excel(output_file, index=False)
|
||||||
|
print(f"{indicator_name}表现分析已保存至{output_file}")
|
||||||
|
|
||||||
|
# 分析整体表现
|
||||||
|
avg_performance = {}
|
||||||
|
for day in range(1, 11):
|
||||||
|
avg_performance[f'day{day}'] = results_df[f'day{day}_change'].mean()
|
||||||
|
avg_performance['avg_10day'] = results_df['avg_10day_change'].mean()
|
||||||
|
|
||||||
|
print(f"\n{indicator_name}极值行业的平均表现:")
|
||||||
|
for day, perf in avg_performance.items():
|
||||||
|
print(f"{day}: {perf:.4f}%")
|
||||||
|
|
||||||
|
# 分析期望值(正指数变化的百分比)
|
||||||
|
success_rates = {}
|
||||||
|
for day in range(1, 11):
|
||||||
|
success_rates[f'day{day}'] = (results_df[f'day{day}_change'] > 0).mean() * 100
|
||||||
|
|
||||||
|
print(f"\n{indicator_name}极值后上涨的概率:")
|
||||||
|
for day, rate in success_rates.items():
|
||||||
|
print(f"{day}: {rate:.2f}%")
|
||||||
|
|
||||||
|
# ------------------ 验证特定交易策略 ------------------
|
||||||
|
print("\n交易策略验证:")
|
||||||
|
|
||||||
|
# T+1买入,T+2卖出
|
||||||
|
day1_to_day2_change = results_df['day2_change'] - results_df['day1_change']
|
||||||
|
avg_change_1_to_2 = day1_to_day2_change.mean()
|
||||||
|
win_rate_1_to_2 = (day1_to_day2_change > 0).mean() * 100
|
||||||
|
print(f"策略A - T+1(第1日)买入,T+2(第2日)卖出的平均收益: {avg_change_1_to_2:.4f}%")
|
||||||
|
print(f"策略A - T+1(第1日)买入,T+2(第2日)卖出的盈利概率: {win_rate_1_to_2:.2f}%")
|
||||||
|
|
||||||
|
# T+4买入,T+8卖出
|
||||||
|
day4_to_day8_change = results_df['day8_change'] + results_df['day7_change'] + results_df['day6_change'] + \
|
||||||
|
results_df['day5_change'] - results_df['day4_change']
|
||||||
|
avg_change_3_to_8 = day4_to_day8_change.mean()
|
||||||
|
win_rate_3_to_8 = (day4_to_day8_change > 0).mean() * 100
|
||||||
|
print(f"策略B - T+4(第4日)买入,T+8(第8日)卖出的平均收益: {avg_change_3_to_8:.4f}%")
|
||||||
|
print(f"策略B - T+4(第4日)买入,T+8(第8日)卖出的盈利概率: {win_rate_3_to_8:.2f}%")
|
||||||
|
|
||||||
|
# 分析策略组合效果
|
||||||
|
# 模拟完整策略:T+1买入,T+2卖出,T+3买入,T+8卖出
|
||||||
|
combined_change = day1_to_day2_change + day4_to_day8_change
|
||||||
|
avg_combined_change = combined_change.mean()
|
||||||
|
win_rate_combined = (combined_change > 0).mean() * 100
|
||||||
|
print(f"组合策略 - 完整策略组合的平均总收益: {avg_combined_change:.4f}%")
|
||||||
|
print(f"组合策略 - 完整策略至少盈利的概率: {win_rate_combined:.2f}%")
|
||||||
|
|
||||||
|
# 绘制策略示意图
|
||||||
|
plt.figure(figsize=(14, 8))
|
||||||
|
try:
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']
|
||||||
|
title = f'{indicator_name}极值后交易策略示意图'
|
||||||
|
xlabel = '交易日'
|
||||||
|
ylabel = '指数变化率 (%)'
|
||||||
|
strategy_names = ['策略A: T+1买入,T+2卖出', '策略B: T+4买入,T+8卖出']
|
||||||
|
except:
|
||||||
|
title = f'Trading Strategy after {indicator_name} Extreme Value'
|
||||||
|
xlabel = 'Trading Day'
|
||||||
|
ylabel = 'Index Change Rate (%)'
|
||||||
|
strategy_names = ['Strategy A: Buy T+1, Sell T+2', 'Strategy B: Buy T+3, Sell T+8']
|
||||||
|
|
||||||
|
days = range(11) # 0-10天
|
||||||
|
values = [results_df['current_pct_change'].mean()] + [avg_performance[f'day{i}'] for i in range(1, 11)]
|
||||||
|
plt.plot(days, values, marker='o', color='blue', linewidth=2, label='平均表现')
|
||||||
|
|
||||||
|
# 标记策略A: T+1买入,T+2卖出
|
||||||
|
plt.plot([1, 2], [values[1], values[2]], color='green', linewidth=4, alpha=0.7, label=strategy_names[0])
|
||||||
|
plt.scatter([1, 2], [values[1], values[2]], color='green', s=100)
|
||||||
|
|
||||||
|
# 标记策略B: T+4买入,T+8卖出
|
||||||
|
plt.plot([3, 8], [values[3], values[8]], color='red', linewidth=4, alpha=0.7, label=strategy_names[1])
|
||||||
|
plt.scatter([3, 8], [values[3], values[8]], color='red', s=100)
|
||||||
|
|
||||||
|
plt.axhline(y=0, color='gray', linestyle='--')
|
||||||
|
plt.title(title, fontsize=14)
|
||||||
|
plt.ylabel(ylabel)
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.xticks(days, ['T'] + [f'T+{i}' for i in range(1, 11)])
|
||||||
|
plt.grid(True)
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
# 保存策略图表
|
||||||
|
strategy_image = f'result/{indicator}_strategy.png'
|
||||||
|
plt.savefig(strategy_image, dpi=300, bbox_inches='tight')
|
||||||
|
print(f"交易策略示意图已保存至{strategy_image}")
|
||||||
|
|
||||||
|
# 绘制折线图显示未来10天的平均表现
|
||||||
|
plt.figure(figsize=(14, 8))
|
||||||
|
# 设置中文字体
|
||||||
|
try:
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||||
|
days = ['当天'] + [f'第{i}天' for i in range(1, 11)]
|
||||||
|
direction = "最高" if correlation > 0 else "最低"
|
||||||
|
title = f'{indicator_name}{direction}后的平均表现 (10天)'
|
||||||
|
xlabel = '时间'
|
||||||
|
ylabel = '指数变化率 (%)'
|
||||||
|
except:
|
||||||
|
# 如果没有中文字体,使用英文
|
||||||
|
days = ['Current'] + [f'Day+{i}' for i in range(1, 11)]
|
||||||
|
direction = "Highest" if correlation > 0 else "Lowest"
|
||||||
|
title = f'Average Performance After {indicator_name} {direction} (10 Days)'
|
||||||
|
xlabel = 'Time'
|
||||||
|
ylabel = 'Index Change Rate (%)'
|
||||||
|
|
||||||
|
values = [results_df['current_pct_change'].mean()] + [avg_performance[f'day{i}'] for i in range(1, 11)]
|
||||||
|
plt.plot(days, values, marker='o', linewidth=2)
|
||||||
|
plt.axhline(y=0, color='r', linestyle='--')
|
||||||
|
plt.title(title, fontsize=14)
|
||||||
|
plt.ylabel(ylabel)
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.xticks(rotation=45) # 旋转x轴标签以避免重叠
|
||||||
|
|
||||||
|
# 保存图表
|
||||||
|
output_image = f'result/{indicator}_performance.png'
|
||||||
|
plt.savefig(output_image, dpi=300, bbox_inches='tight') # 添加bbox_inches参数确保所有标签都显示
|
||||||
|
print(f"{indicator_name}表现图表已保存至{output_image}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 指定日期范围
|
||||||
|
start_date = '20230912'
|
||||||
|
end_date = None
|
||||||
|
|
||||||
|
# 获取板块资金流向数据
|
||||||
|
get_sector_moneyflow_data(start_date, end_date)
|
||||||
|
|
||||||
|
analyze_money_flow()
|
97
utils.py
97
utils.py
@ -1,13 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import yaml
|
|
||||||
import tushare as ts
|
|
||||||
import pandas as pd
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from sqlalchemy import create_engine
|
|
||||||
|
import pandas as pd
|
||||||
|
import tushare as ts
|
||||||
|
import yaml
|
||||||
|
from sqlalchemy import create_engine, Column, Integer, String, Float, text
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
# 模块级单例
|
# 模块级单例
|
||||||
_config = None
|
_config = None
|
||||||
_engine = None
|
_engine = None
|
||||||
|
_Session = None
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
@ -43,7 +48,6 @@ def load_config():
|
|||||||
|
|
||||||
return _config
|
return _config
|
||||||
|
|
||||||
|
|
||||||
def get_engine():
|
def get_engine():
|
||||||
"""获取单例数据库引擎"""
|
"""获取单例数据库引擎"""
|
||||||
global _engine
|
global _engine
|
||||||
@ -74,3 +78,86 @@ def get_trade_cal(start_date=None, end_date=None):
|
|||||||
pro = ts.pro_api()
|
pro = ts.pro_api()
|
||||||
trade_cal_df = pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
|
trade_cal_df = pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
|
||||||
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
|
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
|
||||||
|
|
||||||
|
def get_db_engine():
|
||||||
|
"""获取SQLite数据库引擎,如果不存在则创建"""
|
||||||
|
global _engine
|
||||||
|
if _engine is not None:
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
db_path = config.get('db_path', 'data/market_data.db')
|
||||||
|
|
||||||
|
# 确保数据库目录存在
|
||||||
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||||
|
|
||||||
|
# 创建SQLite数据库引擎
|
||||||
|
_engine = create_engine(f'sqlite:///{db_path}', echo=False)
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
def get_session():
|
||||||
|
"""获取数据库会话"""
|
||||||
|
global _Session
|
||||||
|
if _Session is None:
|
||||||
|
_Session = sessionmaker(bind=get_db_engine())
|
||||||
|
return _Session()
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
"""初始化数据库表结构"""
|
||||||
|
Base.metadata.create_all(get_db_engine())
|
||||||
|
|
||||||
|
def get_existing_trade_dates(table_name):
|
||||||
|
"""
|
||||||
|
从数据库中获取已有的交易日期
|
||||||
|
|
||||||
|
参数:
|
||||||
|
table_name (str): 数据表名称
|
||||||
|
|
||||||
|
返回:
|
||||||
|
set: 已存在于数据库中的交易日期集合
|
||||||
|
"""
|
||||||
|
engine = get_db_engine()
|
||||||
|
query = f"SELECT DISTINCT trade_date FROM {table_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# SQLAlchemy 2.0+ 版本的执行方式
|
||||||
|
from sqlalchemy import text
|
||||||
|
with engine.connect() as connection:
|
||||||
|
result = connection.execute(text(query))
|
||||||
|
return {row[0] for row in result}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取已存在交易日期时出错: {e}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def load_df_from_db(table_name, conditions=None):
|
||||||
|
"""
|
||||||
|
从数据库中加载数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
table_name (str): 表名
|
||||||
|
conditions (str): 过滤条件,如 "trade_date > '20230101'"
|
||||||
|
|
||||||
|
返回:
|
||||||
|
pandas.DataFrame: 查询结果
|
||||||
|
"""
|
||||||
|
engine = get_db_engine()
|
||||||
|
query = f"SELECT * FROM {table_name}"
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
query += f" WHERE {conditions}"
|
||||||
|
|
||||||
|
return pd.read_sql(query, engine)
|
||||||
|
|
||||||
|
|
||||||
|
def save_df_to_db(df, table_name, if_exists='append'):
|
||||||
|
"""
|
||||||
|
保存DataFrame到数据库
|
||||||
|
|
||||||
|
参数:
|
||||||
|
df (pandas.DataFrame): 要保存的数据
|
||||||
|
table_name (str): 表名
|
||||||
|
if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append'
|
||||||
|
"""
|
||||||
|
engine = get_db_engine()
|
||||||
|
df.to_sql(table_name, engine, if_exists=if_exists, index=False)
|
Loading…
Reference in New Issue
Block a user