From b652322061a90915981c242e49382ffcbf8ac490 Mon Sep 17 00:00:00 2001 From: Qihang Zhang Date: Sat, 19 Apr 2025 01:41:15 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=97=91=EF=B8=8F=20remove(get=5Fkpl=5Flist?= =?UTF-8?q?.py):=20=E5=88=A0=E9=99=A4=E6=B6=A8=E5=81=9C=E6=9D=BF=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=8E=B7=E5=8F=96=E6=97=A7=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✨ feat(main_force_strategy.py): 新增主力资金流向策略分析功能 🔧 refactor(utils.py): 增加数据库操作支持和辅助函数 --- get_kpl_list.py | 343 ----------------------------------------- main_force_strategy.py | 313 +++++++++++++++++++++++++++++++++++++ utils.py | 99 +++++++++++- 3 files changed, 406 insertions(+), 349 deletions(-) delete mode 100644 get_kpl_list.py create mode 100644 main_force_strategy.py diff --git a/get_kpl_list.py b/get_kpl_list.py deleted file mode 100644 index 35b408f..0000000 --- a/get_kpl_list.py +++ /dev/null @@ -1,343 +0,0 @@ -import os -import time - -import openpyxl -import pandas as pd -from openpyxl.styles import PatternFill -from tqdm import tqdm - -from utils import load_config, get_trade_cal - -# 加载配置并初始化tushare -config = load_config() -import tushare as ts - -ts.set_token(config['tushare_token']) -pro = ts.pro_api() - - -def get_kpl_data(start_date=None, end_date=None): - """ - 获取指定时间段内的打板数据 - 参数: - start_date (str): 开始日期,格式'YYYYMMDD' - end_date (str): 结束日期,格式'YYYYMMDD' - 返回: - pandas.DataFrame: 所有打板数据 - """ - # 获取目标交易日历 - all_trade_dates = get_trade_cal(start_date, end_date) - - # 检查是否已有现有数据 - existing_data = pd.DataFrame() - existing_dates = set() - output_file = 'ori_kpl_list.xlsx' - - if os.path.exists(output_file): - try: - print(f"检测到已有数据文件: {output_file}") - existing_data = pd.read_excel(output_file) - if not existing_data.empty and 'trade_date' in existing_data.columns: - # 确保trade_date是字符串类型 - existing_data['trade_date'] = existing_data['trade_date'].astype(str) - # 提取已有数据的交易日期 - existing_dates = set(existing_data['trade_date'].astype(str).unique()) - print(f"已有数据包含 {len(existing_dates)} 个交易日") - except Exception as e: - print(f"读取现有数据时出错: {e}") - - # 确定需要获取的日期 - dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates] - - if not dates_to_fetch: - print("所有数据均已存在,无需更新") - return existing_data - - print(f"需要获取 {len(dates_to_fetch)} 个新交易日的数据") - - # 获取新的打板数据 - new_data = [] - for trade_date in tqdm(dates_to_fetch): - try: - # 不指定字段参数,获取所有返回的字段 - df = pro.kpl_list(trade_date=trade_date, tag='涨停') - if not df.empty: - # 确保新数据的trade_date也是字符串类型 - if 'trade_date' in df.columns: - df['trade_date'] = df['trade_date'].astype(str) - new_data.append(df) - except Exception as e: - print(f"获取 {trade_date} 数据时出错: {e}") - time.sleep(1) # 出错时稍微多等待一下 - - # 合并所有数据 - if new_data: - new_result = pd.concat(new_data, ignore_index=True) - print(f"成功获取 {len(new_data)} 个交易日的新数据,共 {len(new_result)} 条记录") - - # 合并新旧数据 - if not existing_data.empty: - result = pd.concat([existing_data, new_result], ignore_index=True) - print(f"合并后共有 {len(result)} 条记录") - else: - result = new_result - - # 进行一次去重操作,以防万一 - if 'ts_code' in result.columns and 'trade_date' in result.columns: - result = result.drop_duplicates(subset=['ts_code', 'trade_date'], keep='last') - print(f"去重后共有 {len(result)} 条记录") - - # 确保trade_date是字符串类型后再排序 - if 'trade_date' in result.columns: - result['trade_date'] = result['trade_date'].astype(str) - result = result.sort_values(by='trade_date', ascending=False) # 降序排列,最新的数据在前 - print("数据已按交易日期排序") - - return result - else: - print("未获取到任何新数据") - return existing_data - - -def analyze_kpl_data(): - """ - 分析涨停板数据,统计每日各板块涨停数量,并使用热力图风格展示 - 按照最近100个工作日的涨停总数对板块进行排序 - 删除100个工作日内没有涨停记录的板块 - """ - print("开始分析涨停板数据...") - - # 1. 从原始文件读取数据 - try: - ori_data = pd.read_excel('ori_kpl_list.xlsx') - print(f"成功读取原始数据,共 {len(ori_data)} 条记录") - except Exception as e: - print(f"读取数据失败: {e}") - return - - # 确保日期字段是字符串类型 - ori_data['trade_date'] = ori_data['trade_date'].astype(str) - - # 获取所有唯一的交易日期和板块 - all_dates = sorted(ori_data['trade_date'].unique(), reverse=True) # 降序排列日期 - - if 'lu_desc' not in ori_data.columns: - print("错误: 原始数据中没有板块信息字段 'lu_desc'") - return - - # 获取所有唯一的板块 - all_sectors = ori_data['lu_desc'].dropna().unique() - print(f"数据包含 {len(all_dates)} 个交易日和 {len(all_sectors)} 个板块") - - # 创建一个包含所有日期和板块的DataFrame用于统计 - temp_result = pd.DataFrame(0, index=all_dates, columns=all_sectors) - - # 按日期分组统计 - for date in all_dates: - # 获取当日数据 - daily_data = ori_data[ori_data['trade_date'] == date] - - # 统计各板块涨停数量 - sector_counts = daily_data.groupby('lu_desc').size() - - # 更新临时结果DataFrame - for sector, count in sector_counts.items(): - if sector in temp_result.columns: - temp_result.loc[date, sector] = count - - # 计算最近100个工作日(或所有可用天数)的各板块涨停总数 - recent_days = min(100, len(all_dates)) - recent_dates = all_dates[:recent_days] - - # 计算这些日期内每个板块的涨停总数 - sector_totals = temp_result.loc[recent_dates].sum() - - # 筛选出在100个工作日内有涨停记录的板块 - active_sectors = sector_totals[sector_totals > 0].index.tolist() - - # 按照涨停总数对活跃板块进行排序 - sorted_sectors = sector_totals[active_sectors].sort_values(ascending=False).index.tolist() - - print(f"已按最近{recent_days}个工作日的涨停总数对板块排序") - print( - f"共保留了{len(sorted_sectors)}个有涨停记录的板块,删除了{len(all_sectors) - len(sorted_sectors)}个无涨停记录的板块") - - if sorted_sectors: - print("涨停数量前10的板块:") - for i, sector in enumerate(sorted_sectors[:min(10, len(sorted_sectors))], 1): - print(f"{i}. {sector}: {sector_totals[sector]}只") - - # 如果没有活跃板块,提前返回 - if not sorted_sectors: - print("警告: 在指定时间段内没有板块有涨停记录") - return pd.DataFrame() - - # 创建最终结果DataFrame,只使用有涨停记录的排序后的板块 - result = pd.DataFrame("", index=all_dates, columns=sorted_sectors) - - # 填充数据,只填入非零值 - for date in all_dates: - for sector in sorted_sectors: - count = temp_result.loc[date, sector] - if count > 0: - result.loc[date, sector] = count - - # 保存结果到新的Excel文件 - output_file = 'sector_limit_up_analysis.xlsx' - result.to_excel(output_file) - - # 创建热力图色阶函数:从浅红色到深红色(FFFF0000) - def get_heatmap_color(value): - try: - value = int(value) - # 将值限制在0-20范围内 - value = min(max(value, 0), 20) - - # 计算颜色深度 - 值越大颜色越深 - # 红色固定为FF,绿色和蓝色从FF(浅)递减到00(深) - intensity = int(255 - (value / 20 * 255)) - intensity_hex = format(intensity, '02X') - - # 构建颜色代码: 红色固定为FF,绿色和蓝色根据值变化 - color_code = f"FF{intensity_hex}{intensity_hex}" - - return color_code - except: - return None - - # 使用openpyxl添加热力图风格 - print("正在添加热力图样式...") - workbook = openpyxl.load_workbook(output_file) - worksheet = workbook.active - - # 遍历所有数据单元格 - for row in range(2, worksheet.max_row + 1): # 跳过标题行 - for col in range(2, worksheet.max_column + 1): # 跳过索引列 - cell = worksheet.cell(row=row, column=col) - if cell.value and str(cell.value).strip(): # 只处理非空单元格 - # 获取相应的热力图颜色 - color_code = get_heatmap_color(cell.value) - if color_code: - # 应用背景色 - cell.fill = PatternFill(start_color=color_code, end_color=color_code, fill_type="solid") - - # 保存格式化后的Excel - workbook.save(output_file) - - print(f"分析完成,结果已保存到 {output_file}") - print(f"统计了 {len(result.columns)} 个活跃板块的涨停数据") - print(f"已使用红色热力图标记涨停数量:0-20对应从浅红到深红") - - return result - - -def get_sector_moneyflow_data(start_date=None, end_date=None): - """ - 获取指定时间段内的板块资金流向数据 - - 参数: - start_date (str): 开始日期,格式'YYYYMMDD' - end_date (str): 结束日期,格式'YYYYMMDD' - - 返回: - pandas.DataFrame: 所有板块资金流向数据 - """ - # 获取目标交易日历 - all_trade_dates = get_trade_cal(start_date, end_date) - - # 检查是否已有现有数据 - existing_data = pd.DataFrame() - existing_dates = set() - output_file = 'ori_sector_moneyflow.xlsx' - - if os.path.exists(output_file): - try: - print(f"检测到已有数据文件: {output_file}") - existing_data = pd.read_excel(output_file) - if not existing_data.empty and 'trade_date' in existing_data.columns: - # 确保trade_date是字符串类型 - existing_data['trade_date'] = existing_data['trade_date'].astype(str) - # 提取已有数据的交易日期 - existing_dates = set(existing_data['trade_date'].astype(str).unique()) - print(f"已有数据包含 {len(existing_dates)} 个交易日") - except Exception as e: - print(f"读取现有数据时出错: {e}") - - # 确定需要获取的日期 - dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates] - if not dates_to_fetch: - print("所有数据均已存在,无需更新") - return existing_data - - print(f"需要获取 {len(dates_to_fetch)} 个新交易日的数据") - - # 获取新的板块资金流向数据 - new_data = [] - - for trade_date in tqdm(dates_to_fetch): - try: - # 获取当日所有板块资金流向 - df = pro.moneyflow_ind_dc(trade_date=trade_date) - if not df.empty: - # 确保新数据的trade_date也是字符串类型 - if 'trade_date' in df.columns: - df['trade_date'] = df['trade_date'].astype(str) - new_data.append(df) - except Exception as e: - print(f"获取 {trade_date} 数据时出错: {e}") - time.sleep(1) # 出错时稍微多等待一下 - - # 合并所有数据 - if new_data: - new_result = pd.concat(new_data, ignore_index=True) - print(f"成功获取 {len(new_data)} 个交易日的新数据,共 {len(new_result)} 条记录") - - # 合并新旧数据 - if not existing_data.empty: - result = pd.concat([existing_data, new_result], ignore_index=True) - print(f"合并后共有 {len(result)} 条记录") - else: - result = new_result - - # 进行一次去重操作,以防万一 - if 'name' in result.columns and 'trade_date' in result.columns: - result = result.drop_duplicates(subset=['name', 'trade_date'], keep='last') - print(f"去重后共有 {len(result)} 条记录") - - # 确保trade_date是字符串类型后再排序 - if 'trade_date' in result.columns: - result['trade_date'] = result['trade_date'].astype(str) - result = result.sort_values(by='trade_date', ascending=False) # 降序排列,最新的数据在前 - print("数据已按交易日期排序") - - return result - else: - print("未获取到任何新数据") - return existing_data - - -if __name__ == "__main__": - # 指定日期范围 - start_date = '20250101' - end_date = None - - # 获取打板数据 - kpl_data = get_kpl_data(start_date, end_date) - # 保存到Excel - if not kpl_data.empty: - kpl_data.to_excel('ori_kpl_list.xlsx', index=False) - print(f"数据已保存到 ori_kpl_list.xlsx, 共 {len(kpl_data)} 条记录") - else: - print("没有数据可保存") - - # 获取板块资金流向数据 - sector_moneyflow_data = get_sector_moneyflow_data(start_date, end_date) - # 保存到Excel - if not sector_moneyflow_data.empty: - sector_moneyflow_data.to_excel('ori_sector_moneyflow.xlsx', index=False) - print(f"资金流向数据已保存到 ori_sector_moneyflow.xlsx, 共 {len(sector_moneyflow_data)} 条记录") - else: - print("没有资金流向数据可保存") - - # 执行分析 - # analyze_kpl_data() \ No newline at end of file diff --git a/main_force_strategy.py b/main_force_strategy.py new file mode 100644 index 0000000..b4f75f2 --- /dev/null +++ b/main_force_strategy.py @@ -0,0 +1,313 @@ +import os +import time + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from tqdm import tqdm + +from utils import load_config, get_trade_cal +from utils import save_df_to_db, load_df_from_db, get_existing_trade_dates + +# 加载配置并初始化tushare +config = load_config() +import tushare as ts +import seaborn as sns + +ts.set_token(config['tushare_token']) +pro = ts.pro_api() + + +def get_sector_moneyflow_data(start_date=None, end_date=None): + """ + 获取指定时间段内的板块资金流向数据,使用数据库缓存 + + 参数: + start_date (str): 开始日期,格式'YYYYMMDD' + end_date (str): 结束日期,格式'YYYYMMDD' + + 返回: + pandas.DataFrame: 所有板块资金流向数据 + """ + # 获取目标交易日历 + all_trade_dates = get_trade_cal(start_date, end_date) + + # 从数据库获取已有的交易日期 + existing_dates = get_existing_trade_dates('sector_fund_flow') + + # 筛选出需要新获取的日期 + new_dates = [date for date in all_trade_dates if date not in existing_dates] + + if not new_dates: + print("所有数据已在数据库中,无需更新") + return load_df_from_db('sector_fund_flow') + + print(f"需要获取 {len(new_dates)} 个新交易日的数据") + + # 获取新日期的数据 + all_new_data = [] + + # 使用tqdm显示进度 + for trade_date in tqdm(new_dates): + try: + # 从tushare获取当日板块资金流向数据 + df = pro.moneyflow_ind_dc(trade_date=trade_date) + + # 如果有数据,添加到列表 + if not df.empty: + # 计算主力资金 = 超大单买入 + 大单买入 + df['main_force_amount'] = df['buy_elg_amount'] + df['buy_lg_amount'] + all_new_data.append(df) + else: + print(f"日期 {trade_date} 无数据") + + except Exception as e: + print(f"获取 {trade_date} 的数据时出错: {e}") + + # 如果有新数据,合并并保存到数据库 + if all_new_data: + # 将所有新数据合并为一个DataFrame + new_df = pd.concat(all_new_data, ignore_index=True) + + # 保存到数据库 + save_df_to_db(new_df, table_name='sector_fund_flow', if_exists='append') + + print(f"已将 {len(new_df)} 条新记录保存到数据库") + else: + print("未获取到任何新数据") + + return load_df_from_db('sector_fund_flow') + + +def analyze_money_flow(): + """ + 分析各类资金流向指标对行业在随后1-10天表现的影响 + 包括期望收益分析和特定交易策略验证 + """ + # 读取资金流数据 + try: + df = load_df_from_db('sector_fund_flow') + print(f"成功从数据库加载资金流数据,共计{len(df)}条记录") + except Exception as e: + print(f"从数据库读取数据失败:{e}") + return + + # 将日期格式转换为datetime - 如果存储在数据库中的是字符串格式 + df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d', errors='coerce') + df = df[~df['trade_date'].isna()] + # 按日期排序 + df = df.sort_values('trade_date') + # 获取所有交易日期 + all_dates = df['trade_date'].unique() + + # 定义要分析的资金流指标 + # 格式: (指标名, 排序方向, 关联性) + # 关联性: 正相关=1, 负相关=-1 (用于确定是取最高还是最低) + flow_indicators = [ + ('main_force_amount', 1, '主力净额') + ] + + # 确保结果目录存在 + os.makedirs('result', exist_ok=True) + + # 为每个指标进行分析 + for indicator, correlation, indicator_name in flow_indicators: + print(f"\n\n分析 {indicator_name} 与未来指数关系...") + + # 创建结果数据结构 + results = [] + + # 遍历每个交易日期(除了最后10天) + for i in range(len(all_dates) - 10): + current_date = all_dates[i] + # 获取当前日期的数据 + current_day_data = df[df['trade_date'] == current_date] + + # 确定排序方向和选择逻辑 + sort_ascending = correlation < 0 # 负相关时升序(最小值), 正相关时降序(最大值) + + # 找出该指标排名靠前的行业 + if correlation > 0: + # 正相关,找最高值 + top_sectors = current_day_data.sort_values(indicator, ascending=False).head(1)['name'].tolist() + else: + # 负相关,找最低值 + top_sectors = current_day_data.sort_values(indicator, ascending=True).head(1)['name'].tolist() + + # 分析每个行业在随后1-10天的表现 + for sector in top_sectors: + # 获取该行业当天的指数变化和指标值 + sector_current = current_day_data[current_day_data['name'] == sector] + if sector_current.empty: + continue + current_pct_change = sector_current['pct_change'].values[0] + current_indicator_value = sector_current[indicator].values[0] + + # 分析随后1-10天的表现 + future_changes = [] + for day_offset in range(1, 11): + if i + day_offset < len(all_dates): + future_date = all_dates[i + day_offset] + future_data = df[(df['trade_date'] == future_date) & (df['name'] == sector)] + if not future_data.empty: + future_changes.append(future_data['pct_change'].values[0]) + else: + future_changes.append(None) + else: + future_changes.append(None) + + # 如果至少有一个未来日期有数据 + if any(x is not None for x in future_changes): + result_entry = { + 'date': current_date.strftime('%Y%m%d'), # 将日期格式化为YYYYMMDD字符串 + 'sector': sector, + f'{indicator}': current_indicator_value, + 'current_pct_change': current_pct_change, + } + # 添加1-10天的变化 + for day in range(1, 11): + result_entry[f'day{day}_change'] = future_changes[day - 1] + # 计算平均变化 + result_entry['avg_10day_change'] = np.nanmean([x for x in future_changes if x is not None]) + results.append(result_entry) + + # 转换为DataFrame + results_df = pd.DataFrame(results) + if results_df.empty: + print(f"没有足够的数据来分析{indicator_name}与后续表现的关系") + continue + + # 保存结果 + output_file = f'result/{indicator}_performance.xlsx' + results_df.to_excel(output_file, index=False) + print(f"{indicator_name}表现分析已保存至{output_file}") + + # 分析整体表现 + avg_performance = {} + for day in range(1, 11): + avg_performance[f'day{day}'] = results_df[f'day{day}_change'].mean() + avg_performance['avg_10day'] = results_df['avg_10day_change'].mean() + + print(f"\n{indicator_name}极值行业的平均表现:") + for day, perf in avg_performance.items(): + print(f"{day}: {perf:.4f}%") + + # 分析期望值(正指数变化的百分比) + success_rates = {} + for day in range(1, 11): + success_rates[f'day{day}'] = (results_df[f'day{day}_change'] > 0).mean() * 100 + + print(f"\n{indicator_name}极值后上涨的概率:") + for day, rate in success_rates.items(): + print(f"{day}: {rate:.2f}%") + + # ------------------ 验证特定交易策略 ------------------ + print("\n交易策略验证:") + + # T+1买入,T+2卖出 + day1_to_day2_change = results_df['day2_change'] - results_df['day1_change'] + avg_change_1_to_2 = day1_to_day2_change.mean() + win_rate_1_to_2 = (day1_to_day2_change > 0).mean() * 100 + print(f"策略A - T+1(第1日)买入,T+2(第2日)卖出的平均收益: {avg_change_1_to_2:.4f}%") + print(f"策略A - T+1(第1日)买入,T+2(第2日)卖出的盈利概率: {win_rate_1_to_2:.2f}%") + + # T+4买入,T+8卖出 + day4_to_day8_change = results_df['day8_change'] + results_df['day7_change'] + results_df['day6_change'] + \ + results_df['day5_change'] - results_df['day4_change'] + avg_change_3_to_8 = day4_to_day8_change.mean() + win_rate_3_to_8 = (day4_to_day8_change > 0).mean() * 100 + print(f"策略B - T+4(第4日)买入,T+8(第8日)卖出的平均收益: {avg_change_3_to_8:.4f}%") + print(f"策略B - T+4(第4日)买入,T+8(第8日)卖出的盈利概率: {win_rate_3_to_8:.2f}%") + + # 分析策略组合效果 + # 模拟完整策略:T+1买入,T+2卖出,T+3买入,T+8卖出 + combined_change = day1_to_day2_change + day4_to_day8_change + avg_combined_change = combined_change.mean() + win_rate_combined = (combined_change > 0).mean() * 100 + print(f"组合策略 - 完整策略组合的平均总收益: {avg_combined_change:.4f}%") + print(f"组合策略 - 完整策略至少盈利的概率: {win_rate_combined:.2f}%") + + # 绘制策略示意图 + plt.figure(figsize=(14, 8)) + try: + plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei'] + title = f'{indicator_name}极值后交易策略示意图' + xlabel = '交易日' + ylabel = '指数变化率 (%)' + strategy_names = ['策略A: T+1买入,T+2卖出', '策略B: T+4买入,T+8卖出'] + except: + title = f'Trading Strategy after {indicator_name} Extreme Value' + xlabel = 'Trading Day' + ylabel = 'Index Change Rate (%)' + strategy_names = ['Strategy A: Buy T+1, Sell T+2', 'Strategy B: Buy T+3, Sell T+8'] + + days = range(11) # 0-10天 + values = [results_df['current_pct_change'].mean()] + [avg_performance[f'day{i}'] for i in range(1, 11)] + plt.plot(days, values, marker='o', color='blue', linewidth=2, label='平均表现') + + # 标记策略A: T+1买入,T+2卖出 + plt.plot([1, 2], [values[1], values[2]], color='green', linewidth=4, alpha=0.7, label=strategy_names[0]) + plt.scatter([1, 2], [values[1], values[2]], color='green', s=100) + + # 标记策略B: T+4买入,T+8卖出 + plt.plot([3, 8], [values[3], values[8]], color='red', linewidth=4, alpha=0.7, label=strategy_names[1]) + plt.scatter([3, 8], [values[3], values[8]], color='red', s=100) + + plt.axhline(y=0, color='gray', linestyle='--') + plt.title(title, fontsize=14) + plt.ylabel(ylabel) + plt.xlabel(xlabel) + plt.xticks(days, ['T'] + [f'T+{i}' for i in range(1, 11)]) + plt.grid(True) + plt.legend() + + # 保存策略图表 + strategy_image = f'result/{indicator}_strategy.png' + plt.savefig(strategy_image, dpi=300, bbox_inches='tight') + print(f"交易策略示意图已保存至{strategy_image}") + + # 绘制折线图显示未来10天的平均表现 + plt.figure(figsize=(14, 8)) + # 设置中文字体 + try: + plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei'] + plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 + days = ['当天'] + [f'第{i}天' for i in range(1, 11)] + direction = "最高" if correlation > 0 else "最低" + title = f'{indicator_name}{direction}后的平均表现 (10天)' + xlabel = '时间' + ylabel = '指数变化率 (%)' + except: + # 如果没有中文字体,使用英文 + days = ['Current'] + [f'Day+{i}' for i in range(1, 11)] + direction = "Highest" if correlation > 0 else "Lowest" + title = f'Average Performance After {indicator_name} {direction} (10 Days)' + xlabel = 'Time' + ylabel = 'Index Change Rate (%)' + + values = [results_df['current_pct_change'].mean()] + [avg_performance[f'day{i}'] for i in range(1, 11)] + plt.plot(days, values, marker='o', linewidth=2) + plt.axhline(y=0, color='r', linestyle='--') + plt.title(title, fontsize=14) + plt.ylabel(ylabel) + plt.xlabel(xlabel) + plt.grid(True) + plt.xticks(rotation=45) # 旋转x轴标签以避免重叠 + + # 保存图表 + output_image = f'result/{indicator}_performance.png' + plt.savefig(output_image, dpi=300, bbox_inches='tight') # 添加bbox_inches参数确保所有标签都显示 + print(f"{indicator_name}表现图表已保存至{output_image}") + + return True + + +if __name__ == "__main__": + # 指定日期范围 + start_date = '20230912' + end_date = None + + # 获取板块资金流向数据 + get_sector_moneyflow_data(start_date, end_date) + + analyze_money_flow() \ No newline at end of file diff --git a/utils.py b/utils.py index 600a27d..96c51fb 100644 --- a/utils.py +++ b/utils.py @@ -1,13 +1,18 @@ import os -import yaml -import tushare as ts -import pandas as pd from datetime import datetime, timedelta -from sqlalchemy import create_engine + +import pandas as pd +import tushare as ts +import yaml +from sqlalchemy import create_engine, Column, Integer, String, Float, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker # 模块级单例 _config = None _engine = None +_Session = None +Base = declarative_base() def load_config(): @@ -43,7 +48,6 @@ def load_config(): return _config - def get_engine(): """获取单例数据库引擎""" global _engine @@ -73,4 +77,87 @@ def get_trade_cal(start_date=None, end_date=None): pro = ts.pro_api() trade_cal_df = pro.trade_cal(exchange='', start_date=start_date, end_date=end_date) - return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist() \ No newline at end of file + return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist() + +def get_db_engine(): + """获取SQLite数据库引擎,如果不存在则创建""" + global _engine + if _engine is not None: + return _engine + + config = load_config() + db_path = config.get('db_path', 'data/market_data.db') + + # 确保数据库目录存在 + os.makedirs(os.path.dirname(db_path), exist_ok=True) + + # 创建SQLite数据库引擎 + _engine = create_engine(f'sqlite:///{db_path}', echo=False) + return _engine + +def get_session(): + """获取数据库会话""" + global _Session + if _Session is None: + _Session = sessionmaker(bind=get_db_engine()) + return _Session() + +def init_db(): + """初始化数据库表结构""" + Base.metadata.create_all(get_db_engine()) + +def get_existing_trade_dates(table_name): + """ + 从数据库中获取已有的交易日期 + + 参数: + table_name (str): 数据表名称 + + 返回: + set: 已存在于数据库中的交易日期集合 + """ + engine = get_db_engine() + query = f"SELECT DISTINCT trade_date FROM {table_name}" + + try: + # SQLAlchemy 2.0+ 版本的执行方式 + from sqlalchemy import text + with engine.connect() as connection: + result = connection.execute(text(query)) + return {row[0] for row in result} + except Exception as e: + print(f"获取已存在交易日期时出错: {e}") + return set() + + +def load_df_from_db(table_name, conditions=None): + """ + 从数据库中加载数据 + + 参数: + table_name (str): 表名 + conditions (str): 过滤条件,如 "trade_date > '20230101'" + + 返回: + pandas.DataFrame: 查询结果 + """ + engine = get_db_engine() + query = f"SELECT * FROM {table_name}" + + if conditions: + query += f" WHERE {conditions}" + + return pd.read_sql(query, engine) + + +def save_df_to_db(df, table_name, if_exists='append'): + """ + 保存DataFrame到数据库 + + 参数: + df (pandas.DataFrame): 要保存的数据 + table_name (str): 表名 + if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append' + """ + engine = get_db_engine() + df.to_sql(table_name, engine, if_exists=if_exists, index=False) \ No newline at end of file