1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| from numbers import Number
import numpy as np import pandas as pd
from Strategy import Strategy, SmaCross from Utils import read_file, assert_msg
class ExchangeAPI: def __init__(self, data, cash, commission): assert_msg(0 < cash, "初始现金数量必须大于 0,输入的现金数量:{}".format(cash)) assert_msg(0 <= commission <= 0.05, "合理的手续费率一般不会超过 5%,输入的费率:{}".format(commission)) self._inital_cash = cash self._data = data self._commission = commission self._position = 0 self._cash = cash self._i = 0
@property def cash(self): """ :return: 返回当前账户现金数量 """ return self._cash
@property def position(self): """ :return: 返回当前账户仓位 """ return self._position
@property def initial_cash(self): return self._inital_cash
@property def market_value(self): return self._cash + self._position * self.current_price
@property def current_price(self): return self._data.Close[self._i]
def buy(self): self._position = float(self._cash * (1 - self._commission) / self.current_price) self._cash = 0.0
def sell(self): self._cash += float(self._position * self.current_price * (1 - self._commission)) self._position = 0.0
def next(self, tick): self._i = tick
class Backtest: def __init__(self, data: pd.DataFrame, strategy_type: type(Strategy), broker_type: type(ExchangeAPI), cash: float = 10000, commission: float = .0): assert_msg(issubclass(strategy_type, Strategy), "strategy_type 不是 Strategy 类型") assert_msg(issubclass(broker_type, ExchangeAPI), "broker_type 不是 ExchangeAPI 类型") assert_msg(isinstance(commission, Number), "commission 不是浮点数值类型")
data = data.copy(False)
if 'Volume' not in data: data['Volume'] = np.nan
assert_msg(not data[['Open', 'High', 'Low', 'Close']].max().isnull().any(), ("部分 OHLC 包含缺失值,请去掉那些行或者通过差值填充"))
if not data.index.is_monotonic_increasing: data = data.sort_index()
self._data = data self._broker = broker_type(data, cash, commission) self._strategy = strategy_type(self._broker, self._data) self._results = None
def run(self) -> pd.Series: strategy = self._strategy broker = self._broker
strategy.init()
start = 100 end = len(self._data)
for i in range(start, end): broker.next(i) strategy.next(i)
self._results = self._compute_result(broker) return self._results
def _compute_result(self, broker): s = pd.Series() s['初始市值'] = broker.initial_cash s['结束市值'] = broker.market_value s['收益'] = broker.market_value - broker.initial_cash return s
def main(): BTCUSD = read_file('BTCUSD_GEMINI.csv') ret = Backtest(BTCUSD, SmaCross, ExchangeAPI, 10000.0, 0.003).run() print(ret)
if __name__ == '__main__': main()
|