1 #pragma once 2 3 #include <c10/core/impl/LocalDispatchKeySet.h> 4 5 namespace at::impl { 6 7 // VmapMode contains a thread local count of how many nested vmaps 8 // we are currently inside. That number is known as the `vmap level`. 9 // VmapMode is used in the implementation of the Python `torch.vmap` API. 10 // 11 // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. 12 13 struct TORCH_API VmapMode { 14 // Returns the vmap level, aka the count of how many nested vmaps we're in. 15 static int64_t current_vmap_level(); 16 17 // Increment the count of nested vmaps. If this causes the vmap level to be 18 // greater than 0, then it enables DispatchKey::VmapMode on all tensors. 19 static int64_t increment_nesting(); 20 21 // Decrements the count of nested vmaps. If this causes the vmap level to be 22 // equal to 0, then it disables DispatchKey::VmapMode on all tensors. 23 static int64_t decrement_nesting(); 24 }; 25 26 } // namespace at::impl 27