diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsDynamoDbRetryIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsDynamoDbRetryIntegration.java new file mode 100644 index 000000000..a386c5835 --- /dev/null +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsDynamoDbRetryIntegration.java @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.aws.codegen; + +import java.util.List; +import java.util.Set; +import software.amazon.smithy.aws.traits.ServiceTrait; +import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.RuntimeTypes; +import software.amazon.smithy.python.codegen.integrations.PythonIntegration; +import software.amazon.smithy.python.codegen.sections.InitRetryStrategyResolverSection; +import software.amazon.smithy.python.codegen.writer.PythonWriter; +import software.amazon.smithy.utils.CodeInterceptor; +import software.amazon.smithy.utils.CodeSection; + +/** + * Injects DynamoDB's default retry options (max attempts 4, 25ms non-throttling + * base backoff). + */ +public final class AwsDynamoDbRetryIntegration implements PythonIntegration { + + private static final Set DYNAMODB_SDK_IDS = Set.of("DynamoDB", "DynamoDB Streams"); + + private static final double DYNAMODB_BASE_BACKOFF_SECONDS = 0.025; + private static final int DYNAMODB_MAX_ATTEMPTS = 4; + + private static boolean isDynamoDb(GenerationContext context) { + return context.settings() + .service(context.model()) + .getTrait(ServiceTrait.class) + .map(trait -> DYNAMODB_SDK_IDS.contains(trait.getSdkId())) + .orElse(false); + } + + @Override + public List> interceptors( + GenerationContext context + ) { + if (!isDynamoDb(context)) { + return List.of(); + } + return List.of(new DynamoDbRetryStrategyResolverInterceptor()); + } + + private static final class DynamoDbRetryStrategyResolverInterceptor + implements CodeInterceptor { + + @Override + public Class sectionType() { + return InitRetryStrategyResolverSection.class; + } + + @Override + public void write(PythonWriter writer, String previousText, InitRetryStrategyResolverSection section) { + writer.write( + "self._retry_strategy_resolver = $T(default_max_attempts=$L, default_backoff_scale=$L)", + RuntimeTypes.RETRY_STRATEGY_RESOLVER, + DYNAMODB_MAX_ATTEMPTS, + DYNAMODB_BASE_BACKOFF_SECONDS); + } + } +} diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsLongPollingIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsLongPollingIntegration.java new file mode 100644 index 000000000..4e74ede2c --- /dev/null +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsLongPollingIntegration.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.aws.codegen; + +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.aws.traits.ServiceTrait; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.python.codegen.integrations.PythonIntegration; + +/** + * Marks the known long-polling operations so the generic client generator + * applies long-polling retry behavior to them. + * + *

These operations are hard-coded until the {@code aws.api#longPoll} trait + * ships in service models. Once it ships, this can check for the trait instead. + */ +public final class AwsLongPollingIntegration implements PythonIntegration { + + private static final Map> LONG_POLLING_OPERATIONS = Map.of( + "SQS", + Set.of("ReceiveMessage"), + "SFN", + Set.of("GetActivityTask"), + "SWF", + Set.of("PollForActivityTask", "PollForDecisionTask")); + + @Override + public boolean isLongPollingOperation(Model model, ServiceShape service, OperationShape operation) { + return service.getTrait(ServiceTrait.class) + .map(trait -> LONG_POLLING_OPERATIONS.get(trait.getSdkId())) + .map(operations -> operations.contains(operation.getId().getName())) + .orElse(false); + } +} diff --git a/codegen/aws/core/src/main/resources/META-INF/services/software.amazon.smithy.python.codegen.integrations.PythonIntegration b/codegen/aws/core/src/main/resources/META-INF/services/software.amazon.smithy.python.codegen.integrations.PythonIntegration index a338df30c..7c2d33be1 100644 --- a/codegen/aws/core/src/main/resources/META-INF/services/software.amazon.smithy.python.codegen.integrations.PythonIntegration +++ b/codegen/aws/core/src/main/resources/META-INF/services/software.amazon.smithy.python.codegen.integrations.PythonIntegration @@ -8,3 +8,5 @@ software.amazon.smithy.python.aws.codegen.AwsProtocolsIntegration software.amazon.smithy.python.aws.codegen.AwsServiceIdIntegration software.amazon.smithy.python.aws.codegen.AwsUserAgentIntegration software.amazon.smithy.python.aws.codegen.AwsStandardRegionalEndpointsIntegration +software.amazon.smithy.python.aws.codegen.AwsDynamoDbRetryIntegration +software.amazon.smithy.python.aws.codegen.AwsLongPollingIntegration diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 6a3f65c4a..554d87f5a 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -22,6 +22,7 @@ import software.amazon.smithy.model.traits.StringTrait; import software.amazon.smithy.python.codegen.integrations.PythonIntegration; import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin; +import software.amazon.smithy.python.codegen.sections.InitRetryStrategyResolverSection; import software.amazon.smithy.python.codegen.writer.PythonWriter; import software.amazon.smithy.utils.SmithyInternalApi; @@ -86,13 +87,13 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): for plugin in client_plugins: plugin(self._config) - self._retry_strategy_resolver = $5T() + $5C """, configSymbol, pluginSymbol, writer.consumer(w -> writeConstructorDocs(w, serviceSymbol.getName())), writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)), - RuntimeTypes.RETRY_STRATEGY_RESOLVER); + writer.consumer(this::writeRetryStrategyResolverInit)); var topDownIndex = TopDownIndex.of(model); var eventStreamIndex = EventStreamIndex.of(model); @@ -113,6 +114,22 @@ private void writeDefaultPlugins(PythonWriter writer, Collection { writer.write(""" @@ -210,6 +227,7 @@ private void writeSharedOperationInit( } writer.putContext("operation", symbolProvider.toSymbol(operation)); + writer.putContext("isLongPolling", isLongPollingOperation(operation)); writer.addStdlibImport("copy", "deepcopy"); writer.write(""" @@ -240,7 +258,8 @@ private void writeSharedOperationInit( auth_scheme_resolver=config.auth_scheme_resolver, supported_auth_schemes=config.auth_schemes, endpoint_resolver=config.endpoint_resolver, - retry_strategy=retry_strategy, + retry_strategy=retry_strategy,${?isLongPolling} + is_long_polling=True,${/isLongPolling} ) """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)), diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/PythonIntegration.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/PythonIntegration.java index 8fe18b2e7..032f3b3ae 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/PythonIntegration.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/PythonIntegration.java @@ -8,6 +8,8 @@ import java.util.List; import software.amazon.smithy.codegen.core.SmithyIntegration; import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.python.codegen.GenerationContext; import software.amazon.smithy.python.codegen.PythonSettings; import software.amazon.smithy.python.codegen.generators.ProtocolGenerator; @@ -44,6 +46,21 @@ default Model preprocessModel(Model model, PythonSettings settings) { return model; } + /** + * Determines whether the given operation is a long-polling operation, which + * must back off before returning even when the retry quota is exhausted. AWS + * integrations use this hook to identify these operations while the + * {@code aws.api#longPoll} trait is not yet shipped in service models. + * + * @param model Model the operation belongs to. + * @param service Service the operation belongs to. + * @param operation Operation to test. + * @return Returns true if the operation is a long-polling operation. + */ + default boolean isLongPollingOperation(Model model, ServiceShape service, OperationShape operation) { + return false; + } + /** * Writes out all extra files required by runtime plugins. */ diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/sections/InitRetryStrategyResolverSection.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/sections/InitRetryStrategyResolverSection.java new file mode 100644 index 000000000..390843b64 --- /dev/null +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/sections/InitRetryStrategyResolverSection.java @@ -0,0 +1,14 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.codegen.sections; + +import software.amazon.smithy.utils.CodeSection; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * A section that controls constructing the client's {@code RetryStrategyResolver}. + */ +@SmithyInternalApi +public record InitRetryStrategyResolverSection() implements CodeSection {} diff --git a/designs/retries.md b/designs/retries.md index 5d681f71d..29aba7024 100644 --- a/designs/retries.md +++ b/designs/retries.md @@ -32,12 +32,15 @@ class RetryStrategy(Protocol): """Upper limit on total attempt count (initial attempt plus retries).""" async def acquire_initial_retry_token( - self, *, token_scope: str | None = None + self, *, token_scope: str | None = None, is_long_polling: bool = False ) -> RetryToken: """Called before any retries (for the first attempt at the operation). :param token_scope: An arbitrary string accepted by the retry strategy to separate tokens into scopes. + :param is_long_polling: Whether the operation is a long-polling operation. + Long-polling operations must back off before returning even when the + retry quota is exhausted. :returns: A retry token, to be used for determining the retry delay, refreshing the token after a failure, and recording success after success. :raises RetryError: If the retry strategy has no available tokens. @@ -110,8 +113,9 @@ class HasFault(Protocol): `RetryStrategy` implementations MUST raise a `RetryError` if they receive an exception where `is_retry_safe` is `False` and SHOULD raise a `RetryError` if it -is `None`. `RetryStrategy` implementations SHOULD use a delay that is at least -as long as `retry_after` but MAY choose to wait longer. +is `None`. `RetryStrategy` implementations SHOULD take `retry_after` into account +when computing the delay, but MAY adjust it (for example, by clamping it to an +upper bound). ### Backoff Strategy diff --git a/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-881860b5bda049d1b000983dd2a3bd0e.json b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-881860b5bda049d1b000983dd2a3bd0e.json new file mode 100644 index 000000000..3750401e8 --- /dev/null +++ b/packages/smithy-aws-core/.changes/next-release/smithy-aws-core-feature-881860b5bda049d1b000983dd2a3bd0e.json @@ -0,0 +1,4 @@ +{ + "type": "feature", + "description": "Added support for the `x-amz-retry-after` response header in the `awsQuery` protocol." +} \ No newline at end of file diff --git a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py index 5a9abc058..af86fab12 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/_private/query/errors.py @@ -99,6 +99,7 @@ def create_aws_query_error( wrapper_elements: tuple[str, ...], status: int, context: TypedProperties, + retry_after: float | None = None, ) -> CallError: """Create a modeled or generic CallError from an awsQuery error response.""" code = _parse_aws_query_error_code(body, wrapper_elements) @@ -121,7 +122,10 @@ def create_aws_query_error( deserializer = XMLCodec().create_deserializer( body, wrapper_elements=wrapper_elements ) - return error_shape.deserialize(deserializer) + modeled_error = error_shape.deserialize(deserializer) + if retry_after is not None: + modeled_error.retry_after = retry_after + return modeled_error message = f"Unknown error for operation {operation.schema.id} - status: {status}" if code is not None: @@ -137,4 +141,5 @@ def create_aws_query_error( is_throttling_error=is_throttle, is_timeout_error=is_timeout, is_retry_safe=is_throttle or is_timeout or None, + retry_after=retry_after, ) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index afaac61eb..a4a418ed1 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -27,7 +27,11 @@ from smithy_http import tuples_to_fields from smithy_http.aio import HTTPRequest as _HTTPRequest from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse -from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol +from smithy_http.aio.protocols import ( + HttpBindingClientProtocol, + HttpClientProtocol, + parse_retry_after, +) from smithy_http.deserializers import HTTPResponseDeserializer from .._private.query.errors import ( @@ -353,6 +357,7 @@ async def _create_error( wrapper_elements=self._error_wrapper_elements(), status=response.status, context=context, + retry_after=parse_retry_after(response), ) def _action_name( diff --git a/packages/smithy-aws-core/tests/unit/test_query.py b/packages/smithy-aws-core/tests/unit/test_query.py index a6a069d3f..2c94ed980 100644 --- a/packages/smithy-aws-core/tests/unit/test_query.py +++ b/packages/smithy-aws-core/tests/unit/test_query.py @@ -2,13 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass from io import BytesIO +from typing import Any, cast +from unittest.mock import Mock +from smithy_aws_core._private.query.errors import create_aws_query_error from smithy_aws_core._private.query.serializers import QueryShapeSerializer +from smithy_core.documents import TypeRegistry from smithy_core.prelude import STRING -from smithy_core.schemas import Schema +from smithy_core.schemas import APIOperation, Schema from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType from smithy_core.traits import XMLFlattenedTrait, XMLNameTrait +from smithy_core.types import TypedProperties def test_query_list_serialization() -> None: @@ -280,3 +285,39 @@ def serialize_members(self, serializer: ShapeSerializer) -> None: Outer(inner=Inner("x")).serialize(serializer) assert params == [("inner.value", "x")] + + +def _error_test_operation() -> APIOperation[Any, Any]: + operation = Mock(spec=APIOperation) + operation.schema = Schema( + id=ShapeID("com.example#TestOp"), shape_type=ShapeType.OPERATION + ) + operation.error_schemas = [] + return cast("APIOperation[Any, Any]", operation) + + +def test_aws_query_error_sets_retry_after_on_generic_error() -> None: + error = create_aws_query_error( + body=b"", + operation=_error_test_operation(), + error_registry=TypeRegistry({}), + default_namespace="com.example", + wrapper_elements=("ErrorResponse", "Error"), + status=503, + context=TypedProperties(), + retry_after=1.5, + ) + assert error.retry_after == 1.5 + + +def test_aws_query_error_retry_after_none_by_default() -> None: + error = create_aws_query_error( + body=b"", + operation=_error_test_operation(), + error_registry=TypeRegistry({}), + default_namespace="com.example", + wrapper_elements=("ErrorResponse", "Error"), + status=503, + context=TypedProperties(), + ) + assert error.retry_after is None diff --git a/packages/smithy-core/.changes/next-release/smithy-core-breaking-dc093af32db648139ed3f64cf4251a40.json b/packages/smithy-core/.changes/next-release/smithy-core-breaking-dc093af32db648139ed3f64cf4251a40.json new file mode 100644 index 000000000..c16440e13 --- /dev/null +++ b/packages/smithy-core/.changes/next-release/smithy-core-breaking-dc093af32db648139ed3f64cf4251a40.json @@ -0,0 +1,4 @@ +{ + "type": "breaking", + "description": "Updated retry quota costs and added an `is_long_polling` argument to `acquire_initial_retry_token`." +} \ No newline at end of file diff --git a/packages/smithy-core/.changes/next-release/smithy-core-feature-aa633008af584d94ad03679f73a4d436.json b/packages/smithy-core/.changes/next-release/smithy-core-feature-aa633008af584d94ad03679f73a4d436.json new file mode 100644 index 000000000..0a37cc467 --- /dev/null +++ b/packages/smithy-core/.changes/next-release/smithy-core-feature-aa633008af584d94ad03679f73a4d436.json @@ -0,0 +1,4 @@ +{ + "type": "feature", + "description": "Added support for separate backoff for throttling errors, the `x-amz-retry-after` header, per-service retry defaults, and long-polling backoff." +} \ No newline at end of file diff --git a/packages/smithy-core/src/smithy_core/aio/client.py b/packages/smithy-core/src/smithy_core/aio/client.py index e84c9a94b..72d08012d 100644 --- a/packages/smithy-core/src/smithy_core/aio/client.py +++ b/packages/smithy-core/src/smithy_core/aio/client.py @@ -83,6 +83,9 @@ class ClientCall[I: SerializeableShape, O: DeserializeableShape]: retry_scope: str | None = None """The retry scope for the operation.""" + is_long_polling: bool = False + """Whether the operation is a long-polling operation.""" + def retryable(self) -> bool: # TODO: check to see if the stream is seekable return self.operation.input_stream_member is None @@ -331,7 +334,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( retry_strategy = call.retry_strategy retry_token = await retry_strategy.acquire_initial_retry_token( - token_scope=call.retry_scope + token_scope=call.retry_scope, is_long_polling=call.is_long_polling ) while True: @@ -353,7 +356,11 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( token_to_renew=retry_token, error=output_context.response, ) - except RetryError: + except RetryError as retry_error: + # Long-polling operations back off even when the retry quota + # is exhausted; the strategy surfaces that delay here. + if retry_error.retry_after is not None: + await sleep(retry_error.retry_after) raise output_context.response _LOGGER.debug( diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/retries.py b/packages/smithy-core/src/smithy_core/aio/interfaces/retries.py index fbe188b55..df5dc1274 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/retries.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/retries.py @@ -16,12 +16,15 @@ class RetryStrategy(Protocol): """Upper limit on total attempt count (initial attempt plus retries).""" async def acquire_initial_retry_token( - self, *, token_scope: str | None = None + self, *, token_scope: str | None = None, is_long_polling: bool = False ) -> RetryToken: """Create a base retry token for the start of a request. :param token_scope: An arbitrary string accepted by the retry strategy to separate tokens into scopes. + :param is_long_polling: Whether the operation is a long-polling operation. + Long-polling operations must back off before returning even when the + retry quota is exhausted. :returns: A retry token, to be used for determining the retry delay, refreshing the token after a failure, and recording success after success. :raises RetryError: If the retry strategy has no available tokens. diff --git a/packages/smithy-core/src/smithy_core/aio/retries.py b/packages/smithy-core/src/smithy_core/aio/retries.py index e3fa6340e..706694614 100644 --- a/packages/smithy-core/src/smithy_core/aio/retries.py +++ b/packages/smithy-core/src/smithy_core/aio/retries.py @@ -22,8 +22,31 @@ class RetryStrategyResolver: This resolver caches retry strategy instances based on their configuration to reuse existing instances of RetryStrategy with the same settings. Uses LRU cache for thread-safe caching. + + A service may supply its own defaults for these options. Precedence is considered on + a per-value basis: a customer-configured value always takes precedence, and a default + only fills in a value the customer did not set. For example, if the customer sets + ``max_attempts`` but not the backoff scale, the resulting strategy uses the customer's + max attempts and the service's default backoff scale. """ + def __init__( + self, + *, + default_max_attempts: int | None = None, + default_backoff_scale: float | None = None, + ): + """Initialize the resolver. + + :param default_max_attempts: The maximum number of attempts to use when the + customer did not configure one. Only applies to standard mode. + :param default_backoff_scale: The base backoff scale in seconds used for + non-throttling errors when the customer did not configure one. Only applies + to standard mode. + """ + self._default_max_attempts = default_max_attempts + self._default_backoff_scale = default_backoff_scale + async def resolve_retry_strategy( self, *, retry_strategy: RetryStrategy | RetryStrategyOptions | None ) -> RetryStrategy: @@ -48,7 +71,10 @@ async def resolve_retry_strategy( def _create_retry_strategy( self, retry_mode: RetryStrategyType, max_attempts: int | None ) -> RetryStrategy: - kwargs = {"max_attempts": max_attempts} + if max_attempts is None: + max_attempts = self._default_max_attempts + + kwargs: dict[str, Any] = {"max_attempts": max_attempts} filtered_kwargs: dict[str, Any] = { k: v for k, v in kwargs.items() if v is not None } @@ -56,6 +82,10 @@ def _create_retry_strategy( case "simple": return SimpleRetryStrategy(**filtered_kwargs) case "standard": + if self._default_backoff_scale is not None: + filtered_kwargs["default_backoff_scale"] = ( + self._default_backoff_scale + ) return StandardRetryStrategy(**filtered_kwargs) case _: raise ValueError(f"Unknown retry mode: {retry_mode}") @@ -80,11 +110,12 @@ def __init__( self.max_attempts = max_attempts async def acquire_initial_retry_token( - self, *, token_scope: str | None = None + self, *, token_scope: str | None = None, is_long_polling: bool = False ) -> SimpleRetryToken: """Create a base retry token for the start of a request. :param token_scope: This argument is ignored by this retry strategy. + :param is_long_polling: This argument is ignored by this retry strategy. """ retry_delay = self.backoff_strategy.compute_next_backoff_delay(0) return SimpleRetryToken(retry_count=0, retry_delay=retry_delay) @@ -123,18 +154,43 @@ def __deepcopy__(self, memo: Any) -> "SimpleRetryStrategy": class StandardRetryStrategy: + _RETRY_AFTER_MAX_ADDITIONAL: float = 5 + """Maximum number of seconds a server-directed backoff (e.g. x-amz-retry-after) + may exceed the normal computed backoff (t_i).""" + + _NON_THROTTLING_BACKOFF_SCALE: float = 0.05 + """Base backoff scale (seconds) for non-throttling errors (50ms).""" + + _THROTTLING_BACKOFF_SCALE: float = 1 + """Base backoff scale (seconds) for throttling errors (1000ms).""" + + _MAX_BACKOFF: float = 20 + """Upper bound (seconds) for the computed backoff, applied before jitter.""" + def __init__( self, *, backoff_strategy: retries_interface.RetryBackoffStrategy | None = None, + throttling_backoff_strategy: retries_interface.RetryBackoffStrategy + | None = None, + default_backoff_scale: float | None = None, max_attempts: int = 3, retry_quota: StandardRetryQuota | None = None, ): """Standard retry strategy using truncated binary exponential backoff with full jitter. - :param backoff_strategy: The backoff strategy used by returned tokens to compute - the retry delay. Defaults to :py:class:`ExponentialRetryBackoffStrategy`. + :param backoff_strategy: The backoff strategy used to compute the retry delay + for non-throttling errors. Defaults to a 50ms-base + :py:class:`ExponentialRetryBackoffStrategy`. + + :param throttling_backoff_strategy: The backoff strategy used to compute the + retry delay for throttling errors. Defaults to a 1000ms-base + :py:class:`ExponentialRetryBackoffStrategy`. + + :param default_backoff_scale: Overrides the base backoff scale (in seconds) of + the default non-throttling backoff strategy. Ignored when + ``backoff_strategy`` is provided. :param max_attempts: Upper limit on total number of attempts made, including initial attempt and retries. @@ -147,23 +203,41 @@ def __init__( f"max_attempts must be a non-negative integer, got {max_attempts}" ) + non_throttling_scale = ( + self._NON_THROTTLING_BACKOFF_SCALE + if default_backoff_scale is None + else default_backoff_scale + ) self.backoff_strategy = backoff_strategy or ExponentialRetryBackoffStrategy( - backoff_scale_value=1, - max_backoff=20, + backoff_scale_value=non_throttling_scale, + max_backoff=self._MAX_BACKOFF, jitter_type=ExponentialBackoffJitterType.FULL, ) + self.throttling_backoff_strategy = ( + throttling_backoff_strategy + or ExponentialRetryBackoffStrategy( + backoff_scale_value=self._THROTTLING_BACKOFF_SCALE, + max_backoff=self._MAX_BACKOFF, + jitter_type=ExponentialBackoffJitterType.FULL, + ) + ) self.max_attempts = max_attempts self._retry_quota = retry_quota or StandardRetryQuota() async def acquire_initial_retry_token( - self, *, token_scope: str | None = None + self, *, token_scope: str | None = None, is_long_polling: bool = False ) -> StandardRetryToken: """Create a base retry token for the start of a request. :param token_scope: This argument is ignored by this retry strategy. + :param is_long_polling: Whether the operation is a long-polling operation. + Long-polling operations back off before returning even when the retry + quota is exhausted. """ retry_delay = self.backoff_strategy.compute_next_backoff_delay(0) - return StandardRetryToken(retry_count=0, retry_delay=retry_delay) + return StandardRetryToken( + retry_count=0, retry_delay=retry_delay, is_long_polling=is_long_polling + ) async def refresh_retry_token_for_retry( self, @@ -192,21 +266,34 @@ async def refresh_retry_token_for_retry( f"Reached maximum number of allowed attempts: {self.max_attempts}" ) from error - # Acquire additional quota for this retry attempt - # (may raise a RetryError if none is available) - quota_acquired = self._retry_quota.acquire(error=error) + # Throttling errors use a larger base backoff than other errors. + backoff_strategy = ( + self.throttling_backoff_strategy + if error.is_throttling_error + else self.backoff_strategy + ) + t_i = backoff_strategy.compute_next_backoff_delay(retry_count) + + try: + quota_acquired = self._retry_quota.acquire(error=error) + except RetryError as quota_error: + if token_to_renew.is_long_polling: + raise RetryError(str(quota_error), retry_after=t_i) from error + raise if error.retry_after is not None: - retry_delay = error.retry_after - else: - retry_delay = self.backoff_strategy.compute_next_backoff_delay( - retry_count + # Bound a server-directed backoff to [t_i, t_i + 5] seconds. + retry_delay = max( + t_i, min(error.retry_after, self._RETRY_AFTER_MAX_ADDITIONAL + t_i) ) + else: + retry_delay = t_i return StandardRetryToken( retry_count=retry_count, retry_delay=retry_delay, quota_acquired=quota_acquired, + is_long_polling=token_to_renew.is_long_polling, ) else: raise RetryError(f"Error is not retryable: {error}") from error diff --git a/packages/smithy-core/src/smithy_core/exceptions.py b/packages/smithy-core/src/smithy_core/exceptions.py index 0a99976f9..dadb850e4 100644 --- a/packages/smithy-core/src/smithy_core/exceptions.py +++ b/packages/smithy-core/src/smithy_core/exceptions.py @@ -88,7 +88,16 @@ class DiscriminatorError(SmithyError): class RetryError(SmithyError): - """Base exception type for all exceptions raised in retry strategies.""" + """Base exception type for all exceptions raised in retry strategies. + + :param retry_after: An optional delay in seconds that the caller should wait before + giving up on retries. Long-polling operations use this to back off even when the + retry quota is exhausted. + """ + + def __init__(self, message: str = "", *, retry_after: float | None = None) -> None: + super().__init__(message) + self.retry_after = retry_after class ExpectationNotMetError(SmithyError): diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index 1b106c68c..a1d023f2e 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -204,9 +204,9 @@ class StandardRetryQuota: """Retry quota used by :py:class:`StandardRetryStrategy`.""" INITIAL_RETRY_TOKENS: int = 500 - RETRY_COST: int = 5 + RETRY_COST: int = 14 NO_RETRY_INCREMENT: int = 1 - TIMEOUT_RETRY_COST: int = 10 + THROTTLING_RETRY_COST: int = 5 def __init__(self, initial_capacity: int = INITIAL_RETRY_TOKENS): """Initialize retry quota with configurable capacity. @@ -224,11 +224,13 @@ def acquire(self, *, error: Exception) -> int: Otherwise, return the amount of capacity successfully allocated. """ - is_timeout = ( + is_throttling = ( isinstance(error, retries_interface.ErrorRetryInfo) - and error.is_timeout_error + and error.is_throttling_error + ) + capacity_amount = ( + self.THROTTLING_RETRY_COST if is_throttling else self.RETRY_COST ) - capacity_amount = self.TIMEOUT_RETRY_COST if is_timeout else self.RETRY_COST with self._lock: if capacity_amount > self._available_capacity: @@ -268,3 +270,6 @@ class StandardRetryToken: quota_acquired: int = 0 """The amount of quota acquired for this retry attempt.""" + + is_long_polling: bool = False + """Whether this token is for a long-polling operation.""" diff --git a/packages/smithy-core/tests/functional/test_retries.py b/packages/smithy-core/tests/functional/test_retries.py index 4889ebe37..8fbc3d1eb 100644 --- a/packages/smithy-core/tests/functional/test_retries.py +++ b/packages/smithy-core/tests/functional/test_retries.py @@ -60,7 +60,7 @@ async def test_standard_retry_eventually_succeeds(): assert result == "success" assert attempts == 3 - assert quota.available_capacity == 495 + assert quota.available_capacity == 486 async def test_standard_retry_fails_due_to_max_attempts(): @@ -70,11 +70,11 @@ async def test_standard_retry_fails_due_to_max_attempts(): with pytest.raises(CallError, match="502"): await retry_operation(strategy, [502, 502, 502]) - assert quota.available_capacity == 490 + assert quota.available_capacity == 472 async def test_retry_quota_exhausted_after_single_retry(): - quota = StandardRetryQuota(initial_capacity=5) + quota = StandardRetryQuota(initial_capacity=14) strategy = StandardRetryStrategy(max_attempts=3, retry_quota=quota) with pytest.raises(CallError, match="502"): @@ -94,26 +94,26 @@ async def test_retry_quota_prevents_retries_when_quota_zero(): async def test_retry_quota_stops_retries_when_exhausted(): - quota = StandardRetryQuota(initial_capacity=10) + quota = StandardRetryQuota(initial_capacity=20) strategy = StandardRetryStrategy(max_attempts=5, retry_quota=quota) - with pytest.raises(CallError, match="503"): - await retry_operation(strategy, [500, 502, 503]) + with pytest.raises(CallError, match="502"): + await retry_operation(strategy, [500, 502]) - assert quota.available_capacity == 0 + assert quota.available_capacity == 6 async def test_retry_quota_recovers_after_successful_responses(): - quota = StandardRetryQuota(initial_capacity=15) + quota = StandardRetryQuota(initial_capacity=30) strategy = StandardRetryStrategy(max_attempts=5, retry_quota=quota) # First operation: 2 retries then success await retry_operation(strategy, [500, 502, 200]) - assert quota.available_capacity == 10 + assert quota.available_capacity == 16 # Second operation: 1 retry then success await retry_operation(strategy, [500, 200]) - assert quota.available_capacity == 10 + assert quota.available_capacity == 16 async def test_retry_quota_shared_across_concurrent_operations(): @@ -136,7 +136,7 @@ async def test_retry_quota_shared_across_concurrent_operations(): assert result1 == ("success", 3) assert result2 == ("success", 2) - assert quota.available_capacity == 495 + assert quota.available_capacity == 486 async def test_retry_quota_handles_timeout_errors(): @@ -150,4 +150,4 @@ async def test_retry_quota_handles_timeout_errors(): assert result == "success" assert attempts == 3 - assert quota.available_capacity == 490 + assert quota.available_capacity == 486 diff --git a/packages/smithy-core/tests/unit/aio/test_retries.py b/packages/smithy-core/tests/unit/aio/test_retries.py index f35c50750..94d7076f9 100644 --- a/packages/smithy-core/tests/unit/aio/test_retries.py +++ b/packages/smithy-core/tests/unit/aio/test_retries.py @@ -8,8 +8,12 @@ StandardRetryStrategy, ) from smithy_core.exceptions import CallError, RetryError +from smithy_core.retries import ( + ExponentialBackoffJitterType as EBJT, +) from smithy_core.retries import ( ExponentialRetryBackoffStrategy, + StandardRetryQuota, ) @@ -96,14 +100,165 @@ async def test_standard_retry_does_not_retry(error: Exception | CallError) -> No await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) -async def test_standard_retry_after_overrides_backoff() -> None: - strategy = StandardRetryStrategy() - error = CallError(is_retry_safe=True, retry_after=5.5) +async def test_standard_retry_after_within_bounds_is_honored() -> None: + strategy = StandardRetryStrategy( + backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=1, jitter_type=EBJT.NONE + ) + ) + error = CallError(is_retry_safe=True, retry_after=3.0) + token = await strategy.acquire_initial_retry_token() + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == 3.0 + + +async def test_standard_retry_after_floored_to_backoff() -> None: + strategy = StandardRetryStrategy( + backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=1, jitter_type=EBJT.NONE + ) + ) + error = CallError(is_retry_safe=True, retry_after=0.5) + token = await strategy.acquire_initial_retry_token() + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == 1.0 + + +async def test_standard_retry_after_capped_at_backoff_plus_max() -> None: + strategy = StandardRetryStrategy( + backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=1, jitter_type=EBJT.NONE + ) + ) + error = CallError(is_retry_safe=True, retry_after=10.0) token = await strategy.acquire_initial_retry_token() token = await strategy.refresh_retry_token_for_retry( token_to_renew=token, error=error ) - assert token.retry_delay == 5.5 + assert token.retry_delay == 6.0 + + +async def test_standard_non_throttling_uses_default_backoff_scale() -> None: + strategy = StandardRetryStrategy( + backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=0.05, + jitter_type=EBJT.NONE, + ) + ) + error = CallError(is_retry_safe=True, is_throttling_error=False) + token = await strategy.acquire_initial_retry_token() + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == pytest.approx(0.05) # type: ignore + + +async def test_standard_throttling_uses_throttling_backoff_scale() -> None: + strategy = StandardRetryStrategy( + throttling_backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=1, + jitter_type=EBJT.NONE, + ) + ) + error = CallError(is_retry_safe=True, is_throttling_error=True) + token = await strategy.acquire_initial_retry_token() + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == pytest.approx(1.0) # type: ignore + + +async def test_dynamodb_profile_uses_25ms_scale_and_4_attempts() -> None: + strategy = StandardRetryStrategy( + backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=0.025, jitter_type=EBJT.NONE + ), + max_attempts=4, + ) + assert strategy.max_attempts == 4 + + error = CallError(is_retry_safe=True) + token = await strategy.acquire_initial_retry_token() + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == pytest.approx(0.025) # type: ignore + token = await strategy.refresh_retry_token_for_retry( + token_to_renew=token, error=error + ) + assert token.retry_delay == pytest.approx(0.05) # type: ignore + + +async def test_long_polling_backs_off_when_quota_exhausted() -> None: + strategy = StandardRetryStrategy( + backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=0.05, jitter_type=EBJT.NONE + ), + retry_quota=StandardRetryQuota(initial_capacity=0), + max_attempts=5, + ) + error = CallError(is_retry_safe=True) + token = await strategy.acquire_initial_retry_token(is_long_polling=True) + assert token.is_long_polling is True + with pytest.raises(RetryError) as exc_info: + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + assert exc_info.value.retry_after == pytest.approx(0.05) # type: ignore + + +async def test_long_polling_throttling_backs_off_with_throttling_scale() -> None: + strategy = StandardRetryStrategy( + throttling_backoff_strategy=ExponentialRetryBackoffStrategy( + backoff_scale_value=1, jitter_type=EBJT.NONE + ), + retry_quota=StandardRetryQuota(initial_capacity=0), + max_attempts=5, + ) + error = CallError(is_retry_safe=True, is_throttling_error=True) + token = await strategy.acquire_initial_retry_token(is_long_polling=True) + with pytest.raises(RetryError) as exc_info: + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + assert exc_info.value.retry_after == pytest.approx(1.0) # type: ignore + + +async def test_long_polling_no_backoff_when_max_attempts_reached() -> None: + strategy = StandardRetryStrategy( + retry_quota=StandardRetryQuota(initial_capacity=0), + max_attempts=1, + ) + error = CallError(is_retry_safe=True) + token = await strategy.acquire_initial_retry_token(is_long_polling=True) + with pytest.raises(RetryError) as exc_info: + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + assert exc_info.value.retry_after is None + + +async def test_long_polling_no_backoff_for_non_retryable_error() -> None: + strategy = StandardRetryStrategy( + retry_quota=StandardRetryQuota(initial_capacity=0), + max_attempts=5, + ) + error = CallError(fault="client", is_retry_safe=False) + token = await strategy.acquire_initial_retry_token(is_long_polling=True) + with pytest.raises(RetryError) as exc_info: + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + assert exc_info.value.retry_after is None + + +async def test_non_long_polling_does_not_back_off_when_quota_exhausted() -> None: + strategy = StandardRetryStrategy( + retry_quota=StandardRetryQuota(initial_capacity=0), + max_attempts=5, + ) + error = CallError(is_retry_safe=True) + token = await strategy.acquire_initial_retry_token() + assert token.is_long_polling is False + with pytest.raises(RetryError) as exc_info: + await strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + assert exc_info.value.retry_after is None async def test_standard_retry_invalid_max_attempts() -> None: @@ -166,3 +321,69 @@ async def test_retry_strategy_resolver_rejects_invalid_type() -> None: match="retry_strategy must be RetryStrategy, RetryStrategyOptions, or None", ): await resolver.resolve_retry_strategy(retry_strategy="invalid") # type: ignore + + +async def test_resolver_service_defaults_applied_when_customer_unset() -> None: + resolver = RetryStrategyResolver( + default_max_attempts=4, default_backoff_scale=0.025 + ) + + strategy = await resolver.resolve_retry_strategy(retry_strategy=None) + + assert isinstance(strategy, StandardRetryStrategy) + assert strategy.max_attempts == 4 + delay = strategy.backoff_strategy.compute_next_backoff_delay(1) + assert 0 <= delay <= 0.025 + + +async def test_resolver_customer_max_attempts_overrides_default_keeps_backoff() -> None: + resolver = RetryStrategyResolver( + default_max_attempts=4, default_backoff_scale=0.025 + ) + + strategy = await resolver.resolve_retry_strategy( + retry_strategy=RetryStrategyOptions(max_attempts=10) + ) + + assert isinstance(strategy, StandardRetryStrategy) + assert strategy.max_attempts == 10 + delay = strategy.backoff_strategy.compute_next_backoff_delay(1) + assert 0 <= delay <= 0.025 + + +async def test_resolver_empty_options_still_get_service_defaults() -> None: + resolver = RetryStrategyResolver( + default_max_attempts=4, default_backoff_scale=0.025 + ) + + strategy = await resolver.resolve_retry_strategy( + retry_strategy=RetryStrategyOptions() + ) + + assert isinstance(strategy, StandardRetryStrategy) + assert strategy.max_attempts == 4 + delay = strategy.backoff_strategy.compute_next_backoff_delay(1) + assert 0 <= delay <= 0.025 + + +async def test_resolver_no_service_defaults_uses_strategy_defaults() -> None: + resolver = RetryStrategyResolver() + + strategy = await resolver.resolve_retry_strategy(retry_strategy=None) + + assert isinstance(strategy, StandardRetryStrategy) + assert strategy.max_attempts == 3 + delay = strategy.backoff_strategy.compute_next_backoff_delay(1) + assert 0 <= delay <= 0.05 + + +async def test_resolver_explicit_strategy_ignores_service_defaults() -> None: + resolver = RetryStrategyResolver( + default_max_attempts=4, default_backoff_scale=0.025 + ) + provided = StandardRetryStrategy(max_attempts=7) + + strategy = await resolver.resolve_retry_strategy(retry_strategy=provided) + + assert strategy is provided + assert strategy.max_attempts == 7 diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index 65f9a2c47..4dc32d221 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -58,20 +58,21 @@ def test_exponential_backoff_strategy( @pytest.fixture def retry_quota() -> StandardRetryQuota: - return StandardRetryQuota(initial_capacity=10) + return StandardRetryQuota(initial_capacity=28) def test_retry_quota_initial_state( retry_quota: StandardRetryQuota, ) -> None: - assert retry_quota.available_capacity == 10 + assert retry_quota.available_capacity == 28 def test_retry_quota_acquire_success( retry_quota: StandardRetryQuota, ) -> None: acquired = retry_quota.acquire(error=Exception()) - assert retry_quota.available_capacity == 10 - acquired + assert acquired == StandardRetryQuota.RETRY_COST + assert retry_quota.available_capacity == 28 - acquired def test_retry_quota_acquire_when_exhausted( @@ -81,7 +82,7 @@ def test_retry_quota_acquire_when_exhausted( retry_quota.acquire(error=Exception()) retry_quota.acquire(error=Exception()) - # Not enough capacity for another retry (need 5, only 0 left) + # Not enough capacity for another retry (need RETRY_COST, only 0 left) with pytest.raises(RetryError, match="Retry quota exceeded"): retry_quota.acquire(error=Exception()) @@ -91,16 +92,19 @@ def test_retry_quota_release_restores_capacity( ) -> None: acquired = retry_quota.acquire(error=Exception()) retry_quota.release(release_amount=acquired) - assert retry_quota.available_capacity == 10 + assert retry_quota.available_capacity == 28 def test_retry_quota_release_zero_adds_increment( retry_quota: StandardRetryQuota, ) -> None: retry_quota.acquire(error=Exception()) - assert retry_quota.available_capacity == 5 + assert retry_quota.available_capacity == 28 - StandardRetryQuota.RETRY_COST retry_quota.release(release_amount=0) - assert retry_quota.available_capacity == 6 + assert ( + retry_quota.available_capacity + == 28 - StandardRetryQuota.RETRY_COST + StandardRetryQuota.NO_RETRY_INCREMENT + ) def test_retry_quota_release_caps_at_max( @@ -110,13 +114,15 @@ def test_retry_quota_release_caps_at_max( retry_quota.acquire(error=Exception()) # Release more than we acquired. Should cap at initial capacity. retry_quota.release(release_amount=50) - assert retry_quota.available_capacity == 10 + assert retry_quota.available_capacity == 28 -def test_retry_quota_acquire_timeout_error( +def test_retry_quota_acquire_throttling_error( retry_quota: StandardRetryQuota, ) -> None: - timeout_error = CallError(is_timeout_error=True, is_retry_safe=True) - acquired = retry_quota.acquire(error=timeout_error) - assert acquired == StandardRetryQuota.TIMEOUT_RETRY_COST - assert retry_quota.available_capacity == 0 + throttling_error = CallError(is_throttling_error=True, is_retry_safe=True) + acquired = retry_quota.acquire(error=throttling_error) + assert acquired == StandardRetryQuota.THROTTLING_RETRY_COST + assert ( + retry_quota.available_capacity == 28 - StandardRetryQuota.THROTTLING_RETRY_COST + ) diff --git a/packages/smithy-http/.changes/next-release/smithy-http-feature-fe81da670b724feca76976bcf09fa736.json b/packages/smithy-http/.changes/next-release/smithy-http-feature-fe81da670b724feca76976bcf09fa736.json new file mode 100644 index 000000000..849574df6 --- /dev/null +++ b/packages/smithy-http/.changes/next-release/smithy-http-feature-fe81da670b724feca76976bcf09fa736.json @@ -0,0 +1,4 @@ +{ + "type": "feature", + "description": "Added support for the `x-amz-retry-after` response header." +} \ No newline at end of file diff --git a/packages/smithy-http/src/smithy_http/aio/protocols.py b/packages/smithy-http/src/smithy_http/aio/protocols.py index af32cee16..7d967dcf5 100644 --- a/packages/smithy-http/src/smithy_http/aio/protocols.py +++ b/packages/smithy-http/src/smithy_http/aio/protocols.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import logging import os from collections.abc import AsyncIterable from inspect import iscoroutinefunction @@ -29,6 +30,30 @@ from ..serializers import HTTPRequestSerializer from .interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse +_LOGGER = logging.getLogger(__name__) + +_RETRY_AFTER_HEADER = "x-amz-retry-after" + + +def parse_retry_after(response: HTTPResponse) -> float | None: + """Parse the ``x-amz-retry-after`` header into a backoff duration in seconds. + + The header value is an integer number of milliseconds. Invalid or missing + values are ignored (return ``None``) so they fall back to exponential backoff. + """ + if _RETRY_AFTER_HEADER not in response.fields: + return None + raw = response.fields[_RETRY_AFTER_HEADER].as_string() + try: + milliseconds = int(raw) + except (ValueError, TypeError): + _LOGGER.debug("Ignoring invalid %s header value: %r", _RETRY_AFTER_HEADER, raw) + return None + if milliseconds < 0: + _LOGGER.debug("Ignoring negative %s header value: %r", _RETRY_AFTER_HEADER, raw) + return None + return milliseconds / 1000.0 + class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]): """An HTTP-based protocol.""" @@ -186,6 +211,8 @@ async def _create_error( operation=operation, response=response ) + retry_after = parse_retry_after(response) + if error_id is None and self._matches_content_type(response): if isinstance(response_body, bytearray): response_body = bytes(response_body) @@ -213,7 +240,10 @@ async def _create_error( response=response, body=response_body, ) - return error_shape.deserialize(deserializer) + modeled_error = error_shape.deserialize(deserializer) + if retry_after is not None: + modeled_error.retry_after = retry_after + return modeled_error message = ( f"Unknown error for operation {operation.schema.id} " @@ -234,6 +264,7 @@ async def _create_error( is_throttling_error=is_throttle, is_timeout_error=is_timeout, is_retry_safe=is_throttle or is_timeout or None, + retry_after=retry_after, ) def _matches_content_type(self, response: HTTPResponse) -> bool: diff --git a/packages/smithy-http/tests/unit/aio/test_protocols.py b/packages/smithy-http/tests/unit/aio/test_protocols.py index 4ae18ce67..c36a8e1c9 100644 --- a/packages/smithy-http/tests/unit/aio/test_protocols.py +++ b/packages/smithy-http/tests/unit/aio/test_protocols.py @@ -11,11 +11,11 @@ from smithy_core.interfaces import URI as URIInterface from smithy_core.schemas import APIOperation from smithy_core.shapes import ShapeID -from smithy_http import Fields -from smithy_http.aio import HTTPRequest +from smithy_http import Field, Fields +from smithy_http.aio import HTTPRequest, HTTPResponse from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface -from smithy_http.aio.protocols import HttpClientProtocol +from smithy_http.aio.protocols import HttpClientProtocol, parse_retry_after class MockProtocol(HttpClientProtocol): @@ -135,3 +135,37 @@ def test_http_protocol_joins_uris( updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint) actual = updated_request.destination assert actual == expected + + +@pytest.mark.parametrize( + "header_value, expected", + [ + ("1500", 1.5), + ("0", 0.0), + ("20", 0.02), + ("invalid", None), + ("1.5", None), + ("-100", None), + ("", None), + ], +) +def test_parse_retry_after(header_value: str, expected: float | None) -> None: + response = HTTPResponse( + status=500, + fields=Fields([Field(name="x-amz-retry-after", values=[header_value])]), + ) + assert parse_retry_after(response) == expected + + +def test_parse_retry_after_missing_header() -> None: + response = HTTPResponse(status=500, fields=Fields()) + assert parse_retry_after(response) is None + + +def test_parse_retry_after_ignores_standard_retry_after_header() -> None: + # The standard HTTP Retry-After header must be ignored. + response = HTTPResponse( + status=503, + fields=Fields([Field(name="Retry-After", values=["120"])]), + ) + assert parse_retry_after(response) is None