Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fiddle/_src/arg_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class ArgFactory:
# of `ArgFactory(list)` is `ArgFactory`, but the declared type is `list`.
# (Without the overload, we'd get an annotation-type-mismatch error.)
@overload
def __new__(cls, factory: Callable[..., T]) -> T:
def __new__(cls, factory: Callable[..., T]) -> T: # pyrefly: ignore[invalid-overload]
...

def __new__(cls, /, *args, **kwargs):
Expand Down Expand Up @@ -406,7 +406,7 @@ def wrapper(*args, **kwargs):
# `default_factory` is `ArgFactory`, but the declared type is `list`.
# (Without the overload, we'd get an annotation-type-mismatch error.)
@overload
def default_factory(factory: Callable[..., T]) -> T:
def default_factory(factory: Callable[..., T]) -> T: # pyrefly: ignore[invalid-overload]
...


Expand Down
2 changes: 1 addition & 1 deletion fiddle/_src/building.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _build(value: Any, state: daglish.State) -> Any:
metadata: config_lib.BuildableTraverserMetadata = sub_traversal.metadata
arguments = metadata.arguments(sub_traversal.values)
is_built = True
return call_buildable(value, arguments, current_path=state.current_path)
return call_buildable(value, arguments, current_path=state.current_path) # pyrefly: ignore[bad-argument-type]
else:
return state.map_children(value)

Expand Down
2 changes: 1 addition & 1 deletion fiddle/_src/building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Bar:
x: int


class NonBuildableLoggingTest(absltest.TestCase, unittest.TestCase):
class NonBuildableLoggingTest(absltest.TestCase, unittest.TestCase): # pyrefly: ignore[inconsistent-inheritance]

def setUp(self):
super().setUp()
Expand Down
2 changes: 1 addition & 1 deletion fiddle/_src/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def cast(
src_type,
new_type,
)
return new_type.__unflatten__(*buildable.__flatten__())
return new_type.__unflatten__(*buildable.__flatten__()) # pyrefly: ignore[missing-attribute]


