76 lines
2.2 KiB
Python
76 lines
2.2 KiB
Python
import os
|
|
import yaml
|
|
import tushare as ts
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from sqlalchemy import create_engine
|
|
|
|
# 模块级单例
|
|
_config = None
|
|
_engine = None
|
|
|
|
|
|
def load_config():
|
|
global _config
|
|
if _config is not None:
|
|
return _config
|
|
|
|
config_path = 'config.yaml'
|
|
if not os.path.exists(config_path):
|
|
config = {
|
|
'tushare_token': 'your_token_here',
|
|
'mysql': {
|
|
'user': 'user',
|
|
'password': 'password',
|
|
'host': '127.0.0.1',
|
|
'port': 3306,
|
|
'database': 'tushare',
|
|
'charset': 'utf8'
|
|
},
|
|
'openai_api': {
|
|
'api_key': 'sk-your-api-key-here',
|
|
'base_url': 'https://api.tu-zi.com/v1',
|
|
'model': 'gpt-4'
|
|
}
|
|
}
|
|
with open(config_path, 'w') as f:
|
|
yaml.dump(config, f)
|
|
print(f"请在 {config_path} 中填入您的 tushare token、MySQL 连接信息以及第三方 API 信息")
|
|
exit(1)
|
|
|
|
with open(config_path, 'r') as f:
|
|
_config = yaml.safe_load(f)
|
|
|
|
return _config
|
|
|
|
|
|
def get_engine():
|
|
"""获取单例数据库引擎"""
|
|
global _engine
|
|
if _engine is None:
|
|
config = load_config()
|
|
mysql = config['mysql']
|
|
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
|
_engine = create_engine(connection_string)
|
|
|
|
return _engine
|
|
|
|
def get_trade_cal(start_date=None, end_date=None):
|
|
"""
|
|
获取指定时间段内的交易日历
|
|
|
|
参数:
|
|
start_date (str): 开始日期,格式'YYYYMMDD'
|
|
end_date (str): 结束日期,格式'YYYYMMDD'
|
|
|
|
返回:
|
|
pandas.DataFrame: 交易日历
|
|
"""
|
|
if start_date is None:
|
|
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
|
if end_date is None:
|
|
end_date = datetime.now().strftime('%Y%m%d')
|
|
|
|
pro = ts.pro_api()
|
|
trade_cal_df = pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
|
|
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist() |