Skip to content

fix(strategy): fix TopkDropoutStrategy to use actual sell order count#2223

Open
lingbai-kong wants to merge 1 commit into
microsoft:mainfrom
lingbai-kong:klb/fix-topk-dropout-strategy
Open

fix(strategy): fix TopkDropoutStrategy to use actual sell order count#2223
lingbai-kong wants to merge 1 commit into
microsoft:mainfrom
lingbai-kong:klb/fix-topk-dropout-strategy

Conversation

@lingbai-kong
Copy link
Copy Markdown
Contributor

Description

This PR fixes an issue in TopkDropoutStrategy where the number of stocks in the portfolio could exceed the topk setting. The root cause is that the strategy was using the length of the planned sell list (len(sell)) instead of the actual number of successfully executed sell orders (len(sell_order_list)) when calculating how many stocks to buy.

The Fix

Original code (qlib/qlib/contrib/strategy/signal_strategy.py):

# Get the stock list we really want to buy
buy = today[: len(sell) + self.topk - len(last)]

Fixed code:

# Get the stock list we really want to buy.
# Buy only enough to keep holdings within topk after the sell orders that were actually generated.
buy = today[: max(0, len(sell_order_list) + self.topk - len(last))]

Key changes:

  1. Move buy list calculation after the sell orders are actually generated
  2. Use len(sell_order_list) (actual executed sells) instead of len(sell) (planned sells)
  3. Add max(0, ...) to prevent negative index when appropriate

Motivation and Context

Fixes issue #809 (#809).

The problem occurs when some sell orders fail due to:

  • Tradability checks (is_stock_tradable() returns False)
  • hold_threshold constraints (stock not held long enough)

The original code assumed all planned sells would succeed, leading to buying more stocks than needed and exceeding the topk limit.

How Has This Been Tested?

Test Script

"""
Test script for TopkDropoutStrategy fix (Issue #809).
Compares original (buggy) vs fixed behavior.
Run with: conda activate qlib && python test_topk_fix_real.py
"""
import sys, os

# Remove local qlib source dir from sys.path to use conda-installed qlib (editable install)
_script_dir = os.path.dirname(os.path.abspath(__file__))
_local_qlib = os.path.join(_script_dir, 'qlib')
sys.path = [p for p in sys.path if os.path.abspath(p) not in (_script_dir, _local_qlib)]

import pandas as pd
import numpy as np
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
from qlib.contrib.strategy.signal_strategy import TopkDropoutStrategy


# ---------------------------------------------------------------------------
# Buggy version: simulates the ORIGINAL code before the fix
# Override generate_trade_decision to use len(sell) instead of len(sell_order_list)
# ---------------------------------------------------------------------------
class BuggyTopkDropoutStrategy(TopkDropoutStrategy):
    """TopkDropoutStrategy with the original bug: uses len(sell) to size buy list."""

    def generate_trade_decision(self, execute_result=None):
        trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
        pred_score = self.signal.get_signal(start_time=trade_start_time, end_time=trade_end_time)
        if isinstance(pred_score, pd.DataFrame):
            pred_score = pred_score.iloc[:, 0]
        if pred_score is None:
            return TradeDecisionWO([], self)
        if isinstance(self.method_sell, str) and self.method_sell not in ["top", "bottom", "random"]:
            raise NotImplementedError(f"This type of input is not supported")
        if isinstance(self.method_buy, str) and self.method_buy not in ["top", "bottom", "random"]:
            raise NotImplementedError(f"This type of input is not supported")

        # Helper functions (only_tradable=False path)
        def get_first_n(li, n):
            return list(li)[:n]

        def get_last_n(li, n):
            return list(li)[-n:]

        import copy
        current_temp = copy.deepcopy(self.trade_position)
        sell_order_list = []
        buy_order_list = []
        cash = current_temp.get_cash()
        current_stock_list = current_temp.get_stock_list()
        last = pred_score.reindex(current_stock_list).sort_values(ascending=False).index
        if self.method_buy == "top":
            today = get_first_n(
                pred_score[~pred_score.index.isin(last)].sort_values(ascending=False).index,
                self.n_drop + self.topk - len(last),
            )
        elif self.method_buy == "random":
            topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk)
            candi = list(filter(lambda x: x not in last, topk_candi))
            n = self.n_drop + self.topk - len(last)
            try:
                today = np.random.choice(candi, n, replace=False)
            except ValueError:
                today = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")
        comb = pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index

        if self.method_sell == "bottom":
            sell = last[last.isin(get_last_n(comb, self.n_drop))]
        elif self.method_sell == "random":
            candi = last
            try:
                sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
            except ValueError:
                sell = candi
        else:
            raise NotImplementedError(f"This type of input is not supported")

        # ── BUGGY LINE: uses len(sell) BEFORE sell orders are generated ──
        buy = today[: len(sell) + self.topk - len(last)]

        for code in current_stock_list:
            if not self.trade_exchange.is_stock_tradable(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time,
                direction=None if self.forbid_all_trade_at_limit else OrderDir.SELL,
            ):
                continue
            if code in sell:
                time_per_step = self.trade_calendar.get_freq()
                if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
                    continue
                sell_amount = current_temp.get_stock_amount(code=code)
                sell_order = Order(
                    stock_id=code, amount=sell_amount,
                    start_time=trade_start_time, end_time=trade_end_time,
                    direction=Order.SELL,
                )
                if self.trade_exchange.check_order(sell_order):
                    sell_order_list.append(sell_order)
                    trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
                        sell_order, position=current_temp
                    )
                    cash += trade_val - trade_cost

        value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
        for code in buy:
            if not self.trade_exchange.is_stock_tradable(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time,
                direction=None if self.forbid_all_trade_at_limit else OrderDir.BUY,
            ):
                continue
            buy_price = self.trade_exchange.get_deal_price(
                stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY
            )
            buy_amount = value / buy_price
            factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
            buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
            buy_order = Order(
                stock_id=code, amount=buy_amount,
                start_time=trade_start_time, end_time=trade_end_time,
                direction=Order.BUY,
            )
            buy_order_list.append(buy_order)
        return TradeDecisionWO(sell_order_list + buy_order_list, self)


