From 842142d00891375eda46bca4f17ded015c13e414 Mon Sep 17 00:00:00 2001 From: Leandro Lucarella Date: Sat, 26 Nov 2022 23:36:04 +0100 Subject: [PATCH] Make Select.ready() an async iterator The async iterator yields a set of receivers that are ready to be consumed. Users need to consume() explicitly from the receivers that are ready and are not automatically consumed() by the select object if they were not consumed in the select loop. If a receiver is stopped, then it will be automatically removed from the select loop. Signed-off-by: Leandro Lucarella --- RELEASE_NOTES.md | 20 ++++ src/frequenz/channels/util/_select.py | 156 +++++++------------------- tests/test_select.py | 46 ++++---- tests/utils/test_file_watcher.py | 38 ++++--- tests/utils/test_timer.py | 16 ++- 5 files changed, 119 insertions(+), 157 deletions(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 7185e48e..7ab7964b 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -20,6 +20,26 @@ * Now exceptions are not raised in Receiver.ready() but in Receiver.consume() (receive() or the async iterator `anext`). +* `Select` constructor now takes a variable number of receivers: + + ```py + select = Select(recv1, recv2) + ``` + +* `Select.ready()` is now an async iterator and yields a set of receivers that are ready to be consumed. Receivers must be explicitly consumed and if a ready receiver is not consumed, the ready message won't be discarded by select any more, it will wait indefinitely until it is consumed. + + Example: + + ```py + select = Select(recv1, recv2) + async for ready_set in select.ready(): + if recv1 in ready_set: + msg = recv1.consume() + # do whatever with msg, consume() can also raise an error as normal + if recv2 in ready_set: + msg = recv2.consume() + ``` + ## New Features * New exceptions were added: diff --git a/src/frequenz/channels/util/_select.py b/src/frequenz/channels/util/_select.py index 9c99c5dd..0be618ed 100644 --- a/src/frequenz/channels/util/_select.py +++ b/src/frequenz/channels/util/_select.py @@ -8,52 +8,19 @@ is closed in case of `Receiver` class. """ +from __future__ import annotations + import asyncio import logging -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, TypeVar +from collections.abc import AsyncIterator +from typing import Any, Dict, Set, TypeVar from .._base_classes import Receiver -from .._exceptions import ReceiverStoppedError logger = logging.Logger(__name__) T = TypeVar("T") -@dataclass -class _Selected: - """A wrapper class for holding values in `Select`. - - Using this wrapper class allows `Select` to inform user code when a - receiver gets closed. - """ - - inner: Optional[Any] - - -@dataclass -class _ReadyReceiver: - """A class for tracking receivers that have a message ready to be read. - - Used to make sure that receivers are not consumed from until messages are accessed - by user code, at which point, it will be converted into a `_Selected` object. - - When a channel has closed, `recv` should be `None`. - """ - - recv: Optional[Receiver[Any]] - - def get(self) -> _Selected: - """Consume a message from the receiver and return a `_Selected` object. - - Returns: - An instance of `_Selected` holding a value from the receiver. - """ - if self.recv is None: - return _Selected(None) - return _Selected(self.recv.consume()) # pylint: disable=protected-access - - class Select: """Select the next available message from a group of Receivers. @@ -87,24 +54,20 @@ class Select: ``` """ - def __init__(self, **kwargs: Receiver[Any]) -> None: + def __init__(self, *receivers: Receiver[Any]) -> None: """Create a `Select` instance. Args: - **kwargs: sequence of receivers + *receivers: A set of receivers to select from. """ - self._receivers = kwargs + self._receivers: Dict[str, Receiver[Any]] = { + f"0x{id(r):x}": r for r in receivers + } self._pending: Set[asyncio.Task[bool]] = set() for name, recv in self._receivers.items(): self._pending.add(asyncio.create_task(recv.ready(), name=name)) - self._ready_count = 0 - self._prev_ready_count = 0 - self._result: Dict[str, Optional[_ReadyReceiver]] = { - name: None for name in self._receivers - } - def __del__(self) -> None: """Cleanup any pending tasks.""" for task in self._pending: @@ -117,82 +80,41 @@ async def stop(self) -> None: await asyncio.gather(*self._pending, return_exceptions=True) self._pending = set() - async def ready(self) -> bool: + async def ready(self) -> AsyncIterator[Set[Receiver[Any]]]: """Wait until there is a message in any of the receivers. Returns `True` if there is a message available, and `False` if all receivers have closed. - Returns: - Whether there are further messages or not. - """ - # This function will change radically soon - # pylint: disable=too-many-nested-blocks - if self._ready_count > 0: - if self._ready_count == self._prev_ready_count: - dropped_names: List[str] = [] - for name, value in self._result.items(): - if value is not None: - dropped_names.append(name) - if value.recv is not None: - try: - value.recv.consume() - except ReceiverStoppedError: - pass - self._result[name] = None - self._ready_count = 0 - self._prev_ready_count = 0 - logger.warning( - "Select.ready() dropped data from receiver(s): %s, " - "because no messages have been fetched since the last call to ready().", - dropped_names, - ) - else: - self._prev_ready_count = self._ready_count - return True - if len(self._pending) == 0: - return False - - # once all the pending messages have been consumed, reset the - # `_prev_ready_count` as well, and wait for new messages. - self._prev_ready_count = 0 - - done, self._pending = await asyncio.wait( - self._pending, return_when=asyncio.FIRST_COMPLETED - ) - for task in done: - name = task.get_name() - recv = self._receivers[name] - receiver_active = task.result() - if receiver_active: - ready_recv = recv - else: - ready_recv = None - self._ready_count += 1 - self._result[name] = _ReadyReceiver(ready_recv) - # if channel or Receiver is closed - # don't add a task for it again. - if not receiver_active: - continue - self._pending.add(asyncio.create_task(recv.ready(), name=name)) - return True - - def __getattr__(self, name: str) -> Optional[Any]: - """Return the latest unread message from a `Receiver`, if available. - - Args: - name: Name of the channel. - - Returns: - Latest unread message for the specified `Receiver`, or `None`. + Yields: + A set with the receivers that are ready to be consumed. Raises: - KeyError: when the name was not specified when creating the - `Select` instance. + BaseException: if the receivers raise any exceptions. + + # noqa: DAR401 exc (https://github.com/terrencepreilly/darglint/issues/181) """ - result = self._result[name] - if result is None: - return result - self._result[name] = None - self._ready_count -= 1 - return result.get() + while self._pending: + done, self._pending = await asyncio.wait( + self._pending, return_when=asyncio.FIRST_COMPLETED + ) + ready_set: Set[Receiver[Any]] = set() + for task in done: + name = task.get_name() + recv = self._receivers[name] + # This will raise if there was an exception in the task + # Colloect or not collect exceptions + exc = task.exception() + if exc is not None: + raise exc + ready_set.add(recv) + + yield ready_set + + for task in done: + receiver_active = task.result() + if not receiver_active: + continue + name = task.get_name() + recv = self._receivers[name] + self._pending.add(asyncio.create_task(recv.ready(), name=name)) diff --git a/tests/test_select.py b/tests/test_select.py index fa372544..53b734dd 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -6,7 +6,7 @@ import asyncio from typing import List -from frequenz.channels import Anycast, Sender +from frequenz.channels import Anycast, ReceiverStoppedError, Sender from frequenz.channels.util import Select @@ -29,27 +29,39 @@ async def send(ch1: Sender[int], ch2: Sender[int], ch3: Sender[int]) -> None: senders = asyncio.create_task( send(chan1.new_sender(), chan2.new_sender(), chan3.new_sender()), ) - select = Select( - ch1=chan1.new_receiver(), - ch2=chan2.new_receiver(), - ch3=chan3.new_receiver(), - ) + + recv1 = chan1.new_receiver() + recv2 = chan2.new_receiver() + recv3 = chan3.new_receiver() + + select = Select(recv1, recv2, recv3) # only check for messages from all iterators but `ch3`. # it ensures iterators are not blocking channels in case they # are not being read from. results: List[int] = [] - while await select.ready(): - if item := select.ch1: - if val := item.inner: - results.append(val) - else: + async for ready_set in select.ready(): + if recv1 in ready_set: + try: + msg = recv1.consume() + except ReceiverStoppedError: results.append(-1) - elif item := select.ch2: - if val := item.inner: - results.append(val) else: + results.append(msg) + + if recv2 in ready_set: + try: + msg = recv2.consume() + except ReceiverStoppedError: results.append(-2) + else: + results.append(msg) + + if recv3 in ready_set: + try: + _ = recv3.consume() + except ReceiverStoppedError: + pass await senders expected_results = [ @@ -68,9 +80,3 @@ async def send(ch1: Sender[int], ch2: Sender[int], ch3: Sender[int]) -> None: -2, # marks end of messages from channel 2 ] assert results == expected_results - got_err = False - try: - item = select.unknown_channel - except KeyError: - got_err = True - assert got_err diff --git a/tests/utils/test_file_watcher.py b/tests/utils/test_file_watcher.py index 2a640230..90c0247f 100644 --- a/tests/utils/test_file_watcher.py +++ b/tests/utils/test_file_watcher.py @@ -22,12 +22,16 @@ async def test_file_watcher(tmp_path: pathlib.Path) -> None: number_of_writes = 0 expected_number_of_writes = 3 - select = Select(timer=Timer(0.1), file_watcher=file_watcher) - while await select.ready(): - if msg := select.timer: - filename.write_text(f"{msg.inner}") - elif msg := select.file_watcher: - assert msg.inner == filename + timer = Timer(0.1) + + select = Select(timer, file_watcher) + async for ready_set in select.ready(): + if timer in ready_set: + msg = timer.consume() + filename.write_text(f"{msg}") + if file_watcher in ready_set: + fname = file_watcher.consume() + assert fname == filename number_of_writes += 1 # After receiving a write 3 times, unsubscribe from the writes channel if number_of_writes == expected_number_of_writes: @@ -48,16 +52,22 @@ async def test_file_watcher_change_types(tmp_path: pathlib.Path) -> None: paths=[str(tmp_path)], event_types={FileWatcher.EventType.DELETE} ) - select = Select( - write_timer=Timer(0.1), deletion_timer=Timer(0.5), watcher=file_watcher - ) + write_timer = Timer(0.1) + deletion_timer = Timer(0.5) + watcher = file_watcher + + select = Select(write_timer, deletion_timer, watcher) number_of_receives = 0 - while await select.ready(): - if msg := select.write_timer: - filename.write_text(f"{msg.inner}") - elif _ := select.deletion_timer: + async for ready_set in select.ready(): + if write_timer in ready_set: + msg = write_timer.consume() + filename.write_text(f"{msg}") + if deletion_timer in ready_set: + _ = deletion_timer.consume() # We need to consume the message os.remove(filename) - elif _ := select.watcher: + if watcher in ready_set: + fname = watcher.consume() + assert fname == filename number_of_receives += 1 break assert number_of_receives == 1 diff --git a/tests/utils/test_timer.py b/tests/utils/test_timer.py index bdcdb704..b01b059f 100644 --- a/tests/utils/test_timer.py +++ b/tests/utils/test_timer.py @@ -55,6 +55,7 @@ class _TestCase: assert fail_count < len(test_cases) +# pylint: disable=too-many-locals async def test_timer_reset() -> None: """Ensure timer resets function as expected.""" chan1 = Anycast[int]() @@ -68,18 +69,21 @@ async def send(ch1: Sender[int]) -> None: await asyncio.sleep(reset_delta) await ch1.send(ctr) + senders = asyncio.create_task(send(chan1.new_sender())) + timer = Timer(timer_delta) + msg_recv = chan1.new_receiver() - senders = asyncio.create_task(send(chan1.new_sender())) - select = Select(msg=chan1.new_receiver(), timer=timer) + select = Select(msg_recv, timer) start_ts = datetime.now(timezone.utc) stop_ts: Optional[datetime] = None - while await select.ready(): - if select.msg: + async for ready_set in select.ready(): + if msg_recv in ready_set: + _ = msg_recv.consume() # We need to consume the message timer.reset() - elif event_ts := select.timer: - stop_ts = event_ts.inner + if timer in ready_set: + stop_ts = timer.consume() break assert stop_ts is not None