Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 222 additions & 10 deletions api/analyzers/python/ts_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
logger = logging.getLogger(__name__)


# Maximum number of same-named method definitions the guarded name-based
# fallback will link a receiver-less ``obj.method()`` call to. Common method
# names (``get`` / ``add`` / ``run`` / ``close`` ...) are defined on dozens of
# classes; binding a call to all of them is a false-CALLS factory, so above this
# count we emit no edge at all.
_NAME_FALLBACK_MAX_CANDIDATES = 5


# ---------------------------------------------------------------------------
# Symbol table data model
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -393,26 +401,155 @@ def resolve(
file_path: Path,
project_root: Path,
node: Node,
) -> list[tuple[File, Node]]:
) -> list[tuple[File, Node, str]]:
"""Resolve ``node`` (an identifier or dotted attribute) to definitions.

Returns a list of ``(File, def_node)`` tuples matching the shape
produced by ``AbstractAnalyzer.resolve``.
Returns a list of ``(File, def_node, resolution)`` tuples. ``resolution``
is ``"static_exact"`` for direct lookups (module top-level, imports,
dotted-attribute walks, unique bare names, and ``self``/``cls`` methods)
and ``"static_name"`` for the guarded receiver-agnostic method fallback.
"""
self._ensure_built(files, project_root)
parts = _node_to_dotted_parts(node)
if not parts:
return []
current_module = self._path_to_module.get(file_path)
candidate_defs = self._lookup(current_module, parts)
out: list[tuple[File, Node]] = []
for d in candidate_defs:

call_expr = _call_function_expr(node)
if call_expr is not None:
results = self._resolve_call(current_module, call_expr, node)
else:
# Non-call reference (type annotations, base classes, ...). Exact
# resolution only: no receiver reconstruction and no name-based
# method fallback, so type edges stay precisely as before.
parts = _node_to_dotted_parts(node)
results = (
[(d, "static_exact") for d in self._lookup(current_module, parts)]
if parts
else []
)

out: list[tuple[File, Node, str]] = []
for d, resolution in results:
f = files.get(d.file_path)
if f is None:
continue
out.append((f, d.node))
out.append((f, d.node, resolution))
return out

def _resolve_call(
self,
current_module: Optional[str],
call_expr: Node,
site_node: Node,
) -> list[tuple[_Definition, str]]:
"""Resolve a call's function expression to ``(_Definition, resolution)``."""
parts = _node_to_dotted_parts(call_expr)
if not parts:
return []
head = parts[0]
method_name = parts[-1]

# 1. ``self.method()`` / ``cls.method()`` -> the *enclosing* class's own
# method (exact). Inherited methods aren't found locally and fall through
# to the guarded name fallback below.
if len(parts) >= 2 and head in ("self", "cls"):
method_def = self._resolve_self_method(
current_module, call_expr, site_node, head, method_name
)
if method_def is not None:
return [(method_def, "static_exact")]

# 2. Exact resolution: module top-level, imports, dotted walk, or a
# unique cross-module bare name.
exact = self._lookup(current_module, parts)
if exact:
return [(d, "static_exact") for d in exact]

# 3. Guarded name-based method fallback for ``receiver.method()`` whose
# receiver type can't be resolved statically (e.g. unannotated params).
# Restrict to a *simple identifier* receiver: chained calls, subscripts,
# and dotted receivers (``a.b.c()``, ``x.y().z()``) give no head we can
# trust, so we never guess a method binding for them.
if call_expr.type != "attribute":
return []
receiver = call_expr.child_by_field_name("object")
if receiver is None or receiver.type != "identifier" or head == "super":
return []
# Only when the receiver head is genuinely unknown -- not an import
# alias, project top-level symbol, or module prefix. ``self``/``cls``
# reach here only as inherited-method fallthrough and always qualify.
if head not in ("self", "cls") and self._head_is_resolvable(current_module, head):
return []
return [(d, "static_name") for d in self._name_fallback(method_name)]

