import math from enum import Enum, auto from typing import List, override from internal_types.types import OHLC, BidAsk, Instrument, Position from strategy.strategy import Strategy from utils.utils import SEC_1_HOUR, SMA, Portfolio class State(Enum): POS_0 = auto() POS_1 = auto() class Cross(Enum): UNSPECIFIED = auto() GOLDEN = auto() DEATH = auto() class SMACrossover(Strategy): def __init__(self, init_balance, instr: Instrument, interval_sec: int, short_window_sec: int = 12 * SEC_1_HOUR, long_window_sec: int = 26 * SEC_1_HOUR): self.state = State.POS_0 self.init_balance = init_balance self.instr = instr self.short_sma = SMA(interval_sec, short_window_sec) self.long_sma = SMA(interval_sec, long_window_sec) self.portfolio = Portfolio() self.desired_portfolio = Portfolio() def __cross(self) -> Cross: return Cross.GOLDEN if self.short_sma.val() > self.long_sma.val() else Cross.DEATH def __unit_limit(self, price_at: float) -> int: return math.floor(self.net_liquid_value(price_at) / price_at / self.instr.multiplier) def warmup(self, warmup_historical_data: List[OHLC]): for ohlc in warmup_historical_data: self.short_sma.append(ohlc.close) self.long_sma.append(ohlc.close) if not (self.short_sma.has_val() and self.long_sma.has_val()): raise ValueError('need as least `long_window_sec` of OHLC to warmup') self.state = State.POS_1 @override def unfilled_positions(self, instr: Instrument) -> Position: x = self.desired_portfolio.outstanding_shares(instr) - self.portfolio.outstanding_shares(instr) return self.desired_portfolio.consolidate_last_x_shares(instr, x) @override def order_filled(self, new_pos: Position): self.portfolio.add_position(new_pos) @override def process_bid_ask(self, bid_ask: BidAsk): assert False, 'todo' @override def process_ohlc(self, ohlc: OHLC): if self.state == State.POS_0: raise RuntimeError('strategy wasn\'t warmed up yet') if ohlc.instr != self.instr: return prev_cross = self.__cross() self.short_sma.append(ohlc.close) self.long_sma.append(ohlc.close) curr_cross = self.__cross() if prev_cross != curr_cross: outstanding_shares = self.desired_portfolio.outstanding_shares(self.instr) desired_shares = self.__unit_limit(ohlc.close) if curr_cross == Cross.DEATH: # desired_shares = 0 # long only desired_shares = -desired_shares quantity = desired_shares - outstanding_shares self.desired_portfolio.add_position(Position(self.instr, quantity, ohlc.close)) @override def net_liquid_value(self, at_price: float) -> float: return self.init_balance + self.portfolio.total_gains(self.instr, at_price)