# ---------------------------------------------------------------------------
# Mock infrastructure (shared between both versions)
# ---------------------------------------------------------------------------

class MockTradeCalendar:
    def __init__(self, freq="day"):
        self.freq = freq

    def get_trade_step(self):
        return 0

    def get_freq(self):
        return self.freq

    def get_step_time(self, step=0, shift=0):
        return pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-01")


class MockExchange:
    def __init__(self, tradable_stocks=None):
        self.tradable_stocks = set(tradable_stocks or [])

    def is_stock_tradable(self, stock_id, start_time, end_time, direction=None):
        return stock_id in self.tradable_stocks

    def check_order(self, order):
        return order.stock_id in self.tradable_stocks

    def deal_order(self, order, position):
        trade_val = order.amount * 10.0
        trade_cost = trade_val * 0.001
        return trade_val, trade_cost, 10.0

    def get_factor(self, stock_id, start_time, end_time):
        return 1.0

    def get_deal_price(self, stock_id, start_time, end_time, direction):
        return 10.0

    def round_amount_by_trade_unit(self, amount, factor):
        return amount


class MockSignal:
    def __init__(self, scores):
        self.scores = scores

    def get_signal(self, start_time, end_time):
        return pd.Series(self.scores)


class MockPosition:
    def __init__(self, stocks, cash=100000):
        self._stocks = {s: 100 for s in stocks}
        self._cash = cash

    def get_stock_list(self):
        return list(self._stocks.keys())

    def get_cash(self):
        return self._cash

    def get_stock_amount(self, code):
        return self._stocks.get(code, 0)

    def get_stock_count(self, code, bar=None):
        return 100 if self._stocks.get(code, 0) > 0 else 0


# ---------------------------------------------------------------------------
# Test runner
# ---------------------------------------------------------------------------

def run_test(strategy_cls, topk, n_drop, current_stocks, scores, tradable_stocks):
    strategy = strategy_cls.__new__(strategy_cls)
    strategy.topk = topk
    strategy.n_drop = n_drop
    strategy.method_sell = "bottom"
    strategy.method_buy = "top"
    strategy.hold_thresh = 1
    strategy.only_tradable = False
    strategy.forbid_all_trade_at_limit = True
    strategy.risk_degree = 0.95

    mock_position = MockPosition(current_stocks)
    mock_account = type("Account", (), {"current_position": mock_position})()
    strategy.__dict__["level_infra"]   = {"trade_calendar": MockTradeCalendar()}
    strategy.__dict__["common_infra"]  = {"trade_account": mock_account}
    strategy.__dict__["_trade_exchange"] = MockExchange(tradable_stocks=tradable_stocks)
    strategy.signal = MockSignal(scores)

    decision = strategy.generate_trade_decision()

    sells = [o for o in decision.order_list if o.direction == Order.SELL]
    buys  = [o for o in decision.order_list if o.direction == Order.BUY]
    final_holdings = len(current_stocks) - len(sells) + len(buys)

    return len(sells), len(buys), final_holdings, [o.stock_id for o in sells], [o.stock_id for o in buys]


