1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 #include <ATen/functorch/BatchedTensorImpl.h>
7
8 #include <ATen/WrapDimUtils.h>
9 #include <c10/util/Exception.h>
10
11 #include <c10/util/irange.h>
12
13 namespace at::functorch {
14
BatchedTensorImpl(DispatchKeySet key_set,Tensor value,int64_t bdim,int64_t level)15 BatchedTensorImpl::BatchedTensorImpl(DispatchKeySet key_set, Tensor value, int64_t bdim, int64_t level)
16 : TensorImpl(
17 key_set.add(
18 value.is_nested() ? DispatchKey::BatchedNestedTensor : DispatchKey::FuncTorchBatched),
19 value.dtype(),
20 value.device()
21 )
22 , value_(std::move(value))
23 , level_(level)
24 , bdim_(bdim)
25 {
26 TORCH_INTERNAL_ASSERT(value_.defined());
27 if (value_.is_nested() || value_.key_set().has(DispatchKey::BatchedNestedTensor)) {
28 TORCH_CHECK(bdim_ == 0,
29 "Nested tensors can only be vmapped over dim=0, but got dim=", bdim_);
30 TORCH_CHECK(level_ == 1,
31 "Only one level of vmap is supported when vmapping over nested tensors");
32 }
33 set_storage_access_should_throw();
34 set_custom_sizes_strides(
35 value_.is_nested() ? SizesStridesPolicy::CustomSizes : SizesStridesPolicy::CustomStrides);
36 checkInvariants();
37 refreshTensorMetadata();
38 }
39
refreshTensorMetadata()40 void BatchedTensorImpl::refreshTensorMetadata() {
41 const auto public_dims = value_.dim() - 1;
42 if (value_.is_nested()) {
43 sizes_and_strides_.resize(public_dims);
44 storage_offset_= value_.storage_offset();
45 refresh_numel();
46 refresh_contiguous();
47 } else {
48 c10::SymDimVector new_sizes;
49 c10::SymDimVector new_strides;
50 new_sizes.reserve(public_dims);
51 new_strides.reserve(public_dims);
52
53 // update size, strides and storage_offset
54 // for tensor with symbolic size and strides
55 const auto value_sizes = value_.sym_sizes();
56 const auto value_strides = value_.sym_strides();
57
58 for (const auto dim : c10::irange(0, public_dims)) {
59 auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
60 new_sizes.push_back(value_sizes.at(actual_dim));
61 new_strides.push_back(value_strides.at(actual_dim));
62 }
63
64 // `set_sizes_and_strides` takes care of calling `refresh_numel` and
65 // `refresh_contiguous`
66 set_sizes_and_strides(new_sizes, new_strides, value_.sym_storage_offset());
67 }
68 }
69
actualDim(int64_t dim,bool wrap_dim) const70 int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
71 if (wrap_dim) {
72 const auto ndim = sizes_and_strides_.size();
73 dim = maybe_wrap_dim(dim, static_cast<int64_t>(ndim));
74 }
75 if (bdim_ <= dim) {
76 return dim + 1;
77 } else {
78 return dim;
79 }
80 }
81
checkInvariants() const82 void BatchedTensorImpl::checkInvariants() const {
83 TORCH_INTERNAL_ASSERT(level_ > -1);
84 }
85
size_custom(int64_t d) const86 int64_t BatchedTensorImpl::size_custom(int64_t d) const {
87 if (!value_.is_nested()) {
88 d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
89 return sizes_default()[d];
90 }
91 // TODO: Error messages will mention the actualDim, which could be confusing; fix this
92 auto actual_dim = actualDim(d, /*wrap_dim=*/ true);
93 return value_.size(actual_dim);
94 }
95
sym_size_custom(int64_t d) const96 c10::SymInt BatchedTensorImpl::sym_size_custom(int64_t d) const {
97 if (!value_.is_nested()) {
98 d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
99 return sym_sizes_default()[d];
100 }
101 // TODO: Error messages will mention the actualDim, which could be confusing; fix this
102 auto actual_dim = actualDim(d, /*wrap_dim=*/ true);
103 return value_.sym_size(actual_dim);
104 }
105
sizes_custom() const106 IntArrayRef BatchedTensorImpl::sizes_custom() const {
107 TORCH_CHECK(!value_.is_nested(), "sizes() is not supported for batched nested tensors");
108 return sizes_default();
109 }
110
sym_sizes_custom() const111 SymIntArrayRef BatchedTensorImpl::sym_sizes_custom() const {
112 TORCH_CHECK(!value_.is_nested(), "sizes() is not supported for batched nested tensors");
113 return sym_sizes_default();
114 }
115
116 // The following are publically exposed as methods of Tensor
117
strides_custom() const118 IntArrayRef BatchedTensorImpl::strides_custom() const {
119 return strides_default();
120 }
121
sym_strides_custom() const122 SymIntArrayRef BatchedTensorImpl::sym_strides_custom() const {
123 return sym_strides_default();
124 }
125
126
127 // TODO: implement proper contiguity on batched tensor, then put
128 // sizes_strides_policy back to Default
is_contiguous_custom(at::MemoryFormat memory_format) const129 bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
130 TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
131 "NYI: querying is_contiguous inside of vmap for memory_format ",
132 "other than torch.contiguous_format");
133 return is_contiguous_default(memory_format);
134 }
135
136 // The following are some internal inherited methods that we do not support.
137 // They should never get called.
set_size(int64_t dim,int64_t new_size)138 void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) {
139 TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl");
140 }
set_stride(int64_t dim,int64_t new_stride)141 void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
142 TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl");
143 }
144 #ifdef DEBUG
has_storage() const145 bool BatchedTensorImpl::has_storage() const {
146 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set");
147 return false;
148 }
149 #endif
150
tensorimpl_type_name() const151 const char* BatchedTensorImpl::tensorimpl_type_name() const {
152 return "BatchedTensorImpl";
153 }
154
shallow_copy_and_detach(const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const155 c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach(
156 const c10::VariableVersion& version_counter,
157 bool allow_tensor_metadata_change) const {
158 TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed");
159 return nullptr;
160 }
161
shallow_copy_and_detach(c10::VariableVersion && version_counter,bool allow_tensor_metadata_change) const162 c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach(
163 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
164 c10::VariableVersion&& version_counter,
165 bool allow_tensor_metadata_change) const {
166 TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed");
167 return nullptr;
168 }
169
shallow_copy_from(const c10::intrusive_ptr<TensorImpl> & impl)170 void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
171 TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
172 }
173
makeBatched(const Tensor & tensor,int64_t bdim,int64_t level)174 Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) {
175 DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
176 auto* batched = maybeGetBatchedImpl(tensor);
177 if (batched) {
178 auto batched_level = batched->level();
179 TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level);
180 }
181 return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, bdim, level);
182 }
183
addBatchDim(const Tensor & tensor,int64_t dim,int64_t level)184 Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level) {
185 return makeBatched(tensor, dim, level);
186 }
187
188 } // namespace at::functorch
189