From b0744de42ea20584f0342bec0f62bac7a6114d13 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Thu, 2 Jul 2026 05:39:09 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 941648921 --- fiddle/_src/arg_factory.py | 4 ++-- fiddle/_src/building.py | 2 +- fiddle/_src/building_test.py | 2 +- fiddle/_src/casting.py | 2 +- fiddle/_src/config.py | 26 +++++++++++++------------- fiddle/_src/config_test.py | 4 ++-- fiddle/_src/daglish.py | 8 ++++---- fiddle/_src/daglish_test.py | 10 +++++----- fiddle/_src/diffing.py | 20 ++++++++++---------- fiddle/_src/diffing_test.py | 16 ++++++++-------- fiddle/_src/graphviz.py | 10 +++++----- fiddle/_src/history_test.py | 2 +- fiddle/_src/partial.py | 2 +- fiddle/_src/selectors.py | 6 +++--- fiddle/_src/signatures.py | 10 +++++----- fiddle/_src/signatures_test.py | 2 +- fiddle/_src/tag_type.py | 2 +- fiddle/_src/tagging.py | 6 +++--- fiddle/_src/tagging_test.py | 4 ++-- 19 files changed, 69 insertions(+), 69 deletions(-) diff --git a/fiddle/_src/arg_factory.py b/fiddle/_src/arg_factory.py index 038bf346..789b6862 100644 --- a/fiddle/_src/arg_factory.py +++ b/fiddle/_src/arg_factory.py @@ -195,7 +195,7 @@ class ArgFactory: # of `ArgFactory(list)` is `ArgFactory`, but the declared type is `list`. # (Without the overload, we'd get an annotation-type-mismatch error.) @overload - def __new__(cls, factory: Callable[..., T]) -> T: + def __new__(cls, factory: Callable[..., T]) -> T: # pyrefly: ignore[invalid-overload] ... def __new__(cls, /, *args, **kwargs): @@ -406,7 +406,7 @@ def wrapper(*args, **kwargs): # `default_factory` is `ArgFactory`, but the declared type is `list`. # (Without the overload, we'd get an annotation-type-mismatch error.) @overload -def default_factory(factory: Callable[..., T]) -> T: +def default_factory(factory: Callable[..., T]) -> T: # pyrefly: ignore[invalid-overload] ... diff --git a/fiddle/_src/building.py b/fiddle/_src/building.py index 61fc29ce..521688a3 100644 --- a/fiddle/_src/building.py +++ b/fiddle/_src/building.py @@ -177,7 +177,7 @@ def _build(value: Any, state: daglish.State) -> Any: metadata: config_lib.BuildableTraverserMetadata = sub_traversal.metadata arguments = metadata.arguments(sub_traversal.values) is_built = True - return call_buildable(value, arguments, current_path=state.current_path) + return call_buildable(value, arguments, current_path=state.current_path) # pyrefly: ignore[bad-argument-type] else: return state.map_children(value) diff --git a/fiddle/_src/building_test.py b/fiddle/_src/building_test.py index 15e8cfb1..1de1225a 100644 --- a/fiddle/_src/building_test.py +++ b/fiddle/_src/building_test.py @@ -35,7 +35,7 @@ class Bar: x: int -class NonBuildableLoggingTest(absltest.TestCase, unittest.TestCase): +class NonBuildableLoggingTest(absltest.TestCase, unittest.TestCase): # pyrefly: ignore[inconsistent-inheritance] def setUp(self): super().setUp() diff --git a/fiddle/_src/casting.py b/fiddle/_src/casting.py index f9d44749..91246877 100644 --- a/fiddle/_src/casting.py +++ b/fiddle/_src/casting.py @@ -58,7 +58,7 @@ def cast( src_type, new_type, ) - return new_type.__unflatten__(*buildable.__flatten__()) + return new_type.__unflatten__(*buildable.__flatten__()) # pyrefly: ignore[missing-attribute] def register_supported_cast(src_type, dst_type): diff --git a/fiddle/_src/config.py b/fiddle/_src/config.py index d7987154..19131ffe 100644 --- a/fiddle/_src/config.py +++ b/fiddle/_src/config.py @@ -78,7 +78,7 @@ def _buildable_flatten( metadata = BuildableTraverserMetadata( fn_or_cls=buildable.__fn_or_cls__, argument_names=keys, - argument_tags=argument_tags, + argument_tags=argument_tags, # pyrefly: ignore[bad-argument-type] argument_history=argument_history, ) return values, metadata @@ -88,7 +88,7 @@ def _buildable_path_elements( buildable: Buildable, include_defaults: bool = False ) -> Tuple[daglish.PathElement]: """Implement Buildable.__path_elements__ method.""" - return tuple( + return tuple( # pyrefly: ignore[bad-return] daglish.Attr(name) if isinstance(name, str) else daglish.Index(name) for name in ordered_arguments( buildable, include_defaults=include_defaults @@ -198,7 +198,7 @@ def without_history(self) -> BuildableTraverserMetadata: def arguments(self, values: Iterable[Any]) -> Dict[str, Any]: """Returns a dictionary combining ``self.argument_names`` with ``values``.""" - return dict(zip(self.argument_names, values)) + return dict(zip(self.argument_names, values)) # pyrefly: ignore[bad-return] def tags(self) -> Dict[str, set[tag_type.TagType]]: return collections.defaultdict( @@ -265,7 +265,7 @@ def __init__( f'Unexpected type received for the argument name: {key!r}' ) - for name, tags in tag_type.find_tags_from_annotations(fn_or_cls).items(): + for name, tags in tag_type.find_tags_from_annotations(fn_or_cls).items(): # pyrefly: ignore[bad-argument-type] self.__argument_tags__[name].update(tags) self.__argument_history__.add_updated_tags( name, self.__argument_tags__[name] @@ -445,7 +445,7 @@ def __delitem__(self, key: Any): new_placeholders = old_placeholders.copy() # Traverse from largest index to maintain order of undeleted indices. for index in indices[::-1]: - if index < var_positional_start: + if index < var_positional_start: # pyrefly: ignore[unsupported-operation] k = self.__signature_info__.index_to_key(index, self.__arguments__) if k in self.__arguments__: self._arguments_del_value(k) @@ -453,7 +453,7 @@ def __delitem__(self, key: Any): del new_placeholders[index] # Delete var-positional args and compact the *args list. - for index in range(var_positional_start, len(old_placeholders)): + for index in range(var_positional_start, len(old_placeholders)): # pyrefly: ignore[bad-argument-type] if index < len(new_placeholders): if new_placeholders[index] != old_placeholders[index]: new_value = self.__arguments__[new_placeholders[index].index] @@ -463,7 +463,7 @@ def __delitem__(self, key: Any): def _set_item_by_index(self, key: int, value: Any): """Set positional arguments by index.""" - key = self.__signature_info__.index_to_key(key, self.__arguments__) + key = self.__signature_info__.index_to_key(key, self.__arguments__) # pyrefly: ignore[bad-assignment] positional_num = self.__signature_info__.var_positional_start if positional_num is None: # *args does not exist @@ -555,7 +555,7 @@ def __dir__(self) -> Collection[str]: set_argument_names = self.__arguments__.keys() valid_param_names = set(self.__signature_info__.valid_param_names) all_names = valid_param_names.union(set_argument_names) - return all_names + return all_names # pyrefly: ignore[bad-return] # Buildable are mutable so do not make this `@functools.cached_property` def _fn_or_cls_name_repr(self) -> str: @@ -581,7 +581,7 @@ def _params_name_tags_and_value( for name in param_names: tags = self.__argument_tags__.get(name, set()) value = self.__arguments__.get(name, NO_VALUE) - yield name, tags, value + yield name, tags, value # pyrefly: ignore[invalid-yield] def __repr__(self): formatted_fn_or_cls = self._fn_or_cls_name_repr() @@ -825,7 +825,7 @@ def tagged_value_fn( if tags: msg += ' Unset tags: ' + str(tags) raise tag_type.TaggedValueNotFilledError(msg) - return value + return value # pyrefly: ignore[bad-return] class TaggedValueCls(Generic[T], Config[T]): @@ -851,7 +851,7 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T: 'Unexpected __fn_or_cls__ in TaggedValueCls; found:' f'{self.__fn_or_cls__}' ) - return self.__fn_or_cls__(tags=self.tags, *args, **kwargs) + return self.__fn_or_cls__(tags=self.tags, *args, **kwargs) # pyrefly: ignore[unexpected-keyword] def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str): @@ -941,11 +941,11 @@ def ordered_arguments( if include_var_keyword: for name, value in buildable.__arguments__.items(): - param = buildable.__signature_info__.parameters.get(name) + param = buildable.__signature_info__.parameters.get(name) # pyrefly: ignore[bad-argument-type] if param is None or param.kind == param.VAR_KEYWORD: result[name] = value if not include_positional: result = {k: v for k, v in result.items() if isinstance(k, str)} - return result + return result # pyrefly: ignore[bad-return] diff --git a/fiddle/_src/config_test.py b/fiddle/_src/config_test.py index 56a77408..bdfc6ddb 100644 --- a/fiddle/_src/config_test.py +++ b/fiddle/_src/config_test.py @@ -121,7 +121,7 @@ def raise_error(): @dataclasses.dataclass class GenericClass(Generic[_T]): - x: _T = 1 + x: _T = 1 # pyrefly: ignore[bad-assignment] class ConfigTest(parameterized.TestCase): @@ -921,7 +921,7 @@ def test_history_tracking(self): cfg.__argument_history__['arg1'][1].sequence_id) def test_custom_location_history_tracking(self): - with history.custom_location(lambda: 'abc:123'): + with history.custom_location(lambda: 'abc:123'): # pyrefly: ignore[bad-argument-type] cfg = fdl.Config(SampleClass, 'arg1') cfg.arg2 = 'arg2' self.assertEqual( diff --git a/fiddle/_src/daglish.py b/fiddle/_src/daglish.py index 3f9a4bc7..16c37d9e 100644 --- a/fiddle/_src/daglish.py +++ b/fiddle/_src/daglish.py @@ -61,7 +61,7 @@ def follow(self, container: Union[List[Any], Tuple[Any, ...]]) -> Any: def __lt__(self, other: PathElement) -> bool: if type(self) is type(other): - return self.index < other.index + return self.index < other.index # pyrefly: ignore[missing-attribute] else: return super().__lt__(other) @@ -80,7 +80,7 @@ def follow(self, container: Dict[Any, Any]) -> Any: def __lt__(self, other: PathElement) -> bool: if type(self) is type(other): - return self.key < other.key + return self.key < other.key # pyrefly: ignore[missing-attribute] else: return super().__lt__(other) @@ -99,7 +99,7 @@ def follow(self, container: Any) -> Any: def __lt__(self, other: PathElement) -> bool: if type(self) is type(other): - return self.name < other.name + return self.name < other.name # pyrefly: ignore[missing-attribute] else: return super().__lt__(other) @@ -273,7 +273,7 @@ def is_traversable_type(self, node_type: Type[Any]) -> bool: dict, flatten_fn=lambda x: (tuple(x.values()), tuple(x.keys())), unflatten_fn=lambda values, keys: dict(zip(keys, values)), - path_elements_fn=lambda x: [Key(key) for key in x.keys()]) + path_elements_fn=lambda x: [Key(key) for key in x.keys()]) # pyrefly: ignore[bad-argument-type] def flatten_defaultdict(node): diff --git a/fiddle/_src/daglish_test.py b/fiddle/_src/daglish_test.py index 96f62a04..72a2b2ba 100644 --- a/fiddle/_src/daglish_test.py +++ b/fiddle/_src/daglish_test.py @@ -215,11 +215,11 @@ def test_follow_path(self): self.assertIs(daglish.follow_path(root, path5), root[2]) path6 = (daglish.Index(1), daglish.Key("a")) - self.assertIs(daglish.follow_path(root, path6), root[1]["a"]) + self.assertIs(daglish.follow_path(root, path6), root[1]["a"]) # pyrefly: ignore[bad-index] path7 = (daglish.Index(2), daglish.Index(2), daglish.BuildableFnOrCls()) self.assertIs( - daglish.follow_path(root, path7), fdl.get_callable(root[2][2])) + daglish.follow_path(root, path7), fdl.get_callable(root[2][2])) # pyrefly: ignore[bad-argument-type, bad-index] bad_path_1 = (daglish.Key("a"), daglish.Key("b")) with self.assertRaisesRegex( @@ -313,7 +313,7 @@ def test_register_node_traverser_non_type_error(self): cast(Any, 42), flatten_fn=lambda x: (tuple(x), None), unflatten_fn=lambda x, _: list(x), - path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x)))) + path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x)))) # pyrefly: ignore[bad-argument-type] def test_register_node_traverser_existing_registration_error(self): with self.assertRaises(ValueError): @@ -321,7 +321,7 @@ def test_register_node_traverser_existing_registration_error(self): list, flatten_fn=lambda x: (tuple(x), None), unflatten_fn=lambda x, _: list(x), - path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x)))) + path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x)))) # pyrefly: ignore[bad-argument-type] def test_node_traverser_registry_with_fallback(self): registry = daglish.NodeTraverserRegistry(use_fallback=True) @@ -674,7 +674,7 @@ def traverse(value, state: daglish.State): return state.map_children(value) x = [1, 2, 3] - x.append(x) + x.append(x) # pyrefly: ignore[bad-argument-type] with self.assertRaisesRegex( ValueError, "Fiddle detected a cycle while traversing a value: " diff --git a/fiddle/_src/diffing.py b/fiddle/_src/diffing.py index 1ec8ec42..75079c14 100644 --- a/fiddle/_src/diffing.py +++ b/fiddle/_src/diffing.py @@ -107,7 +107,7 @@ def ignoring_paths(self, paths=Iterable[daglish.Path]) -> 'Diff': Returns: A new `Diff` without changes that relate to the given `paths`. """ - paths = set(paths) + paths = set(paths) # pyrefly: ignore[bad-argument-type] def _ignore_fn(change: DiffOperation): return any( @@ -123,7 +123,7 @@ class SetValue(DiffOperation): The target's parent may not be a sequence (list or tuple). """ - target: daglish.Path + target: daglish.Path # pyrefly: ignore[bad-override] new_value: Union[Reference, Any] def apply(self, parent: Any, child: daglish.PathElement): @@ -142,13 +142,13 @@ class ModifyValue(DiffOperation): The target's parent may not be a tuple. """ - target: daglish.Path + target: daglish.Path # pyrefly: ignore[bad-override] new_value: Union[Reference, Any] def apply(self, parent: Any, child: daglish.PathElement): """Replaces `child.follow(parent)` with self.new_value.""" if isinstance(child, daglish.BuildableFnOrCls): - mutate_buildable.update_callable(parent, self.new_value) + mutate_buildable.update_callable(parent, self.new_value) # pyrefly: ignore[bad-argument-type] elif isinstance(child, daglish.Attr): setattr(parent, child.name, self.new_value) elif isinstance(child, daglish.Index): @@ -165,7 +165,7 @@ class DeleteValue(DiffOperation): The target's parent may not be a sequence (list or tuple). """ - target: daglish.Path + target: daglish.Path # pyrefly: ignore[bad-override] def apply(self, parent: Any, child: daglish.PathElement): """Deletes `child.follow(parent)`.""" @@ -183,7 +183,7 @@ class AddTag(DiffOperation): The target's parent must be a `fdl.Buildable`. """ - target: daglish.Path + target: daglish.Path # pyrefly: ignore[bad-override] tag: tag_type.TagType def apply(self, parent: Any, child: daglish.PathElement): @@ -199,7 +199,7 @@ class RemoveTag(DiffOperation): The target's parent must be a `fdl.Buildable`. """ - target: daglish.Path + target: daglish.Path # pyrefly: ignore[bad-override] tag: tag_type.TagType def apply(self, parent: Any, child: daglish.PathElement): @@ -656,7 +656,7 @@ def record_tag_diffs(self, old_path: daglish.Path, empty_set = set([]) # Default value for dict.get. tag_name = lambda tag: tag.__name__ # For sorting. for arg_name in sorted(set(old_arg_tags) | set(new_arg_tags)): - target = old_path + (daglish.Attr(arg_name),) + target = old_path + (daglish.Attr(arg_name),) # pyrefly: ignore[bad-argument-type] old_tags = old_arg_tags.get(arg_name, empty_set) new_tags = new_arg_tags.get(arg_name, empty_set) for removed_tag in sorted(old_tags - new_tags, key=tag_name): @@ -1014,7 +1014,7 @@ def add_reference_target(path, value): if isinstance(change, RemoveTag): tagging.add_tag( daglish.follow_path(root, change.target[:-1]), - change.target[-1].name, + change.target[-1].name, # pyrefly: ignore[missing-attribute] change.tag, ) daglish_legacy.traverse_with_path(add_reference_target, @@ -1062,7 +1062,7 @@ def _add_path_to_skeleton(skeleton, path, skip_leaf=False): raise ValueError(f'Unuspported PathElement {path[0]}') # Recurse to the child element. - child = _add_path_to_skeleton(path[0].follow(skeleton), path[1:], skip_leaf) + child = _add_path_to_skeleton(path[0].follow(skeleton), path[1:], skip_leaf) # pyrefly: ignore[bad-argument-type] if isinstance(path[0], daglish.Attr): assert isinstance(skeleton, config_lib.Config) setattr(skeleton, path[0].name, child) diff --git a/fiddle/_src/diffing_test.py b/fiddle/_src/diffing_test.py index 96b52660..7b5c8404 100644 --- a/fiddle/_src/diffing_test.py +++ b/fiddle/_src/diffing_test.py @@ -791,8 +791,8 @@ def test_resolve_ref_from_new_shared_value_to_new_shared_value(self): diff_z = resolved_diff.changes[0] self.assertEqual(diff_z.target, parse_path('.z')) self.assertIsInstance(diff_z, diffing.SetValue) - self.assertIs(diff_z.new_value[0], resolved_diff.new_shared_values[0]) - self.assertIs(diff_z.new_value[1], resolved_diff.new_shared_values[1]) + self.assertIs(diff_z.new_value[0], resolved_diff.new_shared_values[0]) # pyrefly: ignore[bad-index] + self.assertIs(diff_z.new_value[1], resolved_diff.new_shared_values[1]) # pyrefly: ignore[bad-index] self.assertIs(resolved_diff.new_shared_values[1][0], resolved_diff.new_shared_values[0]) @@ -826,12 +826,12 @@ def test_resolve_diff_multiple_references(self): diff_1_x = resolved_diff.changes[0] self.assertEqual(diff_1_x.target, parse_path("[1]['x']")) self.assertIsInstance(diff_1_x, diffing.ModifyValue) - self.assertIs(diff_1_x.new_value, old[1]['y']) + self.assertIs(diff_1_x.new_value, old[1]['y']) # pyrefly: ignore[bad-index] diff_1_y = resolved_diff.changes[1] self.assertEqual(diff_1_y.target, parse_path("[1]['y']")) self.assertIsInstance(diff_1_y, diffing.ModifyValue) - self.assertIs(diff_1_y.new_value, old[1]['x']) + self.assertIs(diff_1_y.new_value, old[1]['x']) # pyrefly: ignore[bad-index] diff_1_z = resolved_diff.changes[2] self.assertEqual(diff_1_z.target, parse_path("[1]['z']")) @@ -965,10 +965,10 @@ def test_apply_diff_with_multiple_references(self): # Manually apply the same changes described by the diff: new = copy.deepcopy(old) - new[1]['x'], new[1]['y'] = new[1]['y'], new[1]['x'] - new[1]['z'] = new[2] - new[2].x = new[3] - new[2].z = [new[0], new[3]] + new[1]['x'], new[1]['y'] = new[1]['y'], new[1]['x'] # pyrefly: ignore[bad-index, unsupported-operation] + new[1]['z'] = new[2] # pyrefly: ignore[unsupported-operation] + new[2].x = new[3] # pyrefly: ignore[missing-attribute] + new[2].z = [new[0], new[3]] # pyrefly: ignore[missing-attribute] diffing.apply_diff(cfg_diff, old) self.assertEqual(old, new) diff --git a/fiddle/_src/graphviz.py b/fiddle/_src/graphviz.py index 9ba81a2a..463822c5 100644 --- a/fiddle/_src/graphviz.py +++ b/fiddle/_src/graphviz.py @@ -351,10 +351,10 @@ def _render_config(self, config: config_lib.Buildable, bgcolor: str) -> str: # Generate the arguments table. if config.__arguments__: label = self._render_dict( - config.__arguments__, + config.__arguments__, # pyrefly: ignore[bad-argument-type] header=header, key_format_fn=str, - tags=config.__argument_tags__) + tags=config.__argument_tags__) # pyrefly: ignore[bad-argument-type] else: table = self.tag('table') italics = self.tag('i') @@ -412,7 +412,7 @@ def _render_value(self, value: Any, color=_DEFAULT_HEADER_COLOR) -> str: elif isinstance(value, _ChangedValue): return self._render_changed_value(value) elif isinstance(value, _ChangedBuildable): - return self._render_changed_buildable(value, color) + return self._render_changed_buildable(value, color) # pyrefly: ignore[bad-argument-type] elif isinstance(value, dict): return self._render_dict( value, header=self._header_row(type(value).__name__, bgcolor=color)) @@ -470,7 +470,7 @@ def _render_nested_value(self, value: Any): self._dot.edge(f'{self._current_id}:{port}:c', f'{node_id}:c', **edge_attrs) # Return a table with a single colored cell, using the port name from above. - style = self._config_header_style(value) + style = self._config_header_style(value) # pyrefly: ignore[bad-argument-type] table = self.tag('table', style=style) tr = self.tag('tr') td = self.tag('td', port=port, bgcolor=self._color(value), style=style) @@ -542,7 +542,7 @@ def _render_dict( continue key_str = html.escape(key_format_fn(key)) value_str = self._render_nested_value(value) - key_tags = tags.get(key, ()) + key_tags = tags.get(key, ()) # pyrefly: ignore[no-matching-overload] if key_tags: key_str = self._render_tags(key_str, key_tags) rows.append(tr([key_td(key_str), value_td(value_str)])) diff --git a/fiddle/_src/history_test.py b/fiddle/_src/history_test.py index 3caec264..3f4a15d6 100644 --- a/fiddle/_src/history_test.py +++ b/fiddle/_src/history_test.py @@ -71,7 +71,7 @@ def test_entry_deletion(self): def test_updating_tags(self): tag_set = {SampleTag, AdditionalTag} - entry = history.update_tags("z", tag_set) + entry = history.update_tags("z", tag_set) # pyrefly: ignore[bad-argument-type] self.assertEqual(entry.param_name, "z") self.assertEqual(entry.kind, history.ChangeKind.UPDATE_TAGS) self.assertIsNot(tag_set, entry.new_value) # Must not be the same! diff --git a/fiddle/_src/partial.py b/fiddle/_src/partial.py index b8d4b569..47154e99 100644 --- a/fiddle/_src/partial.py +++ b/fiddle/_src/partial.py @@ -135,7 +135,7 @@ def is_arg_factory(value): # If there are nested structures containing _BuiltArgFactory objects, # then promote them. - args = [_promote_arg_factory(arg) for arg in args] + args = [_promote_arg_factory(arg) for arg in args] # pyrefly: ignore[bad-assignment] kwargs = {name: _promote_arg_factory(arg) for name, arg in kwargs.items()} # Split the keyword args into those that should be handled by functools vs. diff --git a/fiddle/_src/selectors.py b/fiddle/_src/selectors.py index bc35f087..d9c450be 100644 --- a/fiddle/_src/selectors.py +++ b/fiddle/_src/selectors.py @@ -122,7 +122,7 @@ def _matches(self, node: config_lib.Buildable) -> bool: self.match_subclasses # and isinstance(self.fn_or_cls, type) # and isinstance(config_lib.get_callable(node), type) # - and issubclass(config_lib.get_callable(node), self.fn_or_cls)) + and issubclass(config_lib.get_callable(node), self.fn_or_cls)) # pyrefly: ignore[bad-argument-type] if not is_subclass: return False @@ -208,7 +208,7 @@ def __iter__(self) -> Iterator[Any]: if isinstance(value, config_lib.Buildable): for name, tags in value.__argument_tags__.items(): if any(issubclass(tag, self.tag) for tag in tags): - yield getattr(value, name, tagging.NO_VALUE) + yield getattr(value, name, tagging.NO_VALUE) # pyrefly: ignore[no-matching-overload] def replace(self, value: Any, deepcopy: bool = True) -> None: @@ -217,7 +217,7 @@ def replace(self, value: Any, deepcopy: bool = True) -> None: for name, tags in node_value.__argument_tags__.items(): if any(issubclass(tag, self.tag) for tag in tags): to_set = value if not deepcopy else copy.deepcopy(value) - setattr(node_value, name, to_set) + setattr(node_value, name, to_set) # pyrefly: ignore[bad-argument-type] def get(self, name: str) -> Iterator[Any]: raise NotImplementedError( diff --git a/fiddle/_src/signatures.py b/fiddle/_src/signatures.py index 0cfc7a8d..8fad2cb1 100644 --- a/fiddle/_src/signatures.py +++ b/fiddle/_src/signatures.py @@ -198,14 +198,14 @@ def signature_binding( # Use the index as key for positional only arguments if param.kind == param.POSITIONAL_ONLY: value = arguments.pop(param.name) - arguments[index] = value + arguments[index] = value # pyrefly: ignore[unsupported-operation] if param.kind == param.VAR_POSITIONAL: values = arguments.pop(param.name) for i, value in enumerate(values): - arguments[index + i] = value + arguments[index + i] = value # pyrefly: ignore[unsupported-operation] if param.kind == param.VAR_KEYWORD: arguments.update(arguments.pop(param.name)) - return arguments + return arguments # pyrefly: ignore[bad-return] def get_default(self, argument: Union[int, str], missing: Any) -> Any: """Get default value for the argument, return missing if not found. @@ -325,7 +325,7 @@ def replace_varargs_handle(self, key: Union[int, slice]) -> Any: if isinstance(key, slice): key = slice(replace_fn(key.start), replace_fn(key.stop), key.step) else: - key = replace_fn(key) + key = replace_fn(key) # pyrefly: ignore[bad-assignment] assert isinstance( key, (int, slice) ), f'Key must be an int or slice, got {key}.' @@ -351,7 +351,7 @@ def parameters(self) -> Mapping[str, inspect.Parameter]: @property def valid_param_names(self) -> Tuple[str]: - return tuple( + return tuple( # pyrefly: ignore[bad-return] name for name, param in self.signature.parameters.items() if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY) diff --git a/fiddle/_src/signatures_test.py b/fiddle/_src/signatures_test.py index 9e3da88a..b24d86ff 100644 --- a/fiddle/_src/signatures_test.py +++ b/fiddle/_src/signatures_test.py @@ -274,7 +274,7 @@ def test_replace_varargs_handle(self): self.signature_positional.replace_varargs_handle(slc), slice(3, -2) ) self.assertEqual( - self.signature_positional.replace_varargs_handle(signatures.VARARGS), 3 + self.signature_positional.replace_varargs_handle(signatures.VARARGS), 3 # pyrefly: ignore[bad-argument-type] ) def test_index_to_key(self): diff --git a/fiddle/_src/tag_type.py b/fiddle/_src/tag_type.py index 569d25ab..32c7a888 100644 --- a/fiddle/_src/tag_type.py +++ b/fiddle/_src/tag_type.py @@ -74,7 +74,7 @@ def __call__(cls, *args, **kwds): @property def description(cls) -> str: """A string describing the semantics and intended usecases for this tag.""" - return cls.__doc__ + return cls.__doc__ # pyrefly: ignore[bad-return] @property def name(cls) -> str: diff --git a/fiddle/_src/tagging.py b/fiddle/_src/tagging.py index 05688ce7..0f8a2457 100644 --- a/fiddle/_src/tagging.py +++ b/fiddle/_src/tagging.py @@ -104,7 +104,7 @@ def new(cls, default: Any = NO_VALUE) -> Any: Returns: A TaggedValue tagged with the tag `cls`. """ - return TaggedValue(tags=(cls,), default=default) + return TaggedValue(tags=(cls,), default=default) # pyrefly: ignore[bad-argument-type] if not typing.TYPE_CHECKING: new = auto_config.AutoConfigClassMethod( @@ -141,7 +141,7 @@ def TaggedValue( # pylint: disable=invalid-name result.value = default for tag in tags: add_tag(result, 'value', tag) - return result + return result # pyrefly: ignore[bad-return] def set_tagged(root: config.Buildable, *, tag: TagType, value: Any) -> None: @@ -157,7 +157,7 @@ def set_tagged(root: config.Buildable, *, tag: TagType, value: Any) -> None: if isinstance(node, config.Buildable): for key, tags in node.__argument_tags__.items(): if any(issubclass(t, tag) for t in tags): - setattr(node, key, value) + setattr(node, key, value) # pyrefly: ignore[bad-argument-type] def list_tags( diff --git a/fiddle/_src/tagging_test.py b/fiddle/_src/tagging_test.py index dc38cdfb..50cff179 100644 --- a/fiddle/_src/tagging_test.py +++ b/fiddle/_src/tagging_test.py @@ -503,7 +503,7 @@ def test_tagging_with_out_of_range_index(self): for fn in (fdl.add_tag, fdl.set_tags, fdl.remove_tag): with self.assertRaisesRegex(IndexError, ".*is out of range"): - fn(cfg, 3, Tag1) + fn(cfg, 3, Tag1) # pyrefly: ignore[bad-argument-type] for fn in (fdl.get_tags, fdl.clear_tags): with self.assertRaisesRegex(IndexError, ".*is out of range"): @@ -514,7 +514,7 @@ def test_tagging_with_negative_index(self): for fn in (fdl.add_tag, fdl.set_tags, fdl.remove_tag): with self.assertRaisesRegex(IndexError, "Cannot use negative index"): - fn(cfg, -1, Tag1) + fn(cfg, -1, Tag1) # pyrefly: ignore[bad-argument-type] for fn in (fdl.get_tags, fdl.clear_tags): with self.assertRaisesRegex(IndexError, "Cannot use negative index"):