xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyVmapMode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/LegacyVmapMode.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker namespace at::impl {
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker thread_local int64_t VmapMode_current_vmap_level = 0;
6*da0073e9SAndroid Build Coastguard Worker 
current_vmap_level()7*da0073e9SAndroid Build Coastguard Worker int64_t VmapMode::current_vmap_level() {
8*da0073e9SAndroid Build Coastguard Worker   return VmapMode_current_vmap_level;
9*da0073e9SAndroid Build Coastguard Worker }
10*da0073e9SAndroid Build Coastguard Worker 
increment_nesting()11*da0073e9SAndroid Build Coastguard Worker int64_t VmapMode::increment_nesting() {
12*da0073e9SAndroid Build Coastguard Worker   VmapMode_current_vmap_level++;
13*da0073e9SAndroid Build Coastguard Worker   if (VmapMode_current_vmap_level == 1) {
14*da0073e9SAndroid Build Coastguard Worker     c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, true);
15*da0073e9SAndroid Build Coastguard Worker   }
16*da0073e9SAndroid Build Coastguard Worker   return VmapMode_current_vmap_level;
17*da0073e9SAndroid Build Coastguard Worker }
18*da0073e9SAndroid Build Coastguard Worker 
decrement_nesting()19*da0073e9SAndroid Build Coastguard Worker int64_t VmapMode::decrement_nesting() {
20*da0073e9SAndroid Build Coastguard Worker   VmapMode_current_vmap_level--;
21*da0073e9SAndroid Build Coastguard Worker   if (VmapMode_current_vmap_level == 0) {
22*da0073e9SAndroid Build Coastguard Worker     c10::impl::tls_set_dispatch_key_included(DispatchKey::VmapMode, false);
23*da0073e9SAndroid Build Coastguard Worker   }
24*da0073e9SAndroid Build Coastguard Worker   return VmapMode_current_vmap_level;
25*da0073e9SAndroid Build Coastguard Worker }
26*da0073e9SAndroid Build Coastguard Worker } // namespace at::impl
27