1 #pragma once 2 3 #include <ATen/LegacyBatchedTensorImpl.h> 4 #include <ATen/core/IListRef.h> 5 6 namespace at { 7 8 // This file contains abstractions used for transforming *logical* vmap 9 // arguments into *physical* arguments. (Keep reading for definitions of these 10 // terms). 11 12 // NOTE: [Logical vs physical args] 13 // Consider the following vmap. 14 // vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) 15 // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], 16 // with batch dims 0 and 2: 17 // BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) 18 // 19 // We say the *logical* view of the tensor has size [3] -- tensors inside 20 // `func` appear to have size [3]. 21 // However, the *physical* underlying tensor (the one passed to vmap) has size 22 // [2, 3, 4]. 23 // 24 // This notion of logical vs physical also extends to non-tensor arguments. 25 // Consider the previous tensor; let's assume the user called 26 // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical 27 // dimension they are reducing over is dim 0 but the physical dim is dim 1 28 // (the first non-batch dimension) 29 30 // Forward declared; see NOTE: [What is a VmapPhysicalView?] 31 struct VmapPhysicalView; 32 33 // Most PyTorch operators take 4 or fewer inputs. 34 constexpr int64_t kVmapTransformStaticInputSize = 4; 35 using VmapPhysicalViewVec = 36 SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>; 37 38 // Pytorch generally advertises good performance for <= 5 dims. 39 // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap 40 // dimensions to get 8. Adjust this number as necessary 41 constexpr int64_t kVmapStaticDimVecSize = 8; 42 using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>; 43 using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>; 44 45 // NOTE: [What is an VmapTransform?] 46 // An *VmapTransform* converts logical views of tensors to physical views. 47 // 48 // Batching rules use VmapTransforms to convert logical arguments to 49 // physical arguments, then call one or more at:: operator that handles the 50 // physical arguments, and then converts the physical result back to a logical 51 // argument. 52 53 // VmapTransform for operators that take tensors with multiple batch dims. 54 // Given one or more logical views on Tensors, `logicalToPhysical` 55 // permutes all of the batch dims to the front of the tensor, aligns 56 // and expands the batch dims to match each other (according to their `level`), 57 // and returns a VmapPhysicalView on the tensor(s). 58 struct TORCH_API MultiBatchVmapTransform { 59 static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); 60 static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); 61 }; 62 63 // VmapTransform for operators that broadcast all inputs. 64 // Given some logical views on Tensors, `logicalToPhysical`: 65 // - permutes all of the batch dims to the front of the tensors 66 // - aligns all the batch dims to the collective levels of all of the tensors. 67 // If a tensor does not have a batch dim for a vmap level, then it receives 68 // a size-one dimension for said level. 69 // - aligns the non-batch dims to have the same dimensionality, adding extra 70 // size-1 dimensions in between the batch dimensions and the non-batch 71 // dimensions so that the batch dimensions are lined up from the right. 72 // 73 // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch 74 // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap 75 // tensors of size (B, 1, 2) and (B, 3, 2). 76 // 77 // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns 78 // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't 79 // actually *need* to return a tensor of size (1, 2) for the second tensor 80 // because the broadcasting operation takes care of that for us, but we do 81 // it anyways to keep things simple. 82 struct TORCH_API BroadcastingVmapTransform { 83 static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); 84 }; 85 86 // Forward declared, if you're reading this file head to toe, don't worry about 87 // it yet. 88 struct VmapPhysicalToLogicalMap; 89 90 // NOTE: [What is a VmapPhysicalView?] 91 // VmapPhysicalView represents a physical view on a Tensor. 92 // 93 // One can use it to further convert logical dimension indices, logical shapes, 94 // and more to their physical variants, or convert a new (physical) tensor into 95 // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). 96 // 97 // VmapPhysicalView stores a physical tensor with all of its batch dimensions at 98 // the front and some levels that correspond to said batch dimensions. 99 // 100 // The levels bitset specifies which vmap levels correspond to the batch 101 // dimensions at the front of the tensor. In particular, the number of set bits 102 // corresponds to the number of batch dimensions on `tensor` and the rightmost 103 // bit of `levels` specifies the maximum number of nested vmaps we are in at 104 // this point in time. 105 // For example, given: 106 // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) 107 // 108 // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less 109 // than or equal to 3. 110 // bitset: 010100 111 // ^ 112 // | 113 // levels: 012345 114 struct TORCH_API VmapPhysicalView { VmapPhysicalViewVmapPhysicalView115 VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels) 116 : levels_(levels), tensor_(std::move(tensor)) { 117 TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_)); 118 } 119 tensorVmapPhysicalView120 Tensor& tensor() { 121 return tensor_; 122 } tensorVmapPhysicalView123 const Tensor& tensor() const { 124 return tensor_; 125 } 126 127 // Maps logical dim indices to physical dim indices. Also does dim wrapping. 128 // 129 // For example, given: 130 // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) 131 // 132 // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. 133 // This is because the size of levels tell us that the first two dimensions 134 // of `tensor_` are batch dimensions, so a logical dim of `n` is actually 135 // a physical dim of `n + 2`. 136 VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const; 137 int64_t getPhysicalDim(int64_t logical_dim) const; 138 139 // Returns a VmapPhysicalToLogicalMap object. This can be used for 140 // mapping a physical tensor to a new logical tensor (BatchedTensor) 141 VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; 142 143 // Maps a logical shape to a physical shape by pre-pending the batch 144 // sizes to the logical shape. 145 VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; 146 147 int64_t numBatchDims() const; 148 149 private: 150 int64_t numLogicalDims() const; 151 152 std::bitset<kVmapNumLevels> levels_; 153 Tensor tensor_; 154 }; 155 156 // Convenience struct used for mapping a physical tensor (a non-BatchedTensor) 157 // to a logical one (BatchedTensor). It holds some levels that are used to do 158 // the mapping and assumes that the batch dimensions in the physical tensor all 159 // occur at the front of the tensor. 160 struct TORCH_API VmapPhysicalToLogicalMap { VmapPhysicalToLogicalMapVmapPhysicalToLogicalMap161 VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels) 162 : levels_(levels) {} 163 164 // Maps a physical tensor to a new logical tensor (BatchedTensor). 165 // Assumes that all of the "batch dimensions" are at the front 166 // of the physical tensor. For example, given: 167 // - x = rank-4 Tensor with size 2, 3, 5, 7 168 // - levels = (2, 4) 169 // Returns: 170 // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) 171 Tensor apply(const Tensor& physical_tensor) const; 172 173 // Given a vector of physical tensors, 174 // 1. maps each tensor to a new logical tensor. Assumes that all of the 175 // "batch dimensions" are at the front of the physical tensors. 176 // 2. stores the new logical tensors back into the passed-in vector. This is 177 // to avoid additional dynamic allocations. 178 void applyInplace(std::vector<Tensor>& physical_tensors) const; 179 180 std::bitset<kVmapNumLevels> levels_; 181 }; 182 183 } // namespace at 184