def _resolve_self_method(
self,
current_module: Optional[str],
call_expr: Node,
site_node: Node,
receiver: str,
method_name: str,
) -> Optional[_Definition]:
"""Resolve ``self.method``/``cls.method`` against the enclosing class.

Guards against binding when ``self``/``cls`` isn't actually the method
receiver in scope -- e.g. a ``@staticmethod`` referencing ``self`` or a
nested function that redefines the first parameter.
"""
if call_expr.type != "attribute":
return None
recv_node = call_expr.child_by_field_name("object")
if recv_node is None or recv_node.type != "identifier":
return None
if not current_module or current_module not in self._modules:
return None
class_node = _enclosing_class_node(site_node)
if class_node is None:
return None
# The function immediately enclosing the call site must take ``receiver``
# as its first parameter for the binding to be sound.
func = _enclosing_function_node(site_node, class_node)
if func is None or _first_parameter_name(func) != receiver:
return None
name_node = class_node.child_by_field_name("name")
if name_node is None:
return None
class_name = name_node.text.decode("utf-8")
mi = self._modules[current_module]
return mi.class_methods.get(class_name, {}).get(method_name)

def _head_is_resolvable(self, current_module: Optional[str], head: str) -> bool:
"""Whether the receiver ``head`` names something we can place statically.

Distinguishes "head not found at all" (an unknown/local receiver such as
an unannotated parameter, which qualifies for the name fallback) from
"head is a known symbol whose tail walk merely failed" (does not).
"""
if current_module and current_module in self._modules:
mi = self._modules[current_module]
if head in mi.imports or head in mi.top_level:
return True
if head in self._modules:
return True
# A project-wide top-level definition (class/func/var) by this name.
for d in self._by_name.get(head, ()):
if d.kind != "method":
return True
return False

def _name_fallback(self, method_name: str) -> list[_Definition]:
"""Return method defs named ``method_name`` (guarded by a count cap)."""
candidates = [d for d in self._by_name.get(method_name, ()) if d.kind == "method"]
if not candidates:
return []
if len(candidates) > _NAME_FALLBACK_MAX_CANDIDATES:
logger.debug(
"ts_resolver: skipping name fallback for %r (%d candidates > %d)",
method_name, len(candidates), _NAME_FALLBACK_MAX_CANDIDATES,
)
return []
# Deterministic ordering so edge writes are stable across runs/workers.
candidates.sort(key=lambda d: (str(d.file_path), d.node.start_byte))
return candidates

def _lookup(self, current_module: Optional[str], parts: list[str]) -> list[_Definition]:
if not parts:
return []
Expand Down Expand Up @@ -530,6 +667,81 @@ def _strip_decorator(def_node: Node) -> Node:
return def_node


def _enclosing_class_node(node: Node) -> Optional[Node]:
"""Return the nearest enclosing ``class_definition`` ancestor of ``node``.

Used to bind ``self``/``cls`` to the class that owns the containing method.
Walking strictly upward keeps nested classes/functions correct: the call
site's first ``class_definition`` ancestor is the class whose ``self`` is in
scope.
"""
cur = node.parent
while cur is not None:
if cur.type == "class_definition":
return cur
cur = cur.parent
return None


def _call_function_expr(node: Node) -> Optional[Node]:
"""If ``node`` sits in a call's *function* position, return that expression.

Handles both shapes reaching the resolver: the full call-function node
(``Foo.bar`` / ``helper``) and the bare method-name identifier (``bar``)
that ``PythonAnalyzer._extract_call_target`` produces for ``recv.method()``
in production. Returns ``None`` for non-call references (type annotations,
base classes, plain name reads) so they keep exact-only resolution.
"""
parent = node.parent
if parent is None:
return None
if parent.type == "call" and parent.child_by_field_name("function") == node:
return node
if parent.type == "attribute" and parent.child_by_field_name("attribute") == node:
grand = parent.parent
if (
grand is not None
and grand.type == "call"
and grand.child_by_field_name("function") == parent
):
return parent
return None


def _enclosing_function_node(node: Node, stop_at: Node) -> Optional[Node]:
"""Nearest ``function_definition`` ancestor of ``node`` below ``stop_at``.

``stop_at`` is the owning class; the search never crosses it so a method's
own ``function_definition`` (not the class) is returned.
"""
cur = node.parent
while cur is not None and cur != stop_at:
if cur.type == "function_definition":
return cur
cur = cur.parent
return None


