xref: /aosp_15_r20/external/pytorch/torch/package/analyze/trace_dependencies.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import sys
3from typing import Any, Callable, Iterable, List, Tuple
4
5
6__all__ = ["trace_dependencies"]
7
8
9def trace_dependencies(
10    callable: Callable[[Any], Any], inputs: Iterable[Tuple[Any, ...]]
11) -> List[str]:
12    """Trace the execution of a callable in order to determine which modules it uses.
13
14    Args:
15        callable: The callable to execute and trace.
16        inputs: The input to use during tracing. The modules used by 'callable' when invoked by each set of inputs
17            are union-ed to determine all modules used by the callable for the purpooses of packaging.
18
19    Returns: A list of the names of all modules used during callable execution.
20    """
21    modules_used = set()
22
23    def record_used_modules(frame, event, arg):
24        # If the event being profiled is not a Python function
25        # call, there is nothing to do.
26        if event != "call":
27            return
28
29        # This is the name of the function that was called.
30        name = frame.f_code.co_name
31        module = None
32
33        # Try to determine the name of the module that the function
34        # is in:
35        #   1) Check the global namespace of the frame.
36        #   2) Check the local namespace of the frame.
37        #   3) To handle class instance method calls, check
38        #       the attribute named 'name' of the object
39        #       in the local namespace corresponding to "self".
40        if name in frame.f_globals:
41            module = frame.f_globals[name].__module__
42        elif name in frame.f_locals:
43            module = frame.f_locals[name].__module__
44        elif "self" in frame.f_locals:
45            method = getattr(frame.f_locals["self"], name, None)
46            module = method.__module__ if method else None
47
48        # If a module was found, add it to the set of used modules.
49        if module:
50            modules_used.add(module)
51
52    try:
53        # Attach record_used_modules as the profiler function.
54        sys.setprofile(record_used_modules)
55
56        # Execute the callable with all inputs.
57        for inp in inputs:
58            callable(*inp)
59
60    finally:
61        # Detach the profiler function.
62        sys.setprofile(None)
63
64    return list(modules_used)
65