1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from contextlib import contextmanager 7 8from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers 9 10 11_enabled = False 12 13 14@contextmanager 15def _enable_layers(dims): 16 global _enabled 17 assert not _enabled 18 input = sorted((d._level, d.size) for d in dims if not isinstance(d, int)) 19 n = len(input) 20 try: 21 _vmap_add_layers(input) 22 _enabled = True 23 yield 24 finally: 25 _enabled = False 26 _vmap_remove_layers(n) 27