def register_supported_cast(src_type, dst_type):
Expand Down
26 changes: 13 additions & 13 deletions fiddle/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _buildable_flatten(
metadata = BuildableTraverserMetadata(
fn_or_cls=buildable.__fn_or_cls__,
argument_names=keys,
argument_tags=argument_tags,
argument_tags=argument_tags, # pyrefly: ignore[bad-argument-type]
argument_history=argument_history,
)
return values, metadata
Expand All @@ -88,7 +88,7 @@ def _buildable_path_elements(
buildable: Buildable, include_defaults: bool = False
) -> Tuple[daglish.PathElement]:
"""Implement Buildable.__path_elements__ method."""
return tuple(
return tuple( # pyrefly: ignore[bad-return]
daglish.Attr(name) if isinstance(name, str) else daglish.Index(name)
for name in ordered_arguments(
buildable, include_defaults=include_defaults
Expand Down Expand Up @@ -198,7 +198,7 @@ def without_history(self) -> BuildableTraverserMetadata:

def arguments(self, values: Iterable[Any]) -> Dict[str, Any]:
"""Returns a dictionary combining ``self.argument_names`` with ``values``."""
return dict(zip(self.argument_names, values))
return dict(zip(self.argument_names, values)) # pyrefly: ignore[bad-return]

def tags(self) -> Dict[str, set[tag_type.TagType]]:
return collections.defaultdict(
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(
f'Unexpected type received for the argument name: {key!r}'
)

for name, tags in tag_type.find_tags_from_annotations(fn_or_cls).items():
for name, tags in tag_type.find_tags_from_annotations(fn_or_cls).items(): # pyrefly: ignore[bad-argument-type]
self.__argument_tags__[name].update(tags)
self.__argument_history__.add_updated_tags(
name, self.__argument_tags__[name]
Expand Down Expand Up @@ -445,15 +445,15 @@ def __delitem__(self, key: Any):
new_placeholders = old_placeholders.copy()
# Traverse from largest index to maintain order of undeleted indices.
for index in indices[::-1]:
if index < var_positional_start:
if index < var_positional_start: # pyrefly: ignore[unsupported-operation]
k = self.__signature_info__.index_to_key(index, self.__arguments__)
if k in self.__arguments__:
self._arguments_del_value(k)
else:
del new_placeholders[index]

# Delete var-positional args and compact the *args list.
for index in range(var_positional_start, len(old_placeholders)):
for index in range(var_positional_start, len(old_placeholders)): # pyrefly: ignore[bad-argument-type]
if index < len(new_placeholders):
if new_placeholders[index] != old_placeholders[index]:
new_value = self.__arguments__[new_placeholders[index].index]
Expand All @@ -463,7 +463,7 @@ def __delitem__(self, key: Any):

def _set_item_by_index(self, key: int, value: Any):
"""Set positional arguments by index."""
key = self.__signature_info__.index_to_key(key, self.__arguments__)
key = self.__signature_info__.index_to_key(key, self.__arguments__) # pyrefly: ignore[bad-assignment]
positional_num = self.__signature_info__.var_positional_start
if positional_num is None:
# *args does not exist
Expand Down Expand Up @@ -555,7 +555,7 @@ def __dir__(self) -> Collection[str]:
set_argument_names = self.__arguments__.keys()
valid_param_names = set(self.__signature_info__.valid_param_names)
all_names = valid_param_names.union(set_argument_names)
return all_names
return all_names # pyrefly: ignore[bad-return]

# Buildable are mutable so do not make this `@functools.cached_property`
def _fn_or_cls_name_repr(self) -> str:
Expand All @@ -581,7 +581,7 @@ def _params_name_tags_and_value(
for name in param_names:
tags = self.__argument_tags__.get(name, set())
value = self.__arguments__.get(name, NO_VALUE)
yield name, tags, value
yield name, tags, value # pyrefly: ignore[invalid-yield]

def __repr__(self):
formatted_fn_or_cls = self._fn_or_cls_name_repr()
Expand Down Expand Up @@ -825,7 +825,7 @@ def tagged_value_fn(
if tags:
msg += ' Unset tags: ' + str(tags)
raise tag_type.TaggedValueNotFilledError(msg)
return value
return value # pyrefly: ignore[bad-return]


class TaggedValueCls(Generic[T], Config[T]):
Expand All @@ -851,7 +851,7 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T:
'Unexpected __fn_or_cls__ in TaggedValueCls; found:'
f'{self.__fn_or_cls__}'
)
return self.__fn_or_cls__(tags=self.tags, *args, **kwargs)
return self.__fn_or_cls__(tags=self.tags, *args, **kwargs) # pyrefly: ignore[unexpected-keyword]


def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
Expand Down Expand Up @@ -941,11 +941,11 @@ def ordered_arguments(

if include_var_keyword:
for name, value in buildable.__arguments__.items():
param = buildable.__signature_info__.parameters.get(name)
param = buildable.__signature_info__.parameters.get(name) # pyrefly: ignore[bad-argument-type]
if param is None or param.kind == param.VAR_KEYWORD:
result[name] = value

if not include_positional:
result = {k: v for k, v in result.items() if isinstance(k, str)}

return result
return result # pyrefly: ignore[bad-return]
4 changes: 2 additions & 2 deletions fiddle/_src/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def raise_error():

@dataclasses.dataclass
class GenericClass(Generic[_T]):
x: _T = 1
x: _T = 1 # pyrefly: ignore[bad-assignment]


class ConfigTest(parameterized.TestCase):
Expand Down Expand Up @@ -921,7 +921,7 @@ def test_history_tracking(self):
cfg.__argument_history__['arg1'][1].sequence_id)

def test_custom_location_history_tracking(self):
with history.custom_location(lambda: 'abc:123'):
with history.custom_location(lambda: 'abc:123'): # pyrefly: ignore[bad-argument-type]
cfg = fdl.Config(SampleClass, 'arg1')
cfg.arg2 = 'arg2'
self.assertEqual(
Expand Down
8 changes: 4 additions & 4 deletions fiddle/_src/daglish.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def follow(self, container: Union[List[Any], Tuple[Any, ...]]) -> Any:

def __lt__(self, other: PathElement) -> bool:
if type(self) is type(other):
return self.index < other.index
return self.index < other.index # pyrefly: ignore[missing-attribute]
else:
return super().__lt__(other)

Expand All @@ -80,7 +80,7 @@ def follow(self, container: Dict[Any, Any]) -> Any:

def __lt__(self, other: PathElement) -> bool:
if type(self) is type(other):
return self.key < other.key
return self.key < other.key # pyrefly: ignore[missing-attribute]
else:
return super().__lt__(other)

Expand All @@ -99,7 +99,7 @@ def follow(self, container: Any) -> Any:

def __lt__(self, other: PathElement) -> bool:
if type(self) is type(other):
return self.name < other.name
return self.name < other.name # pyrefly: ignore[missing-attribute]
else:
return super().__lt__(other)

Expand Down Expand Up @@ -273,7 +273,7 @@ def is_traversable_type(self, node_type: Type[Any]) -> bool:
dict,
flatten_fn=lambda x: (tuple(x.values()), tuple(x.keys())),
unflatten_fn=lambda values, keys: dict(zip(keys, values)),
path_elements_fn=lambda x: [Key(key) for key in x.keys()])
path_elements_fn=lambda x: [Key(key) for key in x.keys()]) # pyrefly: ignore[bad-argument-type]


def flatten_defaultdict(node):
Expand Down
10 changes: 5 additions & 5 deletions fiddle/_src/daglish_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def test_follow_path(self):
self.assertIs(daglish.follow_path(root, path5), root[2])

path6 = (daglish.Index(1), daglish.Key("a"))
self.assertIs(daglish.follow_path(root, path6), root[1]["a"])
self.assertIs(daglish.follow_path(root, path6), root[1]["a"]) # pyrefly: ignore[bad-index]

path7 = (daglish.Index(2), daglish.Index(2), daglish.BuildableFnOrCls())
self.assertIs(
daglish.follow_path(root, path7), fdl.get_callable(root[2][2]))
daglish.follow_path(root, path7), fdl.get_callable(root[2][2])) # pyrefly: ignore[bad-argument-type, bad-index]

bad_path_1 = (daglish.Key("a"), daglish.Key("b"))
with self.assertRaisesRegex(
Expand Down Expand Up @@ -313,15 +313,15 @@ def test_register_node_traverser_non_type_error(self):
cast(Any, 42),
flatten_fn=lambda x: (tuple(x), None),
unflatten_fn=lambda x, _: list(x),
path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x))))
path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x)))) # pyrefly: ignore[bad-argument-type]

