xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyVmapMode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/LegacyVmapMode.h>
2 
3 namespace at::impl {
4 
5 thread_local int64_t VmapMode_current_vmap_level = 0;
6 
current_vmap_level()7 int64_t VmapMode::current_vmap_level() {
8   return VmapMode_current_vmap_level;
9 }
10 
increment_nesting()11 int64_t VmapMode::increment_nesting() {
12   VmapMode_current_vmap_level++;
13   if (VmapMode_current_vmap_level == 1) {
14     c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
15   }
16   return VmapMode_current_vmap_level;
17 }
18 
decrement_nesting()19 int64_t VmapMode::decrement_nesting() {
20   VmapMode_current_vmap_level--;
21   if (VmapMode_current_vmap_level == 0) {
22     c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
23   }
24   return VmapMode_current_vmap_level;
25 }
26 } // namespace at::impl
27