xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyBatchedTensorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/LegacyBatchedTensorImpl.h>
2 
3 #include <ATen/WrapDimUtils.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 
7 namespace at {
8 
BatchedTensorImpl(Tensor value,BatchDims bdims)9 BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
10   : TensorImpl(
11       c10::DispatchKeySet(DispatchKey::Batched),
12       value.dtype(),
13       value.device()
14     )
15   , value_(std::move(value))
16   , bdims_(std::move(bdims))
17 {
18   TORCH_INTERNAL_ASSERT(value_.defined());
19   set_storage_access_should_throw();
20   set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
21   checkInvariants();
22 
23   const auto public_dims = value_.dim() - bdims_.size();
24   const auto value_sizes = value_.sizes();
25   const auto value_strides = value_.strides();
26   sizes_and_strides_.resize(public_dims);
27   for (const auto dim : c10::irange(public_dims)) {
28     auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
29     sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
30     sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
31   }
32   storage_offset_ = value_.storage_offset();
33   refresh_numel();
34   refresh_contiguous();
35 }
36 
actualDim(int64_t dim,bool wrap_dim) const37 int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
38   if (wrap_dim) {
39     const auto ndim = sizes_and_strides_.size();
40     dim = maybe_wrap_dim(dim, ndim);
41   }
42   auto is_bdim = createBatchDimBitset(bdims_);
43 
44   // Example: assume dim = 3, and is_bdim = 10010011000...
45   // The 1's are batch dims and 0's are normal dims of the underlying value_ Tensor.
46   // actualDim gives us the index of `dim` in the `value_` Tensor, which is equivalent
47   // to asking "where does the 3rd (0-indexed) zero occur in the bitset?".
48   // The answer to that is index 5.
49   //
50   // TODO(rzou): the PDEP instruction does exactly this
51   // (https://stackoverflow.com/questions/7669057/find-nth-set-bit-in-an-int)
52   // but it might require newer (>= ~2015) CPUs. We should clean this up
53   // if/when we have dropped support for older CPUs.
54   int64_t non_bdim_count = 0;
55   for (const auto actual_dim : c10::irange(kVmapMaxTensorDims)) {
56     if (is_bdim[actual_dim]) {
57       continue;
58     }
59     if (non_bdim_count == dim) {
60       return actual_dim;
61     }
62     non_bdim_count++;
63   }
64   // If we hit this assert, then that means
65   // `non_bdim_count` + #num_bdims > kVmapMaxTensorDims. We restrict the number
66   // of dims a BatchedTensorImpl can have to kVmapMaxTensorDims so this should
67   // never be hit.
68   TORCH_INTERNAL_ASSERT(false);
69 }
70 
checkInvariants() const71 void BatchedTensorImpl::checkInvariants() const {
72   int64_t prev_level = -1;
73   for (const auto& bdim : bdims_) {
74     TORCH_INTERNAL_ASSERT(bdim.level() > prev_level);
75     prev_level = bdim.level();
76   }
77 }
78 
79 // The following are publically exposed as methods of Tensor
80 
strides_custom() const81 IntArrayRef BatchedTensorImpl::strides_custom() const {
82   return strides_default();
83 }
84 
85 // TODO: implement proper contiguity on batched tensor, then put
86 // sizes_strides_policy back to Default
is_contiguous_custom(at::MemoryFormat memory_format) const87 bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
88   TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
89       "NYI: querying is_contiguous inside of vmap for memory_format ",
90       "other than torch.contiguous_format");
91   return is_contiguous_;
92 }
93 
94 // The following are some internal inherited methods that we do not support.
95 // They should never get called.
set_size(int64_t dim,int64_t new_size)96 void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) {
97   TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl");
98 }
set_stride(int64_t dim,int64_t new_stride)99 void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
100   TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl");
101 }
set_storage_offset(int64_t storage_offset)102 void BatchedTensorImpl::set_storage_offset(int64_t storage_offset) {
103   TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for BatchedTensorImpl");
104 }
105 #ifdef DEBUG
has_storage() const106 bool BatchedTensorImpl::has_storage() const {
107   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set");
108   return false;
109 }
110 #endif
111 
tensorimpl_type_name() const112 const char* BatchedTensorImpl::tensorimpl_type_name() const {
113   return "BatchedTensorImpl";
114 }
115 
makeBatched(const Tensor & tensor,BatchDims bdims)116 Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
117   TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
118   auto tensor_dim = tensor.dim();
119   TORCH_CHECK(
120       tensor_dim <= kVmapMaxTensorDims,
121       "vmap only supports tensors of dimensionality up to ", kVmapMaxTensorDims,
122       "; got a tensor with dim ", tensor_dim);
123   TORCH_INTERNAL_ASSERT(
124       std::all_of(bdims.begin(), bdims.end(),
125           [](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }),
126       "We only support up to ", kVmapNumLevels, " nested vmaps");
127   return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
128 }
129 
addBatchDim(const Tensor & tensor,int64_t level,int64_t dim)130 Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
131   const auto* batched = maybeGetBatchedImpl(tensor);
132   if (!batched) {
133     BatchDims bdims;
134     bdims.emplace_back(level, dim);
135     return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
136   }
137   BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end());
138   auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true);
139   new_bdims.emplace_back(level, actual_bdim);
140   return makeBatched(batched->value(), std::move(new_bdims));
141 }
142 
inplaceIsVmapCompatible(const Tensor & self,const Tensor & other)143 bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) {
144   const auto* other_batched = maybeGetBatchedImpl(other);
145   if (!other_batched) {
146     return true;
147   }
148   const auto* self_batched = maybeGetBatchedImpl(self);
149   if (!self_batched) {
150     // self is not batched but other is batched
151     return false;
152   }
153   auto self_levels = createVmapLevelsBitset(self_batched->bdims());
154   auto other_levels = createVmapLevelsBitset(other_batched->bdims());
155   return self_levels == (self_levels | other_levels);
156 }
157 
158 } // namespace at
159