def test_register_node_traverser_existing_registration_error(self):
with self.assertRaises(ValueError):
daglish.register_node_traverser(
list,
flatten_fn=lambda x: (tuple(x), None),
unflatten_fn=lambda x, _: list(x),
path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x))))
path_elements_fn=lambda x: (daglish.Index(i) for i in range(len(x)))) # pyrefly: ignore[bad-argument-type]

def test_node_traverser_registry_with_fallback(self):
registry = daglish.NodeTraverserRegistry(use_fallback=True)
Expand Down Expand Up @@ -674,7 +674,7 @@ def traverse(value, state: daglish.State):
return state.map_children(value)

x = [1, 2, 3]
x.append(x)
x.append(x) # pyrefly: ignore[bad-argument-type]
with self.assertRaisesRegex(
ValueError,
"Fiddle detected a cycle while traversing a value: "
Expand Down
20 changes: 10 additions & 10 deletions fiddle/_src/diffing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def ignoring_paths(self, paths=Iterable[daglish.Path]) -> 'Diff':
Returns:
A new `Diff` without changes that relate to the given `paths`.
"""
paths = set(paths)
paths = set(paths) # pyrefly: ignore[bad-argument-type]

def _ignore_fn(change: DiffOperation):
return any(
Expand All @@ -123,7 +123,7 @@ class SetValue(DiffOperation):

The target's parent may not be a sequence (list or tuple).
"""
target: daglish.Path
target: daglish.Path # pyrefly: ignore[bad-override]
new_value: Union[Reference, Any]

def apply(self, parent: Any, child: daglish.PathElement):
Expand All @@ -142,13 +142,13 @@ class ModifyValue(DiffOperation):

The target's parent may not be a tuple.
"""
target: daglish.Path
target: daglish.Path # pyrefly: ignore[bad-override]
new_value: Union[Reference, Any]

def apply(self, parent: Any, child: daglish.PathElement):
"""Replaces `child.follow(parent)` with self.new_value."""
if isinstance(child, daglish.BuildableFnOrCls):
mutate_buildable.update_callable(parent, self.new_value)
mutate_buildable.update_callable(parent, self.new_value) # pyrefly: ignore[bad-argument-type]
elif isinstance(child, daglish.Attr):
setattr(parent, child.name, self.new_value)
elif isinstance(child, daglish.Index):
Expand All @@ -165,7 +165,7 @@ class DeleteValue(DiffOperation):

The target's parent may not be a sequence (list or tuple).
"""
target: daglish.Path
target: daglish.Path # pyrefly: ignore[bad-override]

def apply(self, parent: Any, child: daglish.PathElement):
"""Deletes `child.follow(parent)`."""
Expand All @@ -183,7 +183,7 @@ class AddTag(DiffOperation):

The target's parent must be a `fdl.Buildable`.
"""
target: daglish.Path
target: daglish.Path # pyrefly: ignore[bad-override]
tag: tag_type.TagType

def apply(self, parent: Any, child: daglish.PathElement):
Expand All @@ -199,7 +199,7 @@ class RemoveTag(DiffOperation):

The target's parent must be a `fdl.Buildable`.
"""
target: daglish.Path
target: daglish.Path # pyrefly: ignore[bad-override]
tag: tag_type.TagType

def apply(self, parent: Any, child: daglish.PathElement):
Expand Down Expand Up @@ -656,7 +656,7 @@ def record_tag_diffs(self, old_path: daglish.Path,
empty_set = set([]) # Default value for dict.get.
tag_name = lambda tag: tag.__name__ # For sorting.
for arg_name in sorted(set(old_arg_tags) | set(new_arg_tags)):
target = old_path + (daglish.Attr(arg_name),)
target = old_path + (daglish.Attr(arg_name),) # pyrefly: ignore[bad-argument-type]
old_tags = old_arg_tags.get(arg_name, empty_set)
new_tags = new_arg_tags.get(arg_name, empty_set)
for removed_tag in sorted(old_tags - new_tags, key=tag_name):
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def add_reference_target(path, value):
if isinstance(change, RemoveTag):
tagging.add_tag(
daglish.follow_path(root, change.target[:-1]),
change.target[-1].name,
change.target[-1].name, # pyrefly: ignore[missing-attribute]
change.tag,
)
daglish_legacy.traverse_with_path(add_reference_target,
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def _add_path_to_skeleton(skeleton, path, skip_leaf=False):
raise ValueError(f'Unuspported PathElement {path[0]}')

# Recurse to the child element.
child = _add_path_to_skeleton(path[0].follow(skeleton), path[1:], skip_leaf)
child = _add_path_to_skeleton(path[0].follow(skeleton), path[1:], skip_leaf) # pyrefly: ignore[bad-argument-type]
if isinstance(path[0], daglish.Attr):
assert isinstance(skeleton, config_lib.Config)
setattr(skeleton, path[0].name, child)
Expand Down
16 changes: 8 additions & 8 deletions fiddle/_src/diffing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,8 @@ def test_resolve_ref_from_new_shared_value_to_new_shared_value(self):
diff_z = resolved_diff.changes[0]
self.assertEqual(diff_z.target, parse_path('.z'))
self.assertIsInstance(diff_z, diffing.SetValue)
self.assertIs(diff_z.new_value[0], resolved_diff.new_shared_values[0])
self.assertIs(diff_z.new_value[1], resolved_diff.new_shared_values[1])
self.assertIs(diff_z.new_value[0], resolved_diff.new_shared_values[0]) # pyrefly: ignore[bad-index]
self.assertIs(diff_z.new_value[1], resolved_diff.new_shared_values[1]) # pyrefly: ignore[bad-index]
self.assertIs(resolved_diff.new_shared_values[1][0],
resolved_diff.new_shared_values[0])

Expand Down Expand Up @@ -826,12 +826,12 @@ def test_resolve_diff_multiple_references(self):
diff_1_x = resolved_diff.changes[0]
self.assertEqual(diff_1_x.target, parse_path("[1]['x']"))
self.assertIsInstance(diff_1_x, diffing.ModifyValue)
self.assertIs(diff_1_x.new_value, old[1]['y'])
self.assertIs(diff_1_x.new_value, old[1]['y']) # pyrefly: ignore[bad-index]

diff_1_y = resolved_diff.changes[1]
self.assertEqual(diff_1_y.target, parse_path("[1]['y']"))
self.assertIsInstance(diff_1_y, diffing.ModifyValue)
self.assertIs(diff_1_y.new_value, old[1]['x'])
self.assertIs(diff_1_y.new_value, old[1]['x']) # pyrefly: ignore[bad-index]

diff_1_z = resolved_diff.changes[2]
self.assertEqual(diff_1_z.target, parse_path("[1]['z']"))
Expand Down Expand Up @@ -965,10 +965,10 @@ def test_apply_diff_with_multiple_references(self):

# Manually apply the same changes described by the diff:
new = copy.deepcopy(old)
new[1]['x'], new[1]['y'] = new[1]['y'], new[1]['x']
new[1]['z'] = new[2]
new[2].x = new[3]
new[2].z = [new[0], new[3]]
new[1]['x'], new[1]['y'] = new[1]['y'], new[1]['x'] # pyrefly: ignore[bad-index, unsupported-operation]
new[1]['z'] = new[2] # pyrefly: ignore[unsupported-operation]
new[2].x = new[3] # pyrefly: ignore[missing-attribute]
new[2].z = [new[0], new[3]] # pyrefly: ignore[missing-attribute]

diffing.apply_diff(cfg_diff, old)
self.assertEqual(old, new)
Expand Down
Loading
Loading