xref: /aosp_15_r20/external/pytorch/benchmarks/operator_benchmark/benchmark_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import bisect
3import itertools
4import os
5import random
6
7import numpy as np
8
9
10"""Performance microbenchmarks's utils.
11
12This module contains utilities for writing microbenchmark tests.
13"""
14
15# Here are the reserved keywords in the benchmark suite
16_reserved_keywords = {"probs", "total_samples", "tags"}
17_supported_devices = {"cpu", "cuda"}
18
19
20def shape_to_string(shape):
21    return ", ".join([str(x) for x in shape])
22
23
24def str2bool(v):
25    if isinstance(v, bool):
26        return v
27    if v.lower() in ("yes", "true", "t", "y", "1"):
28        return True
29    elif v.lower() in ("no", "false", "f", "n", "0"):
30        return False
31    else:
32        raise argparse.ArgumentTypeError("Boolean value expected.")
33
34
35def numpy_random(dtype, *shapes):
36    """Return a random numpy tensor of the provided dtype.
37    Args:
38        shapes: int or a sequence of ints to defining the shapes of the tensor
39        dtype: use the dtypes from numpy
40            (https://docs.scipy.org/doc/numpy/user/basics.types.html)
41    Return:
42        numpy tensor of dtype
43    """
44    # TODO: consider more complex/custom dynamic ranges for
45    # comprehensive test coverage.
46    return np.random.rand(*shapes).astype(dtype)
47
48
49def set_omp_threads(num_threads):
50    existing_value = os.environ.get("OMP_NUM_THREADS", "")
51    if existing_value != "":
52        print(
53            f"Overwriting existing OMP_NUM_THREADS value: {existing_value}; Setting it to {num_threads}."
54        )
55    os.environ["OMP_NUM_THREADS"] = str(num_threads)
56
57
58def set_mkl_threads(num_threads):
59    existing_value = os.environ.get("MKL_NUM_THREADS", "")
60    if existing_value != "":
61        print(
62            f"Overwriting existing MKL_NUM_THREADS value: {existing_value}; Setting it to {num_threads}."
63        )
64    os.environ["MKL_NUM_THREADS"] = str(num_threads)
65
66
67def cross_product(*inputs):
68    """
69    Return a list of cartesian product of input iterables.
70    For example, cross_product(A, B) returns ((x,y) for x in A for y in B).
71    """
72    return list(itertools.product(*inputs))
73
74
75def get_n_rand_nums(min_val, max_val, n):
76    random.seed((1 << 32) - 1)
77    return random.sample(range(min_val, max_val), n)
78
79
80def generate_configs(**configs):
81    """
82    Given configs from users, we want to generate different combinations of
83    those configs
84    For example, given M = ((1, 2), N = (4, 5)) and sample_func being cross_product,
85    we will generate (({'M': 1}, {'N' : 4}),
86                      ({'M': 1}, {'N' : 5}),
87                      ({'M': 2}, {'N' : 4}),
88                      ({'M': 2}, {'N' : 5}))
89    """
90    assert "sample_func" in configs, "Missing sample_func to generate configs"
91    result = []
92    for key, values in configs.items():
93        if key == "sample_func":
94            continue
95        tmp_result = []
96        for value in values:
97            tmp_result.append({key: value})
98        result.append(tmp_result)
99
100    results = configs["sample_func"](*result)
101    return results
102
103
104def cross_product_configs(**configs):
105    """
106    Given configs from users, we want to generate different combinations of
107    those configs
108    For example, given M = ((1, 2), N = (4, 5)),
109    we will generate (({'M': 1}, {'N' : 4}),
110                      ({'M': 1}, {'N' : 5}),
111                      ({'M': 2}, {'N' : 4}),
112                      ({'M': 2}, {'N' : 5}))
113    """
114    _validate(configs)
115    configs_attrs_list = []
116    for key, values in configs.items():
117        tmp_results = [{key: value} for value in values]
118        configs_attrs_list.append(tmp_results)
119
120    # TODO(mingzhe0908) remove the conversion to list.
121    # itertools.product produces an iterator that produces element on the fly
122    # while converting to a list produces everything at the same time.
123    generated_configs = list(itertools.product(*configs_attrs_list))
124    return generated_configs
125
126
127def _validate(configs):
128    """Validate inputs from users."""
129    if "device" in configs:
130        for v in configs["device"]:
131            assert v in _supported_devices, "Device needs to be a string."
132
133
134def config_list(**configs):
135    """Generate configs based on the list of input shapes.
136    This function will take input shapes specified in a list from user. Besides
137    that, all other parameters will be cross producted first and each of the
138    generated list will be merged with the input shapes list.
139
140    Reserved Args:
141        attr_names(reserved): a list of names for input shapes.
142        attrs(reserved): a list of values for each input shape.
143        corss_product: a dictionary of attributes which will be
144                       cross producted with the input shapes.
145        tags(reserved): a tag used to filter inputs.
146
147    Here is an example:
148    attrs = [
149        [1, 2],
150        [4, 5],
151    ],
152    attr_names = ['M', 'N'],
153    cross_product_configs={
154        'device': ['cpu', 'cuda'],
155    },
156
157    we will generate [[{'M': 1}, {'N' : 2}, {'device' : 'cpu'}],
158                      [{'M': 1}, {'N' : 2}, {'device' : 'cuda'}],
159                      [{'M': 4}, {'N' : 5}, {'device' : 'cpu'}],
160                      [{'M': 4}, {'N' : 5}, {'device' : 'cuda'}]]
161    """
162    generated_configs = []
163    reserved_names = ["attrs", "attr_names", "tags"]
164    if any(attr not in configs for attr in reserved_names):
165        raise ValueError("Missing attrs in configs")
166
167    _validate(configs)
168
169    cross_configs = None
170    if "cross_product_configs" in configs:
171        cross_configs = cross_product_configs(**configs["cross_product_configs"])
172
173    for inputs in configs["attrs"]:
174        tmp_result = [
175            {configs["attr_names"][i]: input_value}
176            for i, input_value in enumerate(inputs)
177        ]
178        # TODO(mingzhe0908):
179        # If multiple 'tags' were provided, do they get concat?
180        # If a config has both ['short', 'medium'], it should match
181        # both 'short' and 'medium' tag-filter?
182        tmp_result.append({"tags": "_".join(configs["tags"])})
183        if cross_configs:
184            generated_configs += [tmp_result + list(config) for config in cross_configs]
185        else:
186            generated_configs.append(tmp_result)
187
188    return generated_configs
189
190
191def attr_probs(**probs):
192    """return the inputs in a dictionary"""
193    return probs
194
195
196class RandomSample:
197    def __init__(self, configs):
198        self.saved_cum_distribution = {}
199        self.configs = configs
200
201    def _distribution_func(self, key, weights):
202        """this is a cumulative distribution function used for random sampling inputs"""
203        if key in self.saved_cum_distribution:
204            return self.saved_cum_distribution[key]
205
206        total = sum(weights)
207        result = []
208        cumsum = 0
209        for w in weights:
210            cumsum += w
211            result.append(cumsum / total)
212        self.saved_cum_distribution[key] = result
213        return result
214
215    def _random_sample(self, key, values, weights):
216        """given values and weights, this function randomly sample values based their weights"""
217        # TODO(mingzhe09088): cache the results to avoid recalculation overhead
218        assert len(values) == len(weights)
219        _distribution_func_vals = self._distribution_func(key, weights)
220        x = random.random()
221        idx = bisect.bisect(_distribution_func_vals, x)
222
223        assert idx <= len(values), "Wrong index value is returned"
224        # Due to numerical property, the last value in cumsum could be slightly
225        # smaller than 1, and lead to the (index == len(values)).
226        if idx == len(values):
227            idx -= 1
228        return values[idx]
229
230    def get_one_set_of_inputs(self):
231        tmp_attr_list = []
232        for key, values in self.configs.items():
233            if key in _reserved_keywords:
234                continue
235            value = self._random_sample(key, values, self.configs["probs"][str(key)])
236            tmp_results = {key: value}
237            tmp_attr_list.append(tmp_results)
238        return tmp_attr_list
239
240
241def random_sample_configs(**configs):
242    """
243    This function randomly sample <total_samples> values from the given inputs based on
244    their weights.
245    Here is an example showing what are the expected inputs and outputs from this function:
246    M = [1, 2],
247    N = [4, 5],
248    K = [7, 8],
249    probs = attr_probs(
250        M = [0.7, 0.2],
251        N = [0.5, 0.2],
252        K = [0.6, 0.2],
253    ),
254    total_samples=10,
255    this function will generate
256    [
257        [{'K': 7}, {'M': 1}, {'N': 4}],
258        [{'K': 7}, {'M': 2}, {'N': 5}],
259        [{'K': 8}, {'M': 2}, {'N': 4}],
260        ...
261    ]
262    Note:
263    The probs is optional. Without them, it implies everything is 1. The probs doesn't
264    have to reflect the actual normalized probability, the implementation will
265    normalize it.
266    TODO (mingzhe09088):
267    (1):  a lambda that accepts or rejects a config as a sample. For example: for matmul
268    with M, N, and K, this function could get rid of (M * N * K > 1e8) to filter out
269    very slow benchmarks.
270    (2): Make sure each sample is unique. If the number of samples are larger than the
271    total combinations, just return the cross product. Otherwise, if the number of samples
272    is close to the number of cross-products, it is numerical safer to generate the list
273    that you don't want, and remove them.
274    """
275    if "probs" not in configs:
276        raise ValueError(
277            "probs is missing. Consider adding probs or using other config functions"
278        )
279
280    configs_attrs_list = []
281    randomsample = RandomSample(configs)
282    for i in range(configs["total_samples"]):
283        tmp_attr_list = randomsample.get_one_set_of_inputs()
284        tmp_attr_list.append({"tags": "_".join(configs["tags"])})
285        configs_attrs_list.append(tmp_attr_list)
286    return configs_attrs_list
287
288
289def op_list(**configs):
290    """Generate a list of ops organized in a specific format.
291    It takes two parameters which are "attr_names" and "attr".
292    attrs stores the name and function of operators.
293    Args:
294        configs: key-value pairs including the name and function of
295        operators. attrs and attr_names must be present in configs.
296    Return:
297        a sequence of dictionaries which stores the name and function
298        of ops in a specifal format
299    Example:
300    attrs = [
301        ["abs", torch.abs],
302        ["abs_", torch.abs_],
303    ]
304    attr_names = ["op_name", "op"].
305
306    With those two examples,
307    we will generate (({"op_name": "abs"}, {"op" : torch.abs}),
308                      ({"op_name": "abs_"}, {"op" : torch.abs_}))
309    """
310    generated_configs = []
311    if "attrs" not in configs:
312        raise ValueError("Missing attrs in configs")
313    for inputs in configs["attrs"]:
314        tmp_result = {
315            configs["attr_names"][i]: input_value
316            for i, input_value in enumerate(inputs)
317        }
318        generated_configs.append(tmp_result)
319    return generated_configs
320
321
322def get_operator_range(chars_range):
323    """Generates the characters from chars_range inclusive."""
324    if chars_range == "None" or chars_range is None:
325        return None
326
327    if all(item not in chars_range for item in [",", "-"]):
328        raise ValueError(
329            "The correct format for operator_range is "
330            "<start>-<end>, or <point>, <start>-<end>"
331        )
332
333    ops_start_chars_set = set()
334    ranges = chars_range.split(",")
335    for item in ranges:
336        if len(item) == 1:
337            ops_start_chars_set.add(item.lower())
338            continue
339        start, end = item.split("-")
340        ops_start_chars_set.update(
341            chr(c).lower() for c in range(ord(start), ord(end) + 1)
342        )
343    return ops_start_chars_set
344
345
346def process_arg_list(arg_list):
347    if arg_list == "None":
348        return None
349
350    return [fr.strip() for fr in arg_list.split(",") if len(fr.strip()) > 0]
351