import os import yaml import tushare as ts import pandas as pd from datetime import datetime, timedelta from sqlalchemy import create_engine def load_config(): config_path = 'config.yaml' if not os.path.exists(config_path): config = { 'tushare_token': 'your_token_here', 'mysql': { 'user': 'user', 'password': 'password', 'host': '127.0.0.1', 'port': 3306, 'database': 'tushare', 'charset': 'utf8' } } with open(config_path, 'w') as f: yaml.dump(config, f) print(f"请在 {config_path} 中填入您的 tushare token 和 MySQL 连接信息") exit(1) with open(config_path, 'r') as f: config = yaml.safe_load(f) return config config = load_config() ts.set_token(config['tushare_token']) pro = ts.pro_api() def get_trans_data(stock_code, start_date, end_date, data_type='daily'): # 确保日期格式正确 start_date = pd.to_datetime(start_date).strftime('%Y%m%d') end_date = pd.to_datetime(end_date).strftime('%Y%m%d') # 创建保存目录 save_dir = os.path.join('data', stock_code) os.makedirs(save_dir, exist_ok=True) # 生成文件名 file_name = f"{data_type}_{start_date}_{end_date}.csv" file_path = os.path.join(save_dir, file_name) # 检查文件是否已存在 if os.path.exists(file_path): print(f"{file_path} 已存在,跳过下载") return True # 根据数据类型选择相应的API if data_type == 'daily': df = pro.daily(ts_code=stock_code, start_date=start_date, end_date=end_date) elif data_type == 'weekly': df = pro.weekly(ts_code=stock_code, start_date=start_date, end_date=end_date) elif data_type == 'monthly': df = pro.monthly(ts_code=stock_code, start_date=start_date, end_date=end_date) elif data_type == 'money_flow': df = pro.moneyflow(ts_code=stock_code, start_date=start_date, end_date=end_date) elif data_type == 'daily_basic': df = pro.daily_basic(ts_code=stock_code, start_date=start_date, end_date=end_date) else: print(f"不支持的数据类型: {data_type}") return False # 如果数据为空,返回 if df.empty: print(f"没有找到 {stock_code} 从 {start_date} 到 {end_date} 的 {data_type} 数据") return False # 保存数据 df.to_csv(file_path, index=False) print(f"数据已保存到 {file_path}") return True def get_stock_basic(): # 获取当前时间作为文件名的一部分 current_time = datetime.now().strftime('%Y%m%d') # 创建保存目录 save_dir = 'data' os.makedirs(save_dir, exist_ok=True) # 生成文件名 file_name = f"stock_basic_{current_time}.csv" file_path = os.path.join(save_dir, file_name) # 下载股票基本信息,包含所有可用字段 fields = 'ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type' df = pro.stock_basic(exchange='', list_status='L', fields=fields) # 如果数据为空,返回 if df.empty: print("没有找到股票基本信息数据") return # 保存数据 df.to_csv(file_path, index=False) print(f"股票基本信息数据已保存到 {file_path}") def load_share_data(stock_code, data_type='daily', start_date=None, end_date=None): data_dir = os.path.join('data', stock_code) # 如果未指定日期,尝试加载任何可用数据 if start_date is None or end_date is None: files = [f for f in os.listdir(data_dir) if f.startswith(f"{data_type}_")] if not files: raise ValueError(f"找不到{stock_code}的{data_type}数据文件") # 使用找到的第一个文件 file_path = os.path.join(data_dir, files[0]) else: # 格式化日期 start_date = pd.to_datetime(start_date).strftime('%Y%m%d') end_date = pd.to_datetime(end_date).strftime('%Y%m%d') file_name = f"{data_type}_{start_date}_{end_date}.csv" file_path = os.path.join(data_dir, file_name) if not os.path.exists(file_path): print(f"找不到{file_path}数据文件,执行下载") get_trans_data(stock_code, start_date, end_date, data_type) print(f"加载数据: {file_path}") # 读取CSV文件 df = pd.read_csv(file_path) # 确保日期列是datetime类型并按日期升序排序 df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d') df = df.sort_values('trade_date') # 重命名列以符合backtrader期望的格式 df = df.rename(columns={ 'trade_date': 'datetime', 'open': 'open', 'high': 'high', 'low': 'low', 'close': 'close', 'vol': 'volume' }) # 确保所有必需的列都存在 required_columns = ['datetime', 'open', 'high', 'low', 'close', 'volume'] for col in required_columns: if col not in df.columns: raise ValueError(f"缺少必需的列: {col}") # 将datetime设为索引 df = df.set_index('datetime') return df def create_engine_from_config(config): mysql = config['mysql'] connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1" return create_engine(connection_string) if __name__ == "__main__": stock_codes = ['002112.SZ', '601801.SH', '002737.SZ', '600970.SH'] # 示例股票代码 start_date = '20230101' # 手动输入的开始日期 end_date = '20250403' # 手动输入的结束日期 data_types = ['daily', 'weekly', 'monthly', 'money_flow', 'daily_basic'] # 可选数据类型 # 下载股票基本信息 get_stock_basic() # 下载其他数据 for stock_code in stock_codes: for data_type in data_types: get_trans_data(stock_code, start_date, end_date, data_type)