xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyVmapMode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/LocalDispatchKeySet.h>
4 
5 namespace at::impl {
6 
7 // VmapMode contains a thread local count of how many nested vmaps
8 // we are currently inside. That number is known as the `vmap level`.
9 // VmapMode is used in the implementation of the Python `torch.vmap` API.
10 //
11 // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
12 
13 struct TORCH_API VmapMode {
14   // Returns the vmap level, aka the count of how many nested vmaps we're in.
15   static int64_t current_vmap_level();
16 
17   // Increment the count of nested vmaps. If this causes the vmap level to be
18   // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
19   static int64_t increment_nesting();
20 
21   // Decrements the count of nested vmaps. If this causes the vmap level to be
22   // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
23   static int64_t decrement_nesting();
24 };
25 
26 } // namespace at::impl
27