1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/ATen.h>
8 #include <ATen/functorch/BatchRulesHelper.h>
9 #include <ATen/functorch/BatchedFallback.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 #include <c10/util/Metaprogramming.h>
12
13 // This file contains batching rules for operations that return Tensors of
14 // dynamic shape. We generally don't support those with vmap so we raise
15 // errors for them.
16
17
18 namespace at::functorch {
19
20 namespace {
unsupportedDynamicOp(const c10::OperatorHandle & op,torch::jit::Stack * stack)21 void unsupportedDynamicOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
22 TORCH_CHECK(false, "vmap: We do not support batching operators that can output dynamic shape. ",
23 "Attempted to vmap over ", op.schema().operator_name(), ". ",
24 "Please voice your support in https://github.com/pytorch/functorch/issues/256");
25 }
26 #define UNSUPPORTED_DYNAMIC(op) \
27 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedDynamicOp>());
28
29 // NB: item and is_nonzero can decompose to this...
unsupportedLocalScalarDense(const c10::OperatorHandle & op,torch::jit::Stack * stack)30 void unsupportedLocalScalarDense(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
31 TORCH_CHECK(false,
32 "vmap: It looks like you're either (1) calling .item() on a Tensor or ",
33 "(2) attempting to use a Tensor in some data-dependent control flow or ",
34 "(3) encountering this error in PyTorch internals. ",
35 "For (1): we don't support vmap over calling .item() on a Tensor, please try to ",
36 "rewrite what you're doing with other operations. ",
37 "For (2): If you're doing some ",
38 "control flow instead, we don't support that yet, please shout over at ",
39 "https://github.com/pytorch/functorch/issues/257 . ",
40 "For (3): please file an issue.");
41 }
42
unsupportedItem(const c10::OperatorHandle & op,torch::jit::Stack * stack)43 void unsupportedItem(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
44 TORCH_CHECK(false,
45 "vmap: It looks like you're calling .item() on a Tensor. ",
46 "We don't support vmap over calling .item() on a Tensor, please try to ",
47 "rewrite what you're doing with other operations. If error is occurring ",
48 "somewhere inside PyTorch internals, please file a bug report.");
49 }
50
unsupportedIsNonzero(const c10::OperatorHandle & op,torch::jit::Stack * stack)51 void unsupportedIsNonzero(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
52 TORCH_CHECK(false,
53 "vmap: It looks like you're attempting to use a Tensor in some ",
54 "data-dependent control flow. ",
55 "We don't support that yet, please shout over at ",
56 "https://github.com/pytorch/functorch/issues/257 .");
57 }
58
unsupportedAllclose(const c10::OperatorHandle & op,torch::jit::Stack * stack)59 void unsupportedAllclose(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
60 TORCH_CHECK(false,
61 "vmap over torch.allclose isn't supported yet. Please voice your ",
62 "support over at github.com/pytorch/functorch/issues/275");
63 }
64 }
65
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)66 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
67 UNSUPPORTED_DYNAMIC(nonzero);
68 UNSUPPORTED_DYNAMIC(where);
69 UNSUPPORTED_DYNAMIC(unique_dim);
70 UNSUPPORTED_DYNAMIC(unique_consecutive);
71 UNSUPPORTED_DYNAMIC(unique_dim_consecutive);
72 UNSUPPORTED_DYNAMIC(_unique2);
73 m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&unsupportedLocalScalarDense>());
74 m.impl("item", torch::CppFunction::makeFromBoxedFunction<&unsupportedItem>());
75 m.impl("is_nonzero", torch::CppFunction::makeFromBoxedFunction<&unsupportedIsNonzero>());
76 m.impl("allclose", torch::CppFunction::makeFromBoxedFunction<&unsupportedAllclose>());
77 }
78
79 } // namespace at::functorch
80