1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #pragma once
8
9 #include <bitset>
10
11 #include <ATen/ArrayRef.h>
12 #include <ATen/SmallVector.h>
13 #include <ATen/Tensor.h>
14
15 namespace at::functorch {
16
17 using Tensor = at::Tensor;
18
19 // We assume this in a few other places in the codebase,
20 // but there isn't a centralized definition.
21 constexpr int64_t kVmapMaxTensorDims = 64;
22
23 // The valid vmap levels range from [0, 64). This effectively means that we
24 // support a maximum of 64 nested vmaps.
25 constexpr int64_t kVmapNumLevels = 64;
26
27 // Store this number of elements of BatchDims on the stack. Most people will
28 // probably use <= 5 nested vmaps, but adjust this number as necessary.
29 constexpr int64_t kBatchDimsStackSize = 5;
30
31 // A BatchedTensorImpl holds an underlying Tensor and a single batch dim
32 // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
33 // BatchedTensorImpl.
34 //
35 // The batch dimensions are treated as being "private"; they are not user-visible.
36 // For example, in the following Tensor,
37 // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
38 // dimension 0 is batch dimension.
39 //
40 // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
41 // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
42 struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
43 explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
44
45 // Returns batch dimension of this tensor
bdimBatchedTensorImpl46 int64_t bdim() const { return bdim_; }
47
48 // Returns batch dimension of this tensor
levelBatchedTensorImpl49 int64_t level() const { return level_; }
50
51 // BatchedTensorImpl wraps a Tensor
valueBatchedTensorImpl52 const Tensor& value() const { return value_; }
53
54 // Given a public dimension index, return the dimension index in the underlying
55 // value() tensor.
56 // For example, if we have
57 // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
58 // bt.actualDim(0) -> 1
59 // bt.actualDim(1) -> 2
60 // bt.actualDim(2) -> 3
61 // bt.actualDim(3) -> Error
62 int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
63
64 IntArrayRef sizes_custom() const override;
65 SymIntArrayRef sym_sizes_custom() const override;
66 int64_t size_custom(int64_t d) const override;
67 c10::SymInt sym_size_custom(int64_t d) const override;
68 // We have to override this because we opted into CustomStrides
69 IntArrayRef strides_custom() const override;
70 SymIntArrayRef sym_strides_custom() const override;
71 // Override a bunch of methods inherited from TensorImpl to return error messages.
72 bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
73 void set_size(int64_t dim, int64_t new_size) override;
74 void set_stride(int64_t dim, int64_t new_stride) override;
75 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
76 const c10::VariableVersion& version_counter,
77 bool allow_tensor_metadata_change) const override;
78 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
79 c10::VariableVersion&& version_counter,
80 bool allow_tensor_metadata_change) const override;
81 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
82 #ifdef DEBUG
83 bool has_storage() const override;
84 #endif
85
86 void refreshTensorMetadata();
87
88 // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
89 // accomplishes this is a hack where it is able to modify the levels of
90 // BatchedTensor to match the level of the current vmap transform.
_unsafe_set_levelBatchedTensorImpl91 void _unsafe_set_level(int64_t level) {
92 level_ = level;
93 }
94
95 // Used in batching rule for in-place view operations that can change
96 // the index of the bdim (think squeeze_, unsqueeze_)
unsafe_set_bdimBatchedTensorImpl97 void unsafe_set_bdim(int64_t bdim) {
98 // NB: you MUST call refreshTensorMetadata after doing this.
99 bdim_ = bdim;
100 }
101 private:
102 // see NOTE: [BatchedTensorImpl levels invariant]
103 void checkInvariants() const;
104 const char* tensorimpl_type_name() const override;
105
106 Tensor value_;
107
108 int64_t level_;
109 int64_t bdim_;
110 };
111
112 // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
113 // BatchedTensorImpl.
isBatchedTensor(const Tensor & tensor)114 inline bool isBatchedTensor(const Tensor& tensor) {
115 return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
116 tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
117 }
118
119 // It is unsafe to call this on a Tensor that is not backed by a
120 // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
unsafeGetBatchedImpl(const Tensor & tensor)121 inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
122 return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
123 }
124
maybeGetBatchedImpl(const Tensor & tensor)125 inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
126 if (!isBatchedTensor(tensor)) {
127 return nullptr;
128 }
129 return unsafeGetBatchedImpl(tensor);
130 }
131
132 // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
createBatchDimBitset(int64_t dim)133 inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
134 std::bitset<kVmapMaxTensorDims> is_bdim;
135 is_bdim.set(dim);
136 return is_bdim;
137 }
138
139 // Creates a bitset for the given level
createVmapLevelsBitset(int64_t level)140 inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
141 std::bitset<kVmapNumLevels> result;
142 result.set(level);
143 return result;
144 }
145
146 // Use this to construct a BatchedTensor from a regular Tensor
147 TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
148
149 // Adds a batch dim to `tensor`, returning a BatchedTensor
150 TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
151
152 // Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
153 // any wrapper Tensor subclasses). This is because there are methods on Tensor
154 // that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
155 // TODO: should probably contain more (or all?) backend keys
156 constexpr DispatchKeySet kKeysToPropagateToWrapper({
157 DispatchKey::Negative,
158 DispatchKey::Conjugate,
159 DispatchKey::XLA,
160 DispatchKey::CUDA,
161 DispatchKey::CPU,
162 });
163
164 inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
165 auto key_set = tensor.unsafeGetTensorImpl()->key_set();
166 return key_set & kKeysToPropagateToWrapper;
167 }
168
169 } // namespace at::functorch
170