backtrader/get_kpl_list.py

250 lines
9.1 KiB
Python
Raw Normal View History

import os
import time
import openpyxl
import pandas as pd
from openpyxl.styles import PatternFill
from tqdm import tqdm
from utils import load_config, get_trade_cal
# 加载配置并初始化tushare
config = load_config()
import tushare as ts
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
def get_kpl_data(start_date=None, end_date=None):
"""
获取指定时间段内的打板数据
参数
start_date (str): 开始日期格式'YYYYMMDD'
end_date (str): 结束日期格式'YYYYMMDD'
返回
pandas.DataFrame: 所有打板数据
"""
# 获取目标交易日历
all_trade_dates = get_trade_cal(start_date, end_date)
# 检查是否已有现有数据
existing_data = pd.DataFrame()
existing_dates = set()
output_file = 'ori_kpl_list.xlsx'
if os.path.exists(output_file):
try:
print(f"检测到已有数据文件: {output_file}")
existing_data = pd.read_excel(output_file)
if not existing_data.empty and 'trade_date' in existing_data.columns:
# 确保trade_date是字符串类型
existing_data['trade_date'] = existing_data['trade_date'].astype(str)
# 提取已有数据的交易日期
existing_dates = set(existing_data['trade_date'].astype(str).unique())
print(f"已有数据包含 {len(existing_dates)} 个交易日")
except Exception as e:
print(f"读取现有数据时出错: {e}")
# 确定需要获取的日期
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
if not dates_to_fetch:
print("所有数据均已存在,无需更新")
return existing_data
print(f"需要获取 {len(dates_to_fetch)} 个新交易日的数据")
# 获取新的打板数据
new_data = []
for trade_date in tqdm(dates_to_fetch):
try:
# 不指定字段参数,获取所有返回的字段
df = pro.kpl_list(trade_date=trade_date, tag='涨停')
if not df.empty:
# 确保新数据的trade_date也是字符串类型
if 'trade_date' in df.columns:
df['trade_date'] = df['trade_date'].astype(str)
new_data.append(df)
except Exception as e:
print(f"获取 {trade_date} 数据时出错: {e}")
time.sleep(1) # 出错时稍微多等待一下
# 合并所有数据
if new_data:
new_result = pd.concat(new_data, ignore_index=True)
print(f"成功获取 {len(new_data)} 个交易日的新数据,共 {len(new_result)} 条记录")
# 合并新旧数据
if not existing_data.empty:
result = pd.concat([existing_data, new_result], ignore_index=True)
print(f"合并后共有 {len(result)} 条记录")
else:
result = new_result
# 进行一次去重操作,以防万一
if 'ts_code' in result.columns and 'trade_date' in result.columns:
result = result.drop_duplicates(subset=['ts_code', 'trade_date'], keep='last')
print(f"去重后共有 {len(result)} 条记录")
# 确保trade_date是字符串类型后再排序
if 'trade_date' in result.columns:
result['trade_date'] = result['trade_date'].astype(str)
result = result.sort_values(by='trade_date', ascending=False) # 降序排列,最新的数据在前
print("数据已按交易日期排序")
return result
else:
print("未获取到任何新数据")
return existing_data
def analyze_kpl_data():
"""
分析涨停板数据统计每日各板块涨停数量并使用热力图风格展示
按照最近100个工作日的涨停总数对板块进行排序
删除100个工作日内没有涨停记录的板块
"""
print("开始分析涨停板数据...")
# 1. 从原始文件读取数据
try:
ori_data = pd.read_excel('ori_kpl_list.xlsx')
print(f"成功读取原始数据,共 {len(ori_data)} 条记录")
except Exception as e:
print(f"读取数据失败: {e}")
return
# 确保日期字段是字符串类型
ori_data['trade_date'] = ori_data['trade_date'].astype(str)
# 获取所有唯一的交易日期和板块
all_dates = sorted(ori_data['trade_date'].unique(), reverse=True) # 降序排列日期
if 'lu_desc' not in ori_data.columns:
print("错误: 原始数据中没有板块信息字段 'lu_desc'")
return
# 获取所有唯一的板块
all_sectors = ori_data['lu_desc'].dropna().unique()
print(f"数据包含 {len(all_dates)} 个交易日和 {len(all_sectors)} 个板块")
# 创建一个包含所有日期和板块的DataFrame用于统计
temp_result = pd.DataFrame(0, index=all_dates, columns=all_sectors)
# 按日期分组统计
for date in all_dates:
# 获取当日数据
daily_data = ori_data[ori_data['trade_date'] == date]
# 统计各板块涨停数量
sector_counts = daily_data.groupby('lu_desc').size()
# 更新临时结果DataFrame
for sector, count in sector_counts.items():
if sector in temp_result.columns:
temp_result.loc[date, sector] = count
# 计算最近100个工作日(或所有可用天数)的各板块涨停总数
recent_days = min(100, len(all_dates))
recent_dates = all_dates[:recent_days]
# 计算这些日期内每个板块的涨停总数
sector_totals = temp_result.loc[recent_dates].sum()
# 筛选出在100个工作日内有涨停记录的板块
active_sectors = sector_totals[sector_totals > 0].index.tolist()
# 按照涨停总数对活跃板块进行排序
sorted_sectors = sector_totals[active_sectors].sort_values(ascending=False).index.tolist()
print(f"已按最近{recent_days}个工作日的涨停总数对板块排序")
print(
f"共保留了{len(sorted_sectors)}个有涨停记录的板块,删除了{len(all_sectors) - len(sorted_sectors)}个无涨停记录的板块")
if sorted_sectors:
print("涨停数量前10的板块:")
for i, sector in enumerate(sorted_sectors[:min(10, len(sorted_sectors))], 1):
print(f"{i}. {sector}: {sector_totals[sector]}")
# 如果没有活跃板块,提前返回
if not sorted_sectors:
print("警告: 在指定时间段内没有板块有涨停记录")
return pd.DataFrame()
# 创建最终结果DataFrame只使用有涨停记录的排序后的板块
result = pd.DataFrame("", index=all_dates, columns=sorted_sectors)
# 填充数据,只填入非零值
for date in all_dates:
for sector in sorted_sectors:
count = temp_result.loc[date, sector]
if count > 0:
result.loc[date, sector] = count
# 保存结果到新的Excel文件
output_file = 'sector_limit_up_analysis.xlsx'
result.to_excel(output_file)
# 创建热力图色阶函数:从浅红色到深红色(FFFF0000)
def get_heatmap_color(value):
try:
value = int(value)
# 将值限制在0-20范围内
value = min(max(value, 0), 20)
# 计算颜色深度 - 值越大颜色越深
# 红色固定为FF绿色和蓝色从FF(浅)递减到00(深)
intensity = int(255 - (value / 20 * 255))
intensity_hex = format(intensity, '02X')
# 构建颜色代码: 红色固定为FF绿色和蓝色根据值变化
color_code = f"FF{intensity_hex}{intensity_hex}"
return color_code
except:
return None
# 使用openpyxl添加热力图风格
print("正在添加热力图样式...")
workbook = openpyxl.load_workbook(output_file)
worksheet = workbook.active
# 遍历所有数据单元格
for row in range(2, worksheet.max_row + 1): # 跳过标题行
for col in range(2, worksheet.max_column + 1): # 跳过索引列
cell = worksheet.cell(row=row, column=col)
if cell.value and str(cell.value).strip(): # 只处理非空单元格
# 获取相应的热力图颜色
color_code = get_heatmap_color(cell.value)
if color_code:
# 应用背景色
cell.fill = PatternFill(start_color=color_code, end_color=color_code, fill_type="solid")
# 保存格式化后的Excel
workbook.save(output_file)
print(f"分析完成,结果已保存到 {output_file}")
print(f"统计了 {len(result.columns)} 个活跃板块的涨停数据")
print(f"已使用红色热力图标记涨停数量0-20对应从浅红到深红")
return result
if __name__ == "__main__":
# 指定日期范围
start_date = '20220101'
end_date = None
# 获取打板数据
kpl_data = get_kpl_data(start_date, end_date)
# 保存到Excel
if not kpl_data.empty:
kpl_data.to_excel('ori_kpl_list.xlsx', index=False)
print(f"数据已保存到 ori_kpl_list.xlsx, 共 {len(kpl_data)} 条记录")
else:
print("没有数据可保存")
# 执行分析
analyze_kpl_data()