xref: /aosp_15_r20/external/pytorch/torch/csrc/functorch/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/FunctionalTensorWrapper.h>
8 #include <ATen/WrapDimUtils.h>
9 #include <torch/csrc/utils/python_raii.h>
10 #include <torch/python.h>
11 
12 #include <ATen/functorch/BatchRulesHelper.h>
13 #include <ATen/functorch/BatchedFallback.h>
14 #include <ATen/functorch/BatchedTensorImpl.h>
15 #include <ATen/functorch/DynamicLayer.h>
16 #include <ATen/functorch/Interpreter.h>
17 #include <ATen/functorch/LegacyVmapTransforms.h>
18 #include <ATen/functorch/PlumbingHelper.h>
19 #include <ATen/functorch/TensorWrapper.h>
20 #include <c10/core/AutogradState.h>
21 
22 #include <iostream>
23 
24 // This file contains functorch's Python bindings.
25 
26 namespace torch::functorch::impl {
27 
28 using namespace at::functorch;
29 
has_level(const Tensor & self,int64_t level)30 static bool has_level(const Tensor& self, int64_t level) {
31   const auto* batched = maybeGetBatchedImpl(self);
32   if (!batched) {
33     return false;
34   }
35   return batched->level() >= level;
36 }
37 
_add_batch_dim(const Tensor & self,int64_t batch_dim,int64_t level)38 Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) {
39   return addBatchDim(self, batch_dim, level);
40 }
41 
_wrap_functional_tensor(const Tensor & self,int64_t level)42 Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) {
43   auto t = at::functionalization::impl::to_functional_tensor(self);
44   at::functionalization::impl::unsafeGetFunctionalWrapper(t)->set_level(level);
45   return t;
46 }
47 
_assert_wrapped_functional(const Tensor & unwrapped,const Tensor & wrapped)48 void _assert_wrapped_functional(
49     const Tensor& unwrapped,
50     const Tensor& wrapped) {
51   TORCH_INTERNAL_ASSERT(
52       at::functionalization::impl::isFunctionalTensor(wrapped));
53   TORCH_INTERNAL_ASSERT(
54       !at::functionalization::impl::isFunctionalTensor(unwrapped));
55   auto wrapped_impl =
56       at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped);
57   auto& wrapped_inner = wrapped_impl->value();
58   TORCH_INTERNAL_ASSERT(
59       unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl())
60 }
61 
_propagate_functional_input_mutation(const Tensor & unwrapped,const Tensor & wrapped)62 void _propagate_functional_input_mutation(
63     const Tensor& unwrapped,
64     const Tensor& wrapped) {
65   TORCH_INTERNAL_ASSERT(
66       at::functionalization::impl::isFunctionalTensor(wrapped));
67   TORCH_INTERNAL_ASSERT(
68       !at::functionalization::impl::isFunctionalTensor(unwrapped));
69   auto wrapped_impl =
70       at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped);
71   // Ensure that the input is up to date by committing any pending updates to
72   // the alias.
73   wrapped_impl->sync_();
74   auto& wrapped_inner = wrapped_impl->value();
75   // It would probably be more reasonable to check that the two tensors are
76   // aliased, but we can't do that unless we give BatchedTensorImpl a notion of
77   // storage.
78   if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) {
79   } else {
80     if (unwrapped.sym_nbytes() != wrapped_inner.sym_nbytes()) {
81       // Functions might resize zero-sized inputs, which we need to reflect
82       // ehre.
83       unwrapped.resize__symint(wrapped_inner.sym_sizes());
84     }
85     // If the input tensor's metadata was mutated, then use as_strided_()
86     // to propagate the metadata change.
87     if (unwrapped.sym_sizes() != wrapped_inner.sym_sizes()) {
88       unwrapped.as_strided__symint(
89           wrapped_inner.sym_sizes(), wrapped_inner.sym_strides());
90     }
91     unwrapped.copy_(wrapped_inner);
92   }
93 }
94 
remove_existing_batch_dim(const BatchedTensorImpl * batched,int64_t level)95 static std::pair<Tensor, int64_t> remove_existing_batch_dim(
96     const BatchedTensorImpl* batched,
97     int64_t level) {
98   TORCH_INTERNAL_ASSERT(batched->level() == level);
99   return std::make_pair(batched->value(), batched->bdim());
100 }
101 
102 // Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
103 // while preserving the order of other existing dimensions.
104 // We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
105 // When we do, replace the following with it.
_movedim(const Tensor & self,int64_t src,int64_t dst)106 static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) {
107   auto logical_dim = self.dim();
108   src = at::maybe_wrap_dim(src, logical_dim);
109   dst = at::maybe_wrap_dim(dst, logical_dim);
110   if (src == dst) {
111     return self;
112   }
113   VmapDimVector permutation;
114   permutation.reserve(logical_dim);
115   for (int64_t dim = 0; dim < logical_dim; dim++) {
116     if (dim == src) {
117       continue;
118     }
119     permutation.push_back(dim);
120   }
121   permutation.insert(permutation.begin() + dst, src);
122   return self.permute(permutation);
123 }
124 
125 // Removes the batch dim with level `level` from `self`. If this causes the
126 // last batch dim to be removed from a BatchedTensor, then this returns a
127 // regular Tensor.
128 //
129 // If the `level` of the batch dim to remove does not exist in `self`, then we
130 // add the batch dim in. This can happen if `self` didn't interact with a tensor
131 // inside the vmap level, for example,
132 //     self = torch.randn(3)
133 //     y = torch.randn(5)
134 //     out = vmap(lambda x: vmap(lambda y: x)(y))(self)
135 //     assert out.shape == (3, 5)
136 // Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension
137 // corresponding to the *outer* vmap level and it doesn't have any dimensions
138 // that correspond to the inner vmap level so we need to create one for the
139 // user.
140 //
141 // `out_dim` controls where we should put the batch dimension in the output
142 // tensor.
_remove_batch_dim(const Tensor & self,int64_t level,int64_t batch_size,int64_t out_dim)143 Tensor _remove_batch_dim(
144     const Tensor& self,
145     int64_t level,
146     int64_t batch_size,
147     int64_t out_dim) {
148   TORCH_CHECK(
149       out_dim == 0 || !self.key_set().has(DispatchKey::BatchedNestedTensor),
150       "Nested tensors can only be vmapped over dim=0, but got dim=",
151       out_dim);
152   if (!has_level(self, level)) {
153     auto self_sizes = self.sizes();
154     VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
155     expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size);
156     auto result = self.expand(expanded_sizes);
157     return result;
158   }
159 
160   // Must be batched if has_level(self, /*any_level*/)
161   const auto* batched = maybeGetBatchedImpl(self);
162   TORCH_INTERNAL_ASSERT(batched != nullptr);
163 
164   auto [self_without_bdim, newly_exposed_logical_dim] =
165       remove_existing_batch_dim(batched, level);
166   auto result = _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
167   return result;
168 }
169 
_unwrap_functional_tensor(const Tensor & self,bool add_back_views)170 Tensor _unwrap_functional_tensor(const Tensor& self, bool add_back_views) {
171   // We only ever call that after popping out of a functionalize() call, in
172   // which case the current tensors should always be wrapped in a
173   // FunctionalTensorWrapper.
174   TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
175   auto functional =
176       at::functionalization::impl::unsafeGetFunctionalWrapper(self);
177 
178   // when regenerating the (potentially mutated) input tensors, the
179   // functionalization pass regenerates them through a series of view_copy() op
180   // calls. Functorch wants to turn those back into view ops though. Ensure that
181   // the input is up to date by committing any pending updates to the alias.
182   at::functionalization::impl::FunctionalizationReapplyViewsGuard guard(
183       add_back_views);
184   bool any_updates = functional->apply_updates();
185   if (any_updates) {
186     functional->regenerate_from_base();
187   }
188   return functional->value();
189 }
190 
_wrap_for_grad(const Tensor & self,int64_t level)191 Tensor _wrap_for_grad(const Tensor& self, int64_t level) {
192   // NB: different behavior inside??
193   // return self;
194   // TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self));
195   // TORCH_INTERNAL_ASSERT(self.has_storage());
196   return makeTensorWrapper(self, level);
197 }
198 
_unwrap_for_grad(const Tensor & self,int64_t level)199 Tensor _unwrap_for_grad(const Tensor& self, int64_t level) {
200   auto* result = maybeGetTensorWrapper(self);
201   if (!result) {
202     return self;
203   }
204   TORCH_INTERNAL_ASSERT(result->level().has_value());
205   if (result->level() == level) {
206     return result->value();
207   }
208   return self;
209 }
210 
dlevel(const Tensor & tensor)211 int64_t dlevel(const Tensor& tensor) {
212   auto* wrapped = maybeGetTensorWrapper(tensor);
213   if (!wrapped) {
214     return 0;
215   }
216   if (!wrapped->is_alive()) {
217     return -1;
218   }
219   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
220   return wrapped->level().value();
221 }
222 
dump_tensor(const Tensor & self)223 bool dump_tensor(const Tensor& self) {
224   dumpTensorCout(self);
225   return true;
226 }
227 
get_randomness_enum(const std::string & randomness)228 RandomnessType get_randomness_enum(const std::string& randomness) {
229   if (randomness == "error") {
230     return RandomnessType::Error;
231   } else if (randomness == "same") {
232     return RandomnessType::Same;
233   } else if (randomness == "different") {
234     return RandomnessType::Different;
235   } else {
236     TORCH_CHECK(
237         false, "randomness argument must be error, same, or different.");
238   }
239 }
240 
_grad_increment_nesting()241 int64_t _grad_increment_nesting() {
242   // See NOTE [grad and vjp interaction with no_grad]
243   bool prev_grad_mode = c10::GradMode::is_enabled();
244   return initAndPushDynamicLayer(
245       TransformType::Grad, std::nullopt, std::nullopt, prev_grad_mode);
246 }
247 
_grad_decrement_nesting()248 int64_t _grad_decrement_nesting() {
249   auto layer = popDynamicLayerAndDeleteMetadata();
250   TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Grad);
251   return layer.layerId();
252 }
253 
_jvp_increment_nesting()254 int64_t _jvp_increment_nesting() {
255   // See NOTE [grad and vjp interaction with no_grad]
256   bool prev_fwd_grad_mode =
257       c10::AutogradState::get_tls_state().get_fw_grad_mode();
258   return initAndPushDynamicLayer(
259       TransformType::Jvp,
260       std::nullopt,
261       std::nullopt,
262       std::nullopt,
263       prev_fwd_grad_mode);
264 }
265 
_jvp_decrement_nesting()266 int64_t _jvp_decrement_nesting() {
267   auto layer = popDynamicLayerAndDeleteMetadata();
268   TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Jvp);
269   return layer.layerId();
270 }
271 
_vmap_increment_nesting(c10::SymInt batch_size,const std::string & randomness)272 int64_t _vmap_increment_nesting(
273     c10::SymInt batch_size,
274     const std::string& randomness) {
275   return initAndPushDynamicLayer(
276       TransformType::Vmap,
277       std::move(batch_size),
278       get_randomness_enum(randomness));
279 }
280 
_vmap_decrement_nesting()281 int64_t _vmap_decrement_nesting() {
282   auto layer = popDynamicLayerAndDeleteMetadata();
283   TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Vmap);
284   return layer.layerId();
285 }
286 
_func_increment_nesting(bool reapply_views)287 int64_t _func_increment_nesting(bool reapply_views) {
288   return initAndPushDynamicLayer(
289       TransformType::Functionalize,
290       std::nullopt,
291       std::nullopt,
292       std::nullopt,
293       std::nullopt,
294       /*functionalize_add_back_views=*/reapply_views);
295 }
296 
_func_decrement_nesting()297 int64_t _func_decrement_nesting() {
298   auto layer = popDynamicLayerAndDeleteMetadata();
299   TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Functionalize);
300   return layer.layerId();
301 }
302 
is_batchedtensor(const Tensor & tensor)303 static bool is_batchedtensor(const Tensor& tensor) {
304   auto* batched = maybeGetBatchedImpl(tensor);
305   return batched != nullptr;
306 }
307 
is_legacy_batchedtensor(const Tensor & tensor)308 static bool is_legacy_batchedtensor(const Tensor& tensor) {
309   return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
310 }
311 
is_gradtrackingtensor(const Tensor & tensor)312 static bool is_gradtrackingtensor(const Tensor& tensor) {
313   auto* wrapped = maybeGetTensorWrapper(tensor);
314   return wrapped != nullptr;
315 }
316 
is_functionaltensor(const Tensor & tensor)317 static bool is_functionaltensor(const Tensor& tensor) {
318   return tensor.unsafeGetTensorImpl()->key_set().has(
319       c10::DispatchKey::Functionalize);
320 }
321 
get_unwrapped(const Tensor & tensor)322 static Tensor get_unwrapped(const Tensor& tensor) {
323   auto* batched = maybeGetBatchedImpl(tensor);
324   if (batched) {
325     return batched->value();
326   }
327   auto* wrapped = maybeGetTensorWrapper(tensor);
328   if (wrapped) {
329     return wrapped->value();
330   }
331   if (at::functionalization::impl::isFunctionalTensor(tensor)) {
332     auto* functional =
333         at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
334     return functional->value();
335   }
336   TORCH_CHECK(false, "No wrappers present!");
337 }
338 
maybe_get_level(const Tensor & tensor)339 static int64_t maybe_get_level(const Tensor& tensor) {
340   auto* batched = maybeGetBatchedImpl(tensor);
341   if (batched) {
342     return batched->level();
343   }
344   auto* wrapped = maybeGetTensorWrapper(tensor);
345   if (wrapped) {
346     if (wrapped->level()) {
347       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
348       return *wrapped->level();
349     }
350     // TODO: this is a weird special case...
351     return -2;
352   }
353   if (at::functionalization::impl::isFunctionalTensor(tensor)) {
354     auto* functional =
355         at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
356     return functional->level();
357   }
358   return -1;
359 }
360 
maybe_get_bdim(const Tensor & tensor)361 static int64_t maybe_get_bdim(const Tensor& tensor) {
362   auto* batched = maybeGetBatchedImpl(tensor);
363   if (batched) {
364     return batched->bdim();
365   }
366   return -1;
367 }
368 
currentLevel()369 static int64_t currentLevel() {
370   auto maybe_layer = maybeCurrentDynamicLayer();
371   TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
372   int64_t current_level = maybe_layer->layerId();
373   return current_level;
374 }
375 
maybe_current_level()376 static std::optional<int64_t> maybe_current_level() {
377   auto maybe_layer = maybeCurrentDynamicLayer();
378   if (maybe_layer.has_value()) {
379     int64_t current_level = maybe_layer->layerId();
380     return current_level;
381   }
382   return nullopt;
383 }
384 
tls_set_vmap_excluded(bool excluded)385 static void tls_set_vmap_excluded(bool excluded) {
386   c10::impl::tls_set_dispatch_key_excluded(
387       c10::DispatchKey::FuncTorchBatched, excluded);
388 }
389 
_set_dynamic_layer_keys_included(bool value)390 static void _set_dynamic_layer_keys_included(bool value) {
391   return setDynamicLayerFrontBackKeysIncluded(value);
392 }
393 
dump_dls()394 static void dump_dls() {
395   std::cout << getDynamicLayerStack() << std::endl;
396 }
397 
dump_local_tls()398 static void dump_local_tls() {
399   auto tls = c10::impl::tls_local_dispatch_key_set();
400   std::cout << "[Local Include] " << tls.included_ << std::endl;
401   std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
402 }
403 
404 namespace {
405 
406 // Pop the DynamicLayer stack until it's at the given depth.
popDynamicLayerStackToDepth(size_t depth)407 void popDynamicLayerStackToDepth(size_t depth) {
408   while (at::functorch::getDynamicLayerStack().size() > depth) {
409     const auto top = popDynamicLayer();
410     switch (top.key()) {
411       case at::functorch::TransformType::Vmap:
412         _vmap_decrement_nesting();
413         break;
414       case at::functorch::TransformType::Grad:
415         _grad_decrement_nesting();
416         break;
417       case at::functorch::TransformType::Jvp:
418         _jvp_decrement_nesting();
419         break;
420       case at::functorch::TransformType::Functionalize:
421         _func_decrement_nesting();
422         break;
423       case at::functorch::TransformType::Torch:
424         popDynamicLayerAndDeleteMetadata();
425         break;
426     }
427   }
428 }
429 
430 } // anonymous namespace
431 
unwrapBatched(const Tensor & tensor,int64_t level)432 static std::tuple<Tensor, std::optional<int64_t>> unwrapBatched(
433     const Tensor& tensor,
434     int64_t level) {
435   auto* batched = maybeGetBatchedImpl(tensor);
436   if (!batched) {
437     return std::make_tuple(tensor, std::nullopt);
438   }
439   if (batched->level() == level) {
440     return std::make_tuple(batched->value(), batched->bdim());
441   }
442   return std::make_tuple(tensor, std::nullopt);
443 }
444 
initFuncTorchBindings(PyObject * module)445 void initFuncTorchBindings(PyObject* module) {
446   auto _C = py::handle(module).cast<py::module>();
447   auto m = _C.def_submodule("_functorch");
448 
449   m.def("_add_batch_dim", &_add_batch_dim, "add batch dim");
450   m.def("_remove_batch_dim", &_remove_batch_dim, "remove batch dim");
451   m.def("_unwrap_batched", &unwrapBatched);
452   m.def(
453       "_wrap_functional_tensor",
454       &_wrap_functional_tensor,
455       "add functional tensor");
456   m.def(
457       "_assert_wrapped_functional",
458       &_assert_wrapped_functional,
459       "assert wrapped functional");
460   m.def(
461       "_propagate_functional_input_mutation",
462       &_propagate_functional_input_mutation,
463       "propagate functional input mutations");
464   m.def(
465       "_unwrap_functional_tensor",
466       &_unwrap_functional_tensor,
467       "remove functional tensor");
468   m.def("_vmap_increment_nesting", &_vmap_increment_nesting);
469   m.def("_vmap_decrement_nesting", &_vmap_decrement_nesting);
470   m.def(
471       "_func_increment_nesting",
472       &_func_increment_nesting,
473       "functionalization start");
474   m.def(
475       "_func_decrement_nesting",
476       &_func_decrement_nesting,
477       "functionalization end");
478   m.def("_grad_increment_nesting", &_grad_increment_nesting);
479   m.def("_grad_decrement_nesting", &_grad_decrement_nesting);
480   m.def("_jvp_increment_nesting", &_jvp_increment_nesting);
481   m.def("_jvp_decrement_nesting", &_jvp_decrement_nesting);
482   m.def("_wrap_for_grad", &_wrap_for_grad, "wrap as gradtrackingtensor");
483   m.def(
484       "_unwrap_for_grad", &_unwrap_for_grad, "unwrap from gradtrackingtensor");
485   m.def(
486       "_set_vmap_fallback_warning_enabled",
487       &at::functorch::setVmapFallbackWarningEnabled,
488       "Set vmap fallback warnings");
489   m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
490   m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
491   m.def(
492       "set_inplace_requires_grad_allowed",
493       &at::functorch::setInplaceRequiresGradAllowed);
494   m.def(
495       "get_inplace_requires_grad_allowed",
496       &at::functorch::getInplaceRequiresGradAllowed);
497   m.def(
498       "set_single_level_autograd_function_allowed",
499       &at::functorch::setSingleLevelAutogradFunctionAllowed);
500   m.def(
501       "get_single_level_autograd_function_allowed",
502       &at::functorch::getSingleLevelAutogradFunctionAllowed);
503   m.def("unwrap_if_dead", &unwrapIfDead);
504   m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
505   m.def("dlevel", &dlevel, "dlevel");
506   m.def("dump_tensor", &dump_tensor, "dump_tensor");
507   m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
508   m.def("reshape_dim_outof", &at::functorch::reshape_dim_outof);
509   // various debugging things. Maybe we should offer these as first-class APIs
510   // on Tensors?
511   m.def("is_batchedtensor", &is_batchedtensor);
512   m.def("is_legacy_batchedtensor", &is_legacy_batchedtensor);
513   m.def("is_gradtrackingtensor", &is_gradtrackingtensor);
514   m.def("is_functionaltensor", &is_functionaltensor);
515   m.def("get_unwrapped", &get_unwrapped);
516   m.def("maybe_get_level", &maybe_get_level);
517   m.def("maybe_get_bdim", &maybe_get_bdim);
518   m.def("maybe_current_level", &maybe_current_level);
519   m.def("current_level", &currentLevel);
520   m.def("tls_set_vmap_excluded", &tls_set_vmap_excluded);
521   m.def("_set_dynamic_layer_keys_included", &_set_dynamic_layer_keys_included);
522   m.def("dump_dls", &dump_dls);
523   m.def("dump_local_tls", &dump_local_tls);
524   m.def("is_functorch_wrapped_tensor", [](const Tensor& tensor) {
525     return maybe_get_level(tensor) != -1;
526   });
527   m.def(
528       "get_interpreter_stack", []() -> std::optional<std::vector<Interpreter>> {
529         const auto& stack = getDynamicLayerStack();
530         if (stack.empty()) {
531           return std::nullopt;
532         }
533         std::vector<Interpreter> result;
534         result.reserve(stack.size());
535         for (auto i : stack) {
536           result.push_back(i.interpreter());
537         }
538         return result;
539       });
540   m.def("peek_interpreter_stack", []() -> std::optional<Interpreter> {
541     const auto& stack = getDynamicLayerStack();
542     if (stack.empty()) {
543       return std::nullopt;
544     }
545     auto result = stack.back().interpreter();
546     return result;
547   });
548   m.def("get_dynamic_layer_stack_depth", []() -> size_t {
549     return getDynamicLayerStack().size();
550   });
551   m.def(
552       "pop_dynamic_layer_stack_and_undo_to_depth",
553       &popDynamicLayerStackToDepth);
554   m.def("pop_dynamic_layer_stack", &popDynamicLayer);
555   m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t {
556     return pushDynamicLayer(std::move(layer));
557   });
558   // NOLINTNEXTLINE(bugprone-unused-raii)
559   py::class_<DynamicLayer>(m, "DynamicLayer");
560 
561   py::enum_<TransformType>(m, "TransformType")
562       .value("Torch", TransformType::Torch)
563       .value("Grad", TransformType::Grad)
564       .value("Jvp", TransformType::Jvp)
565       .value("Functionalize", TransformType::Functionalize)
566       .value("Vmap", TransformType::Vmap);
567   py::enum_<RandomnessType>(m, "RandomnessType")
568       .value("Error", RandomnessType::Error)
569       .value("Same", RandomnessType::Same)
570       .value("Different", RandomnessType::Different);
571   py::class_<Interpreter>(m, "CInterpreter")
572       .def("key", &Interpreter::key)
573       .def("level", &Interpreter::level);
574   py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr")
575       .def(py::init<const Interpreter*>())
576       .def("key", &GradInterpreterPtr::key)
577       .def("level", &GradInterpreterPtr::level)
578       .def("lift", &GradInterpreterPtr::lift)
579       .def("prevGradMode", &GradInterpreterPtr::prevGradMode);
580   py::class_<JvpInterpreterPtr>(m, "CJvpInterpreterPtr")
581       .def(py::init<const Interpreter*>())
582       .def("key", &JvpInterpreterPtr::key)
583       .def("level", &JvpInterpreterPtr::level)
584       .def("lift", &JvpInterpreterPtr::lift)
585       .def("prevFwdGradMode", &JvpInterpreterPtr::prevFwdGradMode);
586   py::class_<VmapInterpreterPtr>(m, "CVmapInterpreterPtr")
587       .def(py::init<const Interpreter*>())
588       .def("key", &VmapInterpreterPtr::key)
589       .def("level", &VmapInterpreterPtr::level)
590       .def("batchSize", &VmapInterpreterPtr::batchSize)
591       .def("randomness", &VmapInterpreterPtr::randomness);
592   py::class_<FunctionalizeInterpreterPtr>(m, "CFunctionalizeInterpreterPtr")
593       .def(py::init<const Interpreter*>())
594       .def("key", &FunctionalizeInterpreterPtr::key)
595       .def("level", &FunctionalizeInterpreterPtr::level)
596       .def(
597           "functionalizeAddBackViews",
598           &FunctionalizeInterpreterPtr::functionalizeAddBackViews);
599 }
600 
601 } // namespace torch::functorch::impl
602