diff --git a/mypy/constraints.py b/mypy/constraints.py index 6416791fa74a8..bf69da468a69c 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -411,18 +411,15 @@ def _infer_constraints( # When the template is a union, we are okay with leaving some # type variables indeterminate. This helps with some special # cases, though this isn't very principled. - result = any_constraints( + if has_recursive_types(template) and not has_recursive_types(actual): + return handle_recursive_union(template, actual, direction) + return any_constraints( [ infer_constraints_if_possible(t_item, actual, direction) for t_item in template.items ], eager=isinstance(actual, AnyType), ) - if result: - return result - elif has_recursive_types(template) and not has_recursive_types(actual): - return handle_recursive_union(template, actual, direction) - return [] # Remaining cases are handled by ConstraintBuilderVisitor. return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op)) @@ -535,13 +532,12 @@ def any_constraints(options: list[list[Constraint] | None], *, eager: bool) -> l # Multiple sets of constraints that are all the same. Just pick any one of them. return valid_options[0] - if all(is_similar_constraints(valid_options[0], c) for c in valid_options[1:]): + all_similar = all(is_similar_constraints(valid_options[0], c) for c in valid_options[1:]) + if all_similar: # All options have same structure. In this case we can merge-in trivial # options (i.e. those that only have Any) and try again. - # TODO: More generally, if a given (variable, direction) pair appears in - # every option, combine the bounds with meet/join always, not just for Any. trivial_options = select_trivial(valid_options) - if trivial_options and len(trivial_options) < len(valid_options): + if 0 < len(trivial_options) < len(valid_options): merged_options = [] for option in valid_options: if option in trivial_options: @@ -563,6 +559,31 @@ def any_constraints(options: list[list[Constraint] | None], *, eager: bool) -> l if filtered_options != options: return any_constraints(filtered_options, eager=eager) + if ( + eager + and all_similar + and not any(isinstance(c.target, ErasedType) for group in valid_options for c in group) + ): + # Now we know all constraints might be satisfiable and have similar structure. + # Solver will apply meets and joins as necessary, but Any should be forced into + # union to survive during meet. + # If any targets are erased, fall back to empty, otherwise they will be discarded + # by solver, causing false early matches. + cmap: dict[TypeVarId, list[Constraint]] = {} + for option in valid_options: + for c in option: + cmap.setdefault(c.type_var, []).append(c) + out: list[Constraint] = [] + for group in cmap.values(): + if any(isinstance(get_proper_type(c.target), AnyType) for c in group): + group = [ + merge_with_any(c) + for c in group + if not isinstance(get_proper_type(c.target), AnyType) + ] + out.extend(dict.fromkeys(group)) + return out + # Otherwise, there are either no valid options or multiple, inconsistent valid # options. Give up and deduce nothing. return [] diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index 5a674cca09da3..d31b482db8e35 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -1530,3 +1530,33 @@ def check3(use: bool, val: str) -> "str | Literal[False]": def check4(use: bool, val: str) -> "str | bool": return use and identity(val) [builtins fixtures/tuple.pyi] + +[case testDictOrLiteralInContext] +from typing import Union, Optional, Any + +P = dict[str, Union[Optional[str], dict[str, Optional[str]]]] + +def f(x: P) -> None: + pass + +def g(x: Union[dict[str, Any], None], s: Union[str, None]) -> None: + f(x or {'x': s}) +[builtins fixtures/dict.pyi] + +[case testInferConstrainedTypeVarInUnion] +from typing import Generic, TypeVar, Union + +_S_co = TypeVar("_S_co", str, int, covariant=True) +_S = TypeVar("_S", str, int) + +class HasFoo(Generic[_S_co]): + def foo(self) -> _S_co: ... + +def walk(path: Union[_S, HasFoo[_S]]) -> None: + ... + +class Path(HasFoo[str]): + def foo(self) -> str: ... + +walk(Path()) +[builtins fixtures/tuple.pyi]