xref: /aosp_15_r20/external/pytorch/torch/_strobelight/compile_time_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: disallow-untyped-defs
2
3import logging
4import os
5from datetime import datetime
6from socket import gethostname
7from typing import Any, Optional
8
9from torch._strobelight.cli_function_profiler import StrobelightCLIFunctionProfiler
10
11
12logger = logging.getLogger("strobelight_compile_time_profiler")
13
14console_handler = logging.StreamHandler()
15formatter = logging.Formatter(
16    "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
17)
18console_handler.setFormatter(formatter)
19
20logger.addHandler(console_handler)
21logger.setLevel(logging.INFO)
22logger.propagate = False
23
24
25class StrobelightCompileTimeProfiler:
26    success_profile_count: int = 0
27    failed_profile_count: int = 0
28    ignored_profile_runs: int = 0
29    inside_profile_compile_time: bool = False
30    enabled: bool = False
31    # A unique identifier that is used as the run_user_name in the strobelight profile to
32    # associate all compile time profiles together.
33    identifier: Optional[str] = None
34
35    current_phase: Optional[str] = None
36
37    profiler: Optional[Any] = None
38
39    max_stack_length: int = int(
40        os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 127)
41    )
42    max_profile_time: int = int(
43        os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30)
44    )
45    # Collect sample each x cycles.
46    sample_each: int = int(
47        float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
48    )
49
50    @classmethod
51    def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
52        if cls.enabled:
53            logger.info("compile time strobelight profiling already enabled")
54            return
55
56        logger.info("compile time strobelight profiling enabled")
57
58        if profiler_class is StrobelightCLIFunctionProfiler:
59            import shutil
60
61            if not shutil.which("strobeclient"):
62                logger.info(
63                    "strobeclient not found, cant enable compile time strobelight profiling, seems"
64                    "like you are not on a FB machine."
65                )
66                return
67
68        cls.enabled = True
69        cls._cls_init()
70        # profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler.
71        # we have pass different functionProfilerClass for meta-internal fbcode targets.
72        cls.profiler = profiler_class(
73            sample_each=cls.sample_each,
74            max_profile_duration_sec=cls.max_profile_time,
75            stack_max_len=cls.max_stack_length,
76            async_stack_max_len=cls.max_stack_length,
77            run_user_name="pt2-profiler/"
78            + os.environ.get("USER", os.environ.get("USERNAME", "")),
79            sample_tags={cls.identifier},
80        )
81
82    @classmethod
83    def _cls_init(cls) -> None:
84        cls.identifier = "{date}{pid}{hostname}".format(
85            date=datetime.now().strftime("%Y-%m-%d-%H:%M:%S"),
86            pid=os.getpid(),
87            hostname=gethostname(),
88        )
89
90        logger.info("Unique sample tag for this run is: %s", cls.identifier)
91        logger.info(
92            "You can use the following link to access the strobelight profile at the end of the run: %s",
93            (
94                "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experime"
95                "ntal%2Fon_demand&drillstate=%7B%22purposes%22%3A[]%2C%22end%22%3A%22no"
96                "w%22%2C%22start%22%3A%22-30%20days%22%2C%22filterMode%22%3A%22DEFAULT%"
97                "22%2C%22modifiers%22%3A[]%2C%22sampleCols%22%3A[]%2C%22cols%22%3A[%22n"
98                "amespace_id%22%2C%22namespace_process_id%22]%2C%22derivedCols%22%3A[]%"
99                "2C%22mappedCols%22%3A[]%2C%22enumCols%22%3A[]%2C%22return_remainder%22"
100                "%3Afalse%2C%22should_pivot%22%3Afalse%2C%22is_timeseries%22%3Afalse%2C"
101                "%22hideEmptyColumns%22%3Afalse%2C%22timezone%22%3A%22America%2FLos_Ang"
102                "eles%22%2C%22compare%22%3A%22none%22%2C%22samplingRatio%22%3A%221%22%2"
103                "C%22metric%22%3A%22count%22%2C%22aggregation_field%22%3A%22async_stack"
104                "_complete%22%2C%22top%22%3A10000%2C%22aggregateList%22%3A[]%2C%22param"
105                "_dimensions%22%3A[%7B%22dim%22%3A%22py_async_stack%22%2C%22op%22%3A%22"
106                "edge%22%2C%22param%22%3A%220%22%2C%22anchor%22%3A%220%22%7D]%2C%22orde"
107                "r%22%3A%22weight%22%2C%22order_desc%22%3Atrue%2C%22constraints%22%3A[["
108                "%7B%22column%22%3A%22sample_tags%22%2C%22op%22%3A%22all%22%2C%22value%"
109                f"22%3A[%22[%5C%22{cls.identifier}%5C%22]%22]%7D]]%2C%22c_constraints%22%3A[[]]%2C%22b"
110                "_constraints%22%3A[[]]%2C%22ignoreGroupByInComparison%22%3Afalse%7D&vi"
111                "ew=GraphProfilerView&&normalized=1712358002&pool=uber"
112            ),
113        )
114
115    @classmethod
116    def _log_stats(cls) -> None:
117        logger.info(
118            "%s strobelight success runs out of %s non-recursive compilation events.",
119            cls.success_profile_count,
120            cls.success_profile_count + cls.failed_profile_count,
121        )
122
123    # TODO use threadlevel meta data to tags to record phases.
124    @classmethod
125    def profile_compile_time(
126        cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
127    ) -> Any:
128        if not cls.enabled:
129            return func(*args, **kwargs)
130
131        if cls.profiler is None:
132            logger.error("profiler is not set")
133            return
134
135        if cls.inside_profile_compile_time:
136            cls.ignored_profile_runs += 1
137            logger.info(
138                "profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored",
139                phase_name,
140                cls.current_phase,
141            )
142            return func(*args, **kwargs)
143
144        cls.inside_profile_compile_time = True
145        cls.current_phase = phase_name
146
147        work_result = cls.profiler.profile(func, *args, **kwargs)
148
149        if cls.profiler.profile_result is not None:
150            cls.success_profile_count += 1
151        else:
152            cls.failed_profile_count += 1
153
154        cls._log_stats()
155        cls.inside_profile_compile_time = False
156        return work_result
157