xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/utils/fuzzer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools as it
4from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
6import torch
7
8
9__all__ = [
10    "Fuzzer",
11    "FuzzedParameter", "ParameterAlias",
12    "FuzzedTensor",
13]
14
15
16_DISTRIBUTIONS = (
17    "loguniform",
18    "uniform",
19)
20
21
22class FuzzedParameter:
23    """Specification for a parameter to be generated during fuzzing."""
24    def __init__(
25        self,
26        name: str,
27        minval: Optional[Union[int, float]] = None,
28        maxval: Optional[Union[int, float]] = None,
29        distribution: Optional[Union[str, Dict[Any, float]]] = None,
30        strict: bool = False,
31    ):
32        """
33        Args:
34            name:
35                A string name with which to identify the parameter.
36                FuzzedTensors can reference this string in their
37                specifications.
38            minval:
39                The lower bound for the generated value. See the description
40                of `distribution` for type behavior.
41            maxval:
42                The upper bound for the generated value. Type behavior is
43                identical to `minval`.
44            distribution:
45                Specifies the distribution from which this parameter should
46                be drawn. There are three possibilities:
47                    - "loguniform"
48                        Samples between `minval` and `maxval` (inclusive) such
49                        that the probabilities are uniform in log space. As a
50                        concrete example, if minval=1 and maxval=100, a sample
51                        is as likely to fall in [1, 10) as it is [10, 100].
52                    - "uniform"
53                        Samples are chosen with uniform probability between
54                        `minval` and `maxval` (inclusive). If either `minval`
55                        or `maxval` is a float then the distribution is the
56                        continuous uniform distribution; otherwise samples
57                        are constrained to the integers.
58                    - dict:
59                        If a dict is passed, the keys are taken to be choices
60                        for the variables and the values are interpreted as
61                        probabilities. (And must sum to one.)
62                If a dict is passed, `minval` and `maxval` must not be set.
63                Otherwise, they must be set.
64            strict:
65                If a parameter is strict, it will not be included in the
66                iterative resampling process which Fuzzer uses to find a
67                valid parameter configuration. This allows an author to
68                prevent skew from resampling for a given parameter (for
69                instance, a low size limit could inadvertently bias towards
70                Tensors with fewer dimensions) at the cost of more iterations
71                when generating parameters.
72        """
73        self._name = name
74        self._minval = minval
75        self._maxval = maxval
76        self._distribution = self._check_distribution(distribution)
77        self.strict = strict
78
79    @property
80    def name(self):
81        return self._name
82
83    def sample(self, state):
84        if self._distribution == "loguniform":
85            return self._loguniform(state)
86
87        if self._distribution == "uniform":
88            return self._uniform(state)
89
90        if isinstance(self._distribution, dict):
91            return self._custom_distribution(state)
92
93    def _check_distribution(self, distribution):
94        if not isinstance(distribution, dict):
95            assert distribution in _DISTRIBUTIONS
96        else:
97            assert not any(i < 0 for i in distribution.values()), "Probabilities cannot be negative"
98            assert abs(sum(distribution.values()) - 1) <= 1e-5, "Distribution is not normalized"
99            assert self._minval is None
100            assert self._maxval is None
101
102        return distribution
103
104    def _loguniform(self, state):
105        import numpy as np
106        output = int(2 ** state.uniform(
107            low=np.log2(self._minval) if self._minval is not None else None,
108            high=np.log2(self._maxval) if self._maxval is not None else None,
109        ))
110        if self._minval is not None and output < self._minval:
111            return self._minval
112        if self._maxval is not None and output > self._maxval:
113            return self._maxval
114        return output
115
116    def _uniform(self, state):
117        if isinstance(self._minval, int) and isinstance(self._maxval, int):
118            return int(state.randint(low=self._minval, high=self._maxval + 1))
119        return state.uniform(low=self._minval, high=self._maxval)
120
121    def _custom_distribution(self, state):
122        import numpy as np
123        # If we directly pass the keys to `choice`, numpy will convert
124        # them to numpy dtypes.
125        index = state.choice(
126            np.arange(len(self._distribution)),
127            p=tuple(self._distribution.values()))
128        return list(self._distribution.keys())[index]
129
130
131class ParameterAlias:
132    """Indicates that a parameter should alias the value of another parameter.
133
134    When used in conjunction with a custom distribution, this allows fuzzed
135    tensors to represent a broader range of behaviors. For example, the
136    following sometimes produces Tensors which broadcast:
137
138    Fuzzer(
139        parameters=[
140            FuzzedParameter("x_len", 4, 1024, distribution="uniform"),
141
142            # `y` will either be size one, or match the size of `x`.
143            FuzzedParameter("y_len", distribution={
144                0.5: 1,
145                0.5: ParameterAlias("x_len")
146            }),
147        ],
148        tensors=[
149            FuzzedTensor("x", size=("x_len",)),
150            FuzzedTensor("y", size=("y_len",)),
151        ],
152    )
153
154    Chains of alias' are allowed, but may not contain cycles.
155    """
156    def __init__(self, alias_to):
157        self.alias_to = alias_to
158
159    def __repr__(self):
160        return f"ParameterAlias[alias_to: {self.alias_to}]"
161
162
163def dtype_size(dtype):
164    if dtype == torch.bool:
165        return 1
166    if dtype.is_floating_point or dtype.is_complex:
167        return int(torch.finfo(dtype).bits / 8)
168    return int(torch.iinfo(dtype).bits / 8)
169
170
171def prod(values, base=1):
172    """np.prod can overflow, so for sizes the product should be done in Python.
173
174    Even though np.prod type promotes to int64, it can still overflow in which
175    case the negative value will pass the size check and OOM when attempting to
176    actually allocate the Tensor.
177    """
178    return functools.reduce(lambda x, y: int(x) * int(y), values, base)
179
180
181class FuzzedTensor:
182    def __init__(
183        self,
184        name: str,
185        size: Tuple[Union[str, int], ...],
186        steps: Optional[Tuple[Union[str, int], ...]] = None,
187        probability_contiguous: float = 0.5,
188        min_elements: Optional[int] = None,
189        max_elements: Optional[int] = None,
190        max_allocation_bytes: Optional[int] = None,
191        dim_parameter: Optional[str] = None,
192        roll_parameter: Optional[str] = None,
193        dtype=torch.float32,
194        cuda=False,
195        tensor_constructor: Optional[Callable] = None
196    ):
197        """
198        Args:
199            name:
200                A string identifier for the generated Tensor.
201            size:
202                A tuple of integers or strings specifying the size of the generated
203                Tensor. String values will replaced with a concrete int during the
204                generation process, while ints are simply passed as literals.
205            steps:
206                An optional tuple with the same length as `size`. This indicates
207                that a larger Tensor should be allocated, and then sliced to
208                produce the generated Tensor. For instance, if size is (4, 8)
209                and steps is (1, 4), then a tensor `t` of size (4, 32) will be
210                created and then `t[:, ::4]` will be used. (Allowing one to test
211                Tensors with strided memory.)
212            probability_contiguous:
213                A number between zero and one representing the chance that the
214                generated Tensor has a contiguous memory layout. This is achieved by
215                randomly permuting the shape of a Tensor, calling `.contiguous()`,
216                and then permuting back. This is applied before `steps`, which can
217                also cause a Tensor to be non-contiguous.
218            min_elements:
219                The minimum number of parameters that this Tensor must have for a
220                set of parameters to be valid. (Otherwise they are resampled.)
221            max_elements:
222                Like `min_elements`, but setting an upper bound.
223            max_allocation_bytes:
224                Like `max_elements`, but for the size of Tensor that must be
225                allocated prior to slicing for `steps` (if applicable). For
226                example, a FloatTensor with size (1024, 1024) and steps (4, 4)
227                would have 1M elements, but would require a 64 MB allocation.
228            dim_parameter:
229                The length of `size` and `steps` will be truncated to this value.
230                This allows Tensors of varying dimensions to be generated by the
231                Fuzzer.
232            dtype:
233                The PyTorch dtype of the generated Tensor.
234            cuda:
235                Whether to place the Tensor on a GPU.
236            tensor_constructor:
237                Callable which will be used instead of the default Tensor
238                construction method. This allows the author to enforce properties
239                of the Tensor (e.g. it can only have certain values). The dtype and
240                concrete shape of the Tensor to be created will be passed, and
241                concrete values of all parameters will be passed as kwargs. Note
242                that transformations to the result (permuting, slicing) will be
243                performed by the Fuzzer; the tensor_constructor is only responsible
244                for creating an appropriately sized Tensor.
245        """
246        self._name = name
247        self._size = size
248        self._steps = steps
249        self._probability_contiguous = probability_contiguous
250        self._min_elements = min_elements
251        self._max_elements = max_elements
252        self._max_allocation_bytes = max_allocation_bytes
253        self._dim_parameter = dim_parameter
254        self._dtype = dtype
255        self._cuda = cuda
256        self._tensor_constructor = tensor_constructor
257
258    @property
259    def name(self):
260        return self._name
261
262    @staticmethod
263    def default_tensor_constructor(size, dtype, **kwargs):
264        if dtype.is_floating_point or dtype.is_complex:
265            return torch.rand(size=size, dtype=dtype, device="cpu")
266        else:
267            return torch.randint(1, 127, size=size, dtype=dtype, device="cpu")
268
269    def _make_tensor(self, params, state):
270        import numpy as np
271        size, steps, allocation_size = self._get_size_and_steps(params)
272        constructor = (
273            self._tensor_constructor or
274            self.default_tensor_constructor
275        )
276
277        raw_tensor = constructor(size=allocation_size, dtype=self._dtype, **params)
278        if self._cuda:
279            raw_tensor = raw_tensor.cuda()
280
281        # Randomly permute the Tensor and call `.contiguous()` to force re-ordering
282        # of the memory, and then permute it back to the original shape.
283        dim = len(size)
284        order = np.arange(dim)
285        if state.rand() > self._probability_contiguous:
286            while dim > 1 and np.all(order == np.arange(dim)):
287                order = state.permutation(raw_tensor.dim())
288
289            raw_tensor = raw_tensor.permute(tuple(order)).contiguous()
290            raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
291
292        slices = [slice(0, size * step, step) for size, step in zip(size, steps)]
293        tensor = raw_tensor[slices]
294
295        properties = {
296            "numel": int(tensor.numel()),
297            "order": order,
298            "steps": steps,
299            "is_contiguous": tensor.is_contiguous(),
300            "dtype": str(self._dtype),
301        }
302
303        return tensor, properties
304
305    def _get_size_and_steps(self, params):
306        dim = (
307            params[self._dim_parameter]
308            if self._dim_parameter is not None
309            else len(self._size)
310        )
311
312        def resolve(values, dim):
313            """Resolve values into concrete integers."""
314            values = tuple(params.get(i, i) for i in values)
315            if len(values) > dim:
316                values = values[:dim]
317            if len(values) < dim:
318                values = values + tuple(1 for _ in range(dim - len(values)))
319            return values
320
321        size = resolve(self._size, dim)
322        steps = resolve(self._steps or (), dim)
323        allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps))
324        return size, steps, allocation_size
325
326    def satisfies_constraints(self, params):
327        size, _, allocation_size = self._get_size_and_steps(params)
328        # Product is computed in Python to avoid integer overflow.
329        num_elements = prod(size)
330        assert num_elements >= 0
331
332        allocation_bytes = prod(allocation_size, base=dtype_size(self._dtype))
333
334        def nullable_greater(left, right):
335            if left is None or right is None:
336                return False
337            return left > right
338
339        return not any((
340            nullable_greater(num_elements, self._max_elements),
341            nullable_greater(self._min_elements, num_elements),
342            nullable_greater(allocation_bytes, self._max_allocation_bytes),
343        ))
344
345
346class Fuzzer:
347    def __init__(
348        self,
349        parameters: List[Union[FuzzedParameter, List[FuzzedParameter]]],
350        tensors: List[Union[FuzzedTensor, List[FuzzedTensor]]],
351        constraints: Optional[List[Callable]] = None,
352        seed: Optional[int] = None
353    ):
354        """
355        Args:
356            parameters:
357                List of FuzzedParameters which provide specifications
358                for generated parameters. Iterable elements will be
359                unpacked, though arbitrary nested structures will not.
360            tensors:
361                List of FuzzedTensors which define the Tensors which
362                will be created each step based on the parameters for
363                that step. Iterable elements will be unpacked, though
364                arbitrary nested structures will not.
365            constraints:
366                List of callables. They will be called with params
367                as kwargs, and if any of them return False the current
368                set of parameters will be rejected.
369            seed:
370                Seed for the RandomState used by the Fuzzer. This will
371                also be used to set the PyTorch random seed so that random
372                ops will create reproducible Tensors.
373        """
374        import numpy as np
375        if seed is None:
376            seed = np.random.RandomState().randint(0, 2 ** 32 - 1, dtype=np.int64)
377        self._seed = seed
378        self._parameters = Fuzzer._unpack(parameters, FuzzedParameter)
379        self._tensors = Fuzzer._unpack(tensors, FuzzedTensor)
380        self._constraints = constraints or ()
381
382        p_names = {p.name for p in self._parameters}
383        t_names = {t.name for t in self._tensors}
384        name_overlap = p_names.intersection(t_names)
385        if name_overlap:
386            raise ValueError(f"Duplicate names in parameters and tensors: {name_overlap}")
387
388        self._rejections = 0
389        self._total_generated = 0
390
391    @staticmethod
392    def _unpack(values, cls):
393        return tuple(it.chain(
394            *[[i] if isinstance(i, cls) else i for i in values]
395        ))
396
397    def take(self, n):
398        import numpy as np
399        state = np.random.RandomState(self._seed)
400        torch.manual_seed(state.randint(low=0, high=2 ** 63, dtype=np.int64))
401        for _ in range(n):
402            params = self._generate(state)
403            tensors = {}
404            tensor_properties = {}
405            for t in self._tensors:
406                tensor, properties = t._make_tensor(params, state)
407                tensors[t.name] = tensor
408                tensor_properties[t.name] = properties
409            yield tensors, tensor_properties, params
410
411    @property
412    def rejection_rate(self):
413        if not self._total_generated:
414            return 0.
415        return self._rejections / self._total_generated
416
417    def _generate(self, state):
418        strict_params: Dict[str, Union[float, int, ParameterAlias]] = {}
419        for _ in range(1000):
420            candidate_params: Dict[str, Union[float, int, ParameterAlias]] = {}
421            for p in self._parameters:
422                if p.strict:
423                    if p.name in strict_params:
424                        candidate_params[p.name] = strict_params[p.name]
425                    else:
426                        candidate_params[p.name] = p.sample(state)
427                        strict_params[p.name] = candidate_params[p.name]
428                else:
429                    candidate_params[p.name] = p.sample(state)
430
431            candidate_params = self._resolve_aliases(candidate_params)
432
433            self._total_generated += 1
434            if not all(f(candidate_params) for f in self._constraints):
435                self._rejections += 1
436                continue
437
438            if not all(t.satisfies_constraints(candidate_params) for t in self._tensors):
439                self._rejections += 1
440                continue
441
442            return candidate_params
443        raise ValueError("Failed to generate a set of valid parameters.")
444
445    @staticmethod
446    def _resolve_aliases(params):
447        params = dict(params)
448        alias_count = sum(isinstance(v, ParameterAlias) for v in params.values())
449
450        keys = list(params.keys())
451        while alias_count:
452            for k in keys:
453                v = params[k]
454                if isinstance(v, ParameterAlias):
455                    params[k] = params[v.alias_to]
456            alias_count_new = sum(isinstance(v, ParameterAlias) for v in params.values())
457            if alias_count == alias_count_new:
458                raise ValueError(f"ParameterAlias cycle detected\n{params}")
459
460            alias_count = alias_count_new
461
462        return params
463