if __name__ == "__main__":
    scores = {'D': 0.9, 'E': 0.8, 'F': 0.7, 'A': 0.3, 'B': 0.2, 'C': 0.1}

    test_cases = [
        ("Tradability Check (B not tradable)", ['A', 'B', 'C'], ['A', 'C', 'D', 'E', 'F']),
        ("Partial Tradable (only A)",         ['A', 'B', 'C'], ['A', 'D', 'E', 'F']),
        ("No Tradable (extreme)",             ['A', 'B', 'C'], ['D', 'E', 'F']),
        ("All Tradable (normal)",             ['A', 'B', 'C'], ['A', 'B', 'C', 'D', 'E', 'F']),
    ]

    print()
    print("=" * 80)
    print("TopkDropoutStrategy Fix Test - Issue #809")
    print("Comparing ORIGINAL (buggy) vs FIXED behavior")
    print("=" * 80)

    all_ok = True
    for label, stocks, tradable in test_cases:
        # ── Original (buggy) ──
        o_sell, o_buy, o_hold, o_sell_ids, o_buy_ids = run_test(
            BuggyTopkDropoutStrategy, 3, 3, stocks, scores, tradable,
        )
        # ── Fixed ──
        f_sell, f_buy, f_hold, f_sell_ids, f_buy_ids = run_test(
            TopkDropoutStrategy, 3, 3, stocks, scores, tradable,
        )

        passed = (f_hold == 3)
        if not passed:
            all_ok = False
        status = "PASS" if passed else "FAIL"

        print()
        print(f"[{label}]")
        print(f"  {'ORIGINAL':>10}  sells={o_sell} {o_sell_ids}, buys={o_buy} {o_buy_ids}  "
              f"--> holdings: 3-{o_sell}+{o_buy}={o_hold}")
        print(f"  {'FIXED':>10}  sells={f_sell} {f_sell_ids}, buys={f_buy} {f_buy_ids}  "
              f"--> holdings: 3-{f_sell}+{f_buy}={f_hold}")
        print(f"  -> {status}")

    print()
    print("=" * 80)
    print("SUMMARY: {}".format("ALL PASSED" if all_ok else "SOME FAILED"))
    if all_ok:
        print("Fixed version correctly maintains holdings at topk=3 in all scenarios.")
    print("=" * 80)
    print()

Test Results (ORIGINAL vs FIXED comparison)

================================================================================
TopkDropoutStrategy Fix Test - Issue #809
Comparing ORIGINAL (buggy) vs FIXED behavior
================================================================================

[Tradability Check (B not tradable)]
    ORIGINAL  sells=2 ['A', 'C'], buys=3 ['D', 'E', 'F']  --> holdings: 3-2+3=4
       FIXED  sells=2 ['A', 'C'], buys=2 ['D', 'E']      --> holdings: 3-2+2=3
  -> PASS

[Partial Tradable (only A)]
    ORIGINAL  sells=1 ['A'], buys=3 ['D', 'E', 'F']  --> holdings: 3-1+3=5
       FIXED  sells=1 ['A'], buys=1 ['D']            --> holdings: 3-1+1=3
  -> PASS

[No Tradable (extreme)]
    ORIGINAL  sells=0 [], buys=3 ['D', 'E', 'F']  --> holdings: 3-0+3=6
       FIXED  sells=0 [], buys=0 []               --> holdings: 3-0+0=3
  -> PASS

[All Tradable (normal)]
    ORIGINAL  sells=3 ['A', 'B', 'C'], buys=3 ['D', 'E', 'F']  --> holdings: 3-3+3=3
       FIXED  sells=3 ['A', 'B', 'C'], buys=3 ['D', 'E', 'F']  --> holdings: 3-3+3=3
  -> PASS

================================================================================
SUMMARY: ALL PASSED
Fixed version correctly maintains holdings at topk=3 in all scenarios.
================================================================================

Analysis

The comparison clearly demonstrates the bug and its fix:

Scenario Original (buggy) Fixed Topk Exceeded?
B not tradable 4 holdings 3 holdings 3 ❌→✅
Only A tradable 5 holdings 3 holdings 3 ❌→✅
No tradable 6 holdings 3 holdings 3 ❌→✅
All tradable 3 holdings 3 holdings 3 ✅→✅

In 3 out of 4 test scenarios, the original code causes the portfolio to exceed topk. The fixed code correctly maintains topk = 3 in all scenarios.

Screenshots of Test Results (if appropriate):

  1. Pipeline test: N/A
  2. Own test: See test output above. Test script: test_topk_fix_real.py

Types of changes

  • Fix bugs
  • Add new feature
  • Update documentation

Fix issue microsoft#809 where TopkDropoutStrategy uses len(sell) (planned sell list)
instead of len(sell_order_list) (actual executed sell orders) when
calculating buy list size. This causes portfolio size to exceed topk
when some sell orders fail due to tradability or hold_threshold checks.

Changes:
- Move buy list calculation after sell orders are generated
- Use len(sell_order_list) instead of len(sell)
- Add max(0, ...) to prevent negative buy list length
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant