Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 52 additions & 41 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover
if self.on_exit is not None:
self.on_exit(self)

async def prefetcher(
async def prefetcher( # noqa: C901
self,
queue: "asyncio.Queue[bytes | AckableMessage]",
finish_event: asyncio.Event,
Expand All @@ -396,48 +396,59 @@ async def prefetcher(
"""
fetched_tasks: int = 0
iterator = self.broker.listen()
current_message: asyncio.Task[bytes | AckableMessage] = asyncio.create_task(
iterator.__anext__(), # type: ignore
)
current_message: asyncio.Task[bytes | AckableMessage] | None = None

while True:
if finish_event.is_set():
break
try:
await self.sem_prefetch.acquire()
if (
self.max_tasks_to_execute
and fetched_tasks >= self.max_tasks_to_execute
):
logger.info("Max number of tasks executed.")
break
# Here we wait for the message to be fetched,
# but we make it with timeout so it can be interrupted
done, _ = await asyncio.wait({current_message}, timeout=0.3)
# If the message is not fetched, we release the semaphore
# and continue the loop. So it will check if finished event was set.
if not done:
self.sem_prefetch.release()
continue
# We're done, so now we need to check
# whether task has returned an error.
message = current_message.result()
current_message = asyncio.create_task(iterator.__anext__()) # type: ignore
fetched_tasks += 1
await queue.put(message)
# Custom hooks for OTel and any future instrumentations
for middleware in reversed(self.broker.middlewares):
if hasattr(middleware, "on_prefetch_queue_add"):
await maybe_awaitable(
middleware.on_prefetch_queue_add(), # type: ignore
try:
while not finish_event.is_set():
try:
await self.sem_prefetch.acquire()
if (
self.max_tasks_to_execute
and fetched_tasks >= self.max_tasks_to_execute
):
logger.info("Max number of tasks executed.")
break
if current_message is None:
current_message = asyncio.create_task(
iterator.__anext__(), # type: ignore
)
except (asyncio.CancelledError, StopAsyncIteration):
break
# We don't want to fetch new messages if we are shutting down.
logger.info("Stopping prefetching messages...")
current_message.cancel()
await queue.put(QUEUE_DONE)
self.sem_prefetch.release()
# Here we wait for the message to be fetched,
# but we make it with timeout so it can be interrupted
done, _ = await asyncio.wait({current_message}, timeout=0.3)
# If the message is not fetched, we release the semaphore
# and continue the loop. So it will check if finished event was set.
if not done:
self.sem_prefetch.release()
continue
# We're done, so now we need to check
# whether task has returned an error.
message = current_message.result()
current_message = None
fetched_tasks += 1
await queue.put(message)
# Custom hooks for OTel and any future instrumentations
for middleware in reversed(self.broker.middlewares):
if hasattr(middleware, "on_prefetch_queue_add"):
await maybe_awaitable(
middleware.on_prefetch_queue_add(), # type: ignore
)
except (asyncio.CancelledError, StopAsyncIteration):
break
finally:
# We don't want to fetch new messages if we are shutting down.
logger.info("Stopping prefetching messages...")
# Short window to deliver, then forward or cancel.
if current_message is not None:
await asyncio.wait({current_message}, timeout=0.3)
if not current_message.done():
current_message.cancel()
elif (
not current_message.cancelled()
and current_message.exception() is None
):
await queue.put(current_message.result())
await queue.put(QUEUE_DONE)
self.sem_prefetch.release()

async def runner(
self,
Expand Down
25 changes: 25 additions & 0 deletions tests/receiver/test_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,28 @@ async def test_no_semaphore_without_max_async_tasks() -> None:
"""Test that semaphore is None when max_async_tasks is not set."""
receiver = get_receiver(max_async_tasks=None)
assert receiver.sem is None


async def test_prefetcher_does_not_pop_message_past_max_tasks() -> None:
"""Test not pulling a message without the intention of running it."""
broker = AsyncQueueBroker()

@broker.task
async def noop() -> None:
return None

for _ in range(6):
await noop.kiq()

assert broker.queue.qsize() == 6

receiver = Receiver(
broker,
executor=ThreadPoolExecutor(max_workers=1),
max_async_tasks=1,
max_tasks_to_execute=5,
)

await receiver.listen(asyncio.Event())

assert broker.queue.qsize() == 1
Loading