xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorIterator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #define TORCH_ASSERT_NO_OPERATORS
3 #include <ATen/TensorIterator.h>
4 #undef TORCH_ASSERT_NO_OPERATORS
5 
6 #include <ATen/core/Tensor.h>
7 
8 #include <ATen/ExpandUtils.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/native/TypeProperties.h>
11 #include <ATen/MemoryOverlap.h>
12 #include <ATen/native/Resize.h>
13 #include <ATen/NamedTensorUtils.h>
14 #include <ATen/TensorOperators.h>
15 #include <ATen/TensorIteratorInternal.h>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/empty_strided.h>
22 #endif
23 
24 #include <c10/util/irange.h>
25 #include <c10/util/SmallBuffer.h>
26 
27 #include <array>
28 #include <algorithm>
29 #include <cmath>
30 
31 namespace at {
32 
33 using DimMask = TensorIteratorBase::DimMask;
34 using PtrVector = TensorIteratorBase::PtrVector;
35 using loop2d_t = TensorIteratorBase::loop2d_t;
36 using StrideVector = TensorIteratorBase::StrideVector;
37 
38 namespace {
39 
get_base_ptrs(char ** ptrs,ArrayRef<OperandInfo> operands)40 inline void get_base_ptrs(char** ptrs, ArrayRef<OperandInfo> operands) {
41   std::transform(operands.begin(), operands.end(), ptrs, [](const OperandInfo& op) {
42     return static_cast<char*>(op.data);
43   });
44 }
45 
get_strides(int64_t * strides,ArrayRef<OperandInfo> operands,int64_t ndim)46 inline void get_strides(int64_t* strides, ArrayRef<OperandInfo> operands, int64_t ndim) {
47   for (const auto dim : c10::irange(ndim)) {
48     for (const auto arg : c10::irange(operands.size())) {
49       *strides++ = operands[arg].stride_bytes[dim];
50     }
51   }
52   // Always at least 2d strides to support 2d for_each loops
53   if (ndim < 2) {
54     auto ntensors = operands.size();
55     std::fill_n(strides, (2 - ndim) * ntensors, 0);
56   }
57 }
58 
make_otr(const TensorBase & tensor)59 static OptionalTensorRef make_otr(const TensorBase &tensor) {
60   if (tensor.defined()) {
61     return OptionalTensorRef(tensor);
62   } else {
63     return OptionalTensorRef();
64   }
65 }
66 
67 }
68 
69 namespace internal {
70 
OpaqueOptionalTensorRef()71 OpaqueOptionalTensorRef::OpaqueOptionalTensorRef() {
72   static_assert(alignof(OptionalTensorRef) == alignof(TensorBase));
73   static_assert(sizeof(OptionalTensorRef) == sizeof(TensorBase));
74   new (data_.data()) OptionalTensorRef();
75 }
76 
~OpaqueOptionalTensorRef()77 OpaqueOptionalTensorRef::~OpaqueOptionalTensorRef() {
78   get()->~OptionalTensorRef();
79 }
80 
getTensor() const81 const Tensor& OpaqueOptionalTensorRef::getTensor() const {
82   return get()->getTensorRef();
83 }
84 
85 }
86 
tensor(c10::MaybeOwned<TensorBase> && tensor)87 void OperandInfo::tensor(c10::MaybeOwned<TensorBase> &&tensor) {
88   tensor_base_ = std::move(tensor);
89   *tensor_storage_ = make_otr(*tensor_base_);
90 }
91 
exchange_tensor(c10::MaybeOwned<TensorBase> && new_tensor)92 void OperandInfo::exchange_tensor(c10::MaybeOwned<TensorBase> &&new_tensor) {
93   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!original_tensor_base_->defined());
94   original_tensor_base_ = std::exchange(tensor_base_, std::move(new_tensor));
95   *original_tensor_storage_ = std::exchange(*tensor_storage_, make_otr(*tensor_base_));
96 }
97 
restore_original_tensor()98 void OperandInfo::restore_original_tensor() {
99   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(original_tensor_base_->defined());
100   tensor_base_ = std::move(original_tensor_base_);
101   *tensor_storage_ = std::exchange(*original_tensor_storage_, OptionalTensorRef{});
102 }
103 
104 /// Construction
add_owned_output(const TensorBase & output)105 TensorIteratorConfig& TensorIteratorConfig::add_owned_output(const TensorBase& output) {
106   TORCH_INTERNAL_ASSERT(
107       num_inputs_ == 0,
108       "Keep in mind that you have to add all outputs first before adding any input. "
109       "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
110   tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, output));
111   num_outputs_++;
112   return *this;
113 }
114 
add_owned_input(const TensorBase & input)115 TensorIteratorConfig& TensorIteratorConfig::add_owned_input(const TensorBase& input) {
116   tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, input));
117   num_inputs_++;
118   return *this;
119 }
120 
add_owned_const_input(const TensorBase & input)121 TensorIteratorConfig& TensorIteratorConfig::add_owned_const_input(const TensorBase& input) {
122   const_tensor_indices_.push_back(tensors_.size());
123   tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, input));
124   num_inputs_++;
125   return *this;
126 }
127 
add_borrowed_output(const TensorBase & output)128 TensorIteratorConfig& TensorIteratorConfig::add_borrowed_output(const TensorBase& output) {
129   TORCH_INTERNAL_ASSERT(
130       num_inputs_ == 0,
131       "Keep in mind that you have to add all outputs first before adding any input. "
132       "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
133   tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(output));
134   num_outputs_++;
135   return *this;
136 }
137 
add_borrowed_input(const TensorBase & input)138 TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const TensorBase& input) {
139   tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input));
140   num_inputs_++;
141   return *this;
142 }
143 
add_borrowed_const_input(const TensorBase & input)144 TensorIteratorConfig& TensorIteratorConfig::add_borrowed_const_input(const TensorBase& input) {
145   const_tensor_indices_.push_back(tensors_.size());
146   tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input));
147   num_inputs_++;
148   return *this;
149 }
150 
declare_static_dtype_and_device(ScalarType dtype,Device device)151 TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device(ScalarType dtype, Device device) {
152   TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
153   static_dtype_ = dtype;
154   static_device_ = device;
155   return *this;
156 }
157 
declare_static_dtype(ScalarType dtype)158 TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype(ScalarType dtype) {
159   TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
160   static_dtype_ = dtype;
161   return *this;
162 }
163 
declare_static_device(Device device)164 TensorIteratorConfig& TensorIteratorConfig::declare_static_device(Device device) {
165   static_device_ = device;
166   return *this;
167 }
168 
declare_static_shape(IntArrayRef shape)169 TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape) {
170   // WARNING:
171   //   This will bypass all shape checking in the TensorIterator. Kernels which call this method
172   //   are expected to check shapes before calling `add_owned_input` or `add_owned_output`.
173   TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)")
174   static_shape_ = std::make_optional(DimVector(shape));
175   return *this;
176 }
177 
declare_static_shape(IntArrayRef shape,IntArrayRef squash_dims)178 TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims) {
179   declare_static_shape(shape);
180   if (static_shape_->empty()) return *this;
181   for (const auto& squash_dim : squash_dims) {
182     TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast<int64_t>(static_shape_->size()),
183                 "squash_dim ", squash_dim, " must be in [0, ", static_shape_->size(), ").");
184     (*static_shape_)[squash_dim] = 1;
185   }
186   return *this;
187 }
188 
is_tensor_const(size_t idx)189 bool TensorIteratorConfig::is_tensor_const(size_t idx) {
190   return std::find(const_tensor_indices_.begin(), const_tensor_indices_.end(), idx) != const_tensor_indices_.end();
191 }
192 
193 // NOTE: [Computing output strides]
194 // We use the following algorithm to compute output strides
195 // If correctly sized output is provided, we respect its strides and don't change them
196 // Otherwise, if provided output is of incorrect size or no output is provided,
197 // we try to recover permutation that was applied to the inputs
198 // by sorting the strides of the inputs. Precedence is given to the inputs in the order they were added,
199 // and to permutations involving non-broadcasted dimensions
200 // 1. we loop over inputs starting from the first
201 // 2. for all inputs strides of broadcasted dimensions are set to 0, and 0 compares equal to anything. If one
202 // of the dimensions being compared has a stride of 0, we move on to the next tensor to determine if
203 // these dimensions need to be swapped.
204 // 3. strides of dimensions equal to 1 participate in sorting
205 // 4. if 2 strides are equal and neither is 0, we try to break the tie by looking at the corresponding dimensions
206 // of the tensor. Dimensions were permuted if, when iterating from the end, dimensions corresponding to the
207 // same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie.
208 //
209 // Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly
210 // losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially
211 // improve traversal order of the second tensor. We chose the former option to better propagate channels last layout
212 // for example for a tensor with the sizes N1H1
213 // These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all
214 // arguments are of the same size) or the argument that is not broadcasted, regardless of its position.
215 // As a bonus, it also result in reasonably well-behaved traversal order of the inputs and outputs - in the kernels
216 // output is traversed linearly, and since it closely follows input layouts, inputs are traversed linearly as well
217 //
218 // Examples:
219 // full size tensor + broadcasted tensor with 0 or 1 non-trivial dimensions => strides of output are same
220 // as strides of full size input regardless of the order
221 // 2 tensors of same size but different strides => output strides are the same as first argument
222 //
223 // We also have fast path for memory-dense inputs with the same strides (or, trivially, single memory-dense input)
224 // that outputs a tensor with the same strides as inputs. The only difference in result with the algorithm described
225 // above is for strides for trivial (1) dimensions, where in ambiguous cases for performance reasons we default to
226 // contiguous strides.
227 // Example: tensor with sizes NC11 and strides C1CC will produce output with strides C111 (note differences are only
228 // in the strides of trivial dimensions, so physical layout is unaffected but permutation information is lost)
229 // We might change this behavior in future once performance considerations are resolved
230 
reorder_dimensions()231 void TensorIteratorBase::reorder_dimensions() {
232   // Sort the dimensions based on strides in ascending order with reduced dims
233   // at the front. NOTE: that this inverts the order of C-contiguous tensors.
234   // strides[0] is the fastest moving dimension instead of strides[ndim - 1].
235   // See NOTE: [Computing output strides] and inline  comments for more detailed description
236 
237   perm_.resize(ndim());
238   if (ndim() == 1) {
239     perm_[0] = 0;
240     return;
241   }
242 
243   // initialize perm with n-1, n-2, ..., 1, 0
244   std::iota(perm_.rbegin(), perm_.rend(), 0);
245 
246   // Reordering dimensions changes iteraton order
247   if (enforce_linear_iteration_) {
248     permute_dimensions(perm_);
249     return;
250   }
251 
252   // returns 1 if the dim0 should come after dim1, -1 if dim0 should come
253   // before dim1, and 0 if the comparison is ambiguous.
254   auto should_swap = [&](size_t dim0, size_t dim1) {
255     for (const auto arg : c10::irange(ntensors())) {
256       // ignore undefined or incorrectly sized tensors
257       if (operands_[arg].stride_bytes.empty() || operands_[arg].will_resize) {
258         continue;
259       }
260       int64_t stride0 = operands_[arg].stride_bytes[dim0];
261       int64_t stride1 = operands_[arg].stride_bytes[dim1];
262       if (is_reduction_ && operands_[arg].is_output) {
263         // move reduced dimensions to the front
264         // strides of reduced dimensions are always set to 0 by review_reduce_result
265         if ((stride0 == 0) != (stride1 == 0)) {
266           return stride1 == 0 ? 1 : -1;
267         }
268       }
269       //move on to the next input if one of the dimensions is broadcasted
270       if (stride0 == 0 || stride1 == 0) {
271         continue;
272       // it is important to return here only with strict comparisons, for equal strides we try to break the tie later
273       // by comparing corresponding dimensions or if that does not work, moving on to the next tensor
274       } else if (stride0 < stride1) {
275         return -1;
276       } else  if (stride0 > stride1) {
277         return 1;
278       } else { //equal strides, use dimensions themselves as the tie-breaker.
279         //at this point, with zero strides out of the way, we are guaranteed that operand dimensions are equal to shape_
280          auto t_dim0 = shape_[dim0];
281          auto t_dim1 = shape_[dim1];
282          //return only if dimensions should be swapped, otherwise move on to the next tensor
283          if (t_dim0 > t_dim1) {
284              return 1;
285          }
286       }
287     }
288     return 0;
289   };
290 
291   // insertion sort with support for ambiguous comparisons
292   for (const auto i : c10::irange(1, ndim())) {
293     int dim1 = i;
294     for (int dim0 = i - 1; dim0 >= 0; dim0--) {
295       int comparison = should_swap(perm_[dim0], perm_[dim1]);
296       if (comparison > 0) {
297         std::swap(perm_[dim0], perm_[dim1]);
298         dim1 = dim0;
299       } else if (comparison < 0) {
300         break;
301       }
302     }
303   }
304 
305   // perform re-ordering of shape and strides
306   permute_dimensions(perm_);
307 }
308 
309 // Computes a common dtype using type promotion
310 // See the [Common Dtype Computation] note
compute_common_dtype()311 ScalarType TensorIteratorBase::compute_common_dtype() {
312   at::native::ResultTypeState state = {};
313   for (const auto& op : operands_) {
314     if (op.is_output) {
315       continue;
316     }
317 
318     state = at::native::update_result_type_state(op.tensor(), state);
319   }
320 
321   common_dtype_ = at::native::result_type(state);
322   TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined);
323 
324   return common_dtype_;
325 }
326 
original_options(const OperandInfo & op)327 static TensorOptions original_options(const OperandInfo& op) {
328   if (op.original_tensor_base().defined()) {
329     return op.original_tensor_base().options();
330   } else {
331     return op.options();
332   }
333 }
334 
335 // Implements the behavior of the following flags:
336 //   - check_all_same_dtype_
337 //   - check_all_same_device_
338 //   - enforce_safe_casting_to_output_
339 //   - promote_inputs_to_common_dtype_
340 //   - cast_common_dtype_to_outputs_
341 //
342 // See their descriptions in TensorIterator.h for details.
343 // NOTE: Checks for more specific behaviors (e.g. the first and second
344 //   inputs must share a dtype, but the third must have the long dtype)
345 //   should be implemented directly and outside of TensorIterator.
compute_types(const TensorIteratorConfig & config)346 void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
347   // Reviews operands (1/2)
348   //   - validates that all input tensors are defined
349   //   - computes common device
350   //   - determines if there are undefined outputs
351   //   - determines if there are different dtypes and attempts
352   //       to quickly acquire a common dtype
353   Device common_device = kCPU;
354   common_dtype_ = ScalarType::Undefined;
355   // NB: despite output_dtype's generic sounding name, it only is
356   // used in a nontrivial way if check_all_same_dtype is true
357   ScalarType output_dtype = ScalarType::Undefined;
358   bool has_different_input_dtypes = false;
359   bool has_different_output_dtypes = false;
360   bool has_undefined_outputs = false;
361 
362   for (auto& op : operands_) {
363     // Validates that all inputs have type information, and that
364     //   if an output is missing type information that we can infer
365     //   the device it should be allocated on.
366     if (!op.is_type_defined()) {
367       TORCH_INTERNAL_ASSERT(op.is_output, "Found type undefined input tensor!");
368 
369       if (config.static_dtype_.has_value()) {
370         op.target_dtype = config.static_dtype_.value();
371       } else {
372         has_undefined_outputs = true;
373       }
374 
375       if (config.static_device_.has_value()) {
376         op.device = config.static_device_.value();
377       } else {
378         TORCH_INTERNAL_ASSERT(config.check_all_same_device_);
379       }
380 
381       if (has_undefined_outputs || !op.device.has_value()) {
382         continue;
383       }
384     }
385 
386     // Validates input tensors are defined
387     if (!op.tensor_base().defined()) {
388       TORCH_INTERNAL_ASSERT(op.is_output, "Found undefined input tensor!");
389       continue;
390     }
391 
392     TORCH_INTERNAL_ASSERT(op.target_dtype == op.current_dtype)
393 
394     // Acquires the first non-CPU device (if any) as the common device
395     if (common_device == kCPU && !op.tensor_base().is_cpu()) {
396       common_device = op.tensor_base().device();
397     }
398 
399     if (!op.is_output) {
400       // Determines if there are varying input dtypes
401       // NOTE: the common dtype is set to the first defined input dtype observed
402       if (op.target_dtype != common_dtype_) {
403         if (common_dtype_ == ScalarType::Undefined) {
404           common_dtype_ = op.target_dtype;
405         } else {
406           has_different_input_dtypes = true;
407         }
408       }
409     } else {  // op.is_output
410       // Determines if there are varying output dtypes
411       // NOTE: the output dtype is set to the first defined output dtype observed
412       if (op.target_dtype != output_dtype) {
413         if (output_dtype == ScalarType::Undefined) {
414           output_dtype = op.target_dtype;
415         } else {
416           has_different_output_dtypes = true;
417         }
418       }
419     }
420   }
421 
422   // Checks that either the computation type is computable or unneeded
423   TORCH_INTERNAL_ASSERT(!(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ &&
424                         (has_undefined_outputs || config.enforce_safe_casting_to_output_ ||
425                         config.cast_common_dtype_to_outputs_)));
426 
427   // Checks that all inputs and defined outputs are the same dtype, if requested
428   if (config.check_all_same_dtype_ &&
429       (has_different_input_dtypes || has_different_output_dtypes ||
430       (common_dtype_ != output_dtype && output_dtype != ScalarType::Undefined))) {
431     // Throws an informative error message
432     for (auto& op : operands_) {
433       if (!op.tensor_base().defined()) {
434         continue;
435       }
436 
437       TORCH_CHECK(op.target_dtype == common_dtype_,
438                   "Found dtype ", op.target_dtype, " but expected ", common_dtype_);
439     }
440   }
441 
442   // Short-circuits if no additional work required
443   if (!has_undefined_outputs && !config.check_all_same_device_ &&
444       !config.promote_inputs_to_common_dtype_ && !config.cast_common_dtype_to_outputs_ &&
445       !config.enforce_safe_casting_to_output_) {
446     // Invalidates common_dtype_ if it could not be inferred
447     common_dtype_ = has_different_input_dtypes ? ScalarType::Undefined : common_dtype_;
448     return;
449   }
450 
451   // Computes a common dtype, if needed
452   if ((has_different_input_dtypes || all_ops_are_scalars_) && config.promote_inputs_to_common_dtype_) {
453     common_dtype_ = compute_common_dtype();
454   }
455 
456   // Promotes common dtype to the default float scalar type, if needed
457   if (config.promote_integer_inputs_to_float_ &&
458       c10::isIntegralType(common_dtype_, /*includeBool=*/true)) {
459     common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype());
460   }
461 
462   // Reviews operands (2/2)
463   //   - sets metadata for undefined outputs
464   //   - checks that all tensors are on the same device, if requested
465   //   - checks that the common dtype can safely cast to each output, if requested
466   //   - creates temporaries for CPU operations, if needed and requested
467   common_device_ = common_device;
468   int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0;
469   int current_cpu_scalars_on_non_cpu = 0;
470   for (auto& op : operands_) {
471     bool is_type_defined = op.is_type_defined();
472     bool is_device_defined = op.is_device_defined();
473 
474     if (!is_type_defined) {
475       op.target_dtype = common_dtype_;
476     }
477     if (!is_device_defined) {
478       op.device = common_device;
479     }
480 
481     if (!is_type_defined && !is_device_defined) {
482       continue;
483     }
484 
485     // Skips undefined tensors
486     if (!op.tensor_base().defined()) {
487       continue;
488     }
489 
490     // Checks all tensors are on the same device, if requested
491     if (config.check_all_same_device_) {
492       // Handles CPU scalars on CUDA kernels that support them
493       if (!common_device.is_cpu() &&
494           config.allow_cpu_scalars_ && !op.is_output && op.tensor_base().dim() == 0 &&
495           op.tensor_base().is_cpu()) {
496         TORCH_CHECK(current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu,
497                     "Trying to pass too many CPU scalars to non-CPU kernel!");
498         ++current_cpu_scalars_on_non_cpu;
499       } else if (op.device.value() != common_device) {
500         TORCH_CHECK(false,
501                     "Expected all tensors to be on the same device, but "
502                     "found at least two devices, ", common_device, " and ", op.device.value(), "!");
503       }
504     }
505 
506     // Checks safe casting, if requested
507     if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) {
508       TORCH_CHECK(canCast(common_dtype_, op.current_dtype),
509                   "result type ", common_dtype_, " can't be cast to the "
510                   "desired output type ", op.current_dtype);
511     }
512 
513     // Creates temporaries for CPU operations, if needed and requested
514     // TODO: reuse temporaries when possible (e.g. for inplace operations)
515     if (common_device == kCPU) {
516       // Casts to outputs by creating temporaries of the correct dtype (if needed)
517       // NB: we skip this on is_meta_, because the temporary allocation here is
518       // unnecessary if we aren't going to actually do the compute
519       if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) {
520         TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
521         // Marker [Output original_tensor is set]
522         // NB: do NOT use set_output here, as the temporary is NOT a true output;
523         // op.tensor is the true output and it was pre-provided for us.
524         // TODO: The logic for cast_outputs will need to be handled by the
525         // structured kernels implementation.  What probably should happen
526         // is that we pass in the inferred dtype into the out kernel, and
527         // then after calling the out kernel, do the conversion (which
528         // is cast_outputs here), but integrating this with existing
529         // TensorIterator will take a little doing
530         op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(
531             at::empty_like(op.tensor(),
532                            op.tensor_base().options().dtype(common_dtype_),
533                            LEGACY_CONTIGUOUS_MEMORY_FORMAT)));
534         if (!names_.empty()) {
535           namedinference::propagate_names(op.tensor_base(), names_);
536         }
537         op.current_dtype = common_dtype_;
538         op.target_dtype = common_dtype_;
539       }
540 
541       // Promotes inputs by creating temporaries of the correct dtype
542       if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
543         op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(op.tensor().to(common_dtype_)));
544         op.current_dtype = common_dtype_;
545         op.target_dtype = common_dtype_;
546       }
547     }
548   }
549 }
550 
compatible_stride(int64_t element_size) const551 StrideVector TensorIteratorBase::compatible_stride(int64_t element_size) const {
552   auto stride = StrideVector();
553   int64_t next_stride = element_size;
554   for (const auto dim : c10::irange(ndim())) {
555     stride.push_back(next_stride);
556     next_stride *= shape_[dim];
557   }
558   return stride;
559 }
560 
invert_perm(IntArrayRef input) const561 DimVector TensorIteratorBase::invert_perm(IntArrayRef input) const {
562   // Invert the permutation caused by reorder_dimensions. This is not valid
563   // after coalesce_dimensions is called.
564   TORCH_INTERNAL_ASSERT(!has_coalesced_dimensions_);
565   TORCH_INTERNAL_ASSERT(input.size()==perm_.size());
566   auto res = DimVector(input.size()); //no initialization needed, every value in res should be written to.
567   for (const auto dim : c10::irange(ndim())) {
568     res[perm_[dim]] = input[dim];
569   }
570   return res;
571 }
572 
allocate_or_resize_outputs()573 void TensorIteratorBase::allocate_or_resize_outputs() {
574   for (const auto i : c10::irange(num_outputs_)) {
575     auto& op = operands_[i];
576     if (!op.tensor_base().defined() || op.will_resize) {
577       TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
578       auto element_size = elementSize(op.target_dtype);
579       op.stride_bytes = compatible_stride(static_cast<int64_t>(element_size));
580       // check if permutation is just an inverted order
581       bool inverted = true;
582       for (const auto j : c10::irange(ndim())) {
583         if (perm_[j] != ndim() - j - 1) {
584           inverted = false;
585           break;
586         }
587       }
588       auto tensor_shape = invert_perm(shape_);
589       if (inverted) {
590         // can just return contiguous output
591         // it is faster because it avoids allocating 0 size tensor and
592         // resizing and restriding it
593         set_output_raw_strided(i, tensor_shape, {}, original_options(op), names_);
594       } else {
595         auto tensor_stride = invert_perm(op.stride_bytes);
596         for (const auto dim : c10::irange(ndim())) {
597           tensor_stride[dim] /= static_cast<int64_t>(element_size);
598         }
599         set_output_raw_strided(i, tensor_shape, tensor_stride, original_options(op), names_);
600       }
601       op.current_dtype = op.target_dtype;
602     } else if (op.tensor_base().defined()) {
603       // Even if we don't resize, we still need to tell set_output about
604       // the output, so that we properly set guard and propagate names
605       set_output_raw_strided(i, op.tensor_base().sizes(), {}, original_options(op), names_);
606     }
607   }
608 }
609 
compute_names(const TensorIteratorConfig & config)610 void TensorIteratorBase::compute_names(const TensorIteratorConfig& config) {
611   bool should_infer_names = std::any_of(
612       operands_.begin(),
613       operands_.end(),
614       [](const OperandInfo& op) {
615         return op.tensor_base().defined() && op.tensor_base().has_names();
616       });
617   if (!should_infer_names) {
618     return;
619   }
620 
621   for (auto& op : operands_) {
622     if (!op.tensor_base().defined()) continue;
623     // Don't include output tensors if we are resizing, since we will
624     // clobber their names in any case.  (If the output tensor was
625     // also an input tensor, we'll pick it up when it shows up again
626     // in operands).
627     if (config.resize_outputs_ && op.is_output) continue;
628     // perform name inference
629     if (names_.empty()) {
630       names_ = op.tensor_base().names();
631     } else {
632       names_ = NameVector(unify_from_right(names_, op.tensor_base().names()));
633     }
634   }
635 }
636 
coalesce_dimensions()637 void TensorIteratorBase::coalesce_dimensions() {
638   if (ndim() <= 1) {
639     return;
640   }
641 
642   // We can coalesce two adjacent dimensions if either dim has size 1 or if:
643   // shape[n] * stride[n] == stride[n + 1].
644   auto can_coalesce = [&](int dim0, int dim1) {
645     auto shape0 = shape_[dim0];
646     auto shape1 = shape_[dim1];
647     if (shape0 == 1 || shape1 == 1) {
648       return true;
649     }
650     for (const auto i : c10::irange(ntensors())) {
651       auto& stride = operands_[i].stride_bytes;
652       if (shape0 * stride[dim0] != stride[dim1]) {
653         return false;
654       }
655     }
656     return true;
657   };
658 
659   // replace each operands stride at dim0 with its stride at dim1
660   auto replace_stride = [&](int dim0, int dim1) {
661     for (const auto i : c10::irange(ntensors())) {
662       auto& stride = operands_[i].stride_bytes;
663       stride[dim0] = stride[dim1];
664     }
665   };
666 
667   int prev_dim = 0;
668   for (const auto dim : c10::irange(1, ndim())) {
669     if (can_coalesce(prev_dim, dim)) {
670       if (shape_[prev_dim] == 1) {
671         replace_stride(prev_dim, dim);
672       }
673       shape_[prev_dim] *= shape_[dim];
674     } else {
675       prev_dim++;
676       if (prev_dim != dim) {
677         replace_stride(prev_dim, dim);
678         shape_[prev_dim] = shape_[dim];
679       }
680     }
681   }
682 
683   shape_.resize(prev_dim + 1);
684   for (const auto i : c10::irange(ntensors())) {
685     operands_[i].stride_bytes.resize(ndim());
686   }
687   has_coalesced_dimensions_ = true;
688 }
689 
numel() const690 int64_t TensorIteratorBase::numel() const {
691   int64_t numel = 1;
692   for (int64_t size : shape_) {
693     numel *= size;
694   }
695   return numel;
696 }
697 
get_dim_strides(int dim) const698 StrideVector TensorIteratorBase::get_dim_strides(int dim) const {
699   auto dims = ndim();
700   auto inner_strides = StrideVector();
701   for (auto& op : operands_) {
702     inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]);
703   }
704   return inner_strides;
705 }
706 
get_base_ptrs() const707 SmallVector<char*, 4> TensorIteratorBase::get_base_ptrs() const {
708   auto ptrs = SmallVector<char*, 4>(ntensors());
709   at::get_base_ptrs(ptrs.data(), operands_);
710   return ptrs;
711 }
712 
is_dim_reduced(int dim) const713 bool TensorIteratorBase::is_dim_reduced(int dim) const {
714   for (auto& op : operands_) {
715     if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) {
716       return true;
717     }
718   }
719   return false;
720 }
721 
permute_dimensions(IntArrayRef perm)722 void TensorIteratorBase::permute_dimensions(IntArrayRef perm) {
723   TORCH_INTERNAL_ASSERT(perm.size() == static_cast<unsigned>(ndim()));
724 
725   auto reorder = [perm](IntArrayRef data) {
726     auto res = DimVector(data.size(), 0);
727     for (const auto i : c10::irange(perm.size())) {
728       res[i] = data[perm[i]];
729     }
730     return res;
731   };
732 
733   // Update shape and strides
734   shape_ = reorder(shape_);
735   for (auto& op : operands_) {
736     if (!op.stride_bytes.empty()) {
737       op.stride_bytes = reorder(op.stride_bytes);
738     }
739   }
740 }
741 
num_output_elements() const742 int64_t TensorIteratorBase::num_output_elements() const {
743   int64_t elem = 1;
744   for (const auto dim : c10::irange(ndim())) {
745     if (operands_[0].stride_bytes[dim] != 0 || shape_[dim] == 0)  {
746       elem *= shape_[dim];
747     }
748   }
749   return elem;
750 }
751 
num_reduce_dims() const752 int TensorIteratorBase::num_reduce_dims() const {
753   int count = 0;
754   for (const auto dim : c10::irange(ndim())) {
755     if (operands_[0].stride_bytes[dim] == 0) {
756       count++;
757     }
758   }
759   return count;
760 }
761 
for_each(loop2d_t loop,int64_t grain_size)762 void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) {
763   int64_t numel = this->numel();
764   if (numel == 0) {
765     return;
766   } else if (numel < grain_size || at::get_num_threads() == 1) {
767     return serial_for_each(loop, {0, numel});
768   } else {
769     at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
770       serial_for_each(loop, {begin, end});
771     });
772   }
773 }
774 
get_strides() const775 StrideVector TensorIteratorBase::get_strides() const {
776   const auto dim = ndim();
777   StrideVector strides(static_cast<size_t>(std::max(dim, 2)) * ntensors());
778   at::get_strides(strides.data(), operands_, dim);
779   return strides;
780 }
781 
serial_for_each(loop2d_t loop,Range range) const782 void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const {
783   if (range.size() == 0) {
784     return;
785   }
786 
787   const auto ntensors = this->ntensors();
788   const auto ndim = this->ndim();
789 
790   c10::SmallBuffer<char*, 4> ptrs(ntensors);
791   c10::SmallBuffer<int64_t, 8> strides(ntensors * static_cast<size_t>(std::max(ndim, 2)));
792 
793   at::get_base_ptrs(ptrs.data(), operands_);
794   at::get_strides(strides.data(), operands_, ndim);
795   at::internal::serial_for_each(
796       shape_, strides, ptrs.data(), ptrs.size(), loop, range);
797 }
798 
is_trivial_1d() const799 bool TensorIteratorBase::is_trivial_1d() const {
800   // TODO: check for casting once it's supported
801   return ndim() == 1;
802 }
803 
is_contiguous() const804 bool TensorIteratorBase::is_contiguous() const {
805   if (numel() == 1) {
806     return true;
807   }
808   if (ndim() != 1) {
809     return false;
810   }
811   return has_contiguous_first_dim();
812 }
813 
814 
is_scalar(int64_t arg) const815 bool TensorIteratorBase::is_scalar(int64_t arg) const {
816   const auto& stride = operands_[arg].stride_bytes;
817   for (const auto i : c10::irange(ndim())) {
818     if (stride[i] != 0 && shape_[i] != 1) {
819       return false;
820     }
821   }
822   return true;
823 }
824 
is_cpu_scalar(int64_t arg) const825 bool TensorIteratorBase::is_cpu_scalar(int64_t arg) const {
826   return is_scalar(arg) && device(arg).is_cpu();
827 }
828 
cast_outputs()829 void TensorIteratorBase::cast_outputs() {
830   for (auto& op : operands_) {
831     if (op.is_output && op.original_tensor_base().defined() &&
832         op.original_tensor_base().scalar_type() != op.current_dtype) {
833       // TODO: Now that set_output resizes both the original_tensor
834       // and tensor, this condition should no longer ever be true
835       const auto &original_tensor = op.original_tensor();
836       const auto &tensor = op.tensor();
837       if (original_tensor.sizes() != tensor.sizes()) {
838         original_tensor.resize_as_(tensor).as_strided_(tensor.sizes(), tensor.strides());
839       }
840       original_tensor.copy_(tensor);
841       op.restore_original_tensor();
842     }
843   }
844 }
845 
data_ptr(int64_t arg) const846 void* TensorIteratorBase::data_ptr(int64_t arg) const {
847   return operands_[arg].data;
848 }
849 
remove_operand(int64_t arg)850 void TensorIteratorBase::remove_operand(int64_t arg) {
851   operands_.erase(operands_.begin() + arg);
852 }
853 
unsafe_replace_operand(int64_t arg,void * data)854 void TensorIteratorBase::unsafe_replace_operand(int64_t arg, void* data) {
855   operands_[arg].data = data;
856 }
857 
narrow(int dim,int64_t start,int64_t size)858 void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
859   TORCH_INTERNAL_ASSERT(dim < ndim() && size >= 1);
860   shape_[dim] = size;
861   view_offsets_[dim] += start;
862   for (auto& op : operands_) {
863     op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
864   }
865   if (size == 1 && !is_reduction_) {
866     coalesce_dimensions();
867   }
868 }
869 
select_all_keeping_dim(int start_dim,IntArrayRef indices)870 void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indices) {
871   TORCH_INTERNAL_ASSERT(start_dim <= ndim());
872   for (const auto i : c10::irange(start_dim, ndim())) {
873     for (auto& op : operands_) {
874       op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
875     }
876     shape_[i] = 1;
877   }
878 }
879 
880 #define BINARY_FLOAT_OP_CONFIG()                \
881   TensorIteratorConfig()                        \
882     .set_check_mem_overlap(true)                \
883     .allow_cpu_scalars(true)                    \
884     .promote_inputs_to_common_dtype(true)       \
885     .cast_common_dtype_to_outputs(true)         \
886     .enforce_safe_casting_to_output(true)       \
887     .promote_integer_inputs_to_float(true)
888 
889 // Helper to construct a binary op that promotes integer inputs to float.
build_binary_float_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)890 void TensorIteratorBase::build_binary_float_op(
891     const TensorBase& out, const TensorBase& a, const TensorBase& b) {
892   build(BINARY_FLOAT_OP_CONFIG()
893         .add_owned_output(out)
894         .add_owned_const_input(a)
895         .add_owned_const_input(b));
896 }
897 
build_borrowing_binary_float_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)898 void TensorIteratorBase::build_borrowing_binary_float_op(
899     const TensorBase& out, const TensorBase& a, const TensorBase& b) {
900   build(BINARY_FLOAT_OP_CONFIG()
901         .add_output(out)
902         .add_const_input(a)
903         .add_const_input(b));
904 }
905 
set_up_comparison_op_config(TensorIteratorConfig & config,const TensorBase & out)906 static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) {
907   config.set_check_mem_overlap(true);
908   config.allow_cpu_scalars(true);
909   config.promote_inputs_to_common_dtype(true);
910 
911   // When 'out' isn't defined (e.g. for the functional operator 'a == b'), we
912   // want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we
913   // don't coerce the output.
914   if (!out.defined()) {
915     config.declare_static_dtype(kBool);
916   }
917 
918   // Note [special-case bool outputs]
919   // We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
920   // has `bool` dtype. This is a performance optimization: the functional
921   // version of all comparison/logical ops uses a bool output tensor, and we'd like to
922   // avoid creating a temporary copy of the output.
923   // However, note that all kernels using this TensorIterator will need to special-case when
924   // the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
925   if (out.defined() && out.scalar_type() != kBool) {
926     config.cast_common_dtype_to_outputs(true);
927   }
928 }
929 
build_comparison_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)930 void TensorIteratorBase::build_comparison_op(
931     const TensorBase& out, const TensorBase& a, const TensorBase& b) {
932   TensorIteratorConfig config;
933   set_up_comparison_op_config(config, out);
934 
935   config.add_owned_output(out);
936   config.add_owned_const_input(a);
937   config.add_owned_const_input(b);
938   build(config);
939 }
940 
build_borrowing_comparison_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)941 void TensorIteratorBase::build_borrowing_comparison_op(
942     const TensorBase& out, const TensorBase& a, const TensorBase& b) {
943   TensorIteratorConfig config;
944   set_up_comparison_op_config(config, out);
945 
946   config.add_borrowed_output(out);
947   config.add_borrowed_const_input(a);
948   config.add_borrowed_const_input(b);
949   build(config);
950 }
951 
build_borrowing_except_last_argument_comparison_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)952 void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op(
953     const TensorBase& out, const TensorBase& a, const TensorBase& b) {
954   TensorIteratorConfig config;
955   set_up_comparison_op_config(config, out);
956 
957   config.add_borrowed_output(out);
958   config.add_borrowed_const_input(a);
959   config.add_owned_const_input(b);
960   build(config);
961 }
962 
build_ternary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b,const TensorBase & c)963 void TensorIteratorBase::build_ternary_op(
964     const TensorBase& out, const TensorBase& a,
965     const TensorBase& b, const TensorBase& c) {
966   build(TensorIteratorConfig()
967       .promote_inputs_to_common_dtype(true)
968       .cast_common_dtype_to_outputs(true)
969       .enforce_safe_casting_to_output(true)
970       .add_owned_output(out)
971       .add_owned_const_input(a)
972       .add_owned_const_input(b)
973       .add_owned_const_input(c));
974 }
975 
976 // This cannot be a function because TensorIteratorConfig is not
977 // copyable or movable, so it can't be returned from the function.
978 #define BINARY_OP_CONFIG()                              \
979   TensorIteratorConfig()                                \
980     .set_check_mem_overlap(true)                        \
981     .allow_cpu_scalars(true)                            \
982     .promote_inputs_to_common_dtype(true)               \
983     .cast_common_dtype_to_outputs(true)                 \
984     .enforce_safe_casting_to_output(true)               \
985 
build_binary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)986 void TensorIteratorBase::build_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b) {
987   build(BINARY_OP_CONFIG()
988       .add_owned_output(out)
989       .add_owned_const_input(a)
990       .add_owned_const_input(b));
991 }
992 
build_borrowing_binary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)993 void TensorIteratorBase::build_borrowing_binary_op(
994     const TensorBase& out, const TensorBase& a, const TensorBase& b) {
995   build(BINARY_OP_CONFIG()
996       .add_output(out)
997       .add_const_input(a)
998       .add_const_input(b));
999 }
1000 
1001 // This cannot be a function because TensorIteratorConfig is not
1002 // copyable or movable, so it can't be returned from the function.
1003 #define UNARY_FLOAT_OP_CONFIG()                                         \
1004   TensorIteratorConfig()                                                \
1005   .set_check_mem_overlap(true)                                          \
1006   .promote_inputs_to_common_dtype(true)                                 \
1007   .cast_common_dtype_to_outputs(true)                                   \
1008   .enforce_safe_casting_to_output(true)                                 \
1009   .promote_integer_inputs_to_float(true)
1010 
build_unary_float_op(const TensorBase & out,const TensorBase & a)1011 void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
1012   build(UNARY_FLOAT_OP_CONFIG()
1013       .add_owned_output(out)
1014       .add_owned_const_input(a));
1015 }
1016 
build_borrowing_unary_float_op(const TensorBase & out,const TensorBase & a)1017 void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) {
1018   build(UNARY_FLOAT_OP_CONFIG()
1019       .add_output(out)
1020       .add_const_input(a));
1021 }
1022 
1023 // This cannot be a function because TensorIteratorConfig is not
1024 // copyable or movable, so it can't be returned from the function.
1025 #define UNARY_OP_CONFIG()                                \
1026   TensorIteratorConfig()                                 \
1027     .set_check_mem_overlap(true)                         \
1028     .cast_common_dtype_to_outputs(false)                 \
1029     .enforce_safe_casting_to_output(false)               \
1030     .check_all_same_dtype(true)
1031 
build_unary_op(const TensorBase & out,const TensorBase & a)1032 void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
1033   build(UNARY_OP_CONFIG()
1034       .add_owned_output(out)
1035       .add_owned_const_input(a));
1036 }
1037 
build_borrowing_unary_op(const TensorBase & out,const TensorBase & a)1038 void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) {
1039   build(UNARY_OP_CONFIG()
1040       .add_output(out)
1041       .add_const_input(a));
1042 }
1043 
build_output_borrowing_argument_owning_unary_op(const TensorBase & out,const TensorBase & a)1044 void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) {
1045   build(UNARY_OP_CONFIG()
1046       .add_output(out)
1047       .add_owned_const_input(a));
1048 }
1049 
1050 // Helper to construct a unary op that forcibly promotes output to boolean.
1051 // Only be used when the output tensor must have boolean type.
build_borrowing_unary_force_boolean_op(const TensorBase & out,const TensorBase & a)1052 void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
1053   build(TensorIteratorConfig()
1054       .set_check_mem_overlap(true)
1055       .check_all_same_dtype(false)
1056       .declare_static_dtype(at::kBool)
1057       .declare_static_device(a.device())
1058       .add_output(out)
1059       .add_const_input(a));
1060 }
1061 
binary_op(TensorBase & out,const TensorBase & a,const TensorBase & b)1062 TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
1063   TensorIterator iter;
1064   iter.build_binary_op(out, a, b);
1065   return iter;
1066 }
1067 
borrowing_binary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)1068 TensorIterator TensorIterator::borrowing_binary_op(
1069     const TensorBase& out, const TensorBase& a, const TensorBase& b) {
1070   TensorIterator iter;
1071   iter.build_borrowing_binary_op(out, a, b);
1072   return iter;
1073 }
1074 
binary_float_op(TensorBase & out,const TensorBase & a,const TensorBase & b)1075 TensorIterator TensorIterator::binary_float_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
1076   TensorIterator iter;
1077   iter.build_binary_float_op(out, a, b);
1078   return iter;
1079 }
1080 
comparison_op(TensorBase & out,const TensorBase & a,const TensorBase & b)1081 TensorIterator TensorIterator::comparison_op(TensorBase& out, const TensorBase& a,
1082     const TensorBase& b) {
1083   TensorIterator iter;
1084   iter.build_comparison_op(out, a, b);
1085   return iter;
1086 }
1087 
unary_op(TensorBase & out,const TensorBase & a)1088 TensorIterator TensorIterator::unary_op(TensorBase& out, const TensorBase& a) {
1089   TensorIterator iter;
1090   iter.build_unary_op(out, a);
1091   return iter;
1092 }
1093 
unary_float_op(TensorBase & out,const TensorBase & a)1094 TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase& a) {
1095   TensorIterator iter;
1096   iter.build_unary_float_op(out, a);
1097   return iter;
1098 }
1099 
1100 #define NULLARY_OP_CONFIG()                                     \
1101   TensorIteratorConfig()                                        \
1102     .set_check_mem_overlap(true)                                \
1103     .check_all_same_dtype(false)                                \
1104   /* FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 */ \
1105     .resize_outputs(false)
1106 
nullary_op(TensorBase & out)1107 TensorIterator TensorIterator::nullary_op(TensorBase& out) {
1108   return NULLARY_OP_CONFIG()
1109     .add_owned_output(out)
1110     .build();
1111 }
1112 
borrowing_nullary_op(const TensorBase & out)1113 TensorIterator TensorIterator::borrowing_nullary_op(const TensorBase& out) {
1114   return NULLARY_OP_CONFIG()
1115     .add_output(out)
1116     .build();
1117 }
1118 
reduce_op(TensorBase & out,const TensorBase & a)1119 TensorIterator TensorIterator::reduce_op(TensorBase& out, const TensorBase& a) {
1120   TORCH_INTERNAL_ASSERT(out.defined());
1121   return TensorIteratorConfig()
1122     .set_check_mem_overlap(false)
1123     .add_owned_output(out)
1124     .add_owned_const_input(a)
1125     .resize_outputs(false)
1126     .is_reduction(true)
1127     // TODO: not supporting casting to outputs is only really necessary for arg{min,max}
1128     .promote_inputs_to_common_dtype(true)
1129     .build();
1130 }
1131 
reduce_op(TensorBase & out1,TensorBase & out2,const TensorBase & a)1132 TensorIterator TensorIterator::reduce_op(TensorBase& out1, TensorBase& out2, const TensorBase& a) {
1133   TORCH_INTERNAL_ASSERT(out1.defined());
1134   TORCH_INTERNAL_ASSERT(out2.defined());
1135   TORCH_CHECK(a.device() == out1.device() && out1.device() == out2.device(),
1136       "reduce_op(): expected input and both outputs to be on same device, but input is on ", a.device(),
1137       ", output1 is on ", out1.device(), " and output2 is on", out2.device());
1138   TORCH_CHECK(out1.dim() == out2.dim(), "reduce_op(): expected both outputs to have same number of dims, but output1 has ", out1.dim(),
1139       " and output2 has ", out2.dim());
1140   TORCH_CHECK(out1.sizes() == out2.sizes(), "reduce_op(): expected both outputs to have same sizes, but output1 has ", out1.sizes(),
1141       " and output2 has ", out2.sizes());
1142   TORCH_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has ", out1.strides(),
1143       " and output2 has ", out2.strides());
1144   return TensorIteratorConfig()
1145     .set_check_mem_overlap(false)
1146     .add_owned_output(out1)
1147     .add_owned_output(out2)
1148     .add_owned_const_input(a)
1149     .resize_outputs(false)
1150     .is_reduction(true)
1151     .check_all_same_dtype(false)
1152     .build();
1153 }
1154 
populate_operands(TensorIteratorConfig & config)1155 void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
1156   for (const auto idx : c10::irange(config.tensors_.size())) {
1157     auto& tensor = config.tensors_[idx];
1158     // If *any* of the arguments is a meta tensor, the overall
1159     // computation is a meta computation (don't do any work,
1160     // just compute output information).  This aligns with
1161     // our multiple dispatch semantics.
1162     if (tensor->is_meta()) {
1163       is_meta_ = true;
1164     }
1165     operands_.emplace_back(std::move(tensor));
1166     operands_[idx].is_const = config.is_tensor_const(idx);
1167   }
1168   num_outputs_ = config.num_outputs_;
1169 }
1170 
mark_outputs()1171 void TensorIteratorBase::mark_outputs() {
1172   // TODO: merge this into populate_operands
1173   for (const auto i : c10::irange(num_outputs_)) {
1174     operands_[i].is_output = true;
1175     const auto& output = tensor(i);
1176     if (!output.defined()) continue;
1177 
1178     // check if output is also an input
1179     for (const auto arg : c10::irange(num_outputs_, ntensors())) {
1180       const auto& input = tensor(arg);
1181       if (output.is_same(input)) {
1182         operands_[i].is_read_write = true;
1183       }
1184     }
1185   }
1186 }
1187 
mark_resize_outputs(const TensorIteratorConfig & config)1188 void TensorIteratorBase::mark_resize_outputs(const TensorIteratorConfig& config) {
1189   // Outputs cannot be broadcasted. Check that the shape of the outputs matches
1190   // the inferred shape. There's an exception for write-only tensors to support
1191   // our legacy behavior that functions with `out=` arguments resize their
1192   // outputs.
1193   if (config.static_shape_.has_value()) {
1194     return;
1195   }
1196   for (const auto i : c10::irange(num_outputs_)) {
1197     const auto& output = tensor(i);
1198     if (!output.defined()) {
1199       operands_[i].will_resize = true;
1200     }
1201     if (output.defined() && !output.sizes().equals(shape_)) {
1202       if (config.resize_outputs_ && !operands_[i].is_read_write) {
1203         operands_[i].will_resize = true;
1204         continue;
1205       }
1206       // for reduction, output size does not match shape_, as output is reduced size, and shape_ is size of the input
1207       TORCH_CHECK(is_reduction_,  "output with shape ", output.sizes(), " doesn't match the broadcast shape ",
1208                  shape_);
1209     }
1210   }
1211 }
1212 
compute_mem_overlaps(const TensorIteratorConfig & config)1213 void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config) {
1214   if (!config.check_mem_overlap_) {
1215     return;
1216   }
1217   for (const auto i : c10::irange(num_outputs_)) {
1218     const auto& output = tensor_base(i);
1219     if (!output.defined()) continue;
1220     assert_no_internal_overlap(output);
1221     for (const auto j : c10::irange(num_outputs_, ntensors())) {
1222       const auto& input = tensor_base(j);
1223       if (!input.is_same(output)) {
1224         assert_no_partial_overlap(output, input);
1225       }
1226     }
1227   }
1228 }
1229 
compute_shape(const TensorIteratorConfig & config)1230 void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) {
1231   if (config.static_shape_.has_value()) {
1232     shape_ = *config.static_shape_;
1233     return;
1234   }
1235 
1236   all_ops_same_shape_ = true;
1237   bool has_scalars = false;
1238   bool has_tensors = false;
1239   for (auto& op : operands_) {
1240     if (!op.tensor_base().defined()) continue;
1241 
1242     // For now, don't include output tensors when we're resizing outputs.
1243     // These shapes don't participate in shape computation.
1244     // This preserves the legacy behavior where torch.add(..., out=dst) resizes
1245     // the destination tensor.  If the output tensor is also an input, we'll
1246     // pick it up later in the operands.
1247     if (config.resize_outputs_ && op.is_output) continue;
1248     TORCH_CHECK(!op.tensor_base().unsafeGetTensorImpl()->has_symbolic_sizes_strides(),
1249       "TensorIterator does not support symbolic shapes; please implement this operator in torch/_refs "
1250       "using the elementwise or reduction helpers (look at backtrace to find out what operator this is)");
1251     auto shape = op.tensor_base().sizes();
1252     if (shape.empty()) {
1253       has_scalars = true;
1254     } else {
1255       has_tensors = true;
1256     }
1257     if (has_scalars && has_tensors) {
1258       all_ops_same_shape_ = false;
1259     }
1260     if (shape_.empty()) {
1261       shape_ = shape;
1262     } else if (!shape.equals(shape_)) {
1263       all_ops_same_shape_ = false;
1264       shape_ = infer_size_dimvector(shape_, shape);
1265     }
1266   }
1267   all_ops_are_scalars_ = !has_tensors;
1268 }
1269 
compute_strides(const TensorIteratorConfig & config)1270 void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) {
1271   for (auto& op : operands_) {
1272     if (op.tensor_base().defined() && !op.will_resize) {
1273       IntArrayRef original_shape = config.static_shape_ ? shape_ : op.tensor_base().sizes();
1274       auto original_stride = op.tensor_base().strides();
1275       auto element_size_in_bytes = op.tensor_base().element_size();
1276       auto offset = ndim() - original_shape.size();
1277       if (offset > 0)
1278           op.stride_bytes.resize(ndim(), 0);
1279       else
1280           op.stride_bytes.resize(ndim());
1281       for (const auto i : c10::irange(original_shape.size())) {
1282         // see NOTE: [Computing output strides]
1283         if (original_shape[i] == 1 && shape_[offset + i] !=1) {
1284           op.stride_bytes[offset + i] = 0;
1285         } else {
1286           op.stride_bytes[offset + i] = original_stride[i] * element_size_in_bytes;
1287         }
1288       }
1289     }
1290   }
1291 }
1292 
can_use_32bit_indexing() const1293 bool TensorIteratorBase::can_use_32bit_indexing() const {
1294   int64_t max_value = std::numeric_limits<int32_t>::max();
1295   if (numel() > max_value) {
1296     return false;
1297   }
1298   for (auto& op : operands_) {
1299     int64_t max_offset = 1;
1300     for (const auto dim : c10::irange(ndim())) {
1301       max_offset += (shape_[dim] - 1) * op.stride_bytes[dim];
1302     }
1303     if (max_offset > max_value) {
1304       return false;
1305     }
1306   }
1307   return true;
1308 }
1309 
split(int dim)1310 std::unique_ptr<TensorIterator> TensorIteratorBase::split(int dim) {
1311   TORCH_INTERNAL_ASSERT(dim >= 0 && dim < ndim() && shape()[dim] >= 2);
1312   auto copy = std::make_unique<TensorIterator>(*this);
1313 
1314   bool overlaps = is_dim_reduced(dim);
1315   auto copy_size = shape_[dim] / 2;
1316   auto this_size = shape_[dim] - copy_size;
1317   copy->narrow(dim, 0, copy_size);
1318   copy->final_output_ &= !overlaps;
1319   this->narrow(dim, copy_size, this_size);
1320   this->accumulate_ |= overlaps;
1321 
1322   return copy;
1323 }
1324 
1325 
get_dim_to_split() const1326 int TensorIteratorBase::get_dim_to_split() const {
1327   TORCH_INTERNAL_ASSERT(ndim() >= 1);
1328   int64_t max_extent = -1;
1329   int dim_to_split = -1;
1330   for (int dim = ndim() - 1; dim >= 0; dim--) {
1331     const int64_t size = shape_[dim];
1332     if (size == 0) {
1333       continue;
1334     }
1335     for (auto& op : operands_) {
1336       // std::abs is necessary to handle some special cases where we support negative strides
1337       // see the CUDA backend of at::flip
1338       const int64_t extent = (size - 1) * std::abs(op.stride_bytes[dim]);
1339       if (extent > max_extent) {
1340         max_extent = extent;
1341         dim_to_split = dim;
1342       }
1343     }
1344   }
1345   TORCH_INTERNAL_ASSERT(max_extent >= 0);
1346   return dim_to_split;
1347 }
1348 
fast_set_up(const TensorIteratorConfig & config)1349 bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
1350   // This function tries to do a fast setup to avoid needless reordering of dimensions and tracking output strides
1351   // Return true if it can do fast setup or false otherwise
1352   // TODO enable fast handling for reductions
1353   FastSetupType setup_type = compute_fast_setup_type(config);
1354   if (setup_type == FastSetupType::NONE) {
1355     return false;
1356   }
1357 
1358   // allocate memory for output, memory format depends on setup_type
1359   switch (setup_type) {
1360     case FastSetupType::CONTIGUOUS:
1361       {
1362         for (const auto i : c10::irange(num_outputs_)) {
1363           auto& op = operands_[i];
1364           if (!op.tensor_base().defined()) {
1365             TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1366           }
1367           set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::Contiguous), names_);
1368         }
1369         break;
1370       }
1371     case FastSetupType::CHANNELS_LAST:
1372       {
1373         for (const auto i : c10::irange(num_outputs_)) {
1374           auto& op = operands_[i];
1375           if (!op.tensor_base().defined()) {
1376             TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1377           }
1378           set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::ChannelsLast), names_);
1379         }
1380         break;
1381       }
1382     case FastSetupType::NON_OVERLAPPING_DENSE:
1383       {
1384         // find the index of a defined tensor in operands_ start from input tensor
1385         int i_defined; // NOLINT(cppcoreguidelines-init-variables)
1386         for (i_defined = ntensors() - 1; i_defined >= 0; --i_defined) {
1387           if (tensor(i_defined).defined()) break;
1388         }
1389         TORCH_CHECK(i_defined >= 0, "Can not find a defined tensor when fast allocating memory to outputs");
1390         for (const auto i : c10::irange(num_outputs_)) {
1391           auto& op = operands_[i];
1392           if (!op.tensor_base().defined()) {
1393             TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1394           }
1395           set_output_raw_strided(i, shape_, tensor_base(i_defined).strides(), original_options(op), names_);
1396         }
1397         break;
1398       }
1399     default:
1400       TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", std::to_string((int)setup_type));
1401   }
1402   //coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here)
1403   if (ndim() > 1){
1404     has_coalesced_dimensions_ = true;
1405   }
1406   if (ndim() >= 1) {
1407     shape_[0] = numel();
1408     shape_.resize(1);
1409   }
1410   for (auto& op : operands_ ) {
1411     auto element_size_in_bytes = op.tensor_base().element_size();
1412     op.stride_bytes.resize(ndim());
1413     if (ndim()>0) {
1414       op.stride_bytes[0] = element_size_in_bytes;
1415     }
1416   }
1417   return true;
1418 }
1419 
compute_fast_setup_type(const TensorIteratorConfig & config)1420 FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorConfig& config) {
1421   if (is_reduction_ || !all_ops_same_shape_) {
1422     return FastSetupType::NONE;
1423   }
1424 
1425   // For linear iteration, only contiguous tensors can be coalesced
1426   // Fast setup of any other format requires changing iteration order
1427   if (enforce_linear_iteration_) {
1428     for (const auto& op : operands_) {
1429       if (op.tensor_base().defined() && !op.will_resize) {
1430         auto is_contiguous = op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
1431         if (!is_contiguous) {
1432           return FastSetupType::NONE;
1433         }
1434       }
1435     }
1436     return FastSetupType::CONTIGUOUS;
1437   }
1438 
1439   bool is_contiguous = true;
1440   bool is_channels_last = true;
1441   bool is_non_overlapping_and_dense = true;
1442   for (const auto& op : operands_) {
1443     if (op.tensor_base().defined() && !op.will_resize) {
1444       is_contiguous &= op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
1445       is_channels_last &= op.tensor_base().is_contiguous(at::MemoryFormat::ChannelsLast);
1446       is_non_overlapping_and_dense &= op.tensor_base().is_non_overlapping_and_dense();
1447     }
1448   }
1449   // TODO this leads to ambiguous cases (NC11) to be always treated as contiguous
1450   if (is_contiguous) {
1451     return FastSetupType::CONTIGUOUS;
1452   }
1453   if (is_channels_last) {
1454     return FastSetupType::CHANNELS_LAST;
1455   }
1456   if (is_non_overlapping_and_dense) {
1457     int64_t prev = -1;
1458     // Fast setup is allowed only when all the defined tensors have the same shape and strides,
1459     // Iterate from back to check input tensors' strides first, then output tensors'.
1460     for (int64_t i = ntensors() - 1; i >= 0; --i) {
1461       const auto& op = operands_[i];
1462       if (op.tensor_base().defined() && !op.will_resize) {
1463         if (prev < 0) {
1464           prev = i;
1465           continue;
1466         }
1467         if (!tensor_base(prev).strides().equals(op.tensor_base().strides())) {
1468           // [Note: stride check for non contiguous tensors in fast setup]
1469           // We prevent 3 cases doing fast setup here:
1470           // 1. input tensors have different strides.
1471           // 2. output tensors won't be resized and have different strides.
1472           // 3. input tensors have the same strides, but output tensors have different strides with input tensors.
1473           //    We don't allow re-stride output tensors in this case since it is not compatible with
1474           //    numpy. The behavior in numpy is that if the output tensor has same shape as the input
1475           //    tensor but different strides, the strides of output tensor will be preserved, so we do
1476           //    the same in tensor iterator.
1477           return FastSetupType::NONE;
1478         }
1479       }
1480     }
1481     return FastSetupType::NON_OVERLAPPING_DENSE;
1482   }
1483   return FastSetupType::NONE;
1484 }
1485 
1486 TensorIteratorBase::TensorIteratorBase() = default;
1487 
build(TensorIteratorConfig & config)1488 void TensorIteratorBase::build(TensorIteratorConfig& config) {
1489   // populate some persistent configuration fields
1490   is_reduction_ = config.is_reduction_;
1491   enforce_linear_iteration_ = config.enforce_linear_iteration_;
1492 
1493   // fill in operands_ based on configuration
1494   populate_operands(config);
1495   // set is_output and is_read_write flags on appropriate tensors
1496   mark_outputs();
1497   // Check that the outputs have no internal overlap
1498   // and do not share memory with inputs.
1499   compute_mem_overlaps(config);
1500   // Check that input dimensions are aligned correctly & compute outnames.
1501   compute_names(config);
1502   // compute the broadcasted shape
1503   compute_shape(config);
1504   // mark outputs for resizing if necessary
1505   mark_resize_outputs(config);
1506   // compute the result dtype and device
1507   compute_types(config);
1508   // try fast setup output tensor, if failed, fallback to normal setup
1509   if (!fast_set_up(config)) {
1510     // compute each tensor's stride after broadcasting
1511     compute_strides(config);
1512     // re-order dimensions to improve coalescing
1513     reorder_dimensions();
1514     // allocate the output tensor if it's not provided
1515     allocate_or_resize_outputs();
1516     // coalesce adjacent dimensions when possible
1517     if (!is_meta_) coalesce_dimensions();
1518   }
1519 
1520   if (is_meta_) return;
1521 
1522   auto has_storage = true;
1523   for (auto& op : operands_) {
1524     has_storage &= op.tensor_base().has_storage();
1525   }
1526   auto privateuse1_without_storage =
1527      common_device_.type() == DeviceType::PrivateUse1 &&
1528      !has_storage;
1529 
1530   // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
1531   // Nothing beyond this point is important for meta functions, so it's fine to exit early here.
1532   // Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
1533   if (privateuse1_without_storage  ||
1534       common_device_.type() == DeviceType::MTIA ||
1535       common_device_.type() == DeviceType::XLA  ||
1536       common_device_.type() == DeviceType::IPU  ||
1537       common_device_.type() == DeviceType::Lazy ||
1538       common_device_.type() == DeviceType::MAIA  ||
1539       common_device_.type() == DeviceType::HPU) return;
1540 
1541   for (auto& op : operands_) {
1542     TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
1543     if (op.is_const) {
1544       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1545       op.data = const_cast<void*>(op.tensor_base().const_data_ptr());
1546     } else {
1547       op.data = op.tensor_base().mutable_data_ptr();
1548     }
1549   }
1550 
1551   // zero out offsets
1552   // If the tensor is a scalar, we leave room for it
1553   // So index translations in reduction can access
1554   // a valid value for the offset
1555   int64_t ndim_offsets = (ndim() ? ndim() : 1);
1556   view_offsets_ = DimVector(ndim_offsets, 0);
1557 }
1558 
1559 // This is the structured kernels' implementation of set_output.  It is
1560 // NEVER actually called directly; instead, a subclass of TensorIteratorBase
1561 // will override set_output to actually do the operation, and then call
1562 // set_output on the TensorIteratorBase to setup TI's metadata.
1563 // The precondition for this function is that maybe_get_output() now
1564 // unconditionally returns a real Tensor (prior to output setting,
1565 // this function may return an undefined tensor.)
set_output_raw_strided(int64_t output_idx,IntArrayRef sizes,IntArrayRef strides,TensorOptions options,DimnameList names)1566 void TensorIteratorBase::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1567   auto& op = operands_[output_idx];
1568   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
1569   const auto& t = maybe_get_output(output_idx);
1570   TORCH_INTERNAL_ASSERT(t.defined());
1571   if (!op.tensor_base().defined()) {
1572     op.tensor(c10::MaybeOwned<TensorBase>::borrowed(t));
1573     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.target_dtype == t.scalar_type());
1574   } else if (op.will_resize) {
1575     if (op.original_tensor_base().defined()) {
1576       // OK, so this is pretty weird.  To understand how we can end up in
1577       // this situation, first look at Marker [Output original_tensor is set].
1578       // That is the sole site where original_tensor may be set on an
1579       // output operand.  Essentially, when we are given an explicit output
1580       // tensor whose dtype doesn't match the computed common dtype from
1581       // the input operands, we do a switcheroo: we replace the (incorrectly
1582       // typed) output tensor with a correctly typed, *temporary* tensor,
1583       // and remember the original tensor in original_tensor (which will
1584       // then get written back to when we cast_outputs).
1585       //
1586       // Now, what if the given output tensor also happened to be zero
1587       // size (meaning that we will_resize it)?  Well, at the call site
1588       // above, we don't necessarily(*) know what the correct shape should
1589       // be, so we give the temporary tensor the same shape as the original.
1590       // At the time of set_output is when we DO know what the correct size
1591       // is, and the subclass's implementation of set_output in structured class
1592       // responsible for resizing original_tensor.  But we still have this
1593       // incorrectly sized temporary output which the structured subclass
1594       // knows nothing about, so we are obligated to also resize it here.
1595       //
1596       // This is a slight memory pessimization, because previously
1597       // original_tensor only got resized at the end of the computation, rather
1598       // than at the beginning (as happens here).  However, the peak memory
1599       // usage is the same, since you need to materialize both original tensor
1600       // and temporary tensor to do the copy.
1601       //
1602       // (*) Actually, technically, we probably do know what the shape
1603       // should be, since we do shape computation before dtype computation.
1604       // So hypothetically we could figure out what the correct shape is
1605       // at that point in time and directly allocate the temporary at
1606       // the right size.
1607       //
1608       // But a better solution is to delay allocation of temporaries until
1609       // after TensorIterator builder, waiting until we actually want
1610       // to do the computation.  That would also remove the necessity
1611       // for the is_meta_ test.
1612       TORCH_INTERNAL_ASSERT(op.original_tensor_base().is_same(t));
1613       TORCH_INTERNAL_ASSERT(!op.tensor_base().is_same(t));
1614       OptionalTensorRef tensor(op.tensor());
1615       at::native::resize_output(*tensor, sizes);
1616       if (!strides.empty()) {
1617         TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
1618         tensor->as_strided_(sizes, strides);
1619       } else if (options.memory_format_opt().has_value()) {
1620         tensor->unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
1621       }
1622     }
1623   }
1624   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1625       op.tensor_base().is_same(t) || op.current_dtype == op.tensor_base().scalar_type());
1626 // For simplicity, just always update the cached current_type.
1627   op.current_dtype = op.tensor_base().scalar_type();
1628 }
1629 
1630 // This is the "traditional" implementation of set_output.  On TensorIterator
1631 // instances, it is invoked directly from various call sites in this file.  No
1632 // funny business.
set_output_raw_strided(int64_t output_idx,IntArrayRef sizes,IntArrayRef strides,TensorOptions options,DimnameList names)1633 void TensorIterator::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1634   // NB: intentionally no superclass call
1635   auto& op = operands_[output_idx];
1636   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
1637   if (!op.tensor_base().defined()) {
1638       if (strides.empty()) {
1639         op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty(sizes, options)));
1640       } else {
1641         op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty_strided(sizes, strides, options)));
1642       }
1643       op.current_dtype = op.target_dtype;
1644   } else if (op.will_resize) {
1645       at::native::resize_output(op.tensor(), sizes);
1646       if (!strides.empty()) {
1647         TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
1648         op.tensor().as_strided_(sizes, strides);
1649       } else if (options.memory_format_opt().has_value()) {
1650         op.tensor_base().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
1651       }
1652   }
1653   if (!names.empty()) {
1654     TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
1655     namedinference::propagate_names(op.tensor_base(), names);
1656   }
1657 }
1658 
1659 // Not actually used by anything (TensorIterator subclass calls
1660 // its own implementation of set_output which knows exactly where
1661 // all the outputs are), but we have to provide all pure virtual methods
1662 // for MetaBase
maybe_get_output(int64_t output_idx)1663 const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) {
1664   return output(output_idx);
1665 }
1666 
with_32bit_indexing() const1667 SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const {
1668   return SplitUntil32Bit(*this);
1669 }
1670 
1671 /// SplitUntil32Bit. Recursively splits an iterator into sub-iterators that
1672 /// can use 32-bit indexing.
1673 
iterator(const TensorIteratorBase & iter)1674 SplitUntil32Bit::iterator::iterator(const TensorIteratorBase& iter) {
1675   vec.emplace_back(new TensorIterator(iter));
1676   vec.emplace_back(nullptr); // ++ first pops the last element
1677   ++(*this);
1678 }
1679 
operator ++()1680 SplitUntil32Bit::iterator& SplitUntil32Bit::iterator::operator++() {
1681   vec.pop_back();
1682   while (!vec.empty() && !vec.back()->can_use_32bit_indexing()) {
1683     auto& iter = *vec.back();
1684     auto split_dim = iter.get_dim_to_split();
1685     vec.emplace_back(iter.split(split_dim));
1686   }
1687   return *this;
1688 }
1689 
operator *() const1690 TensorIterator& SplitUntil32Bit::iterator::operator*() const {
1691   return *vec.back();
1692 }
1693 
begin() const1694 SplitUntil32Bit::iterator SplitUntil32Bit::begin() const {
1695   return SplitUntil32Bit::iterator(iter);
1696 }
1697 
end() const1698 SplitUntil32Bit::iterator SplitUntil32Bit::end() const {
1699   return SplitUntil32Bit::iterator();
1700 }
1701 
DimCounter(IntArrayRef shape,Range range)1702 DimCounter::DimCounter(IntArrayRef shape, Range range)
1703   : shape(shape)
1704   , range(range)
1705   , values(shape.size())
1706   , offset(range.begin) {
1707   std::fill(values.begin(), values.end(), 0);
1708   if (range.begin == 0) {
1709     return;
1710   }
1711 
1712   int64_t linear_offset = range.begin;
1713   auto ndim = values.size();
1714   for (const auto dim : c10::irange(ndim)) {
1715     int64_t size = shape[dim];
1716     if (size > 0) {
1717       values[dim] = linear_offset % size;
1718       linear_offset /= size;
1719     }
1720   }
1721   TORCH_INTERNAL_ASSERT(linear_offset == 0);
1722 }
1723 
is_done() const1724 bool DimCounter::is_done() const {
1725   return offset >= range.end;
1726 }
1727 
increment(const std::array<int64_t,2> & step)1728 void DimCounter::increment(const std::array<int64_t, 2>& step) {
1729   offset += step[0] * step[1];
1730   auto ndim = values.size();
1731   int64_t overflow = step[0];
1732   size_t i = 0;
1733   if (step[1] != 1) {
1734     TORCH_INTERNAL_ASSERT(step[0] == shape[0] && values[0] == 0);
1735     i = 1;
1736     overflow = step[1];
1737   }
1738   for (; i < ndim && overflow > 0; i++) {
1739     auto size = shape[i];
1740     auto prev = values[i];
1741     auto value = prev + overflow;
1742     if (value >= size) {
1743       overflow = 1;
1744       value -= size;
1745       TORCH_INTERNAL_ASSERT(value < size);
1746     } else {
1747       overflow = 0;
1748     }
1749     values[i] = static_cast<int64_t>(value);
1750   }
1751   TORCH_INTERNAL_ASSERT(overflow == 0 || overflow == 1);
1752 }
1753 
max_2d_step() const1754 std::array<int64_t, 2> DimCounter::max_2d_step() const {
1755   int64_t step0 = std::min(shape[0] - values[0], range.end - offset);
1756   int64_t step1 = 1;
1757   if (step0 == shape[0] && !shape.empty()) {
1758     step1 = std::min(shape[1] - values[1], (range.end - offset) / shape[0]);
1759   }
1760   return {step0, step1};
1761 }
1762 
1763 }  // namespace at
1764