diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 9ca39d45..11da324e 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -2,16 +2,90 @@ ## Summary - +The minimum Python supported version was bumped to 3.11 and the `Select` class replaced by the new `select()` function. ## Upgrading * The minimum supported Python version was bumped to 3.11, downstream projects will need to upgrade too to use this version. +* The `Select` class was replaced by a new `select()` function, with the following improvements: + + * Type-safe: proper type hinting by using the new helper type guard `selected_from()`. + * Fixes potential starvation issues. + * Simplifies the interface by providing values one-by-one. + * Guarantees there are no dangling tasks left behind when used as an async context manager. + + This new function is an [async iterator](https://docs.python.org/3.11/library/collections.abc.html#collections.abc.AsyncIterator), and makes sure no dangling tasks are left behind after a select loop is done. + + Example: + ```python + timer1 = Timer.periodic(datetime.timedelta(seconds=1)) + timer2 = Timer.timeout(datetime.timedelta(seconds=0.5)) + + async for selected in selector(timer1, timer2): + if selected_from(selected, timer1): + # Beware: `selected.value` might raise an exception, you can always + # check for exceptions with `selected.exception` first or use + # a try-except block. You can also quickly check if the receiver was + # stopped and let any other unexpected exceptions bubble up. + if selected.was_stopped(): + print("timer1 was stopped") + continue + print(f"timer1: now={datetime.datetime.now()} drift={selected.value}") + timer2.stop() + elif selected_from(selected, timer2): + # Explicitly handling of exceptions + match selected.exception: + case ReceiverStoppedError(): + print("timer2 was stopped") + case Exception() as exception: + print(f"timer2: exception={exception}") + case None: + # All good, no exception, we can use `selected.value` safely + print( + f"timer2: now={datetime.datetime.now()} " + f"drift={selected.value}" + ) + case _ as unhanded: + assert_never(unhanded) + else: + # This is not necessary, as select() will check for exhaustiveness, but + # it is good practice to have it in case you forgot to handle a new + # receiver added to `select()` at a later point in time. + assert False + ``` + ## New Features - +* A new `select()` function was added, please look at the *Upgrading* section for details. + +* A new `Event` utility receiver was added. + + This receiver can be made ready manually. It is mainly useful for testing but can also become handy in scenarios where a simple, on-off signal needs to be sent to a select loop for example. + + Example: + + ```python + import asyncio + from frequenz.channels import Receiver + from frequenz.channels.util import Event, select, selected_from + + other_receiver: Receiver[int] = ... + exit_event = Event() + + async def exit_after_10_seconds() -> None: + asyncio.sleep(10) + exit_event.set() + + asyncio.ensure_future(exit_after_10_seconds()) -## Bug Fixes + async for selected in selector(exit_event, other_receiver): + if selected_from(selected, exit_event): + break + if selected_from(selected, other_receiver): + print(selected.value) + else: + assert False, "Unknow receiver selected" + ``` - +* The `Timer` class now has more descriptive `__str__` and `__repr__` methods. diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index 1d314087..cbdf3d1e 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -28,7 +28,7 @@ class Anycast(Generic[T]): thread-safe. When there are multiple channel receivers, they can be awaited - simultaneously using [Select][frequenz.channels.util.Select], + simultaneously using [select][frequenz.channels.util.select], [Merge][frequenz.channels.util.Merge] or [MergeNamed][frequenz.channels.util.MergeNamed]. diff --git a/src/frequenz/channels/_broadcast.py b/src/frequenz/channels/_broadcast.py index 55eeb382..33f61a80 100644 --- a/src/frequenz/channels/_broadcast.py +++ b/src/frequenz/channels/_broadcast.py @@ -38,7 +38,7 @@ class Broadcast(Generic[T]): are thread-safe. Because of this, `Broadcast` channels are thread-safe. When there are multiple channel receivers, they can be awaited - simultaneously using [Select][frequenz.channels.util.Select], + simultaneously using [select][frequenz.channels.util.select], [Merge][frequenz.channels.util.Merge] or [MergeNamed][frequenz.channels.util.MergeNamed]. diff --git a/src/frequenz/channels/util/__init__.py b/src/frequenz/channels/util/__init__.py index 72369a7b..515e1ac2 100644 --- a/src/frequenz/channels/util/__init__.py +++ b/src/frequenz/channels/util/__init__.py @@ -5,6 +5,9 @@ A module with several utilities to work with channels: +* [Event][frequenz.channels.util.Event]: + A [receiver][frequenz.channels.Receiver] that can be made ready through an event. + * [FileWatcher][frequenz.channels.util.FileWatcher]: A [receiver][frequenz.channels.Receiver] that watches for file events. @@ -20,15 +23,22 @@ * [Timer][frequenz.channels.util.Timer]: A [receiver][frequenz.channels.Receiver] that ticks at certain intervals. -* [Select][frequenz.channels.util.Select]: A helper to select the next - available message for each [receiver][frequenz.channels.Receiver] in a group - of receivers. +* [select][frequenz.channels.util.select]: Iterate over the values of all + [receivers][frequenz.channels.Receiver] as new values become available. """ +from ._event import Event from ._file_watcher import FileWatcher from ._merge import Merge from ._merge_named import MergeNamed -from ._select import Select +from ._select import ( + Selected, + SelectError, + SelectErrorGroup, + UnhandledSelectedError, + select, + selected_from, +) from ._timer import ( MissedTickPolicy, SkipMissedAndDrift, @@ -38,13 +48,19 @@ ) __all__ = [ + "Event", "FileWatcher", "Merge", "MergeNamed", "MissedTickPolicy", - "Timer", - "Select", + "SelectError", + "SelectErrorGroup", + "Selected", "SkipMissedAndDrift", "SkipMissedAndResync", + "Timer", "TriggerAllMissed", + "UnhandledSelectedError", + "select", + "selected_from", ] diff --git a/src/frequenz/channels/util/_event.py b/src/frequenz/channels/util/_event.py new file mode 100644 index 00000000..3c24a8ee --- /dev/null +++ b/src/frequenz/channels/util/_event.py @@ -0,0 +1,161 @@ +# License: MIT +# Copyright © 2023 Frequenz Energy-as-a-Service GmbH + +"""A receiver that can be made ready through an event.""" + + +import asyncio as _asyncio + +from frequenz.channels import _base_classes, _exceptions + + +class Event(_base_classes.Receiver[None]): + """A receiver that can be made ready through an event. + + The receiver (the [`ready()`][frequenz.channels.util.Event.ready] method) will wait + until [`set()`][frequenz.channels.util.Event.set] is called. At that point the + receiver will wait again after the event is + [`consume()`][frequenz.channels.Receiver.consume]d. + + The receiver can be completely stopped by calling + [`stop()`][frequenz.channels.Receiver.stop]. + + Example: + ```python + import asyncio + from frequenz.channels import Receiver + from frequenz.channels.util import Event, select, selected_from + + other_receiver: Receiver[int] = ... + exit_event = Event() + + async def exit_after_10_seconds() -> None: + asyncio.sleep(10) + exit_event.set() + + asyncio.ensure_future(exit_after_10_seconds()) + + async for selected in select(exit_event, other_receiver): + if selected_from(selected, exit_event): + break + if selected_from(selected, other_receiver): + print(selected.value) + else: + assert False, "Unknow receiver selected" + ``` + """ + + def __init__(self, name: str | None = None) -> None: + """Create a new instance. + + Args: + name: The name of the receiver. If `None` the `id(self)` will be used as + the name. This is only for debugging purposes, it will be shown in the + string representation of the receiver. + """ + self._event: _asyncio.Event = _asyncio.Event() + """The event that is set when the receiver is ready.""" + + self._name: str = name or str(id(self)) + """The name of the receiver. + + This is for debugging purposes, it will be shown in the string representation + of the receiver. + """ + + self._is_set: bool = False + """Whether the receiver is ready to be consumed. + + This is used to differentiate between when the receiver was stopped (the event + is triggered too) but still there is an event to be consumed and when it was + stopped but was not explicitly set(). + """ + + self._is_stopped: bool = False + """Whether the receiver is stopped.""" + + @property + def name(self) -> str: + """The name of this receiver. + + This is for debugging purposes, it will be shown in the string representation + of this receiver. + + Returns: + The name of this receiver. + """ + return self._name + + @property + def is_set(self) -> bool: + """Whether this receiver is set (ready). + + Returns: + Whether this receiver is set (ready). + """ + return self._is_set + + @property + def is_stopped(self) -> bool: + """Whether this receiver is stopped. + + Returns: + Whether this receiver is stopped. + """ + return self._is_stopped + + def stop(self) -> None: + """Stop this receiver.""" + self._is_stopped = True + self._event.set() + + def set(self) -> None: + """Trigger the event (make the receiver ready).""" + self._is_set = True + self._event.set() + + async def ready(self) -> bool: + """Wait until this receiver is ready. + + Returns: + Whether this receiver is still running. + """ + if self._is_stopped: + return False + await self._event.wait() + return not self._is_stopped + + def consume(self) -> None: + """Consume the event. + + This makes this receiver wait again until the event is set again. + + Raises: + ReceiverStoppedError: If this receiver is stopped. + """ + if not self._is_set and self._is_stopped: + raise _exceptions.ReceiverStoppedError(self) + + assert self._is_set, "calls to `consume()` must be follow a call to `ready()`" + + self._is_set = False + self._event.clear() + + def __str__(self) -> str: + """Return a string representation of this receiver. + + Returns: + A string representation of this receiver. + """ + return f"{type(self).__name__}({self._name!r})" + + def __repr__(self) -> str: + """Return a string representation of this receiver. + + Returns: + A string representation of this receiver. + """ + return ( + f"<{type(self).__name__} name={self._name!r} is_set={self.is_set!r} " + f"is_stopped={self.is_stopped!r}>" + ) diff --git a/src/frequenz/channels/util/_select.py b/src/frequenz/channels/util/_select.py index 66868122..9978fea7 100644 --- a/src/frequenz/channels/util/_select.py +++ b/src/frequenz/channels/util/_select.py @@ -9,198 +9,390 @@ """ import asyncio -import logging -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, TypeVar +from typing import Any, AsyncIterator, Generic, TypeGuard, TypeVar from .._base_classes import Receiver from .._exceptions import ReceiverStoppedError -logger = logging.Logger(__name__) -T = TypeVar("T") +_T = TypeVar("_T") -@dataclass -class _Selected: - """A wrapper class for holding values in `Select`. +class Selected(Generic[_T]): + """A result of a [`select`][frequenz.channels.util.select] iteration. - Using this wrapper class allows `Select` to inform user code when a - receiver gets closed. - """ + The selected receiver is consumed immediately and the received value is stored in + the instance, unless there was an exception while receiving the value, in which case + the exception is stored instead. - inner: Optional[Any] + `Selected` instances should be used in conjunction with the + [`selected_from()`][frequenz.channels.util.selected_from] function to determine + which receiver was selected. + Please see [`select`][frequenz.channels.util.select] for an example. + """ -@dataclass -class _ReadyReceiver: - """A class for tracking receivers that have a message ready to be read. + class _EmptyResult: + """A sentinel value to distinguish between None and empty result. - Used to make sure that receivers are not consumed from until messages are accessed - by user code, at which point, it will be converted into a `_Selected` object. + We need a sentinel because a result can also be `None`. + """ - When a channel has closed, `recv` should be `None`. - """ + def __repr__(self) -> str: + return "" - recv: Optional[Receiver[Any]] + def __init__(self, receiver: Receiver[_T]) -> None: + """Create a new instance. - def get(self) -> _Selected: - """Consume a message from the receiver and return a `_Selected` object. + The receiver is consumed immediately when creating the instance and the received + value is stored in the instance for later use as + [`value`][frequenz.channels.util.Selected.value]. If there was an exception + while receiving the value, then the exception is stored in the instance instead + (as [`exception`][frequenz.channels.util.Selected.exception]). - Returns: - An instance of `_Selected` holding a value from the receiver. + Args: + receiver: The receiver that was selected. """ - if self.recv is None: - return _Selected(None) - return _Selected(self.recv.consume()) # pylint: disable=protected-access + self._recv: Receiver[_T] = receiver + """The receiver that was selected.""" + self._value: _T | Selected._EmptyResult = Selected._EmptyResult() + """The value that was received. -class Select: - """Select the next available message from a group of Receivers. + If there was an exception while receiving the value, then this will be `None`. + """ + self._exception: Exception | None = None + """The exception that was raised while receiving the value (if any).""" - If `Select` was created with more `Receiver` than what are read in - the if-chain after each call to - [ready()][frequenz.channels.util.Select.ready], messages coming in the - additional receivers are dropped, and a warning message is logged. + try: + self._value = receiver.consume() + except Exception as exc: # pylint: disable=broad-except + self._exception = exc - [Receiver][frequenz.channels.Receiver]s also function as `Receiver`. + self._handled: bool = False + """Flag to indicate if this selected has been handled in the if-chain.""" - When Select is no longer needed, then it should be stopped using - `self.stop()` method. This would cleanup any internal pending async tasks. + @property + def value(self) -> _T: + """The value that was received, if any. - Example: - For example, if there are two receivers that you want to - simultaneously wait on, this can be done with: + Returns: + The value that was received. - ```python - from frequenz.channels import Broadcast - - channel1 = Broadcast[int]("input-chan-1") - channel2 = Broadcast[int]("input-chan-2") - receiver1 = channel1.new_receiver() - receiver2 = channel2.new_receiver() - - select = Select(name1 = receiver1, name2 = receiver2) - while await select.ready(): - if msg := select.name1: - if val := msg.inner: - # do something with `val` - pass - else: - # handle closure of receiver. - pass - elif msg := select.name2: - # do something with `msg.inner` - pass - ``` - """ + Raises: + Exception: If there was an exception while receiving the value. Normally + this should be an [`frequenz.channels.Error`][frequenz.channels.Error] + instance, but catches all exceptions in case some receivers can raise + anything else. - def __init__(self, **kwargs: Receiver[Any]) -> None: - """Create a `Select` instance. + # noqa: DAR401 _exception + """ + if self._exception is not None: + raise self._exception + assert not isinstance(self._value, Selected._EmptyResult) + return self._value - Args: - **kwargs: sequence of receivers + @property + def exception(self) -> Exception | None: + """The exception that was raised while receiving the value (if any). + + Returns: + The exception that was raised while receiving the value (if any). """ - self._receivers = kwargs - self._pending: Set[asyncio.Task[bool]] = set() - - for name, recv in self._receivers.items(): - self._pending.add(asyncio.create_task(recv.ready(), name=name)) - - self._ready_count = 0 - self._prev_ready_count = 0 - self._result: Dict[str, Optional[_ReadyReceiver]] = { - name: None for name in self._receivers - } - - def __del__(self) -> None: - """Cleanup any pending tasks.""" - for task in self._pending: - if not task.done() and task.get_loop().is_running(): - task.cancel() - - async def stop(self) -> None: - """Stop the `Select` instance and cleanup any pending tasks.""" - for task in self._pending: - task.cancel() - await asyncio.gather(*self._pending, return_exceptions=True) - self._pending = set() + return self._exception - async def ready(self) -> bool: - """Wait until there is a message in any of the receivers. + def was_stopped(self) -> bool: + """Check if the selected receiver was stopped. - Returns `True` if there is a message available, and `False` if all - receivers have closed. + Check if the selected receiver raised + a [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError] while + consuming a value. Returns: - Whether there are further messages or not. + Whether the receiver was stopped. """ - # This function will change radically soon - # pylint: disable=too-many-nested-blocks - if self._ready_count > 0: - if self._ready_count == self._prev_ready_count: - dropped_names: List[str] = [] - for name, value in self._result.items(): - if value is not None: - dropped_names.append(name) - if value.recv is not None: - try: - value.recv.consume() - except ReceiverStoppedError: - pass - self._result[name] = None - self._ready_count = 0 - self._prev_ready_count = 0 - logger.warning( - "Select.ready() dropped data from receiver(s): %s, " - "because no messages have been fetched since the last call to ready().", - dropped_names, - ) - else: - self._prev_ready_count = self._ready_count - return True - if len(self._pending) == 0: - return False + return isinstance(self._exception, ReceiverStoppedError) - # once all the pending messages have been consumed, reset the - # `_prev_ready_count` as well, and wait for new messages. - self._prev_ready_count = 0 + def __str__(self) -> str: + """Return a string representation of this instance. - done, self._pending = await asyncio.wait( - self._pending, return_when=asyncio.FIRST_COMPLETED + Returns: + A string representation of this instance. + """ + return ( + f"{type(self).__name__}({self._recv}) -> " + f"{self._exception or self._value})" ) - for task in done: - name = task.get_name() - recv = self._receivers[name] - receiver_active = task.result() - if receiver_active: - ready_recv = recv - else: - ready_recv = None - self._ready_count += 1 - self._result[name] = _ReadyReceiver(ready_recv) - # if channel or Receiver is closed - # don't add a task for it again. - if not receiver_active: - continue - self._pending.add(asyncio.create_task(recv.ready(), name=name)) - return True - - def __getattr__(self, name: str) -> Optional[Any]: - """Return the latest unread message from a `Receiver`, if available. - Args: - name: Name of the channel. + def __repr__(self) -> str: + """Return a the internal representation of this instance. Returns: - Latest unread message for the specified `Receiver`, or `None`. + A string representation of this instance. + """ + return ( + f"{type(self).__name__}({self._recv=}, {self._value=}, " + f"{self._exception=}, {self._handled=})" + ) - Raises: - KeyError: when the name was not specified when creating the - `Select` instance. + +# It would have been nice to be able to make this a method of selected, but sadly +# `TypeGuard`s can't be used as methods. For more information see: +# https://github.com/microsoft/pyright/discussions/3125 +def selected_from( + selected: Selected[Any], receiver: Receiver[_T] +) -> TypeGuard[Selected[_T]]: + """Check if the given receiver was selected by [`select`][frequenz.channels.util.select]. + + This function is used in conjunction with the + [`Selected`][frequenz.channels.util.Selected] class to determine which receiver was + selected in `select()` iteration. + + It also works as a [type guard][typing.TypeGuard] to narrow the type of the + `Selected` instance to the type of the receiver. + + Please see [`select`][frequenz.channels.util.select] for an example. + + Args: + selected: The result of a `select()` iteration. + receiver: The receiver to check if it was the source of a select operation. + + Returns: + Whether the given receiver was selected. + """ + if handled := selected._recv is receiver: # pylint: disable=protected-access + selected._handled = True # pylint: disable=protected-access + return handled + + +class SelectError(BaseException): + """A base exception for [`select`][frequenz.channels.util.select]. + + This exception is raised when a `select()` iteration fails. It is raised as + a single exception when one receiver fails during normal operation (while calling + `ready()` for example). It is raised as a group exception + ([`SelectErrorGroup`][frequenz.channels.util.SelectErrorGroup]) when a `select` loop + is cleaning up after it's done. + """ + + +class UnhandledSelectedError(SelectError, Generic[_T]): + """A receiver was not handled in a [`select()`][frequenz.channels.util.select] loop. + + This exception is raised when a `select()` iteration finishes without a call to + [`selected_from()`][frequenz.channels.util.selected_from] for the selected receiver. + """ + + def __init__(self, selected: Selected[_T]) -> None: + """Create a new instance. + + Args: + selected: The selected receiver that was not handled. """ - result = self._result[name] - if result is None: - return result - self._result[name] = None - self._ready_count -= 1 - return result.get() + recv = selected._recv # pylint: disable=protected-access + super().__init__(f"Selected receiver {recv} was not handled in the if-chain") + self.selected = selected + + +class SelectErrorGroup(BaseExceptionGroup[BaseException], SelectError): + """An exception group for [`select()`][frequenz.channels.util.select] operation. + + This exception group is raised when a [`select()`] loops fails while cleaning up + runing tasts to check for ready receivers. + """ + + +# Typing for select() is tricky. We had the idea of using a declarative design for +# select, something like: +# +# ```python +# class MySelector(Selector): +# receiver1: x.new_receiver() +# receiver2: y.new_receiver() +# +# async for selected in MySelector: +# if selected.receiver is receiver1: +# # Do something with selected.value +# elif selected.receiver is receiver1: +# # Do something with selected.value +# ``` +# +# This is similar to `Enum`, but `Enum` has special support in `mypy` that we can't +# have. +# +# With the current implementation, the typing could be slightly improved by using +# `TypeVarTuple`, but we are not because "transformations" are not supported yet, see: +# https://github.com/python/typing/issues/1216 +# +# Also support for `TypeVarTuple` in general is still experimental (and very incomplete +# in `mypy`). +# +# With this we would also probably be able to properly type `select` and *maybe* even be +# able to leverage the exhaustiveness checking of `mypy` to make sure the selected value +# is narrowed down to the correct type to make sure all receivers are handled, with the +# help of `assert_never` as described in: +# https://docs.python.org/3.11/library/typing.html#typing.assert_never +# +# We also explored the possibility of using `match` to perform exhaustiveness checking, +# but we couldn't find a way to make it work with `match`, and `match` is not yet +# checked for exhaustiveness by `mypy` anyway, see: +# https://github.com/python/mypy/issues/13597 + + +async def select(*receivers: Receiver[Any]) -> AsyncIterator[Selected[Any]]: + """Iterate over the values of all receivers as they receive new values. + + This function is used to iterate over the values of all receivers as they receive + new values. It is used in conjunction with the + [`Selected`][frequenz.channels.util.Selected] class and the + [`selected_from()`][frequenz.channels.util.selected_from] function to determine + which function to determine which receiver was selected in a select operation. + + An exhaustiveness check is performed at runtime to make sure all selected receivers + are handled in the if-chain, so you should call `selected_from()` with all the + receivers passed to `select()` inside the select loop, even if you plan to ignore + a value, to signal `select()` that you are purposefully ignoring the value. + + Note: + The `select()` function is intended to be used in cases where the set of + receivers is static and known beforehand. If you need to dynamically add/remove + receivers from a select loop, there are a few alternatives. Depending on your + use case, one or the other could work better for you: + + * Use [`Merge`][frequenz.channels.util.Merge] or + [`MergeNamed`][frequenz.channels.util.MergeNamed]: this is useful when you + have and unknown number of receivers of the same type that can be handled as + a group. + * Use tasks to manage each recever individually: this is better if there are no + relationships between the receivers. + * Break the `select()` loop and start a new one with the new set of receivers + (this should be the last resort, as it has some performance implications + because the loop needs to be restarted). + + Example: + ```python + import datetime + from typing import assert_never + + from frequenz.channels import ReceiverStoppedError + from frequenz.channels.util import select, selected_from, Timer + + timer1 = Timer.periodic(datetime.timedelta(seconds=1)) + timer2 = Timer.timeout(datetime.timedelta(seconds=0.5)) + + async for selected in select(timer1, timer2): + if selected_from(selected, timer1): + # Beware: `selected.value` might raise an exception, you can always + # check for exceptions with `selected.exception` first or use + # a try-except block. You can also quickly check if the receiver was + # stopped and let any other unexpected exceptions bubble up. + if selected.was_stopped: + print("timer1 was stopped") + continue + print(f"timer1: now={datetime.datetime.now()} drift={selected.value}") + timer2.stop() + elif selected_from(selected, timer2): + # Explicitly handling of exceptions + match selected.exception: + case ReceiverStoppedError(): + print("timer2 was stopped") + case Exception() as exception: + print(f"timer2: exception={exception}") + case None: + # All good, no exception, we can use `selected.value` safely + print( + f"timer2: now={datetime.datetime.now()} drift={selected.value}" + ) + case _ as unhanded: + assert_never(unhanded) + else: + # This is not necessary, as select() will check for exhaustiveness, but + # it is good practice to have it in case you forgot to handle a new + # receiver added to `select()` at a later point in time. + assert False + ``` + + Args: + *receivers: The receivers to select from. + + Yields: + The currently selected item. + + Raises: + UnhandledSelectedError: If a selected receiver was not handled in the if-chain. + SelectErrorGroup: If there is an error while finishing the select operation and + receivers fail while cleaning up. + SelectError: If there is an error while selecting receivers during normal + operation. For example if a receiver raises an exception in the `ready()` + method. Normal errors while receiving values are not raised, but reported + via the `Selected` instance. + """ + receivers_map: dict[str, Receiver[Any]] = {str(hash(r)): r for r in receivers} + pending: set[asyncio.Task[bool]] = set() + + try: + for name, recv in receivers_map.items(): + pending.add(asyncio.create_task(recv.ready(), name=name)) + + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + receiver_active: bool = True + name = task.get_name() + recv = receivers_map[name] + if exception := task.exception(): + match exception: + case asyncio.CancelledError(): + # If the receiver was cancelled, then it means we want to + # exit the select loop, so we handle the receiver but we + # don't add it back to the pending list. + receiver_active = False + case _ as exc: + raise SelectError(f"Error while selecting {recv}") from exc + + selected = Selected(recv) + yield selected + if not selected._handled: # pylint: disable=protected-access + raise UnhandledSelectedError(selected) + + receiver_active = task.result() + if not receiver_active: + continue + + # Add back the receiver to the pending list + name = task.get_name() + recv = receivers_map[name] + pending.add(asyncio.create_task(recv.ready(), name=name)) + finally: + await _stop_pending_tasks(pending) + + +async def _stop_pending_tasks(pending: set[asyncio.Task[bool]]) -> None: + """Stop all pending tasks. + + Args: + pending: The pending tasks to stop. + + Raises: + SelectErrorGroup: If the receivers raise any exceptions. + """ + if pending: + for task in pending: + task.cancel() + done, pending = await asyncio.wait(pending) + assert not pending + exceptions: list[BaseException] = [] + for task in done: + if task.cancelled(): + continue + if exception := task.exception(): + exceptions.append(exception) + if exceptions: + # If the select loop is interrupted by a break or exception, then this + # exception will be actually swallowed, as the select() async generator + # will be collected by the asyncio loop. This shouldn't be too bad as + # errors produced by receivers will be re-raised when trying to use them + # again. + raise SelectErrorGroup("Some receivers failed when select()ing", exceptions) diff --git a/src/frequenz/channels/util/_timer.py b/src/frequenz/channels/util/_timer.py index d49a3664..02c764d6 100644 --- a/src/frequenz/channels/util/_timer.py +++ b/src/frequenz/channels/util/_timer.py @@ -66,6 +66,14 @@ def calculate_next_tick_time( """ return 0 # dummy value to avoid darglint warnings + def __repr__(self) -> str: + """Return a string representation of the instance. + + Returns: + The string representation of the instance. + """ + return f"{type(self).__name__}()" + class TriggerAllMissed(MissedTickPolicy): """A policy that triggers all the missed ticks immediately until it catches up. @@ -242,6 +250,22 @@ def calculate_next_tick_time( return now + interval return scheduled_tick_time + interval + def __str__(self) -> str: + """Return a string representation of the instance. + + Returns: + The string representation of the instance. + """ + return f"{type(self).__name__}({self.delay_tolerance})" + + def __repr__(self) -> str: + """Return a string representation of the instance. + + Returns: + The string representation of the instance. + """ + return f"{type(self).__name__}({self.delay_tolerance=})" + class Timer(Receiver[timedelta]): """A timer receiver that triggers every `interval` time. @@ -283,29 +307,28 @@ class Timer(Receiver[timedelta]): print(f"The timer has triggered {drift=}") ``` - But you can also use [`Select`][frequenz.channels.util.Select] to combine it - with other receivers, and even start it (semi) manually: + But you can also use a [`select`][frequenz.channels.util.select] to combine + it with other receivers, and even start it (semi) manually: ```python import logging - from frequenz.channels.util import Select + from frequenz.channels.util import select, selected_from from frequenz.channels import Broadcast timer = Timer.timeout(timedelta(seconds=1.0), auto_start=False) chan = Broadcast[int]("input-chan") - receiver1 = chan.new_receiver() + battery_data = chan.new_receiver() timer = Timer.timeout(timedelta(seconds=1.0), auto_start=False) # Do some other initialization, the timer will start automatically if # a message is awaited (or manually via `reset()`). - select = Select(bat_1=receiver1, timer=timer) - while await select.ready(): - if msg := select.bat_1: - if val := msg.inner: - battery_soc = val - else: + async for selected in select(battery_data, timer): + if selected_from(selected, battery_data): + if selected.was_closed(): logging.warning("battery channel closed") - elif drift := select.timer: + continue + battery_soc = selected.value + elif selected_from(selected, timer): # Print some regular battery data print(f"Battery is charged at {battery_soc}%") ``` @@ -313,7 +336,7 @@ class Timer(Receiver[timedelta]): Example: Timeout example ```python import logging - from frequenz.channels.util import Select + from frequenz.channels.util import select, selected_from from frequenz.channels import Broadcast def process_data(data: int): @@ -325,22 +348,21 @@ def do_heavy_processing(data: int): timer = Timer.timeout(timedelta(seconds=1.0), auto_start=False) chan1 = Broadcast[int]("input-chan-1") chan2 = Broadcast[int]("input-chan-2") - receiver1 = chan1.new_receiver() - receiver2 = chan2.new_receiver() - select = Select(bat_1=receiver1, heavy_process=receiver2, timeout=timer) - while await select.ready(): - if msg := select.bat_1: - if val := msg.inner: - process_data(val) - timer.reset() - else: + battery_data = chan1.new_receiver() + heavy_process = chan2.new_receiver() + async for selected in select(battery_data, heavy_process, timer): + if selected_from(selected, battery_data): + if selected.was_closed(): logging.warning("battery channel closed") - if msg := select.heavy_process: - if val := msg.inner: - do_heavy_processing(val) - else: + continue + process_data(selected.value) + timer.reset() + elif selected_from(selected, heavy_process): + if selected.was_closed(): logging.warning("processing channel closed") - elif drift := select.timeout: + continue + do_heavy_processing(selected.value) + elif selected_from(selected, timer): logging.warning("No data received in time") ``` @@ -681,3 +703,22 @@ def _now(self) -> int: The current monotonic clock time in microseconds. """ return _to_microseconds(self._loop.time()) + + def __str__(self) -> str: + """Return a string representation of the timer. + + Returns: + The string representation of the timer. + """ + return f"{type(self).__name__}({self.interval})" + + def __repr__(self) -> str: + """Return a string representation of the timer. + + Returns: + The string representation of the timer. + """ + return ( + f"{type(self).__name__}<{self.interval=}, {self.missed_tick_policy=}, " + f"{self.loop=}, {self.is_running=}>" + ) diff --git a/tests/test_select.py b/tests/test_select.py deleted file mode 100644 index fa372544..00000000 --- a/tests/test_select.py +++ /dev/null @@ -1,76 +0,0 @@ -# License: MIT -# Copyright © 2022 Frequenz Energy-as-a-Service GmbH - -"""Tests for the Select implementation.""" - -import asyncio -from typing import List - -from frequenz.channels import Anycast, Sender -from frequenz.channels.util import Select - - -async def test_select() -> None: - """Ensure select receives messages in order.""" - chan1 = Anycast[int]() - chan2 = Anycast[int]() - chan3 = Anycast[int]() - - async def send(ch1: Sender[int], ch2: Sender[int], ch3: Sender[int]) -> None: - for ctr in range(5): - await ch1.send(ctr + 1) - await ch2.send(ctr + 101) - await ch3.send(ctr + 201) - await chan1.close() - await ch2.send(1000) - await chan2.close() - await chan3.close() - - senders = asyncio.create_task( - send(chan1.new_sender(), chan2.new_sender(), chan3.new_sender()), - ) - select = Select( - ch1=chan1.new_receiver(), - ch2=chan2.new_receiver(), - ch3=chan3.new_receiver(), - ) - - # only check for messages from all iterators but `ch3`. - # it ensures iterators are not blocking channels in case they - # are not being read from. - results: List[int] = [] - while await select.ready(): - if item := select.ch1: - if val := item.inner: - results.append(val) - else: - results.append(-1) - elif item := select.ch2: - if val := item.inner: - results.append(val) - else: - results.append(-2) - await senders - - expected_results = [ - 1, - 101, - 2, - 102, - 3, - 103, - 4, - 104, - 5, - 105, - -1, # marks end of messages from channel 1 - 1000, - -2, # marks end of messages from channel 2 - ] - assert results == expected_results - got_err = False - try: - item = select.unknown_channel - except KeyError: - got_err = True - assert got_err diff --git a/tests/utils/test_event.py b/tests/utils/test_event.py new file mode 100644 index 00000000..0cda9d23 --- /dev/null +++ b/tests/utils/test_event.py @@ -0,0 +1,59 @@ +# License: MIT +# Copyright © 2023 Frequenz Energy-as-a-Service GmbH + +"""Tests for the select implementation.""" + +import asyncio as _asyncio + +import pytest as _pytest + +from frequenz.channels import ReceiverStoppedError +from frequenz.channels.util import Event + + +async def test_event() -> None: + """Test the event implementation.""" + event = Event() + assert not event.is_set + assert not event.is_stopped + + is_ready = False + + async def wait_for_event() -> None: + nonlocal is_ready + await event.ready() + is_ready = True + + event_task = _asyncio.create_task(wait_for_event()) + + await _asyncio.sleep(0) # Yield so the wait_for_event task can run. + + assert not is_ready + assert not event.is_set + assert not event.is_stopped + + event.set() + + await _asyncio.sleep(0) # Yield so the wait_for_event task can run. + assert is_ready + assert event.is_set + assert not event.is_stopped + + event.consume() + assert not event.is_set + assert not event.is_stopped + assert event_task.done() + assert event_task.result() is None + assert not event_task.cancelled() + + event.stop() + assert not event.is_set + assert event.is_stopped + + await event.ready() + with _pytest.raises(ReceiverStoppedError): + event.consume() + assert event.is_stopped + assert not event.is_set + + await event_task diff --git a/tests/utils/test_integration.py b/tests/utils/test_integration.py index 25700784..e61cb620 100644 --- a/tests/utils/test_integration.py +++ b/tests/utils/test_integration.py @@ -3,15 +3,13 @@ """Integration tests for the `util` module.""" -from __future__ import annotations - import os import pathlib from datetime import timedelta import pytest -from frequenz.channels.util import FileWatcher, Select, Timer +from frequenz.channels.util import FileWatcher, Timer, select, selected_from @pytest.mark.integration @@ -19,29 +17,26 @@ async def test_file_watcher(tmp_path: pathlib.Path) -> None: """Ensure file watcher is returning paths on file events. Args: - tmp_path (pathlib.Path): A tmp directory to run the file watcher on. - Created by pytest. + tmp_path: A tmp directory to run the file watcher on. Created by pytest. """ filename = tmp_path / "test-file" - file_watcher = FileWatcher(paths=[str(tmp_path)]) number_of_writes = 0 expected_number_of_writes = 3 - select = Select( - timer=Timer.timeout(timedelta(seconds=0.1)), - file_watcher=file_watcher, - ) - while await select.ready(): - if msg := select.timer: - filename.write_text(f"{msg.inner}") - elif msg := select.file_watcher: + file_watcher = FileWatcher(paths=[str(tmp_path)]) + timer = Timer.timeout(timedelta(seconds=0.1)) + + async for selected in select(file_watcher, timer): + if selected_from(selected, timer): + filename.write_text(f"{selected.value}") + elif selected_from(selected, file_watcher): event_type = ( FileWatcher.EventType.CREATE if number_of_writes == 0 else FileWatcher.EventType.MODIFY ) - assert msg.inner == FileWatcher.Event(type=event_type, path=filename) + assert selected.value == FileWatcher.Event(type=event_type, path=filename) number_of_writes += 1 # After receiving a write 3 times, unsubscribe from the writes channel if number_of_writes == expected_number_of_writes: @@ -58,19 +53,15 @@ async def test_file_watcher_deletes(tmp_path: pathlib.Path) -> None: the file doesn't exist. Args: - tmp_path (pathlib.Path): A tmp directory to run the file watcher on. - Created by pytest. + tmp_path: A tmp directory to run the file watcher on. Created by pytest. """ filename = tmp_path / "test-file" file_watcher = FileWatcher( paths=[str(tmp_path)], event_types={FileWatcher.EventType.DELETE} ) + write_timer = Timer.timeout(timedelta(seconds=0.1)) + deletion_timer = Timer.timeout(timedelta(seconds=0.25)) - select = Select( - write_timer=Timer.timeout(timedelta(seconds=0.1)), - deletion_timer=Timer.timeout(timedelta(seconds=0.25)), - watcher=file_watcher, - ) number_of_write = 0 number_of_deletes = 0 number_of_events = 0 @@ -91,19 +82,19 @@ async def test_file_watcher_deletes(tmp_path: pathlib.Path) -> None: # W: Write # D: Delete # E: FileWatcher Event - while await select.ready(): - if msg := select.write_timer: + async for selected in select(file_watcher, write_timer, deletion_timer): + if selected_from(selected, write_timer): if number_of_write >= 2 and number_of_events == 0: continue - filename.write_text(f"{msg.inner}") + filename.write_text(f"{selected.value}") number_of_write += 1 - elif _ := select.deletion_timer: + elif selected_from(selected, deletion_timer): # Avoid removing the file twice if not pathlib.Path(filename).is_file(): continue os.remove(filename) number_of_deletes += 1 - elif _ := select.watcher: + elif selected_from(selected, file_watcher): number_of_events += 1 if number_of_events >= 2: break diff --git a/tests/utils/test_select.py b/tests/utils/test_select.py new file mode 100644 index 00000000..a9a46921 --- /dev/null +++ b/tests/utils/test_select.py @@ -0,0 +1,55 @@ +# License: MIT +# Copyright © 2023 Frequenz Energy-as-a-Service GmbH + +"""Tests for the select implementation.""" + +from unittest import mock + +import pytest + +from frequenz.channels import Receiver, ReceiverStoppedError +from frequenz.channels.util import Selected, selected_from + + +class TestSelected: + """Tests for the Selected class.""" + + def test_with_value(self) -> None: + """Test selected from a receiver with a value.""" + recv = mock.MagicMock(spec=Receiver[int]) + recv.consume.return_value = 42 + selected = Selected[int](recv) + + assert selected_from(selected, recv) + assert selected.value == 42 + assert selected.exception is None + assert not selected.was_stopped() + + def test_with_exception(self) -> None: + """Test selected from a receiver with an exception.""" + recv = mock.MagicMock(spec=Receiver[int]) + exception = Exception("test") + recv.consume.side_effect = exception + selected = Selected[int](recv) + + assert selected_from(selected, recv) + with pytest.raises(Exception, match="test"): + _ = selected.value + assert selected.exception is exception + assert not selected.was_stopped() + + def test_with_stopped(self) -> None: + """Test selected from a stopped receiver.""" + recv = mock.MagicMock(spec=Receiver[int]) + exception = ReceiverStoppedError[int](recv) + recv.consume.side_effect = exception + selected = Selected[int](recv) + + assert selected_from(selected, recv) + with pytest.raises( + ReceiverStoppedError, + match=r"Receiver was stopped", + ): + _ = selected.value + assert selected.exception is exception + assert selected.was_stopped() diff --git a/tests/utils/test_select_integration.py b/tests/utils/test_select_integration.py new file mode 100644 index 00000000..7d2ff997 --- /dev/null +++ b/tests/utils/test_select_integration.py @@ -0,0 +1,448 @@ +# License: MIT +# Copyright © 2023 Frequenz Energy-as-a-Service GmbH + +"""Integration tests for Select function. + +These tests are actually a bit in the middle between unit and integration, because we +are using a fake loop to make the tests faster, but we are still testing more than one +class at a time. +""" + +import asyncio +from collections.abc import AsyncIterator, Iterator +from typing import Any + +import async_solipsism +import pytest + +from frequenz.channels import Receiver, ReceiverStoppedError +from frequenz.channels.util import ( + Event, + Selected, + UnhandledSelectedError, + select, + selected_from, +) + + +@pytest.mark.integration +class TestSelect: + """Tests for the select function.""" + + recv1: Event + recv2: Event + recv3: Event + loop: async_solipsism.EventLoop + + @pytest.fixture(autouse=True) + def event_loop( + self, request: pytest.FixtureRequest + ) -> Iterator[async_solipsism.EventLoop]: + """Replace the loop with one that doesn't interact with the outside world.""" + loop = async_solipsism.EventLoop() + request.cls.loop = loop + yield loop + loop.close() + + @pytest.fixture() + async def start_run_ordered_sequence(self) -> AsyncIterator[asyncio.Task[None]]: + """Start the run_ordered_sequence method and wait for it to finish. + + Yields: + The task running the run_ordered_sequence method. + """ + sequence_task = asyncio.create_task(self.run_ordered_sequence()) + yield sequence_task + await sequence_task + + def setup_method(self) -> None: + """Set up the test.""" + self.recv1 = Event("recv1") + self.recv2 = Event("recv2") + self.recv3 = Event("recv3") + + def assert_received_from( + self, + selected: Selected[Any], + receiver: Receiver[None], + *, + at_time: float, + expected_pending_tasks: int = -2, + ) -> None: + """Assert that the selected event was received from the given receiver. + + It also asserts that: + + * The receiver didn't raise an exception. + * The receiver wasn't stopped. + * The select loop is still running. + * It happened at the given time. + + Args: + selected: The selected event. + receiver: The receiver from which the event was received. + at_time: The time at which the event was received. + expected_pending_tasks: Check that a number of tasks are pending. If the + number is negative, a > check is performed with the absolute value. If + it is 0, no check is performed. + """ + assert selected_from(selected, receiver) + assert selected.value is None + assert selected.exception is None + assert not selected.was_stopped() + if expected_pending_tasks > 0: + assert len(asyncio.all_tasks(self.loop)) == expected_pending_tasks + elif expected_pending_tasks < 0: + assert len(asyncio.all_tasks(self.loop)) > expected_pending_tasks + assert self.loop.time() == at_time + + def assert_receiver_stopped( + self, + selected: Selected[Any], + receiver: Receiver[None], + *, + at_time: float, + expected_pending_tasks: int = -2, + ) -> None: + """Assert that the selected event came from a stopped receiver. + + It also asserts that: + + * The amount of pending tasks is as expected. + * It happened at the given time. + + Args: + selected: The selected event. + receiver: The receiver from which the event was received. + at_time: The time at which the event was received. + expected_pending_tasks: Check that a number of tasks are pending. If the + number is negative, a > check is performed with the absolute value. If + it is 0, no check is performed. + """ + assert selected_from(selected, receiver) + assert selected.was_stopped() + assert isinstance(selected.exception, ReceiverStoppedError) + assert selected.exception.receiver is receiver + if expected_pending_tasks > 0: + assert len(asyncio.all_tasks(self.loop)) == expected_pending_tasks + elif expected_pending_tasks < 0: + assert len(asyncio.all_tasks(self.loop)) > expected_pending_tasks + assert self.loop.time() == at_time + + # We use the loop time (and the sleeps in the run_ordered_sequence method) mainly to + # ensure we are processing the events in the correct order and we are really + # following the sequence of events we expect. + + async def run_ordered_sequence(self) -> None: + """Run the sequence of events to be tested.""" + print("time = 0") + self.recv1.set() + await asyncio.sleep(1) + + print("time = 1") + self.recv2.set() + await asyncio.sleep(1) + + print("time = 2") + self.recv3.set() + await asyncio.sleep(1) + + print("time = 3") + self.recv1.set() + await asyncio.sleep(1) + + print("time = 4") + self.recv1.set() + await asyncio.sleep(1) + + print("time = 5") + self.recv3.set() + await asyncio.sleep(1) + + print("time = 6") + self.recv2.set() + await asyncio.sleep(1) + + print("time = 7") + self.recv1.stop() + await asyncio.sleep(1) + + print("time = 8") + self.recv2.set() + await asyncio.sleep(1) + + print("time = 9") + self.recv3.stop() + await asyncio.sleep(1) + + print("time = 10") + self.recv2.set() + await asyncio.sleep(1) + + print("time = 11") + self.recv2.stop() + + # pylint: disable=redefined-outer-name + async def test_select_receives_in_order( + self, + start_run_ordered_sequence: asyncio.Task[ # pylint: disable=unused-argument + None + ], + ) -> None: + """Test that the select loop receives events in the correct order.""" + select_iter = select(self.recv1, self.recv2, self.recv3) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv1, at_time=0) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv2, at_time=1) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv3, at_time=2) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv1, at_time=3) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv1, at_time=4) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv3, at_time=5) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv2, at_time=6) + + selected = await anext(select_iter) + self.assert_receiver_stopped(selected, self.recv1, at_time=7) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv2, at_time=8) + + selected = await anext(select_iter) + self.assert_receiver_stopped(selected, self.recv3, at_time=9) + + selected = await anext(select_iter) + self.assert_received_from(selected, self.recv2, at_time=10) + + selected = await anext(select_iter) + self.assert_receiver_stopped( + selected, self.recv2, at_time=11, expected_pending_tasks=1 + ) + + with pytest.raises(StopAsyncIteration): + selected = await anext(select_iter) + + assert len(asyncio.all_tasks()) == 1 # Only the test task should be alive + + async def test_break( + self, + start_run_ordered_sequence: asyncio.Task[ # pylint: disable=unused-argument + None + ], + ) -> None: + """Test that break works.""" + selected: Selected[Any] | None = None + async for selected in select(self.recv1, self.recv2, self.recv3): + if selected_from(selected, self.recv1): + continue + if selected_from(selected, self.recv2): + continue + if selected_from(selected, self.recv3): + break + + assert selected is not None + self.assert_received_from(selected, self.recv3, at_time=2) + + async for selected in select(self.recv1, self.recv2, self.recv3): + if selected_from(selected, self.recv1): + continue + if selected_from(selected, self.recv2): + break + if selected_from(selected, self.recv3): + continue + + assert selected is not None + self.assert_received_from(selected, self.recv2, at_time=6) + + async for selected in select(self.recv1, self.recv2, self.recv3): + if selected_from(selected, self.recv1): + continue + if selected_from(selected, self.recv2): + continue + if selected_from(selected, self.recv3): + break + + assert selected is not None + self.assert_receiver_stopped(selected, self.recv3, at_time=9) + + assert self.recv1.is_stopped + assert self.recv3.is_stopped + + async for selected in select(self.recv2): + if selected_from(selected, self.recv2): + continue + + self.assert_receiver_stopped( + selected, self.recv2, at_time=11, expected_pending_tasks=1 + ) + + assert len(asyncio.all_tasks()) == 1 # Only the test task should be alive + + async def test_missed_select_from( + self, + start_run_ordered_sequence: asyncio.Task[ # pylint: disable=unused-argument + None + ], + ) -> None: + """Test that a missed `select_from` is detected.""" + selected: Selected[Any] | None = None + with pytest.raises(UnhandledSelectedError) as excinfo: + async for selected in select(self.recv1, self.recv2, self.recv3): + if selected_from(selected, self.recv1): + continue + if selected_from(selected, self.recv2): + continue + + assert False, "Should not reach this point" + + assert selected is not None + assert excinfo.value.selected is selected + self.assert_received_from( + selected, self.recv3, at_time=2, expected_pending_tasks=2 + ) + + # The test task and the run_ordered_sequence tasks should still be alive + assert len(asyncio.all_tasks()) == 2 + assert start_run_ordered_sequence in asyncio.all_tasks() + + @pytest.fixture() + async def start_run_multiple_ready(self) -> AsyncIterator[asyncio.Task[None]]: + """Start the run_multiple_ready method and wait for it to finish. + + Yields: + The task running the run_multiple_ready method. + """ + sequence_task = asyncio.create_task(self.run_multiple_ready()) + yield sequence_task + await sequence_task + + async def run_multiple_ready(self) -> None: + """Run a sequence of events with multiple receivers ready.""" + print("time = 0") + self.recv1.set() + self.recv2.set() + self.recv3.set() + await asyncio.sleep(1) + + print("time = 1") + self.recv2.set() + self.recv3.set() + await asyncio.sleep(1) + + print("time = 2") + self.recv1.set() + self.recv3.set() + await asyncio.sleep(1) + + print("time = 3") + self.recv1.set() + self.recv2.set() + await asyncio.sleep(1) + + print("time = 4") + + async def test_multiple_ready( + self, + start_run_multiple_ready: asyncio.Task[None], # pylint: disable=unused-argument + ) -> None: + """Test that multiple ready receviers are handled properly. + + Also test that the loop waits forever if there are no more receivers ready. + """ + received: set[str] = set() + last_time: float = self.loop.time() + try: + async with asyncio.timeout(15): + async for selected in select(self.recv1, self.recv2, self.recv3): + now = self.loop.time() + if now != last_time: # Only check when there was a jump in time + match now: + case 1: + assert received == { + self.recv1.name, + self.recv2.name, + self.recv3.name, + } + case 2: + assert received == { + self.recv2.name, + self.recv3.name, + } + case 3: + assert received == { + self.recv1.name, + self.recv3.name, + } + # case 4 needs to be checked after the timeout, as there + # are no ready receivers after time == 3. + case _: + assert False, "Should not reach this point" + received.clear() + last_time = now + + if selected_from(selected, self.recv1): + received.add(self.recv1.name) + elif selected_from(selected, self.recv2): + received.add(self.recv2.name) + elif selected_from(selected, self.recv3): + received.add(self.recv3.name) + else: + assert False, "Should not reach this point" + except asyncio.TimeoutError: + assert self.loop.time() == 15 + # This happened after time == 3, but the loop never resumes becuase + # there is nothing ready, so we need to check it after the timeout. + assert received == { + self.recv1.name, + self.recv2.name, + } + else: + assert False, "Should have timed out" + + assert len(asyncio.all_tasks()) == 1 # The test task should still be alive + + def test_tasks_are_cleaned_up_with_break(self) -> None: + """Test that the tasks are cleaned up properly. + + In this test we use a real event loop instead of relying what is provided by + pytest to make absolutely sure that the tasks are cleaned up properly with + a real loop. + """ + loop = asyncio.new_event_loop() + + async def run() -> None: + task = loop.create_task(self.run_multiple_ready()) + async for selected in select(self.recv1, self.recv2, self.recv3): + if selected_from(selected, self.recv1): + continue + if selected_from(selected, self.recv2): + continue + if selected_from(selected, self.recv3): + break + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # The loop might take a few "yields" to process all pending tasks and ensure + # the finalized of select() was called + iterations = 0 + while len(asyncio.all_tasks(loop)) > 1 and iterations < 5: + await asyncio.sleep(0) + + assert len(asyncio.all_tasks(loop)) == 1 + + loop.run_until_complete(run())