xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2import logging
3import math
4import os
5from collections import Counter, defaultdict
6from functools import partial
7from typing import Any, Dict, Generator, Iterable, Tuple
8
9import torch
10from torch.testing import make_tensor
11from torch.utils import _pytree as pytree
12from torch.utils._python_dispatch import TorchDispatchMode
13from torch.utils._pytree import tree_map
14
15
16log = logging.getLogger(__name__)
17
18OP_INP_DIRECTORY = os.path.join(os.path.dirname(__file__), "operator_inp_logs")
19
20TIMM_DIR = os.path.join(OP_INP_DIRECTORY, "timm_train")
21HF_DIR = os.path.join(OP_INP_DIRECTORY, "hf_train")
22TORCHBENCH_DIR = os.path.join(OP_INP_DIRECTORY, "torchbench_train")
23
24aten = torch.ops.aten
25tensor_type = torch._C.TensorType.get()
26
27dtype_abbrs = {
28    torch.bfloat16: "bf16",
29    torch.float64: "f64",
30    torch.float32: "f32",
31    torch.float16: "f16",
32    torch.complex32: "c32",
33    torch.complex64: "c64",
34    torch.complex128: "c128",
35    torch.int8: "i8",
36    torch.int16: "i16",
37    torch.int32: "i32",
38    torch.int64: "i64",
39    torch.bool: "b8",
40    torch.uint8: "u8",
41}
42
43dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}
44
45
46def truncate_inp(arg):
47    if arg in dtype_abbrs:
48        return dtype_abbrs[arg]
49    elif isinstance(arg, torch.device):
50        return arg.type
51    else:
52        return arg
53
54
55# Serialize Function Call
56class FuncCallWrapper:
57    def __init__(self, call, *args, **kwargs):
58        self.call = call
59        self.args = tree_map(truncate_inp, args)
60        self.kwargs = tree_map(truncate_inp, kwargs) if kwargs is not None else {}
61
62    def __repr__(self):
63        args = ", ".join([repr(arg) for arg in self.args])
64        kwargs = "".join(
65            [f", {str(key)}={value}" for key, value in self.kwargs.items()]
66        )
67        out = f"{self.call}({args}{kwargs})".strip('"')
68        # f strings introduce quotations we dont want
69        for key in dtype_abbrs_parsing:
70            out = out.replace(f"'{key}'", key)
71        return out
72
73
74def serialize_sparse_tensor(e):
75    if isinstance(e, torch._subclasses.FakeTensor):
76        return FuncCallWrapper("ST", list(e.shape), e.dtype, e.layout, e.is_coalesced())
77    else:
78        return FuncCallWrapper(
79            "ST", list(e.shape), e.dtype, e.layout, e.is_coalesced(), e._nnz()
80        )
81
82
83def deserialize_sparse_tensor(size, dtype, layout, is_coalesced, nnz=None):
84    raise NotImplementedError
85
86
87def deserialize_tensor(size, dtype, stride=None):
88    if stride is not None:
89        out = torch.empty_strided(size, stride, dtype=dtype)
90    else:
91        out = torch.empty(size, dtype=dtype)
92    try:
93        out.copy_(make_tensor(size, dtype=dtype, device="cpu"))
94    except Exception as e:
95        print(e)
96        return out
97    return out
98
99
100def serialize_tensor(e):
101    if not e.is_contiguous():
102        return FuncCallWrapper("T", list(e.shape), e.dtype, stride=e.stride())
103    else:
104        return FuncCallWrapper("T", list(e.shape), e.dtype)
105
106
107def serialize_torch_args(e):
108    if isinstance(e, torch.Tensor):
109        if e.is_sparse:
110            return serialize_sparse_tensor(e)
111        return serialize_tensor(e)
112    else:
113        return truncate_inp(e)
114
115
116def contains_tensor(elems):
117    for elem in pytree.tree_leaves(elems):
118        if isinstance(elem, torch.Tensor):
119            return True
120    return False
121
122
123def skip_args(elems):
124    for i in pytree.tree_leaves(elems):
125        # only shows up in constructors and ops like that
126        if isinstance(i, (torch.memory_format, torch.storage.UntypedStorage)):
127            return True
128    return False
129
130
131def contains_tensor_types(type):
132    return type.isSubtypeOf(tensor_type) or any(
133        contains_tensor_types(e) for e in type.containedTypes()
134    )
135
136
137@functools.lru_cache(None)
138def non_compute_operator(op):
139    schema = op._schema
140
141    # skip constructors
142    if not any(contains_tensor_types(arg.type) for arg in schema.arguments):
143        return True
144    if "_like" in op.name():
145        return True
146
147    # allow in place writes
148    if schema.is_mutable:
149        return False
150
151    tensor_inps = [arg for arg in schema.arguments if arg.type is tensor_type]
152    tensor_outputs = [ret for ret in schema.returns if ret.type is tensor_type]
153
154    # skip aliasing unless there are multiple outputs
155    if len(tensor_outputs) != 1:
156        return False
157
158    for inp in tensor_inps:
159        if inp.alias_info and tensor_outputs[0].alias_info:
160            if inp.alias_info.before_set.intersection(
161                tensor_outputs[0].alias_info.after_set
162            ):
163                return True
164
165    return False
166
167
168class OperatorInputsMode(TorchDispatchMode):
169    def __init__(self, func_db=None):
170        self.func_db = defaultdict(Counter) if func_db is None else func_db
171
172    def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
173        kwargs = kwargs if kwargs else {}
174        arg_meta, kwarg_meta = tree_map(serialize_torch_args, (args, kwargs))
175
176        out = func_overload(*args, **kwargs)
177
178        inps = (args, kwargs)
179        if contains_tensor(inps) and not skip_args(inps) and contains_tensor(out):
180            serialized_str = repr((arg_meta, kwarg_meta))
181            self.func_db[str(func_overload)][serialized_str] += 1
182
183        return out
184
185    def log_to_file(self, output_filename, *, skip_non_compute_operators=True):
186        sorted_operators = sorted(self.func_db.keys())
187        with open(output_filename, "w") as f:
188            for operator in sorted_operators:
189                if skip_non_compute_operators and non_compute_operator(eval(operator)):
190                    continue
191                f.write(f"Operator: {operator}\n")
192                operator_inputs = self.func_db[operator]
193                for inps, count in operator_inputs.items():
194                    f.write(f"cnt: {count}, ")
195                    # repr will add quotation marks around the dtype strings
196                    for dtype_abbr in dtype_abbrs.values():
197                        inps = inps.replace("'" + dtype_abbr + "'", dtype_abbr)
198                    f.write(inps)
199                    f.write("\n")
200
201
202def map_to_device(e, device):
203    if isinstance(e, torch.Tensor):
204        return e.to(device)
205    elif isinstance(e, torch.device):
206        return device
207    elif isinstance(e, str):
208        if e == "cuda" or e == "cpu":
209            return device.type
210    else:
211        return e
212
213
214def map_to_dtype(e, dtype):
215    if isinstance(e, torch.Tensor) and e.is_floating_point():
216        return e.to(dtype)
217    elif isinstance(e, torch.dtype):
218        return dtype
219    else:
220        return e
221
222
223def deserialize_args(inps):
224    inps = inps.strip().strip("'")
225    global_vals = {
226        "T": deserialize_tensor,
227        "ST": deserialize_sparse_tensor,
228        "th": torch,
229        "inf": math.inf,
230        "torch": torch,
231        **dtype_abbrs_parsing,
232    }
233    # f strings introduce quotations we dont want
234    for key in dtype_abbrs_parsing:
235        inps = inps.replace(f"'{key}'", key)
236    return eval(inps.strip().strip("'").strip('"'), global_vals)
237
238
239class OperatorInputsLoader:
240    def __init__(self, json_file_path):
241        self.operator_db = defaultdict(Counter)
242
243        with open(json_file_path) as f:
244            lines = f.readlines()
245
246        i = 0
247        while i < len(lines):
248            op_line = lines[i].strip("\n")
249            assert "Operator: " in op_line, op_line
250            operator = op_line[len("Operator: ") :]
251            operator = (
252                operator if operator != "aten.sum.SymInt" else "aten.sum.dim_IntList"
253            )
254            op_inps = Counter()
255            i += 1
256            while i < len(lines) and "Operator: " not in lines[i]:
257                line = lines[i]
258                cnt = eval(line[len("cnt: ") : line.find(",")])
259                inps = line[line.find(",") + 2 :].strip("'")
260                op_inps[inps] += cnt
261                i += 1
262            self.operator_db[operator] = op_inps
263
264    def get_inputs_for_operator(
265        self, operator, dtype=None, device="cuda"
266    ) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], None, None]:
267        assert (
268            str(operator) in self.operator_db
269        ), f"Could not find {operator}, must provide overload"
270
271        if "embedding" in str(operator):
272            log.warning("Embedding inputs NYI, input data cannot be randomized")
273            yield
274            return
275
276        # line[1] represents number of times these inputs occured, ignored for now
277        for line in self.operator_db[str(operator)].items():
278            inps = line[0]
279
280            args, kwargs = deserialize_args(inps)
281
282            # Backwards require some inputs to be float16 and some to be float32
283            # So we record on half and upcast to float when specified
284            if dtype and dtype != torch.float16:
285                to_dtype = partial(map_to_dtype, dtype=dtype)
286                args, kwargs = tree_map(to_dtype, (args, kwargs))
287
288            if device:
289                to_device = partial(map_to_device, device=torch.device(device))
290                args, kwargs = tree_map(to_device, (args, kwargs))
291
292            yield args, kwargs
293
294    def get_all_ops(self):
295        for key in self.operator_db.keys():
296            try:
297                op = eval(key)
298            except AttributeError as ae:
299                log.warning("Evaluating an op name into an OpOverload: %s", ae)
300                continue
301            yield op
302
303    def get_call_frequency(self, op):
304        assert (
305            str(op) in self.operator_db
306        ), f"Could not find {op}, must provide overload"
307
308        count = 0
309        for counter in self.operator_db[str(op)].values():
310            count += counter
311        return count
312
313    def merge(self, other):
314        for operator, counter_dict in other.operator_db.items():
315            for inps, cnt in counter_dict.items():
316                self.operator_db[operator][inps] += cnt
317
318    @staticmethod
319    def get_timm_loader():
320        return OperatorInputsLoader._load_directory(TIMM_DIR)
321
322    @staticmethod
323    def get_huggingface_loader():
324        return OperatorInputsLoader._load_directory(HF_DIR)
325
326    @staticmethod
327    def get_torchbench_loader():
328        return OperatorInputsLoader._load_directory(TORCHBENCH_DIR)
329
330    @staticmethod
331    def _load_directory(inp_dir):
332        assert os.path.isdir(inp_dir), inp_dir
333        union = None
334        for inp in os.listdir(inp_dir):
335            if inp[-4:] != ".txt":
336                continue
337            path = os.path.join(inp_dir, inp)
338            if union is None:
339                union = OperatorInputsLoader(path)
340            else:
341                union.merge(OperatorInputsLoader(path))
342        return union
343