diff --git a/custom_worker_tuner/README.md b/custom_worker_tuner/README.md new file mode 100644 index 00000000..a03416c6 --- /dev/null +++ b/custom_worker_tuner/README.md @@ -0,0 +1,88 @@ +# Custom Worker Tuner + +A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want. +This sample gates on a fake DB pool: the worker only polls for a new +activity when the pool has a free connection. + +**Note:** This sample is illustrative only. It shouldn't be used for production grade use-cases. + +## What this sample is +db_pool.py - A fixed-capacity fake pool backed by a `BoundedSemaphore`. Two methods: `acquire(blocking=True)` (claim a slot, returns False if full when non-blocking), `release()` (return a slot) +supplier.py - The custom slot supplier. `reserve_slot` blocks on `connection_pool.acquire()` until a slot is free; `try_reserve_slot` does the same non-blocking. `release_slot` calls `connection_pool.release()` +shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps +worker.py - Wires `FakeDatabaseConnectionPool` + `PoolSlotSupplier` into a WorkerTuner +starter.py - Drives load + +The flow: + +When the pool is at capacity, `reserve_slot` blocks until a +connection frees up. The excess work piles up on the Temporal server, not +inside the worker. + +## Run + +In three terminals from `samples-python/`: + +```bash +temporal server start-dev # terminal 1 +uv run custom_worker_tuner/worker.py # terminal 2 +uv run custom_worker_tuner/starter.py # terminal 3 +``` + +## What you'll see + +The worker prints one line per slot lifecycle event: + +``` +TIME EVENT COUNT QUEUE DETAIL +(COUNT shows before→after / capacity; QUEUE = tasks parked waiting) +───────────────────────────────────────────────────────────────── +12:30:32.591 reserve 0→ 1/10 0 ready to poll +12:30:32.591 reserve 1→ 2/10 0 ready to poll +12:30:32.592 reserve 2→ 3/10 0 ready to poll +12:30:32.592 reserve 3→ 4/10 0 ready to poll +12:30:32.592 reserve 4→ 5/10 0 ready to poll +12:30:32.592 reserve 5→ 6/10 0 ready to poll +12:30:40.501 reserve 6→ 7/10 0 eager dispatch +12:30:40.502 reserve 7→ 8/10 0 eager dispatch +12:30:40.502 reserve 8→ 9/10 0 eager dispatch +12:30:40.505 release 9→ 8/10 0 no task arrived +12:30:40.506 release 8→ 7/10 0 no task arrived +12:30:40.506 release 7→ 6/10 0 no task arrived +12:30:40.510 used 6→ 6/10 0 activity running +12:30:40.510 reserve 6→ 7/10 0 eager dispatch +12:30:40.511 reserve 7→ 8/10 0 eager dispatch +12:30:40.511 reserve 8→ 9/10 0 eager dispatch +12:30:40.514 reserve 9→10/10 0 ready to poll +12:30:40.520 release 10→ 9/10 0 no task arrived +12:30:40.520 release 9→ 8/10 0 no task arrived +12:30:40.520 release 8→ 7/10 0 no task arrived +12:30:40.520 used 7→ 7/10 0 activity running +12:30:40.520 reserve 7→ 8/10 0 eager dispatch +12:30:40.520 reserve 8→ 9/10 0 eager dispatch +12:30:40.520 reserve 9→10/10 0 eager dispatch +12:30:40.525 release 10→10/10 0 no task arrived +12:30:40.525 release 10→ 9/10 0 no task arrived +12:30:40.525 release 9→ 8/10 0 no task arrived +12:30:40.528 reserve 7→ 8/10 0 ready to poll +12:30:40.530 used 8→ 8/10 0 activity running +12:30:40.535 reserve 8→ 9/10 0 eager dispatch +12:30:40.537 reserve 9→10/10 0 eager dispatch +12:30:40.539 used 10→10/10 1 activity running +12:30:40.540 used 10→10/10 1 activity running +12:30:40.541 used 10→10/10 1 activity running +``` + +Under load, with more activities than capacity, COUNT pins at +10/10 — that's the supplier refusing to poll past the gate. +we chose 10 because default there are 5 pollers for python sdk + +## Knobs + +worker.py: + +CAPACITY — pool capacity (the gate) + +starter.py: + +WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load diff --git a/custom_worker_tuner/__init__.py b/custom_worker_tuner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/custom_worker_tuner/db_pool.py b/custom_worker_tuner/db_pool.py new file mode 100644 index 00000000..15a196c6 --- /dev/null +++ b/custom_worker_tuner/db_pool.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import asyncio +import logging + +logger = logging.getLogger(__name__) + + +class FakeDatabaseConnectionPool: + """Pretend connection pool with a fixed capacity, backed by an asyncio.Semaphore.""" + + def __init__(self, allowed_connections: int, name: str = "db") -> None: + self.allowed_connections = allowed_connections + self.name = name + self._connection_pool = asyncio.Semaphore(allowed_connections) + logger.info( + "FakeDatabaseConnectionPool ready: name=%s allowed_connections=%d", + name, + allowed_connections, + ) + + async def acquire(self) -> None: + """Claim a connection, awaiting until one is free.""" + await self._connection_pool.acquire() + + def try_acquire(self) -> bool: + """Non-blocking claim, try_reserve_slot will call this + if the pool is full - it will return false + if it is not full - total pool connections - 1 and slot granted to activity + """ + if self._connection_pool.locked(): + return False + self._connection_pool._value -= 1 + return True + + def release(self) -> None: + """Return a connection to the pool.""" + self._connection_pool.release() + + @property + def in_use(self) -> int: + """Derived from the semaphore — single source of truth.""" + return self.allowed_connections - self._connection_pool._value + + @property + def queued(self) -> int: + """How many tasks are parked waiting for a free slot.""" + waiters = self._connection_pool._waiters + if not waiters: + return 0 + return sum(1 for w in waiters if not w.done()) diff --git a/custom_worker_tuner/shared.py b/custom_worker_tuner/shared.py new file mode 100644 index 00000000..61cb8c3b --- /dev/null +++ b/custom_worker_tuner/shared.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from datetime import timedelta + +from temporalio import activity, workflow + +TASK_QUEUE = "custom-worker-tuner" + + +@dataclass +class BatchInput: + activities: int + seconds: float + + +@activity.defn +async def do_work(seconds: float) -> None: + """Sleep, simulating an I/O-bound activity.""" + await asyncio.sleep(seconds) + + +@workflow.defn +class RunBatch: + """Runs N do_work activities in parallel.""" + + @workflow.run + async def run(self, inp: BatchInput) -> None: + await asyncio.gather( + *( + workflow.execute_activity( + do_work, + inp.seconds, + start_to_close_timeout=timedelta(minutes=2), + ) + for _ in range(inp.activities) + ) + ) diff --git a/custom_worker_tuner/starter.py b/custom_worker_tuner/starter.py new file mode 100644 index 00000000..84ed770b --- /dev/null +++ b/custom_worker_tuner/starter.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import asyncio +import time +import uuid + +from temporalio.client import Client +from temporalio.envconfig import ClientConfig + +from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch + +# Tweak these to push more or less load. +WORKFLOWS = 10 +ACTIVITIES_PER_WORKFLOW = 20 +SECONDS_PER_ACTIVITY = 2.0 + + +async def main() -> None: + config = ClientConfig.load_client_connect_config() + config.setdefault("target_host", "localhost:7233") + client = await Client.connect(**config) + run_id = uuid.uuid4().hex[:8] + inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY) + total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW + + print( + f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s" + ) + t0 = time.perf_counter() + + handles = await asyncio.gather( + *( + client.start_workflow( + RunBatch.run, + inp, + id=f"batch-{run_id}-{i}", + task_queue=TASK_QUEUE, + ) + for i in range(WORKFLOWS) + ) + ) + await asyncio.gather(*(h.result() for h in handles)) + + wall = time.perf_counter() - t0 + print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/custom_worker_tuner/supplier.py b/custom_worker_tuner/supplier.py new file mode 100644 index 00000000..fa8961c1 --- /dev/null +++ b/custom_worker_tuner/supplier.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import itertools +import logging + +from temporalio.worker import ( + CustomSlotSupplier, + SlotMarkUsedContext, + SlotPermit, + SlotReleaseContext, + SlotReserveContext, +) + +from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool + +logger = logging.getLogger(__name__) + +_slot_id_gen = itertools.count(1) + + +class _Permit(SlotPermit): + """SlotPermit subclass that just carries a sequential id for logs.""" + + def __init__(self, slot_id: int) -> None: + super().__init__() + self.slot_id = slot_id + + +class PoolSlotSupplier(CustomSlotSupplier): + """Hands out slots only when the backing pool has a free connection.""" + + def __init__(self, connection_pool: FakeDatabaseConnectionPool) -> None: + self.connection_pool = connection_pool + logger.info("PoolSlotSupplier ready: connection_pool=%s", connection_pool.name) + + async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit: + """Block until the pool has capacity, then grant a slot.""" + await self.connection_pool.acquire() + after = self.connection_pool.in_use + slot_id = next(_slot_id_gen) + self._log("reserve", slot_id, "ready to poll", after - 1, after) + return _Permit(slot_id) + + def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None: + """Eager path: try to claim a slot without blocking.""" + if self.connection_pool.try_acquire(): + after = self.connection_pool.in_use + slot_id = next(_slot_id_gen) + self._log("reserve", slot_id, "eager dispatch", after - 1, after) + return _Permit(slot_id) + return None + + def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None: + slot_id = getattr(ctx.permit, "slot_id", "?") + in_use = self.connection_pool.in_use + self._log("used", slot_id, "activity running", in_use, in_use) + + def release_slot(self, ctx: SlotReleaseContext) -> None: + slot_id = getattr(ctx.permit, "slot_id", "?") + detail = "no task arrived" if ctx.slot_info is None else "activity done" + before = self.connection_pool.in_use + self.connection_pool.release() + after = self.connection_pool.in_use + self._log("release", slot_id, detail, before, after) + + def _log(self, event: str, slot_id, note: str, before: int, after: int) -> None: + cap = self.connection_pool.allowed_connections + count = f"{before:>2}→{after:>2}/{cap}" + queued = self.connection_pool.queued + logger.info(f"{event:<8} {count} {queued:>5} {note}") diff --git a/custom_worker_tuner/worker.py b/custom_worker_tuner/worker.py new file mode 100644 index 00000000..1a0d5669 --- /dev/null +++ b/custom_worker_tuner/worker.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import asyncio +import logging + +from temporalio.client import Client +from temporalio.envconfig import ClientConfig +from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner + +from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool +from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work +from custom_worker_tuner.supplier import PoolSlotSupplier + +CAPACITY = 10 # number of pool connections (and concurrent activities) +LOG_LEVEL = "INFO" + + +async def main() -> None: + logging.basicConfig( + level=getattr(logging, LOG_LEVEL.upper(), logging.INFO), + format="%(asctime)s.%(msecs)03d %(message)s", + datefmt="%H:%M:%S", + ) + + config = ClientConfig.load_client_connect_config() + config.setdefault("target_host", "localhost:7233") + client = await Client.connect(**config) + + pool = FakeDatabaseConnectionPool(allowed_connections=CAPACITY, name="db") + supplier = PoolSlotSupplier(pool) + tuner = WorkerTuner.create_composite( + workflow_supplier=FixedSizeSlotSupplier(100), + activity_supplier=supplier, + local_activity_supplier=FixedSizeSlotSupplier(100), + nexus_supplier=FixedSizeSlotSupplier(100), + ) + + worker = Worker( + client, + task_queue=TASK_QUEUE, + workflows=[RunBatch], + activities=[do_work], + tuner=tuner, + ) + + print(f"\nworker started — capacity={CAPACITY}\n") + print("TIME EVENT COUNT QUEUE DETAIL") + print("(COUNT shows before→after / capacity; QUEUE = tasks parked waiting)") + print("─" * 65) + await worker.run() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass