xref: /aosp_15_r20/external/pytorch/functorch/dim/batch_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from contextlib import contextmanager
7
8from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
9
10
11_enabled = False
12
13
14@contextmanager
15def _enable_layers(dims):
16    global _enabled
17    assert not _enabled
18    input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
19    n = len(input)
20    try:
21        _vmap_add_layers(input)
22        _enabled = True
23        yield
24    finally:
25        _enabled = False
26        _vmap_remove_layers(n)
27