def _first_parameter_name(func_node: Node) -> Optional[str]:
"""Return the textual name of a function's first positional parameter."""
params = func_node.child_by_field_name("parameters")
if params is None:
return None
for child in params.named_children:
if child.type == "identifier":
return child.text.decode("utf-8")
# ``self: Foo`` / ``self=...`` -- unwrap to the bound identifier.
if child.type in ("typed_parameter", "default_parameter", "typed_default_parameter"):
inner = child.child_by_field_name("name")
if inner is None:
inner = next(
(c for c in child.named_children if c.type == "identifier"), None
)
return inner.text.decode("utf-8") if inner is not None else None
return None
return None


def _node_to_dotted_parts(node: Node) -> list[str]:
"""Reduce a tree-sitter Python expression to its dotted name parts.

Expand Down
11 changes: 9 additions & 2 deletions api/analyzers/source_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,22 @@ def _resolve_file(file_path: Path) -> Path:
file = self.files[file_path]
for _, entity in file.entities.items():
for key, resolved_set in entity.resolved_symbols.items():
for resolved in resolved_set:
# Deterministic order so edge writes are stable across
# worker counts (Phase B promises bit-identical output).
for resolved, resolution in sorted(
resolved_set.items(), key=lambda kv: kv[0].id
):
if key == "base_class":
graph.connect_entities("EXTENDS", entity.id, resolved.id)
elif key == "implement_interface":
graph.connect_entities("IMPLEMENTS", entity.id, resolved.id)
elif key == "extend_interface":
graph.connect_entities("EXTENDS", entity.id, resolved.id)
elif key == "call":
graph.connect_entities("CALLS", entity.id, resolved.id)
graph.connect_entities(
"CALLS", entity.id, resolved.id,
{"resolution": resolution},
)
elif key == "return_type":
graph.connect_entities("RETURNS", entity.id, resolved.id)
elif key == "parameters":
Expand Down
29 changes: 22 additions & 7 deletions api/analyzers/tree_sitter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,13 @@ def resolve_symbol(
path: Path,
key: str,
symbol: Node,
) -> list[Entity]:
"""Dispatch a captured symbol to type or callable resolution."""
) -> list:
"""Dispatch a captured symbol to type or callable resolution.

Returns bare ``Entity`` objects for type resolution and
``(Entity, resolution)`` tuples for callable resolution; callers
normalize both shapes (see ``Entity.resolved_symbol``).
"""
if key in self.type_resolution_keys:
return self.resolve_type(files, lsp, file_path, path, symbol)
if key in self.method_resolution_keys:
Expand Down Expand Up @@ -79,7 +84,10 @@ def resolve_type(
target = self._extract_type_target(node)
if target is None:
return res
for file, resolved_node in self.resolve(files, lsp, file_path, path, target):
# ``resolve`` may yield 2-tuples (LSP/jedi) or 3-tuples (static
# resolver, carrying a resolution kind). Type edges ignore the
# resolution kind, so unpack tolerantly.
for file, resolved_node, *_ in self.resolve(files, lsp, file_path, path, target):
type_dec = self.find_parent(resolved_node, self.type_definition_node_types)
if type_dec in file.entities:
res.append(file.entities[type_dec])
Expand All @@ -92,16 +100,23 @@ def resolve_method(
file_path: Path,
path: Path,
node: Node,
) -> list[Entity]:
"""Resolve a call reference to matching callable-definition entities."""
) -> list[tuple[Entity, str]]:
"""Resolve a call reference to matching callable-definition entities.

Returns ``(entity, resolution)`` pairs. ``resolution`` is the kind
reported by the resolver (``static_exact`` / ``static_name`` for the
static tree-sitter resolver) and defaults to ``"lsp"`` for resolvers
that yield bare ``(file, node)`` pairs.
"""
res = []
target = self._extract_call_target(node)
if target is None:
return res
for file, resolved_node in self.resolve(files, lsp, file_path, path, target):
for file, resolved_node, *rest in self.resolve(files, lsp, file_path, path, target):
resolution = rest[0] if rest else "lsp"
method_dec = self.find_parent(resolved_node, self.callable_definition_node_types)
if method_dec and method_dec.type in self.callable_exclude_node_types:
continue
if method_dec in file.entities:
res.append(file.entities[method_dec])
res.append((file.entities[method_dec], resolution))
return res
Loading
Loading