1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/WrapDimUtils.h>
9
10 namespace at::functorch {
11
moveBatchDimToFront(const Tensor & tensor,std::optional<int64_t> maybe_batch_dim)12 Tensor moveBatchDimToFront(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim) {
13 if (!maybe_batch_dim.has_value()) {
14 return tensor;
15 }
16 if (maybe_batch_dim.value() == 0) {
17 return tensor;
18 }
19 return tensor.movedim(maybe_batch_dim.value(), 0);
20 }
21
rankWithoutBatchDim(const Tensor & tensor,std::optional<int64_t> maybe_batch_dim)22 int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim) {
23 int64_t result = tensor.dim();
24 if (maybe_batch_dim.has_value()) {
25 result -= 1;
26 }
27 return result;
28 }
29
numelWithoutBatchDim(const Tensor & tensor,std::optional<int64_t> maybe_batch_dim)30 int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim) {
31 if (!maybe_batch_dim) {
32 return tensor.numel();
33 }
34 return tensor.numel() / tensor.size(*maybe_batch_dim);
35 }
36
valIfNonempty(std::optional<int64_t> maybe_empty,int64_t new_val)37 std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val) {
38 if (maybe_empty.has_value()) {
39 return new_val;
40 }
41 return std::nullopt;
42 }
43
getPhysicalDim(const Tensor & tensor,bool has_batch_dim,int64_t logical_dim)44 int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
45 // NB: assumes the batch dim is at the front of the tensor
46 std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : std::nullopt;
47 auto rank = rankWithoutBatchDim(tensor, bdim);
48 auto wrapped_dim = maybe_wrap_dim(logical_dim, rank);
49 if (has_batch_dim) {
50 return wrapped_dim + 1;
51 }
52 return wrapped_dim;
53 }
54
getPhysicalDims(const Tensor & tensor,bool has_batch_dim,IntArrayRef logical_dims)55 VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) {
56 // NB: assumes the batch dim is at the front of the tensor
57 std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : std::nullopt;
58 auto rank = rankWithoutBatchDim(tensor, bdim);
59 VmapDimVector result;
60 result.reserve(logical_dims.size());
61 for (auto d : logical_dims){
62 if (has_batch_dim) {
63 result.push_back(maybe_wrap_dim(d, rank)+1);
64 } else {
65 result.push_back(maybe_wrap_dim(d, rank));
66 }
67 }
68 return result;
69 }
70
maybePadToLogicalRank(const Tensor & tensor,std::optional<int64_t> has_bdim,int64_t logical_rank)71 Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank) {
72 if (!has_bdim) {
73 return tensor;
74 }
75 auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
76 if (tensor_logical_rank >= logical_rank) {
77 return tensor;
78 }
79 VmapSymDimVector new_sizes(tensor.sym_sizes().begin(), tensor.sym_sizes().end());
80 for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
81 new_sizes.insert(new_sizes.begin() + 1, 1);
82 }
83 return tensor.view_symint(SymIntArrayRef{new_sizes.begin(), new_sizes.end()});
84 }
85
check_randomness(RandomnessType randomness,bool any_tensor_batched)86 void check_randomness(RandomnessType randomness, bool any_tensor_batched) {
87 TORCH_CHECK(
88 randomness != RandomnessType::Error,
89 "vmap: called random operation while in randomness error mode. Please either use the "
90 "'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap"
91 );
92
93 TORCH_CHECK(
94 !(randomness == RandomnessType::Same && any_tensor_batched),
95 "Vmap does not currently support same randomness with a batched tensor input. ",
96 "Please file an issue with functorch"
97 )
98 }
99
check_randomness(RandomnessType randomness)100 void check_randomness(RandomnessType randomness) {
101 check_randomness(randomness, false); // for ops that don't take in any tensors, don't hit same error
102 }
103
reshape_dim_into(int64_t src,int64_t dst,const Tensor & x)104 Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x) {
105 auto x_dim = x.dim();
106 src = maybe_wrap_dim(src, x_dim);
107 dst = maybe_wrap_dim(dst, x_dim - 1); // Returned Tensor has one fewer dim
108 VmapDimVector new_shape(x.sizes().begin(), x.sizes().end());
109 new_shape.erase(new_shape.begin() + src);
110 new_shape[dst] *= x.sizes()[src];
111 return at::reshape(x.movedim(src, dst), new_shape);
112 }
113
reshape_dim_outof(int64_t src,int64_t size1,const Tensor & x)114 Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x) {
115 src = maybe_wrap_dim(src, x.dim());
116 VmapDimVector shape(x.sizes().begin(), x.sizes().end());
117 if (shape[src] != 0) {
118 // NOTE: 0 % 0 leads to FPE
119 TORCH_INTERNAL_ASSERT(shape[src] % size1 == 0);
120 }
121 // split any size out of `0`-sized dim
122 int64_t size2 = 0;
123 if (shape[src] != 0) {
124 size2 = shape[src] / size1;
125 }
126 shape[src] = size1;
127 shape.insert(shape.begin() + src + 1, size2);
128 return at::reshape(x, shape);
129 }
130
reshape_dim_outof_symint(int64_t src,const c10::SymInt & size1,const Tensor & x)131 Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x) {
132 src = maybe_wrap_dim(src, x.dim());
133 c10::SymDimVector shape(x.sym_sizes().begin(), x.sym_sizes().end());
134 if (shape[src] != 0) {
135 // NOTE: 0 % 0 leads to FPE
136 TORCH_INTERNAL_ASSERT(shape[src] % size1 == 0);
137 }
138 c10::SymInt size2;
139 // split any size out of `0`-sized dim
140 if (shape[src] == 0) {
141 size2 = 0;
142 } else {
143 size2 = shape[src] / size1;
144 }
145 shape[src] = size1;
146 shape.insert(shape.begin() + src + 1, size2);
147 return at::reshape_symint(x, shape);
148 }
149
vmapIncompatibleInplaceError(const char * schema_name)150 void vmapIncompatibleInplaceError(const char* schema_name) {
151 TORCH_CHECK(false,
152 "vmap: ", schema_name, "(self, *extra_args) is not possible because ",
153 "there exists a Tensor `other` in extra_args that has more elements ",
154 "than `self`. This happened due to `other` being vmapped over but ",
155 "`self` not being vmapped over in a vmap. ",
156 "Please try to use out-of-place operators instead of ", schema_name, ". ",
157 "If said operator is being called inside the PyTorch framework, ",
158 "please file a bug report instead.");
159 }
160
handleScalarTypePromotion(Tensor & logical_scalar_tensor,Tensor & second)161 static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
162 auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
163 if (logical_scalar_tensor.scalar_type() != result_type) {
164 logical_scalar_tensor = logical_scalar_tensor.to(result_type);
165 }
166 if (second.scalar_type() != result_type) {
167 second = second.to(result_type);
168 }
169 }
170
_binary_pointwise_helper(const Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim,bool do_type_promotion)171 std::tuple<Tensor, Tensor> _binary_pointwise_helper(
172 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
173 const Tensor& other, std::optional<int64_t> other_batch_dim,
174 bool do_type_promotion) {
175 // compute max logical rank
176 auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
177 auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
178 auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
179
180 auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
181 auto other_ = moveBatchDimToFront(other, other_batch_dim);
182
183 // In the (0D, ND) case, type promotion semantics are different :/
184 if (do_type_promotion) {
185 auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
186 auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
187 if (tensor_is_logical_scalar && !other_is_logical_scalar) {
188 handleScalarTypePromotion(tensor_, other_);
189 }
190 if (other_is_logical_scalar && !tensor_is_logical_scalar) {
191 handleScalarTypePromotion(other_, tensor_);
192 }
193 }
194
195 // If the dimensions aren't aligned, we need to line them up.
196 // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
197 // Note that only tensors that have a batch dim need to be modified.
198 // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
199 tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
200 other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
201
202 return std::make_tuple(tensor_, other_);
203 }
204
205 } // namespace at::functorch
206