1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/FunctionalTensorWrapper.h>
8 #include <ATen/WrapDimUtils.h>
9 #include <torch/csrc/utils/python_raii.h>
10 #include <torch/python.h>
11
12 #include <ATen/functorch/BatchRulesHelper.h>
13 #include <ATen/functorch/BatchedFallback.h>
14 #include <ATen/functorch/BatchedTensorImpl.h>
15 #include <ATen/functorch/DynamicLayer.h>
16 #include <ATen/functorch/Interpreter.h>
17 #include <ATen/functorch/LegacyVmapTransforms.h>
18 #include <ATen/functorch/PlumbingHelper.h>
19 #include <ATen/functorch/TensorWrapper.h>
20 #include <c10/core/AutogradState.h>
21
22 #include <iostream>
23
24 // This file contains functorch's Python bindings.
25
26 namespace torch::functorch::impl {
27
28 using namespace at::functorch;
29
has_level(const Tensor & self,int64_t level)30 static bool has_level(const Tensor& self, int64_t level) {
31 const auto* batched = maybeGetBatchedImpl(self);
32 if (!batched) {
33 return false;
34 }
35 return batched->level() >= level;
36 }
37
_add_batch_dim(const Tensor & self,int64_t batch_dim,int64_t level)38 Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) {
39 return addBatchDim(self, batch_dim, level);
40 }
41
_wrap_functional_tensor(const Tensor & self,int64_t level)42 Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) {
43 auto t = at::functionalization::impl::to_functional_tensor(self);
44 at::functionalization::impl::unsafeGetFunctionalWrapper(t)->set_level(level);
45 return t;
46 }
47
_assert_wrapped_functional(const Tensor & unwrapped,const Tensor & wrapped)48 void _assert_wrapped_functional(
49 const Tensor& unwrapped,
50 const Tensor& wrapped) {
51 TORCH_INTERNAL_ASSERT(
52 at::functionalization::impl::isFunctionalTensor(wrapped));
53 TORCH_INTERNAL_ASSERT(
54 !at::functionalization::impl::isFunctionalTensor(unwrapped));
55 auto wrapped_impl =
56 at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped);
57 auto& wrapped_inner = wrapped_impl->value();
58 TORCH_INTERNAL_ASSERT(
59 unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl())
60 }
61
_propagate_functional_input_mutation(const Tensor & unwrapped,const Tensor & wrapped)62 void _propagate_functional_input_mutation(
63 const Tensor& unwrapped,
64 const Tensor& wrapped) {
65 TORCH_INTERNAL_ASSERT(
66 at::functionalization::impl::isFunctionalTensor(wrapped));
67 TORCH_INTERNAL_ASSERT(
68 !at::functionalization::impl::isFunctionalTensor(unwrapped));
69 auto wrapped_impl =
70 at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped);
71 // Ensure that the input is up to date by committing any pending updates to
72 // the alias.
73 wrapped_impl->sync_();
74 auto& wrapped_inner = wrapped_impl->value();
75 // It would probably be more reasonable to check that the two tensors are
76 // aliased, but we can't do that unless we give BatchedTensorImpl a notion of
77 // storage.
78 if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) {
79 } else {
80 if (unwrapped.sym_nbytes() != wrapped_inner.sym_nbytes()) {
81 // Functions might resize zero-sized inputs, which we need to reflect
82 // ehre.
83 unwrapped.resize__symint(wrapped_inner.sym_sizes());
84 }
85 // If the input tensor's metadata was mutated, then use as_strided_()
86 // to propagate the metadata change.
87 if (unwrapped.sym_sizes() != wrapped_inner.sym_sizes()) {
88 unwrapped.as_strided__symint(
89 wrapped_inner.sym_sizes(), wrapped_inner.sym_strides());
90 }
91 unwrapped.copy_(wrapped_inner);
92 }
93 }
94
remove_existing_batch_dim(const BatchedTensorImpl * batched,int64_t level)95 static std::pair<Tensor, int64_t> remove_existing_batch_dim(
96 const BatchedTensorImpl* batched,
97 int64_t level) {
98 TORCH_INTERNAL_ASSERT(batched->level() == level);
99 return std::make_pair(batched->value(), batched->bdim());
100 }
101
102 // Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src`
103 // while preserving the order of other existing dimensions.
104 // We should probably add np.moveaxis (it is more general) to PyTorch. (#36048)
105 // When we do, replace the following with it.
_movedim(const Tensor & self,int64_t src,int64_t dst)106 static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) {
107 auto logical_dim = self.dim();
108 src = at::maybe_wrap_dim(src, logical_dim);
109 dst = at::maybe_wrap_dim(dst, logical_dim);
110 if (src == dst) {
111 return self;
112 }
113 VmapDimVector permutation;
114 permutation.reserve(logical_dim);
115 for (int64_t dim = 0; dim < logical_dim; dim++) {
116 if (dim == src) {
117 continue;
118 }
119 permutation.push_back(dim);
120 }
121 permutation.insert(permutation.begin() + dst, src);
122 return self.permute(permutation);
123 }
124
125 // Removes the batch dim with level `level` from `self`. If this causes the
126 // last batch dim to be removed from a BatchedTensor, then this returns a
127 // regular Tensor.
128 //
129 // If the `level` of the batch dim to remove does not exist in `self`, then we
130 // add the batch dim in. This can happen if `self` didn't interact with a tensor
131 // inside the vmap level, for example,
132 // self = torch.randn(3)
133 // y = torch.randn(5)
134 // out = vmap(lambda x: vmap(lambda y: x)(y))(self)
135 // assert out.shape == (3, 5)
136 // Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension
137 // corresponding to the *outer* vmap level and it doesn't have any dimensions
138 // that correspond to the inner vmap level so we need to create one for the
139 // user.
140 //
141 // `out_dim` controls where we should put the batch dimension in the output
142 // tensor.
_remove_batch_dim(const Tensor & self,int64_t level,int64_t batch_size,int64_t out_dim)143 Tensor _remove_batch_dim(
144 const Tensor& self,
145 int64_t level,
146 int64_t batch_size,
147 int64_t out_dim) {
148 TORCH_CHECK(
149 out_dim == 0 || !self.key_set().has(DispatchKey::BatchedNestedTensor),
150 "Nested tensors can only be vmapped over dim=0, but got dim=",
151 out_dim);
152 if (!has_level(self, level)) {
153 auto self_sizes = self.sizes();
154 VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end());
155 expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size);
156 auto result = self.expand(expanded_sizes);
157 return result;
158 }
159
160 // Must be batched if has_level(self, /*any_level*/)
161 const auto* batched = maybeGetBatchedImpl(self);
162 TORCH_INTERNAL_ASSERT(batched != nullptr);
163
164 auto [self_without_bdim, newly_exposed_logical_dim] =
165 remove_existing_batch_dim(batched, level);
166 auto result = _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim);
167 return result;
168 }
169
_unwrap_functional_tensor(const Tensor & self,bool add_back_views)170 Tensor _unwrap_functional_tensor(const Tensor& self, bool add_back_views) {
171 // We only ever call that after popping out of a functionalize() call, in
172 // which case the current tensors should always be wrapped in a
173 // FunctionalTensorWrapper.
174 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
175 auto functional =
176 at::functionalization::impl::unsafeGetFunctionalWrapper(self);
177
178 // when regenerating the (potentially mutated) input tensors, the
179 // functionalization pass regenerates them through a series of view_copy() op
180 // calls. Functorch wants to turn those back into view ops though. Ensure that
181 // the input is up to date by committing any pending updates to the alias.
182 at::functionalization::impl::FunctionalizationReapplyViewsGuard guard(
183 add_back_views);
184 bool any_updates = functional->apply_updates();
185 if (any_updates) {
186 functional->regenerate_from_base();
187 }
188 return functional->value();
189 }
190
_wrap_for_grad(const Tensor & self,int64_t level)191 Tensor _wrap_for_grad(const Tensor& self, int64_t level) {
192 // NB: different behavior inside??
193 // return self;
194 // TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self));
195 // TORCH_INTERNAL_ASSERT(self.has_storage());
196 return makeTensorWrapper(self, level);
197 }
198
_unwrap_for_grad(const Tensor & self,int64_t level)199 Tensor _unwrap_for_grad(const Tensor& self, int64_t level) {
200 auto* result = maybeGetTensorWrapper(self);
201 if (!result) {
202 return self;
203 }
204 TORCH_INTERNAL_ASSERT(result->level().has_value());
205 if (result->level() == level) {
206 return result->value();
207 }
208 return self;
209 }
210
dlevel(const Tensor & tensor)211 int64_t dlevel(const Tensor& tensor) {
212 auto* wrapped = maybeGetTensorWrapper(tensor);
213 if (!wrapped) {
214 return 0;
215 }
216 if (!wrapped->is_alive()) {
217 return -1;
218 }
219 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
220 return wrapped->level().value();
221 }
222
dump_tensor(const Tensor & self)223 bool dump_tensor(const Tensor& self) {
224 dumpTensorCout(self);
225 return true;
226 }
227
get_randomness_enum(const std::string & randomness)228 RandomnessType get_randomness_enum(const std::string& randomness) {
229 if (randomness == "error") {
230 return RandomnessType::Error;
231 } else if (randomness == "same") {
232 return RandomnessType::Same;
233 } else if (randomness == "different") {
234 return RandomnessType::Different;
235 } else {
236 TORCH_CHECK(
237 false, "randomness argument must be error, same, or different.");
238 }
239 }
240
_grad_increment_nesting()241 int64_t _grad_increment_nesting() {
242 // See NOTE [grad and vjp interaction with no_grad]
243 bool prev_grad_mode = c10::GradMode::is_enabled();
244 return initAndPushDynamicLayer(
245 TransformType::Grad, std::nullopt, std::nullopt, prev_grad_mode);
246 }
247
_grad_decrement_nesting()248 int64_t _grad_decrement_nesting() {
249 auto layer = popDynamicLayerAndDeleteMetadata();
250 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Grad);
251 return layer.layerId();
252 }
253
_jvp_increment_nesting()254 int64_t _jvp_increment_nesting() {
255 // See NOTE [grad and vjp interaction with no_grad]
256 bool prev_fwd_grad_mode =
257 c10::AutogradState::get_tls_state().get_fw_grad_mode();
258 return initAndPushDynamicLayer(
259 TransformType::Jvp,
260 std::nullopt,
261 std::nullopt,
262 std::nullopt,
263 prev_fwd_grad_mode);
264 }
265
_jvp_decrement_nesting()266 int64_t _jvp_decrement_nesting() {
267 auto layer = popDynamicLayerAndDeleteMetadata();
268 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Jvp);
269 return layer.layerId();
270 }
271
_vmap_increment_nesting(c10::SymInt batch_size,const std::string & randomness)272 int64_t _vmap_increment_nesting(
273 c10::SymInt batch_size,
274 const std::string& randomness) {
275 return initAndPushDynamicLayer(
276 TransformType::Vmap,
277 std::move(batch_size),
278 get_randomness_enum(randomness));
279 }
280
_vmap_decrement_nesting()281 int64_t _vmap_decrement_nesting() {
282 auto layer = popDynamicLayerAndDeleteMetadata();
283 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Vmap);
284 return layer.layerId();
285 }
286
_func_increment_nesting(bool reapply_views)287 int64_t _func_increment_nesting(bool reapply_views) {
288 return initAndPushDynamicLayer(
289 TransformType::Functionalize,
290 std::nullopt,
291 std::nullopt,
292 std::nullopt,
293 std::nullopt,
294 /*functionalize_add_back_views=*/reapply_views);
295 }
296
_func_decrement_nesting()297 int64_t _func_decrement_nesting() {
298 auto layer = popDynamicLayerAndDeleteMetadata();
299 TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Functionalize);
300 return layer.layerId();
301 }
302
is_batchedtensor(const Tensor & tensor)303 static bool is_batchedtensor(const Tensor& tensor) {
304 auto* batched = maybeGetBatchedImpl(tensor);
305 return batched != nullptr;
306 }
307
is_legacy_batchedtensor(const Tensor & tensor)308 static bool is_legacy_batchedtensor(const Tensor& tensor) {
309 return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
310 }
311
is_gradtrackingtensor(const Tensor & tensor)312 static bool is_gradtrackingtensor(const Tensor& tensor) {
313 auto* wrapped = maybeGetTensorWrapper(tensor);
314 return wrapped != nullptr;
315 }
316
is_functionaltensor(const Tensor & tensor)317 static bool is_functionaltensor(const Tensor& tensor) {
318 return tensor.unsafeGetTensorImpl()->key_set().has(
319 c10::DispatchKey::Functionalize);
320 }
321
get_unwrapped(const Tensor & tensor)322 static Tensor get_unwrapped(const Tensor& tensor) {
323 auto* batched = maybeGetBatchedImpl(tensor);
324 if (batched) {
325 return batched->value();
326 }
327 auto* wrapped = maybeGetTensorWrapper(tensor);
328 if (wrapped) {
329 return wrapped->value();
330 }
331 if (at::functionalization::impl::isFunctionalTensor(tensor)) {
332 auto* functional =
333 at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
334 return functional->value();
335 }
336 TORCH_CHECK(false, "No wrappers present!");
337 }
338
maybe_get_level(const Tensor & tensor)339 static int64_t maybe_get_level(const Tensor& tensor) {
340 auto* batched = maybeGetBatchedImpl(tensor);
341 if (batched) {
342 return batched->level();
343 }
344 auto* wrapped = maybeGetTensorWrapper(tensor);
345 if (wrapped) {
346 if (wrapped->level()) {
347 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
348 return *wrapped->level();
349 }
350 // TODO: this is a weird special case...
351 return -2;
352 }
353 if (at::functionalization::impl::isFunctionalTensor(tensor)) {
354 auto* functional =
355 at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
356 return functional->level();
357 }
358 return -1;
359 }
360
maybe_get_bdim(const Tensor & tensor)361 static int64_t maybe_get_bdim(const Tensor& tensor) {
362 auto* batched = maybeGetBatchedImpl(tensor);
363 if (batched) {
364 return batched->bdim();
365 }
366 return -1;
367 }
368
currentLevel()369 static int64_t currentLevel() {
370 auto maybe_layer = maybeCurrentDynamicLayer();
371 TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
372 int64_t current_level = maybe_layer->layerId();
373 return current_level;
374 }
375
maybe_current_level()376 static std::optional<int64_t> maybe_current_level() {
377 auto maybe_layer = maybeCurrentDynamicLayer();
378 if (maybe_layer.has_value()) {
379 int64_t current_level = maybe_layer->layerId();
380 return current_level;
381 }
382 return nullopt;
383 }
384
tls_set_vmap_excluded(bool excluded)385 static void tls_set_vmap_excluded(bool excluded) {
386 c10::impl::tls_set_dispatch_key_excluded(
387 c10::DispatchKey::FuncTorchBatched, excluded);
388 }
389
_set_dynamic_layer_keys_included(bool value)390 static void _set_dynamic_layer_keys_included(bool value) {
391 return setDynamicLayerFrontBackKeysIncluded(value);
392 }
393
dump_dls()394 static void dump_dls() {
395 std::cout << getDynamicLayerStack() << std::endl;
396 }
397
dump_local_tls()398 static void dump_local_tls() {
399 auto tls = c10::impl::tls_local_dispatch_key_set();
400 std::cout << "[Local Include] " << tls.included_ << std::endl;
401 std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
402 }
403
404 namespace {
405
406 // Pop the DynamicLayer stack until it's at the given depth.
popDynamicLayerStackToDepth(size_t depth)407 void popDynamicLayerStackToDepth(size_t depth) {
408 while (at::functorch::getDynamicLayerStack().size() > depth) {
409 const auto top = popDynamicLayer();
410 switch (top.key()) {
411 case at::functorch::TransformType::Vmap:
412 _vmap_decrement_nesting();
413 break;
414 case at::functorch::TransformType::Grad:
415 _grad_decrement_nesting();
416 break;
417 case at::functorch::TransformType::Jvp:
418 _jvp_decrement_nesting();
419 break;
420 case at::functorch::TransformType::Functionalize:
421 _func_decrement_nesting();
422 break;
423 case at::functorch::TransformType::Torch:
424 popDynamicLayerAndDeleteMetadata();
425 break;
426 }
427 }
428 }
429
430 } // anonymous namespace
431
unwrapBatched(const Tensor & tensor,int64_t level)432 static std::tuple<Tensor, std::optional<int64_t>> unwrapBatched(
433 const Tensor& tensor,
434 int64_t level) {
435 auto* batched = maybeGetBatchedImpl(tensor);
436 if (!batched) {
437 return std::make_tuple(tensor, std::nullopt);
438 }
439 if (batched->level() == level) {
440 return std::make_tuple(batched->value(), batched->bdim());
441 }
442 return std::make_tuple(tensor, std::nullopt);
443 }
444
initFuncTorchBindings(PyObject * module)445 void initFuncTorchBindings(PyObject* module) {
446 auto _C = py::handle(module).cast<py::module>();
447 auto m = _C.def_submodule("_functorch");
448
449 m.def("_add_batch_dim", &_add_batch_dim, "add batch dim");
450 m.def("_remove_batch_dim", &_remove_batch_dim, "remove batch dim");
451 m.def("_unwrap_batched", &unwrapBatched);
452 m.def(
453 "_wrap_functional_tensor",
454 &_wrap_functional_tensor,
455 "add functional tensor");
456 m.def(
457 "_assert_wrapped_functional",
458 &_assert_wrapped_functional,
459 "assert wrapped functional");
460 m.def(
461 "_propagate_functional_input_mutation",
462 &_propagate_functional_input_mutation,
463 "propagate functional input mutations");
464 m.def(
465 "_unwrap_functional_tensor",
466 &_unwrap_functional_tensor,
467 "remove functional tensor");
468 m.def("_vmap_increment_nesting", &_vmap_increment_nesting);
469 m.def("_vmap_decrement_nesting", &_vmap_decrement_nesting);
470 m.def(
471 "_func_increment_nesting",
472 &_func_increment_nesting,
473 "functionalization start");
474 m.def(
475 "_func_decrement_nesting",
476 &_func_decrement_nesting,
477 "functionalization end");
478 m.def("_grad_increment_nesting", &_grad_increment_nesting);
479 m.def("_grad_decrement_nesting", &_grad_decrement_nesting);
480 m.def("_jvp_increment_nesting", &_jvp_increment_nesting);
481 m.def("_jvp_decrement_nesting", &_jvp_decrement_nesting);
482 m.def("_wrap_for_grad", &_wrap_for_grad, "wrap as gradtrackingtensor");
483 m.def(
484 "_unwrap_for_grad", &_unwrap_for_grad, "unwrap from gradtrackingtensor");
485 m.def(
486 "_set_vmap_fallback_warning_enabled",
487 &at::functorch::setVmapFallbackWarningEnabled,
488 "Set vmap fallback warnings");
489 m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
490 m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
491 m.def(
492 "set_inplace_requires_grad_allowed",
493 &at::functorch::setInplaceRequiresGradAllowed);
494 m.def(
495 "get_inplace_requires_grad_allowed",
496 &at::functorch::getInplaceRequiresGradAllowed);
497 m.def(
498 "set_single_level_autograd_function_allowed",
499 &at::functorch::setSingleLevelAutogradFunctionAllowed);
500 m.def(
501 "get_single_level_autograd_function_allowed",
502 &at::functorch::getSingleLevelAutogradFunctionAllowed);
503 m.def("unwrap_if_dead", &unwrapIfDead);
504 m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper);
505 m.def("dlevel", &dlevel, "dlevel");
506 m.def("dump_tensor", &dump_tensor, "dump_tensor");
507 m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
508 m.def("reshape_dim_outof", &at::functorch::reshape_dim_outof);
509 // various debugging things. Maybe we should offer these as first-class APIs
510 // on Tensors?
511 m.def("is_batchedtensor", &is_batchedtensor);
512 m.def("is_legacy_batchedtensor", &is_legacy_batchedtensor);
513 m.def("is_gradtrackingtensor", &is_gradtrackingtensor);
514 m.def("is_functionaltensor", &is_functionaltensor);
515 m.def("get_unwrapped", &get_unwrapped);
516 m.def("maybe_get_level", &maybe_get_level);
517 m.def("maybe_get_bdim", &maybe_get_bdim);
518 m.def("maybe_current_level", &maybe_current_level);
519 m.def("current_level", ¤tLevel);
520 m.def("tls_set_vmap_excluded", &tls_set_vmap_excluded);
521 m.def("_set_dynamic_layer_keys_included", &_set_dynamic_layer_keys_included);
522 m.def("dump_dls", &dump_dls);
523 m.def("dump_local_tls", &dump_local_tls);
524 m.def("is_functorch_wrapped_tensor", [](const Tensor& tensor) {
525 return maybe_get_level(tensor) != -1;
526 });
527 m.def(
528 "get_interpreter_stack", []() -> std::optional<std::vector<Interpreter>> {
529 const auto& stack = getDynamicLayerStack();
530 if (stack.empty()) {
531 return std::nullopt;
532 }
533 std::vector<Interpreter> result;
534 result.reserve(stack.size());
535 for (auto i : stack) {
536 result.push_back(i.interpreter());
537 }
538 return result;
539 });
540 m.def("peek_interpreter_stack", []() -> std::optional<Interpreter> {
541 const auto& stack = getDynamicLayerStack();
542 if (stack.empty()) {
543 return std::nullopt;
544 }
545 auto result = stack.back().interpreter();
546 return result;
547 });
548 m.def("get_dynamic_layer_stack_depth", []() -> size_t {
549 return getDynamicLayerStack().size();
550 });
551 m.def(
552 "pop_dynamic_layer_stack_and_undo_to_depth",
553 &popDynamicLayerStackToDepth);
554 m.def("pop_dynamic_layer_stack", &popDynamicLayer);
555 m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t {
556 return pushDynamicLayer(std::move(layer));
557 });
558 // NOLINTNEXTLINE(bugprone-unused-raii)
559 py::class_<DynamicLayer>(m, "DynamicLayer");
560
561 py::enum_<TransformType>(m, "TransformType")
562 .value("Torch", TransformType::Torch)
563 .value("Grad", TransformType::Grad)
564 .value("Jvp", TransformType::Jvp)
565 .value("Functionalize", TransformType::Functionalize)
566 .value("Vmap", TransformType::Vmap);
567 py::enum_<RandomnessType>(m, "RandomnessType")
568 .value("Error", RandomnessType::Error)
569 .value("Same", RandomnessType::Same)
570 .value("Different", RandomnessType::Different);
571 py::class_<Interpreter>(m, "CInterpreter")
572 .def("key", &Interpreter::key)
573 .def("level", &Interpreter::level);
574 py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr")
575 .def(py::init<const Interpreter*>())
576 .def("key", &GradInterpreterPtr::key)
577 .def("level", &GradInterpreterPtr::level)
578 .def("lift", &GradInterpreterPtr::lift)
579 .def("prevGradMode", &GradInterpreterPtr::prevGradMode);
580 py::class_<JvpInterpreterPtr>(m, "CJvpInterpreterPtr")
581 .def(py::init<const Interpreter*>())
582 .def("key", &JvpInterpreterPtr::key)
583 .def("level", &JvpInterpreterPtr::level)
584 .def("lift", &JvpInterpreterPtr::lift)
585 .def("prevFwdGradMode", &JvpInterpreterPtr::prevFwdGradMode);
586 py::class_<VmapInterpreterPtr>(m, "CVmapInterpreterPtr")
587 .def(py::init<const Interpreter*>())
588 .def("key", &VmapInterpreterPtr::key)
589 .def("level", &VmapInterpreterPtr::level)
590 .def("batchSize", &VmapInterpreterPtr::batchSize)
591 .def("randomness", &VmapInterpreterPtr::randomness);
592 py::class_<FunctionalizeInterpreterPtr>(m, "CFunctionalizeInterpreterPtr")
593 .def(py::init<const Interpreter*>())
594 .def("key", &FunctionalizeInterpreterPtr::key)
595 .def("level", &FunctionalizeInterpreterPtr::level)
596 .def(
597 "functionalizeAddBackViews",
598 &FunctionalizeInterpreterPtr::functionalizeAddBackViews);
599 }
600
601 } // namespace torch::functorch::impl
602