1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #define TORCH_ASSERT_NO_OPERATORS
3 #include <ATen/TensorIterator.h>
4 #undef TORCH_ASSERT_NO_OPERATORS
5
6 #include <ATen/core/Tensor.h>
7
8 #include <ATen/ExpandUtils.h>
9 #include <ATen/Parallel.h>
10 #include <ATen/native/TypeProperties.h>
11 #include <ATen/MemoryOverlap.h>
12 #include <ATen/native/Resize.h>
13 #include <ATen/NamedTensorUtils.h>
14 #include <ATen/TensorOperators.h>
15 #include <ATen/TensorIteratorInternal.h>
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/empty_strided.h>
22 #endif
23
24 #include <c10/util/irange.h>
25 #include <c10/util/SmallBuffer.h>
26
27 #include <array>
28 #include <algorithm>
29 #include <cmath>
30
31 namespace at {
32
33 using DimMask = TensorIteratorBase::DimMask;
34 using PtrVector = TensorIteratorBase::PtrVector;
35 using loop2d_t = TensorIteratorBase::loop2d_t;
36 using StrideVector = TensorIteratorBase::StrideVector;
37
38 namespace {
39
get_base_ptrs(char ** ptrs,ArrayRef<OperandInfo> operands)40 inline void get_base_ptrs(char** ptrs, ArrayRef<OperandInfo> operands) {
41 std::transform(operands.begin(), operands.end(), ptrs, [](const OperandInfo& op) {
42 return static_cast<char*>(op.data);
43 });
44 }
45
get_strides(int64_t * strides,ArrayRef<OperandInfo> operands,int64_t ndim)46 inline void get_strides(int64_t* strides, ArrayRef<OperandInfo> operands, int64_t ndim) {
47 for (const auto dim : c10::irange(ndim)) {
48 for (const auto arg : c10::irange(operands.size())) {
49 *strides++ = operands[arg].stride_bytes[dim];
50 }
51 }
52 // Always at least 2d strides to support 2d for_each loops
53 if (ndim < 2) {
54 auto ntensors = operands.size();
55 std::fill_n(strides, (2 - ndim) * ntensors, 0);
56 }
57 }
58
make_otr(const TensorBase & tensor)59 static OptionalTensorRef make_otr(const TensorBase &tensor) {
60 if (tensor.defined()) {
61 return OptionalTensorRef(tensor);
62 } else {
63 return OptionalTensorRef();
64 }
65 }
66
67 }
68
69 namespace internal {
70
OpaqueOptionalTensorRef()71 OpaqueOptionalTensorRef::OpaqueOptionalTensorRef() {
72 static_assert(alignof(OptionalTensorRef) == alignof(TensorBase));
73 static_assert(sizeof(OptionalTensorRef) == sizeof(TensorBase));
74 new (data_.data()) OptionalTensorRef();
75 }
76
~OpaqueOptionalTensorRef()77 OpaqueOptionalTensorRef::~OpaqueOptionalTensorRef() {
78 get()->~OptionalTensorRef();
79 }
80
getTensor() const81 const Tensor& OpaqueOptionalTensorRef::getTensor() const {
82 return get()->getTensorRef();
83 }
84
85 }
86
tensor(c10::MaybeOwned<TensorBase> && tensor)87 void OperandInfo::tensor(c10::MaybeOwned<TensorBase> &&tensor) {
88 tensor_base_ = std::move(tensor);
89 *tensor_storage_ = make_otr(*tensor_base_);
90 }
91
exchange_tensor(c10::MaybeOwned<TensorBase> && new_tensor)92 void OperandInfo::exchange_tensor(c10::MaybeOwned<TensorBase> &&new_tensor) {
93 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!original_tensor_base_->defined());
94 original_tensor_base_ = std::exchange(tensor_base_, std::move(new_tensor));
95 *original_tensor_storage_ = std::exchange(*tensor_storage_, make_otr(*tensor_base_));
96 }
97
restore_original_tensor()98 void OperandInfo::restore_original_tensor() {
99 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(original_tensor_base_->defined());
100 tensor_base_ = std::move(original_tensor_base_);
101 *tensor_storage_ = std::exchange(*original_tensor_storage_, OptionalTensorRef{});
102 }
103
104 /// Construction
add_owned_output(const TensorBase & output)105 TensorIteratorConfig& TensorIteratorConfig::add_owned_output(const TensorBase& output) {
106 TORCH_INTERNAL_ASSERT(
107 num_inputs_ == 0,
108 "Keep in mind that you have to add all outputs first before adding any input. "
109 "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
110 tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, output));
111 num_outputs_++;
112 return *this;
113 }
114
add_owned_input(const TensorBase & input)115 TensorIteratorConfig& TensorIteratorConfig::add_owned_input(const TensorBase& input) {
116 tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, input));
117 num_inputs_++;
118 return *this;
119 }
120
add_owned_const_input(const TensorBase & input)121 TensorIteratorConfig& TensorIteratorConfig::add_owned_const_input(const TensorBase& input) {
122 const_tensor_indices_.push_back(tensors_.size());
123 tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, input));
124 num_inputs_++;
125 return *this;
126 }
127
add_borrowed_output(const TensorBase & output)128 TensorIteratorConfig& TensorIteratorConfig::add_borrowed_output(const TensorBase& output) {
129 TORCH_INTERNAL_ASSERT(
130 num_inputs_ == 0,
131 "Keep in mind that you have to add all outputs first before adding any input. "
132 "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
133 tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(output));
134 num_outputs_++;
135 return *this;
136 }
137
add_borrowed_input(const TensorBase & input)138 TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const TensorBase& input) {
139 tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input));
140 num_inputs_++;
141 return *this;
142 }
143
add_borrowed_const_input(const TensorBase & input)144 TensorIteratorConfig& TensorIteratorConfig::add_borrowed_const_input(const TensorBase& input) {
145 const_tensor_indices_.push_back(tensors_.size());
146 tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input));
147 num_inputs_++;
148 return *this;
149 }
150
declare_static_dtype_and_device(ScalarType dtype,Device device)151 TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device(ScalarType dtype, Device device) {
152 TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
153 static_dtype_ = dtype;
154 static_device_ = device;
155 return *this;
156 }
157
declare_static_dtype(ScalarType dtype)158 TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype(ScalarType dtype) {
159 TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
160 static_dtype_ = dtype;
161 return *this;
162 }
163
declare_static_device(Device device)164 TensorIteratorConfig& TensorIteratorConfig::declare_static_device(Device device) {
165 static_device_ = device;
166 return *this;
167 }
168
declare_static_shape(IntArrayRef shape)169 TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape) {
170 // WARNING:
171 // This will bypass all shape checking in the TensorIterator. Kernels which call this method
172 // are expected to check shapes before calling `add_owned_input` or `add_owned_output`.
173 TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)")
174 static_shape_ = std::make_optional(DimVector(shape));
175 return *this;
176 }
177
declare_static_shape(IntArrayRef shape,IntArrayRef squash_dims)178 TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims) {
179 declare_static_shape(shape);
180 if (static_shape_->empty()) return *this;
181 for (const auto& squash_dim : squash_dims) {
182 TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast<int64_t>(static_shape_->size()),
183 "squash_dim ", squash_dim, " must be in [0, ", static_shape_->size(), ").");
184 (*static_shape_)[squash_dim] = 1;
185 }
186 return *this;
187 }
188
is_tensor_const(size_t idx)189 bool TensorIteratorConfig::is_tensor_const(size_t idx) {
190 return std::find(const_tensor_indices_.begin(), const_tensor_indices_.end(), idx) != const_tensor_indices_.end();
191 }
192
193 // NOTE: [Computing output strides]
194 // We use the following algorithm to compute output strides
195 // If correctly sized output is provided, we respect its strides and don't change them
196 // Otherwise, if provided output is of incorrect size or no output is provided,
197 // we try to recover permutation that was applied to the inputs
198 // by sorting the strides of the inputs. Precedence is given to the inputs in the order they were added,
199 // and to permutations involving non-broadcasted dimensions
200 // 1. we loop over inputs starting from the first
201 // 2. for all inputs strides of broadcasted dimensions are set to 0, and 0 compares equal to anything. If one
202 // of the dimensions being compared has a stride of 0, we move on to the next tensor to determine if
203 // these dimensions need to be swapped.
204 // 3. strides of dimensions equal to 1 participate in sorting
205 // 4. if 2 strides are equal and neither is 0, we try to break the tie by looking at the corresponding dimensions
206 // of the tensor. Dimensions were permuted if, when iterating from the end, dimensions corresponding to the
207 // same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie.
208 //
209 // Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly
210 // losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially
211 // improve traversal order of the second tensor. We chose the former option to better propagate channels last layout
212 // for example for a tensor with the sizes N1H1
213 // These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all
214 // arguments are of the same size) or the argument that is not broadcasted, regardless of its position.
215 // As a bonus, it also result in reasonably well-behaved traversal order of the inputs and outputs - in the kernels
216 // output is traversed linearly, and since it closely follows input layouts, inputs are traversed linearly as well
217 //
218 // Examples:
219 // full size tensor + broadcasted tensor with 0 or 1 non-trivial dimensions => strides of output are same
220 // as strides of full size input regardless of the order
221 // 2 tensors of same size but different strides => output strides are the same as first argument
222 //
223 // We also have fast path for memory-dense inputs with the same strides (or, trivially, single memory-dense input)
224 // that outputs a tensor with the same strides as inputs. The only difference in result with the algorithm described
225 // above is for strides for trivial (1) dimensions, where in ambiguous cases for performance reasons we default to
226 // contiguous strides.
227 // Example: tensor with sizes NC11 and strides C1CC will produce output with strides C111 (note differences are only
228 // in the strides of trivial dimensions, so physical layout is unaffected but permutation information is lost)
229 // We might change this behavior in future once performance considerations are resolved
230
reorder_dimensions()231 void TensorIteratorBase::reorder_dimensions() {
232 // Sort the dimensions based on strides in ascending order with reduced dims
233 // at the front. NOTE: that this inverts the order of C-contiguous tensors.
234 // strides[0] is the fastest moving dimension instead of strides[ndim - 1].
235 // See NOTE: [Computing output strides] and inline comments for more detailed description
236
237 perm_.resize(ndim());
238 if (ndim() == 1) {
239 perm_[0] = 0;
240 return;
241 }
242
243 // initialize perm with n-1, n-2, ..., 1, 0
244 std::iota(perm_.rbegin(), perm_.rend(), 0);
245
246 // Reordering dimensions changes iteraton order
247 if (enforce_linear_iteration_) {
248 permute_dimensions(perm_);
249 return;
250 }
251
252 // returns 1 if the dim0 should come after dim1, -1 if dim0 should come
253 // before dim1, and 0 if the comparison is ambiguous.
254 auto should_swap = [&](size_t dim0, size_t dim1) {
255 for (const auto arg : c10::irange(ntensors())) {
256 // ignore undefined or incorrectly sized tensors
257 if (operands_[arg].stride_bytes.empty() || operands_[arg].will_resize) {
258 continue;
259 }
260 int64_t stride0 = operands_[arg].stride_bytes[dim0];
261 int64_t stride1 = operands_[arg].stride_bytes[dim1];
262 if (is_reduction_ && operands_[arg].is_output) {
263 // move reduced dimensions to the front
264 // strides of reduced dimensions are always set to 0 by review_reduce_result
265 if ((stride0 == 0) != (stride1 == 0)) {
266 return stride1 == 0 ? 1 : -1;
267 }
268 }
269 //move on to the next input if one of the dimensions is broadcasted
270 if (stride0 == 0 || stride1 == 0) {
271 continue;
272 // it is important to return here only with strict comparisons, for equal strides we try to break the tie later
273 // by comparing corresponding dimensions or if that does not work, moving on to the next tensor
274 } else if (stride0 < stride1) {
275 return -1;
276 } else if (stride0 > stride1) {
277 return 1;
278 } else { //equal strides, use dimensions themselves as the tie-breaker.
279 //at this point, with zero strides out of the way, we are guaranteed that operand dimensions are equal to shape_
280 auto t_dim0 = shape_[dim0];
281 auto t_dim1 = shape_[dim1];
282 //return only if dimensions should be swapped, otherwise move on to the next tensor
283 if (t_dim0 > t_dim1) {
284 return 1;
285 }
286 }
287 }
288 return 0;
289 };
290
291 // insertion sort with support for ambiguous comparisons
292 for (const auto i : c10::irange(1, ndim())) {
293 int dim1 = i;
294 for (int dim0 = i - 1; dim0 >= 0; dim0--) {
295 int comparison = should_swap(perm_[dim0], perm_[dim1]);
296 if (comparison > 0) {
297 std::swap(perm_[dim0], perm_[dim1]);
298 dim1 = dim0;
299 } else if (comparison < 0) {
300 break;
301 }
302 }
303 }
304
305 // perform re-ordering of shape and strides
306 permute_dimensions(perm_);
307 }
308
309 // Computes a common dtype using type promotion
310 // See the [Common Dtype Computation] note
compute_common_dtype()311 ScalarType TensorIteratorBase::compute_common_dtype() {
312 at::native::ResultTypeState state = {};
313 for (const auto& op : operands_) {
314 if (op.is_output) {
315 continue;
316 }
317
318 state = at::native::update_result_type_state(op.tensor(), state);
319 }
320
321 common_dtype_ = at::native::result_type(state);
322 TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined);
323
324 return common_dtype_;
325 }
326
original_options(const OperandInfo & op)327 static TensorOptions original_options(const OperandInfo& op) {
328 if (op.original_tensor_base().defined()) {
329 return op.original_tensor_base().options();
330 } else {
331 return op.options();
332 }
333 }
334
335 // Implements the behavior of the following flags:
336 // - check_all_same_dtype_
337 // - check_all_same_device_
338 // - enforce_safe_casting_to_output_
339 // - promote_inputs_to_common_dtype_
340 // - cast_common_dtype_to_outputs_
341 //
342 // See their descriptions in TensorIterator.h for details.
343 // NOTE: Checks for more specific behaviors (e.g. the first and second
344 // inputs must share a dtype, but the third must have the long dtype)
345 // should be implemented directly and outside of TensorIterator.
compute_types(const TensorIteratorConfig & config)346 void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
347 // Reviews operands (1/2)
348 // - validates that all input tensors are defined
349 // - computes common device
350 // - determines if there are undefined outputs
351 // - determines if there are different dtypes and attempts
352 // to quickly acquire a common dtype
353 Device common_device = kCPU;
354 common_dtype_ = ScalarType::Undefined;
355 // NB: despite output_dtype's generic sounding name, it only is
356 // used in a nontrivial way if check_all_same_dtype is true
357 ScalarType output_dtype = ScalarType::Undefined;
358 bool has_different_input_dtypes = false;
359 bool has_different_output_dtypes = false;
360 bool has_undefined_outputs = false;
361
362 for (auto& op : operands_) {
363 // Validates that all inputs have type information, and that
364 // if an output is missing type information that we can infer
365 // the device it should be allocated on.
366 if (!op.is_type_defined()) {
367 TORCH_INTERNAL_ASSERT(op.is_output, "Found type undefined input tensor!");
368
369 if (config.static_dtype_.has_value()) {
370 op.target_dtype = config.static_dtype_.value();
371 } else {
372 has_undefined_outputs = true;
373 }
374
375 if (config.static_device_.has_value()) {
376 op.device = config.static_device_.value();
377 } else {
378 TORCH_INTERNAL_ASSERT(config.check_all_same_device_);
379 }
380
381 if (has_undefined_outputs || !op.device.has_value()) {
382 continue;
383 }
384 }
385
386 // Validates input tensors are defined
387 if (!op.tensor_base().defined()) {
388 TORCH_INTERNAL_ASSERT(op.is_output, "Found undefined input tensor!");
389 continue;
390 }
391
392 TORCH_INTERNAL_ASSERT(op.target_dtype == op.current_dtype)
393
394 // Acquires the first non-CPU device (if any) as the common device
395 if (common_device == kCPU && !op.tensor_base().is_cpu()) {
396 common_device = op.tensor_base().device();
397 }
398
399 if (!op.is_output) {
400 // Determines if there are varying input dtypes
401 // NOTE: the common dtype is set to the first defined input dtype observed
402 if (op.target_dtype != common_dtype_) {
403 if (common_dtype_ == ScalarType::Undefined) {
404 common_dtype_ = op.target_dtype;
405 } else {
406 has_different_input_dtypes = true;
407 }
408 }
409 } else { // op.is_output
410 // Determines if there are varying output dtypes
411 // NOTE: the output dtype is set to the first defined output dtype observed
412 if (op.target_dtype != output_dtype) {
413 if (output_dtype == ScalarType::Undefined) {
414 output_dtype = op.target_dtype;
415 } else {
416 has_different_output_dtypes = true;
417 }
418 }
419 }
420 }
421
422 // Checks that either the computation type is computable or unneeded
423 TORCH_INTERNAL_ASSERT(!(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ &&
424 (has_undefined_outputs || config.enforce_safe_casting_to_output_ ||
425 config.cast_common_dtype_to_outputs_)));
426
427 // Checks that all inputs and defined outputs are the same dtype, if requested
428 if (config.check_all_same_dtype_ &&
429 (has_different_input_dtypes || has_different_output_dtypes ||
430 (common_dtype_ != output_dtype && output_dtype != ScalarType::Undefined))) {
431 // Throws an informative error message
432 for (auto& op : operands_) {
433 if (!op.tensor_base().defined()) {
434 continue;
435 }
436
437 TORCH_CHECK(op.target_dtype == common_dtype_,
438 "Found dtype ", op.target_dtype, " but expected ", common_dtype_);
439 }
440 }
441
442 // Short-circuits if no additional work required
443 if (!has_undefined_outputs && !config.check_all_same_device_ &&
444 !config.promote_inputs_to_common_dtype_ && !config.cast_common_dtype_to_outputs_ &&
445 !config.enforce_safe_casting_to_output_) {
446 // Invalidates common_dtype_ if it could not be inferred
447 common_dtype_ = has_different_input_dtypes ? ScalarType::Undefined : common_dtype_;
448 return;
449 }
450
451 // Computes a common dtype, if needed
452 if ((has_different_input_dtypes || all_ops_are_scalars_) && config.promote_inputs_to_common_dtype_) {
453 common_dtype_ = compute_common_dtype();
454 }
455
456 // Promotes common dtype to the default float scalar type, if needed
457 if (config.promote_integer_inputs_to_float_ &&
458 c10::isIntegralType(common_dtype_, /*includeBool=*/true)) {
459 common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype());
460 }
461
462 // Reviews operands (2/2)
463 // - sets metadata for undefined outputs
464 // - checks that all tensors are on the same device, if requested
465 // - checks that the common dtype can safely cast to each output, if requested
466 // - creates temporaries for CPU operations, if needed and requested
467 common_device_ = common_device;
468 int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0;
469 int current_cpu_scalars_on_non_cpu = 0;
470 for (auto& op : operands_) {
471 bool is_type_defined = op.is_type_defined();
472 bool is_device_defined = op.is_device_defined();
473
474 if (!is_type_defined) {
475 op.target_dtype = common_dtype_;
476 }
477 if (!is_device_defined) {
478 op.device = common_device;
479 }
480
481 if (!is_type_defined && !is_device_defined) {
482 continue;
483 }
484
485 // Skips undefined tensors
486 if (!op.tensor_base().defined()) {
487 continue;
488 }
489
490 // Checks all tensors are on the same device, if requested
491 if (config.check_all_same_device_) {
492 // Handles CPU scalars on CUDA kernels that support them
493 if (!common_device.is_cpu() &&
494 config.allow_cpu_scalars_ && !op.is_output && op.tensor_base().dim() == 0 &&
495 op.tensor_base().is_cpu()) {
496 TORCH_CHECK(current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu,
497 "Trying to pass too many CPU scalars to non-CPU kernel!");
498 ++current_cpu_scalars_on_non_cpu;
499 } else if (op.device.value() != common_device) {
500 TORCH_CHECK(false,
501 "Expected all tensors to be on the same device, but "
502 "found at least two devices, ", common_device, " and ", op.device.value(), "!");
503 }
504 }
505
506 // Checks safe casting, if requested
507 if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) {
508 TORCH_CHECK(canCast(common_dtype_, op.current_dtype),
509 "result type ", common_dtype_, " can't be cast to the "
510 "desired output type ", op.current_dtype);
511 }
512
513 // Creates temporaries for CPU operations, if needed and requested
514 // TODO: reuse temporaries when possible (e.g. for inplace operations)
515 if (common_device == kCPU) {
516 // Casts to outputs by creating temporaries of the correct dtype (if needed)
517 // NB: we skip this on is_meta_, because the temporary allocation here is
518 // unnecessary if we aren't going to actually do the compute
519 if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) {
520 TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
521 // Marker [Output original_tensor is set]
522 // NB: do NOT use set_output here, as the temporary is NOT a true output;
523 // op.tensor is the true output and it was pre-provided for us.
524 // TODO: The logic for cast_outputs will need to be handled by the
525 // structured kernels implementation. What probably should happen
526 // is that we pass in the inferred dtype into the out kernel, and
527 // then after calling the out kernel, do the conversion (which
528 // is cast_outputs here), but integrating this with existing
529 // TensorIterator will take a little doing
530 op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(
531 at::empty_like(op.tensor(),
532 op.tensor_base().options().dtype(common_dtype_),
533 LEGACY_CONTIGUOUS_MEMORY_FORMAT)));
534 if (!names_.empty()) {
535 namedinference::propagate_names(op.tensor_base(), names_);
536 }
537 op.current_dtype = common_dtype_;
538 op.target_dtype = common_dtype_;
539 }
540
541 // Promotes inputs by creating temporaries of the correct dtype
542 if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
543 op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(op.tensor().to(common_dtype_)));
544 op.current_dtype = common_dtype_;
545 op.target_dtype = common_dtype_;
546 }
547 }
548 }
549 }
550
compatible_stride(int64_t element_size) const551 StrideVector TensorIteratorBase::compatible_stride(int64_t element_size) const {
552 auto stride = StrideVector();
553 int64_t next_stride = element_size;
554 for (const auto dim : c10::irange(ndim())) {
555 stride.push_back(next_stride);
556 next_stride *= shape_[dim];
557 }
558 return stride;
559 }
560
invert_perm(IntArrayRef input) const561 DimVector TensorIteratorBase::invert_perm(IntArrayRef input) const {
562 // Invert the permutation caused by reorder_dimensions. This is not valid
563 // after coalesce_dimensions is called.
564 TORCH_INTERNAL_ASSERT(!has_coalesced_dimensions_);
565 TORCH_INTERNAL_ASSERT(input.size()==perm_.size());
566 auto res = DimVector(input.size()); //no initialization needed, every value in res should be written to.
567 for (const auto dim : c10::irange(ndim())) {
568 res[perm_[dim]] = input[dim];
569 }
570 return res;
571 }
572
allocate_or_resize_outputs()573 void TensorIteratorBase::allocate_or_resize_outputs() {
574 for (const auto i : c10::irange(num_outputs_)) {
575 auto& op = operands_[i];
576 if (!op.tensor_base().defined() || op.will_resize) {
577 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
578 auto element_size = elementSize(op.target_dtype);
579 op.stride_bytes = compatible_stride(static_cast<int64_t>(element_size));
580 // check if permutation is just an inverted order
581 bool inverted = true;
582 for (const auto j : c10::irange(ndim())) {
583 if (perm_[j] != ndim() - j - 1) {
584 inverted = false;
585 break;
586 }
587 }
588 auto tensor_shape = invert_perm(shape_);
589 if (inverted) {
590 // can just return contiguous output
591 // it is faster because it avoids allocating 0 size tensor and
592 // resizing and restriding it
593 set_output_raw_strided(i, tensor_shape, {}, original_options(op), names_);
594 } else {
595 auto tensor_stride = invert_perm(op.stride_bytes);
596 for (const auto dim : c10::irange(ndim())) {
597 tensor_stride[dim] /= static_cast<int64_t>(element_size);
598 }
599 set_output_raw_strided(i, tensor_shape, tensor_stride, original_options(op), names_);
600 }
601 op.current_dtype = op.target_dtype;
602 } else if (op.tensor_base().defined()) {
603 // Even if we don't resize, we still need to tell set_output about
604 // the output, so that we properly set guard and propagate names
605 set_output_raw_strided(i, op.tensor_base().sizes(), {}, original_options(op), names_);
606 }
607 }
608 }
609
compute_names(const TensorIteratorConfig & config)610 void TensorIteratorBase::compute_names(const TensorIteratorConfig& config) {
611 bool should_infer_names = std::any_of(
612 operands_.begin(),
613 operands_.end(),
614 [](const OperandInfo& op) {
615 return op.tensor_base().defined() && op.tensor_base().has_names();
616 });
617 if (!should_infer_names) {
618 return;
619 }
620
621 for (auto& op : operands_) {
622 if (!op.tensor_base().defined()) continue;
623 // Don't include output tensors if we are resizing, since we will
624 // clobber their names in any case. (If the output tensor was
625 // also an input tensor, we'll pick it up when it shows up again
626 // in operands).
627 if (config.resize_outputs_ && op.is_output) continue;
628 // perform name inference
629 if (names_.empty()) {
630 names_ = op.tensor_base().names();
631 } else {
632 names_ = NameVector(unify_from_right(names_, op.tensor_base().names()));
633 }
634 }
635 }
636
coalesce_dimensions()637 void TensorIteratorBase::coalesce_dimensions() {
638 if (ndim() <= 1) {
639 return;
640 }
641
642 // We can coalesce two adjacent dimensions if either dim has size 1 or if:
643 // shape[n] * stride[n] == stride[n + 1].
644 auto can_coalesce = [&](int dim0, int dim1) {
645 auto shape0 = shape_[dim0];
646 auto shape1 = shape_[dim1];
647 if (shape0 == 1 || shape1 == 1) {
648 return true;
649 }
650 for (const auto i : c10::irange(ntensors())) {
651 auto& stride = operands_[i].stride_bytes;
652 if (shape0 * stride[dim0] != stride[dim1]) {
653 return false;
654 }
655 }
656 return true;
657 };
658
659 // replace each operands stride at dim0 with its stride at dim1
660 auto replace_stride = [&](int dim0, int dim1) {
661 for (const auto i : c10::irange(ntensors())) {
662 auto& stride = operands_[i].stride_bytes;
663 stride[dim0] = stride[dim1];
664 }
665 };
666
667 int prev_dim = 0;
668 for (const auto dim : c10::irange(1, ndim())) {
669 if (can_coalesce(prev_dim, dim)) {
670 if (shape_[prev_dim] == 1) {
671 replace_stride(prev_dim, dim);
672 }
673 shape_[prev_dim] *= shape_[dim];
674 } else {
675 prev_dim++;
676 if (prev_dim != dim) {
677 replace_stride(prev_dim, dim);
678 shape_[prev_dim] = shape_[dim];
679 }
680 }
681 }
682
683 shape_.resize(prev_dim + 1);
684 for (const auto i : c10::irange(ntensors())) {
685 operands_[i].stride_bytes.resize(ndim());
686 }
687 has_coalesced_dimensions_ = true;
688 }
689
numel() const690 int64_t TensorIteratorBase::numel() const {
691 int64_t numel = 1;
692 for (int64_t size : shape_) {
693 numel *= size;
694 }
695 return numel;
696 }
697
get_dim_strides(int dim) const698 StrideVector TensorIteratorBase::get_dim_strides(int dim) const {
699 auto dims = ndim();
700 auto inner_strides = StrideVector();
701 for (auto& op : operands_) {
702 inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]);
703 }
704 return inner_strides;
705 }
706
get_base_ptrs() const707 SmallVector<char*, 4> TensorIteratorBase::get_base_ptrs() const {
708 auto ptrs = SmallVector<char*, 4>(ntensors());
709 at::get_base_ptrs(ptrs.data(), operands_);
710 return ptrs;
711 }
712
is_dim_reduced(int dim) const713 bool TensorIteratorBase::is_dim_reduced(int dim) const {
714 for (auto& op : operands_) {
715 if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) {
716 return true;
717 }
718 }
719 return false;
720 }
721
permute_dimensions(IntArrayRef perm)722 void TensorIteratorBase::permute_dimensions(IntArrayRef perm) {
723 TORCH_INTERNAL_ASSERT(perm.size() == static_cast<unsigned>(ndim()));
724
725 auto reorder = [perm](IntArrayRef data) {
726 auto res = DimVector(data.size(), 0);
727 for (const auto i : c10::irange(perm.size())) {
728 res[i] = data[perm[i]];
729 }
730 return res;
731 };
732
733 // Update shape and strides
734 shape_ = reorder(shape_);
735 for (auto& op : operands_) {
736 if (!op.stride_bytes.empty()) {
737 op.stride_bytes = reorder(op.stride_bytes);
738 }
739 }
740 }
741
num_output_elements() const742 int64_t TensorIteratorBase::num_output_elements() const {
743 int64_t elem = 1;
744 for (const auto dim : c10::irange(ndim())) {
745 if (operands_[0].stride_bytes[dim] != 0 || shape_[dim] == 0) {
746 elem *= shape_[dim];
747 }
748 }
749 return elem;
750 }
751
num_reduce_dims() const752 int TensorIteratorBase::num_reduce_dims() const {
753 int count = 0;
754 for (const auto dim : c10::irange(ndim())) {
755 if (operands_[0].stride_bytes[dim] == 0) {
756 count++;
757 }
758 }
759 return count;
760 }
761
for_each(loop2d_t loop,int64_t grain_size)762 void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) {
763 int64_t numel = this->numel();
764 if (numel == 0) {
765 return;
766 } else if (numel < grain_size || at::get_num_threads() == 1) {
767 return serial_for_each(loop, {0, numel});
768 } else {
769 at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
770 serial_for_each(loop, {begin, end});
771 });
772 }
773 }
774
get_strides() const775 StrideVector TensorIteratorBase::get_strides() const {
776 const auto dim = ndim();
777 StrideVector strides(static_cast<size_t>(std::max(dim, 2)) * ntensors());
778 at::get_strides(strides.data(), operands_, dim);
779 return strides;
780 }
781
serial_for_each(loop2d_t loop,Range range) const782 void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const {
783 if (range.size() == 0) {
784 return;
785 }
786
787 const auto ntensors = this->ntensors();
788 const auto ndim = this->ndim();
789
790 c10::SmallBuffer<char*, 4> ptrs(ntensors);
791 c10::SmallBuffer<int64_t, 8> strides(ntensors * static_cast<size_t>(std::max(ndim, 2)));
792
793 at::get_base_ptrs(ptrs.data(), operands_);
794 at::get_strides(strides.data(), operands_, ndim);
795 at::internal::serial_for_each(
796 shape_, strides, ptrs.data(), ptrs.size(), loop, range);
797 }
798
is_trivial_1d() const799 bool TensorIteratorBase::is_trivial_1d() const {
800 // TODO: check for casting once it's supported
801 return ndim() == 1;
802 }
803
is_contiguous() const804 bool TensorIteratorBase::is_contiguous() const {
805 if (numel() == 1) {
806 return true;
807 }
808 if (ndim() != 1) {
809 return false;
810 }
811 return has_contiguous_first_dim();
812 }
813
814
is_scalar(int64_t arg) const815 bool TensorIteratorBase::is_scalar(int64_t arg) const {
816 const auto& stride = operands_[arg].stride_bytes;
817 for (const auto i : c10::irange(ndim())) {
818 if (stride[i] != 0 && shape_[i] != 1) {
819 return false;
820 }
821 }
822 return true;
823 }
824
is_cpu_scalar(int64_t arg) const825 bool TensorIteratorBase::is_cpu_scalar(int64_t arg) const {
826 return is_scalar(arg) && device(arg).is_cpu();
827 }
828
cast_outputs()829 void TensorIteratorBase::cast_outputs() {
830 for (auto& op : operands_) {
831 if (op.is_output && op.original_tensor_base().defined() &&
832 op.original_tensor_base().scalar_type() != op.current_dtype) {
833 // TODO: Now that set_output resizes both the original_tensor
834 // and tensor, this condition should no longer ever be true
835 const auto &original_tensor = op.original_tensor();
836 const auto &tensor = op.tensor();
837 if (original_tensor.sizes() != tensor.sizes()) {
838 original_tensor.resize_as_(tensor).as_strided_(tensor.sizes(), tensor.strides());
839 }
840 original_tensor.copy_(tensor);
841 op.restore_original_tensor();
842 }
843 }
844 }
845
data_ptr(int64_t arg) const846 void* TensorIteratorBase::data_ptr(int64_t arg) const {
847 return operands_[arg].data;
848 }
849
remove_operand(int64_t arg)850 void TensorIteratorBase::remove_operand(int64_t arg) {
851 operands_.erase(operands_.begin() + arg);
852 }
853
unsafe_replace_operand(int64_t arg,void * data)854 void TensorIteratorBase::unsafe_replace_operand(int64_t arg, void* data) {
855 operands_[arg].data = data;
856 }
857
narrow(int dim,int64_t start,int64_t size)858 void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
859 TORCH_INTERNAL_ASSERT(dim < ndim() && size >= 1);
860 shape_[dim] = size;
861 view_offsets_[dim] += start;
862 for (auto& op : operands_) {
863 op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
864 }
865 if (size == 1 && !is_reduction_) {
866 coalesce_dimensions();
867 }
868 }
869
select_all_keeping_dim(int start_dim,IntArrayRef indices)870 void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indices) {
871 TORCH_INTERNAL_ASSERT(start_dim <= ndim());
872 for (const auto i : c10::irange(start_dim, ndim())) {
873 for (auto& op : operands_) {
874 op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
875 }
876 shape_[i] = 1;
877 }
878 }
879
880 #define BINARY_FLOAT_OP_CONFIG() \
881 TensorIteratorConfig() \
882 .set_check_mem_overlap(true) \
883 .allow_cpu_scalars(true) \
884 .promote_inputs_to_common_dtype(true) \
885 .cast_common_dtype_to_outputs(true) \
886 .enforce_safe_casting_to_output(true) \
887 .promote_integer_inputs_to_float(true)
888
889 // Helper to construct a binary op that promotes integer inputs to float.
build_binary_float_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)890 void TensorIteratorBase::build_binary_float_op(
891 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
892 build(BINARY_FLOAT_OP_CONFIG()
893 .add_owned_output(out)
894 .add_owned_const_input(a)
895 .add_owned_const_input(b));
896 }
897
build_borrowing_binary_float_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)898 void TensorIteratorBase::build_borrowing_binary_float_op(
899 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
900 build(BINARY_FLOAT_OP_CONFIG()
901 .add_output(out)
902 .add_const_input(a)
903 .add_const_input(b));
904 }
905
set_up_comparison_op_config(TensorIteratorConfig & config,const TensorBase & out)906 static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) {
907 config.set_check_mem_overlap(true);
908 config.allow_cpu_scalars(true);
909 config.promote_inputs_to_common_dtype(true);
910
911 // When 'out' isn't defined (e.g. for the functional operator 'a == b'), we
912 // want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we
913 // don't coerce the output.
914 if (!out.defined()) {
915 config.declare_static_dtype(kBool);
916 }
917
918 // Note [special-case bool outputs]
919 // We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
920 // has `bool` dtype. This is a performance optimization: the functional
921 // version of all comparison/logical ops uses a bool output tensor, and we'd like to
922 // avoid creating a temporary copy of the output.
923 // However, note that all kernels using this TensorIterator will need to special-case when
924 // the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
925 if (out.defined() && out.scalar_type() != kBool) {
926 config.cast_common_dtype_to_outputs(true);
927 }
928 }
929
build_comparison_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)930 void TensorIteratorBase::build_comparison_op(
931 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
932 TensorIteratorConfig config;
933 set_up_comparison_op_config(config, out);
934
935 config.add_owned_output(out);
936 config.add_owned_const_input(a);
937 config.add_owned_const_input(b);
938 build(config);
939 }
940
build_borrowing_comparison_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)941 void TensorIteratorBase::build_borrowing_comparison_op(
942 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
943 TensorIteratorConfig config;
944 set_up_comparison_op_config(config, out);
945
946 config.add_borrowed_output(out);
947 config.add_borrowed_const_input(a);
948 config.add_borrowed_const_input(b);
949 build(config);
950 }
951
build_borrowing_except_last_argument_comparison_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)952 void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op(
953 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
954 TensorIteratorConfig config;
955 set_up_comparison_op_config(config, out);
956
957 config.add_borrowed_output(out);
958 config.add_borrowed_const_input(a);
959 config.add_owned_const_input(b);
960 build(config);
961 }
962
build_ternary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b,const TensorBase & c)963 void TensorIteratorBase::build_ternary_op(
964 const TensorBase& out, const TensorBase& a,
965 const TensorBase& b, const TensorBase& c) {
966 build(TensorIteratorConfig()
967 .promote_inputs_to_common_dtype(true)
968 .cast_common_dtype_to_outputs(true)
969 .enforce_safe_casting_to_output(true)
970 .add_owned_output(out)
971 .add_owned_const_input(a)
972 .add_owned_const_input(b)
973 .add_owned_const_input(c));
974 }
975
976 // This cannot be a function because TensorIteratorConfig is not
977 // copyable or movable, so it can't be returned from the function.
978 #define BINARY_OP_CONFIG() \
979 TensorIteratorConfig() \
980 .set_check_mem_overlap(true) \
981 .allow_cpu_scalars(true) \
982 .promote_inputs_to_common_dtype(true) \
983 .cast_common_dtype_to_outputs(true) \
984 .enforce_safe_casting_to_output(true) \
985
build_binary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)986 void TensorIteratorBase::build_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b) {
987 build(BINARY_OP_CONFIG()
988 .add_owned_output(out)
989 .add_owned_const_input(a)
990 .add_owned_const_input(b));
991 }
992
build_borrowing_binary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)993 void TensorIteratorBase::build_borrowing_binary_op(
994 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
995 build(BINARY_OP_CONFIG()
996 .add_output(out)
997 .add_const_input(a)
998 .add_const_input(b));
999 }
1000
1001 // This cannot be a function because TensorIteratorConfig is not
1002 // copyable or movable, so it can't be returned from the function.
1003 #define UNARY_FLOAT_OP_CONFIG() \
1004 TensorIteratorConfig() \
1005 .set_check_mem_overlap(true) \
1006 .promote_inputs_to_common_dtype(true) \
1007 .cast_common_dtype_to_outputs(true) \
1008 .enforce_safe_casting_to_output(true) \
1009 .promote_integer_inputs_to_float(true)
1010
build_unary_float_op(const TensorBase & out,const TensorBase & a)1011 void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
1012 build(UNARY_FLOAT_OP_CONFIG()
1013 .add_owned_output(out)
1014 .add_owned_const_input(a));
1015 }
1016
build_borrowing_unary_float_op(const TensorBase & out,const TensorBase & a)1017 void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) {
1018 build(UNARY_FLOAT_OP_CONFIG()
1019 .add_output(out)
1020 .add_const_input(a));
1021 }
1022
1023 // This cannot be a function because TensorIteratorConfig is not
1024 // copyable or movable, so it can't be returned from the function.
1025 #define UNARY_OP_CONFIG() \
1026 TensorIteratorConfig() \
1027 .set_check_mem_overlap(true) \
1028 .cast_common_dtype_to_outputs(false) \
1029 .enforce_safe_casting_to_output(false) \
1030 .check_all_same_dtype(true)
1031
build_unary_op(const TensorBase & out,const TensorBase & a)1032 void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
1033 build(UNARY_OP_CONFIG()
1034 .add_owned_output(out)
1035 .add_owned_const_input(a));
1036 }
1037
build_borrowing_unary_op(const TensorBase & out,const TensorBase & a)1038 void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) {
1039 build(UNARY_OP_CONFIG()
1040 .add_output(out)
1041 .add_const_input(a));
1042 }
1043
build_output_borrowing_argument_owning_unary_op(const TensorBase & out,const TensorBase & a)1044 void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) {
1045 build(UNARY_OP_CONFIG()
1046 .add_output(out)
1047 .add_owned_const_input(a));
1048 }
1049
1050 // Helper to construct a unary op that forcibly promotes output to boolean.
1051 // Only be used when the output tensor must have boolean type.
build_borrowing_unary_force_boolean_op(const TensorBase & out,const TensorBase & a)1052 void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
1053 build(TensorIteratorConfig()
1054 .set_check_mem_overlap(true)
1055 .check_all_same_dtype(false)
1056 .declare_static_dtype(at::kBool)
1057 .declare_static_device(a.device())
1058 .add_output(out)
1059 .add_const_input(a));
1060 }
1061
binary_op(TensorBase & out,const TensorBase & a,const TensorBase & b)1062 TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
1063 TensorIterator iter;
1064 iter.build_binary_op(out, a, b);
1065 return iter;
1066 }
1067
borrowing_binary_op(const TensorBase & out,const TensorBase & a,const TensorBase & b)1068 TensorIterator TensorIterator::borrowing_binary_op(
1069 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
1070 TensorIterator iter;
1071 iter.build_borrowing_binary_op(out, a, b);
1072 return iter;
1073 }
1074
binary_float_op(TensorBase & out,const TensorBase & a,const TensorBase & b)1075 TensorIterator TensorIterator::binary_float_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
1076 TensorIterator iter;
1077 iter.build_binary_float_op(out, a, b);
1078 return iter;
1079 }
1080
comparison_op(TensorBase & out,const TensorBase & a,const TensorBase & b)1081 TensorIterator TensorIterator::comparison_op(TensorBase& out, const TensorBase& a,
1082 const TensorBase& b) {
1083 TensorIterator iter;
1084 iter.build_comparison_op(out, a, b);
1085 return iter;
1086 }
1087
unary_op(TensorBase & out,const TensorBase & a)1088 TensorIterator TensorIterator::unary_op(TensorBase& out, const TensorBase& a) {
1089 TensorIterator iter;
1090 iter.build_unary_op(out, a);
1091 return iter;
1092 }
1093
unary_float_op(TensorBase & out,const TensorBase & a)1094 TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase& a) {
1095 TensorIterator iter;
1096 iter.build_unary_float_op(out, a);
1097 return iter;
1098 }
1099
1100 #define NULLARY_OP_CONFIG() \
1101 TensorIteratorConfig() \
1102 .set_check_mem_overlap(true) \
1103 .check_all_same_dtype(false) \
1104 /* FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 */ \
1105 .resize_outputs(false)
1106
nullary_op(TensorBase & out)1107 TensorIterator TensorIterator::nullary_op(TensorBase& out) {
1108 return NULLARY_OP_CONFIG()
1109 .add_owned_output(out)
1110 .build();
1111 }
1112
borrowing_nullary_op(const TensorBase & out)1113 TensorIterator TensorIterator::borrowing_nullary_op(const TensorBase& out) {
1114 return NULLARY_OP_CONFIG()
1115 .add_output(out)
1116 .build();
1117 }
1118
reduce_op(TensorBase & out,const TensorBase & a)1119 TensorIterator TensorIterator::reduce_op(TensorBase& out, const TensorBase& a) {
1120 TORCH_INTERNAL_ASSERT(out.defined());
1121 return TensorIteratorConfig()
1122 .set_check_mem_overlap(false)
1123 .add_owned_output(out)
1124 .add_owned_const_input(a)
1125 .resize_outputs(false)
1126 .is_reduction(true)
1127 // TODO: not supporting casting to outputs is only really necessary for arg{min,max}
1128 .promote_inputs_to_common_dtype(true)
1129 .build();
1130 }
1131
reduce_op(TensorBase & out1,TensorBase & out2,const TensorBase & a)1132 TensorIterator TensorIterator::reduce_op(TensorBase& out1, TensorBase& out2, const TensorBase& a) {
1133 TORCH_INTERNAL_ASSERT(out1.defined());
1134 TORCH_INTERNAL_ASSERT(out2.defined());
1135 TORCH_CHECK(a.device() == out1.device() && out1.device() == out2.device(),
1136 "reduce_op(): expected input and both outputs to be on same device, but input is on ", a.device(),
1137 ", output1 is on ", out1.device(), " and output2 is on", out2.device());
1138 TORCH_CHECK(out1.dim() == out2.dim(), "reduce_op(): expected both outputs to have same number of dims, but output1 has ", out1.dim(),
1139 " and output2 has ", out2.dim());
1140 TORCH_CHECK(out1.sizes() == out2.sizes(), "reduce_op(): expected both outputs to have same sizes, but output1 has ", out1.sizes(),
1141 " and output2 has ", out2.sizes());
1142 TORCH_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has ", out1.strides(),
1143 " and output2 has ", out2.strides());
1144 return TensorIteratorConfig()
1145 .set_check_mem_overlap(false)
1146 .add_owned_output(out1)
1147 .add_owned_output(out2)
1148 .add_owned_const_input(a)
1149 .resize_outputs(false)
1150 .is_reduction(true)
1151 .check_all_same_dtype(false)
1152 .build();
1153 }
1154
populate_operands(TensorIteratorConfig & config)1155 void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
1156 for (const auto idx : c10::irange(config.tensors_.size())) {
1157 auto& tensor = config.tensors_[idx];
1158 // If *any* of the arguments is a meta tensor, the overall
1159 // computation is a meta computation (don't do any work,
1160 // just compute output information). This aligns with
1161 // our multiple dispatch semantics.
1162 if (tensor->is_meta()) {
1163 is_meta_ = true;
1164 }
1165 operands_.emplace_back(std::move(tensor));
1166 operands_[idx].is_const = config.is_tensor_const(idx);
1167 }
1168 num_outputs_ = config.num_outputs_;
1169 }
1170
mark_outputs()1171 void TensorIteratorBase::mark_outputs() {
1172 // TODO: merge this into populate_operands
1173 for (const auto i : c10::irange(num_outputs_)) {
1174 operands_[i].is_output = true;
1175 const auto& output = tensor(i);
1176 if (!output.defined()) continue;
1177
1178 // check if output is also an input
1179 for (const auto arg : c10::irange(num_outputs_, ntensors())) {
1180 const auto& input = tensor(arg);
1181 if (output.is_same(input)) {
1182 operands_[i].is_read_write = true;
1183 }
1184 }
1185 }
1186 }
1187
mark_resize_outputs(const TensorIteratorConfig & config)1188 void TensorIteratorBase::mark_resize_outputs(const TensorIteratorConfig& config) {
1189 // Outputs cannot be broadcasted. Check that the shape of the outputs matches
1190 // the inferred shape. There's an exception for write-only tensors to support
1191 // our legacy behavior that functions with `out=` arguments resize their
1192 // outputs.
1193 if (config.static_shape_.has_value()) {
1194 return;
1195 }
1196 for (const auto i : c10::irange(num_outputs_)) {
1197 const auto& output = tensor(i);
1198 if (!output.defined()) {
1199 operands_[i].will_resize = true;
1200 }
1201 if (output.defined() && !output.sizes().equals(shape_)) {
1202 if (config.resize_outputs_ && !operands_[i].is_read_write) {
1203 operands_[i].will_resize = true;
1204 continue;
1205 }
1206 // for reduction, output size does not match shape_, as output is reduced size, and shape_ is size of the input
1207 TORCH_CHECK(is_reduction_, "output with shape ", output.sizes(), " doesn't match the broadcast shape ",
1208 shape_);
1209 }
1210 }
1211 }
1212
compute_mem_overlaps(const TensorIteratorConfig & config)1213 void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config) {
1214 if (!config.check_mem_overlap_) {
1215 return;
1216 }
1217 for (const auto i : c10::irange(num_outputs_)) {
1218 const auto& output = tensor_base(i);
1219 if (!output.defined()) continue;
1220 assert_no_internal_overlap(output);
1221 for (const auto j : c10::irange(num_outputs_, ntensors())) {
1222 const auto& input = tensor_base(j);
1223 if (!input.is_same(output)) {
1224 assert_no_partial_overlap(output, input);
1225 }
1226 }
1227 }
1228 }
1229
compute_shape(const TensorIteratorConfig & config)1230 void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) {
1231 if (config.static_shape_.has_value()) {
1232 shape_ = *config.static_shape_;
1233 return;
1234 }
1235
1236 all_ops_same_shape_ = true;
1237 bool has_scalars = false;
1238 bool has_tensors = false;
1239 for (auto& op : operands_) {
1240 if (!op.tensor_base().defined()) continue;
1241
1242 // For now, don't include output tensors when we're resizing outputs.
1243 // These shapes don't participate in shape computation.
1244 // This preserves the legacy behavior where torch.add(..., out=dst) resizes
1245 // the destination tensor. If the output tensor is also an input, we'll
1246 // pick it up later in the operands.
1247 if (config.resize_outputs_ && op.is_output) continue;
1248 TORCH_CHECK(!op.tensor_base().unsafeGetTensorImpl()->has_symbolic_sizes_strides(),
1249 "TensorIterator does not support symbolic shapes; please implement this operator in torch/_refs "
1250 "using the elementwise or reduction helpers (look at backtrace to find out what operator this is)");
1251 auto shape = op.tensor_base().sizes();
1252 if (shape.empty()) {
1253 has_scalars = true;
1254 } else {
1255 has_tensors = true;
1256 }
1257 if (has_scalars && has_tensors) {
1258 all_ops_same_shape_ = false;
1259 }
1260 if (shape_.empty()) {
1261 shape_ = shape;
1262 } else if (!shape.equals(shape_)) {
1263 all_ops_same_shape_ = false;
1264 shape_ = infer_size_dimvector(shape_, shape);
1265 }
1266 }
1267 all_ops_are_scalars_ = !has_tensors;
1268 }
1269
compute_strides(const TensorIteratorConfig & config)1270 void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) {
1271 for (auto& op : operands_) {
1272 if (op.tensor_base().defined() && !op.will_resize) {
1273 IntArrayRef original_shape = config.static_shape_ ? shape_ : op.tensor_base().sizes();
1274 auto original_stride = op.tensor_base().strides();
1275 auto element_size_in_bytes = op.tensor_base().element_size();
1276 auto offset = ndim() - original_shape.size();
1277 if (offset > 0)
1278 op.stride_bytes.resize(ndim(), 0);
1279 else
1280 op.stride_bytes.resize(ndim());
1281 for (const auto i : c10::irange(original_shape.size())) {
1282 // see NOTE: [Computing output strides]
1283 if (original_shape[i] == 1 && shape_[offset + i] !=1) {
1284 op.stride_bytes[offset + i] = 0;
1285 } else {
1286 op.stride_bytes[offset + i] = original_stride[i] * element_size_in_bytes;
1287 }
1288 }
1289 }
1290 }
1291 }
1292
can_use_32bit_indexing() const1293 bool TensorIteratorBase::can_use_32bit_indexing() const {
1294 int64_t max_value = std::numeric_limits<int32_t>::max();
1295 if (numel() > max_value) {
1296 return false;
1297 }
1298 for (auto& op : operands_) {
1299 int64_t max_offset = 1;
1300 for (const auto dim : c10::irange(ndim())) {
1301 max_offset += (shape_[dim] - 1) * op.stride_bytes[dim];
1302 }
1303 if (max_offset > max_value) {
1304 return false;
1305 }
1306 }
1307 return true;
1308 }
1309
split(int dim)1310 std::unique_ptr<TensorIterator> TensorIteratorBase::split(int dim) {
1311 TORCH_INTERNAL_ASSERT(dim >= 0 && dim < ndim() && shape()[dim] >= 2);
1312 auto copy = std::make_unique<TensorIterator>(*this);
1313
1314 bool overlaps = is_dim_reduced(dim);
1315 auto copy_size = shape_[dim] / 2;
1316 auto this_size = shape_[dim] - copy_size;
1317 copy->narrow(dim, 0, copy_size);
1318 copy->final_output_ &= !overlaps;
1319 this->narrow(dim, copy_size, this_size);
1320 this->accumulate_ |= overlaps;
1321
1322 return copy;
1323 }
1324
1325
get_dim_to_split() const1326 int TensorIteratorBase::get_dim_to_split() const {
1327 TORCH_INTERNAL_ASSERT(ndim() >= 1);
1328 int64_t max_extent = -1;
1329 int dim_to_split = -1;
1330 for (int dim = ndim() - 1; dim >= 0; dim--) {
1331 const int64_t size = shape_[dim];
1332 if (size == 0) {
1333 continue;
1334 }
1335 for (auto& op : operands_) {
1336 // std::abs is necessary to handle some special cases where we support negative strides
1337 // see the CUDA backend of at::flip
1338 const int64_t extent = (size - 1) * std::abs(op.stride_bytes[dim]);
1339 if (extent > max_extent) {
1340 max_extent = extent;
1341 dim_to_split = dim;
1342 }
1343 }
1344 }
1345 TORCH_INTERNAL_ASSERT(max_extent >= 0);
1346 return dim_to_split;
1347 }
1348
fast_set_up(const TensorIteratorConfig & config)1349 bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
1350 // This function tries to do a fast setup to avoid needless reordering of dimensions and tracking output strides
1351 // Return true if it can do fast setup or false otherwise
1352 // TODO enable fast handling for reductions
1353 FastSetupType setup_type = compute_fast_setup_type(config);
1354 if (setup_type == FastSetupType::NONE) {
1355 return false;
1356 }
1357
1358 // allocate memory for output, memory format depends on setup_type
1359 switch (setup_type) {
1360 case FastSetupType::CONTIGUOUS:
1361 {
1362 for (const auto i : c10::irange(num_outputs_)) {
1363 auto& op = operands_[i];
1364 if (!op.tensor_base().defined()) {
1365 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1366 }
1367 set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::Contiguous), names_);
1368 }
1369 break;
1370 }
1371 case FastSetupType::CHANNELS_LAST:
1372 {
1373 for (const auto i : c10::irange(num_outputs_)) {
1374 auto& op = operands_[i];
1375 if (!op.tensor_base().defined()) {
1376 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1377 }
1378 set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::ChannelsLast), names_);
1379 }
1380 break;
1381 }
1382 case FastSetupType::NON_OVERLAPPING_DENSE:
1383 {
1384 // find the index of a defined tensor in operands_ start from input tensor
1385 int i_defined; // NOLINT(cppcoreguidelines-init-variables)
1386 for (i_defined = ntensors() - 1; i_defined >= 0; --i_defined) {
1387 if (tensor(i_defined).defined()) break;
1388 }
1389 TORCH_CHECK(i_defined >= 0, "Can not find a defined tensor when fast allocating memory to outputs");
1390 for (const auto i : c10::irange(num_outputs_)) {
1391 auto& op = operands_[i];
1392 if (!op.tensor_base().defined()) {
1393 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1394 }
1395 set_output_raw_strided(i, shape_, tensor_base(i_defined).strides(), original_options(op), names_);
1396 }
1397 break;
1398 }
1399 default:
1400 TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", std::to_string((int)setup_type));
1401 }
1402 //coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here)
1403 if (ndim() > 1){
1404 has_coalesced_dimensions_ = true;
1405 }
1406 if (ndim() >= 1) {
1407 shape_[0] = numel();
1408 shape_.resize(1);
1409 }
1410 for (auto& op : operands_ ) {
1411 auto element_size_in_bytes = op.tensor_base().element_size();
1412 op.stride_bytes.resize(ndim());
1413 if (ndim()>0) {
1414 op.stride_bytes[0] = element_size_in_bytes;
1415 }
1416 }
1417 return true;
1418 }
1419
compute_fast_setup_type(const TensorIteratorConfig & config)1420 FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorConfig& config) {
1421 if (is_reduction_ || !all_ops_same_shape_) {
1422 return FastSetupType::NONE;
1423 }
1424
1425 // For linear iteration, only contiguous tensors can be coalesced
1426 // Fast setup of any other format requires changing iteration order
1427 if (enforce_linear_iteration_) {
1428 for (const auto& op : operands_) {
1429 if (op.tensor_base().defined() && !op.will_resize) {
1430 auto is_contiguous = op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
1431 if (!is_contiguous) {
1432 return FastSetupType::NONE;
1433 }
1434 }
1435 }
1436 return FastSetupType::CONTIGUOUS;
1437 }
1438
1439 bool is_contiguous = true;
1440 bool is_channels_last = true;
1441 bool is_non_overlapping_and_dense = true;
1442 for (const auto& op : operands_) {
1443 if (op.tensor_base().defined() && !op.will_resize) {
1444 is_contiguous &= op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
1445 is_channels_last &= op.tensor_base().is_contiguous(at::MemoryFormat::ChannelsLast);
1446 is_non_overlapping_and_dense &= op.tensor_base().is_non_overlapping_and_dense();
1447 }
1448 }
1449 // TODO this leads to ambiguous cases (NC11) to be always treated as contiguous
1450 if (is_contiguous) {
1451 return FastSetupType::CONTIGUOUS;
1452 }
1453 if (is_channels_last) {
1454 return FastSetupType::CHANNELS_LAST;
1455 }
1456 if (is_non_overlapping_and_dense) {
1457 int64_t prev = -1;
1458 // Fast setup is allowed only when all the defined tensors have the same shape and strides,
1459 // Iterate from back to check input tensors' strides first, then output tensors'.
1460 for (int64_t i = ntensors() - 1; i >= 0; --i) {
1461 const auto& op = operands_[i];
1462 if (op.tensor_base().defined() && !op.will_resize) {
1463 if (prev < 0) {
1464 prev = i;
1465 continue;
1466 }
1467 if (!tensor_base(prev).strides().equals(op.tensor_base().strides())) {
1468 // [Note: stride check for non contiguous tensors in fast setup]
1469 // We prevent 3 cases doing fast setup here:
1470 // 1. input tensors have different strides.
1471 // 2. output tensors won't be resized and have different strides.
1472 // 3. input tensors have the same strides, but output tensors have different strides with input tensors.
1473 // We don't allow re-stride output tensors in this case since it is not compatible with
1474 // numpy. The behavior in numpy is that if the output tensor has same shape as the input
1475 // tensor but different strides, the strides of output tensor will be preserved, so we do
1476 // the same in tensor iterator.
1477 return FastSetupType::NONE;
1478 }
1479 }
1480 }
1481 return FastSetupType::NON_OVERLAPPING_DENSE;
1482 }
1483 return FastSetupType::NONE;
1484 }
1485
1486 TensorIteratorBase::TensorIteratorBase() = default;
1487
build(TensorIteratorConfig & config)1488 void TensorIteratorBase::build(TensorIteratorConfig& config) {
1489 // populate some persistent configuration fields
1490 is_reduction_ = config.is_reduction_;
1491 enforce_linear_iteration_ = config.enforce_linear_iteration_;
1492
1493 // fill in operands_ based on configuration
1494 populate_operands(config);
1495 // set is_output and is_read_write flags on appropriate tensors
1496 mark_outputs();
1497 // Check that the outputs have no internal overlap
1498 // and do not share memory with inputs.
1499 compute_mem_overlaps(config);
1500 // Check that input dimensions are aligned correctly & compute outnames.
1501 compute_names(config);
1502 // compute the broadcasted shape
1503 compute_shape(config);
1504 // mark outputs for resizing if necessary
1505 mark_resize_outputs(config);
1506 // compute the result dtype and device
1507 compute_types(config);
1508 // try fast setup output tensor, if failed, fallback to normal setup
1509 if (!fast_set_up(config)) {
1510 // compute each tensor's stride after broadcasting
1511 compute_strides(config);
1512 // re-order dimensions to improve coalescing
1513 reorder_dimensions();
1514 // allocate the output tensor if it's not provided
1515 allocate_or_resize_outputs();
1516 // coalesce adjacent dimensions when possible
1517 if (!is_meta_) coalesce_dimensions();
1518 }
1519
1520 if (is_meta_) return;
1521
1522 auto has_storage = true;
1523 for (auto& op : operands_) {
1524 has_storage &= op.tensor_base().has_storage();
1525 }
1526 auto privateuse1_without_storage =
1527 common_device_.type() == DeviceType::PrivateUse1 &&
1528 !has_storage;
1529
1530 // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
1531 // Nothing beyond this point is important for meta functions, so it's fine to exit early here.
1532 // Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
1533 if (privateuse1_without_storage ||
1534 common_device_.type() == DeviceType::MTIA ||
1535 common_device_.type() == DeviceType::XLA ||
1536 common_device_.type() == DeviceType::IPU ||
1537 common_device_.type() == DeviceType::Lazy ||
1538 common_device_.type() == DeviceType::MAIA ||
1539 common_device_.type() == DeviceType::HPU) return;
1540
1541 for (auto& op : operands_) {
1542 TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
1543 if (op.is_const) {
1544 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1545 op.data = const_cast<void*>(op.tensor_base().const_data_ptr());
1546 } else {
1547 op.data = op.tensor_base().mutable_data_ptr();
1548 }
1549 }
1550
1551 // zero out offsets
1552 // If the tensor is a scalar, we leave room for it
1553 // So index translations in reduction can access
1554 // a valid value for the offset
1555 int64_t ndim_offsets = (ndim() ? ndim() : 1);
1556 view_offsets_ = DimVector(ndim_offsets, 0);
1557 }
1558
1559 // This is the structured kernels' implementation of set_output. It is
1560 // NEVER actually called directly; instead, a subclass of TensorIteratorBase
1561 // will override set_output to actually do the operation, and then call
1562 // set_output on the TensorIteratorBase to setup TI's metadata.
1563 // The precondition for this function is that maybe_get_output() now
1564 // unconditionally returns a real Tensor (prior to output setting,
1565 // this function may return an undefined tensor.)
set_output_raw_strided(int64_t output_idx,IntArrayRef sizes,IntArrayRef strides,TensorOptions options,DimnameList names)1566 void TensorIteratorBase::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1567 auto& op = operands_[output_idx];
1568 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
1569 const auto& t = maybe_get_output(output_idx);
1570 TORCH_INTERNAL_ASSERT(t.defined());
1571 if (!op.tensor_base().defined()) {
1572 op.tensor(c10::MaybeOwned<TensorBase>::borrowed(t));
1573 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.target_dtype == t.scalar_type());
1574 } else if (op.will_resize) {
1575 if (op.original_tensor_base().defined()) {
1576 // OK, so this is pretty weird. To understand how we can end up in
1577 // this situation, first look at Marker [Output original_tensor is set].
1578 // That is the sole site where original_tensor may be set on an
1579 // output operand. Essentially, when we are given an explicit output
1580 // tensor whose dtype doesn't match the computed common dtype from
1581 // the input operands, we do a switcheroo: we replace the (incorrectly
1582 // typed) output tensor with a correctly typed, *temporary* tensor,
1583 // and remember the original tensor in original_tensor (which will
1584 // then get written back to when we cast_outputs).
1585 //
1586 // Now, what if the given output tensor also happened to be zero
1587 // size (meaning that we will_resize it)? Well, at the call site
1588 // above, we don't necessarily(*) know what the correct shape should
1589 // be, so we give the temporary tensor the same shape as the original.
1590 // At the time of set_output is when we DO know what the correct size
1591 // is, and the subclass's implementation of set_output in structured class
1592 // responsible for resizing original_tensor. But we still have this
1593 // incorrectly sized temporary output which the structured subclass
1594 // knows nothing about, so we are obligated to also resize it here.
1595 //
1596 // This is a slight memory pessimization, because previously
1597 // original_tensor only got resized at the end of the computation, rather
1598 // than at the beginning (as happens here). However, the peak memory
1599 // usage is the same, since you need to materialize both original tensor
1600 // and temporary tensor to do the copy.
1601 //
1602 // (*) Actually, technically, we probably do know what the shape
1603 // should be, since we do shape computation before dtype computation.
1604 // So hypothetically we could figure out what the correct shape is
1605 // at that point in time and directly allocate the temporary at
1606 // the right size.
1607 //
1608 // But a better solution is to delay allocation of temporaries until
1609 // after TensorIterator builder, waiting until we actually want
1610 // to do the computation. That would also remove the necessity
1611 // for the is_meta_ test.
1612 TORCH_INTERNAL_ASSERT(op.original_tensor_base().is_same(t));
1613 TORCH_INTERNAL_ASSERT(!op.tensor_base().is_same(t));
1614 OptionalTensorRef tensor(op.tensor());
1615 at::native::resize_output(*tensor, sizes);
1616 if (!strides.empty()) {
1617 TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
1618 tensor->as_strided_(sizes, strides);
1619 } else if (options.memory_format_opt().has_value()) {
1620 tensor->unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
1621 }
1622 }
1623 }
1624 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1625 op.tensor_base().is_same(t) || op.current_dtype == op.tensor_base().scalar_type());
1626 // For simplicity, just always update the cached current_type.
1627 op.current_dtype = op.tensor_base().scalar_type();
1628 }
1629
1630 // This is the "traditional" implementation of set_output. On TensorIterator
1631 // instances, it is invoked directly from various call sites in this file. No
1632 // funny business.
set_output_raw_strided(int64_t output_idx,IntArrayRef sizes,IntArrayRef strides,TensorOptions options,DimnameList names)1633 void TensorIterator::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1634 // NB: intentionally no superclass call
1635 auto& op = operands_[output_idx];
1636 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
1637 if (!op.tensor_base().defined()) {
1638 if (strides.empty()) {
1639 op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty(sizes, options)));
1640 } else {
1641 op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty_strided(sizes, strides, options)));
1642 }
1643 op.current_dtype = op.target_dtype;
1644 } else if (op.will_resize) {
1645 at::native::resize_output(op.tensor(), sizes);
1646 if (!strides.empty()) {
1647 TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
1648 op.tensor().as_strided_(sizes, strides);
1649 } else if (options.memory_format_opt().has_value()) {
1650 op.tensor_base().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
1651 }
1652 }
1653 if (!names.empty()) {
1654 TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
1655 namedinference::propagate_names(op.tensor_base(), names);
1656 }
1657 }
1658
1659 // Not actually used by anything (TensorIterator subclass calls
1660 // its own implementation of set_output which knows exactly where
1661 // all the outputs are), but we have to provide all pure virtual methods
1662 // for MetaBase
maybe_get_output(int64_t output_idx)1663 const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) {
1664 return output(output_idx);
1665 }
1666
with_32bit_indexing() const1667 SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const {
1668 return SplitUntil32Bit(*this);
1669 }
1670
1671 /// SplitUntil32Bit. Recursively splits an iterator into sub-iterators that
1672 /// can use 32-bit indexing.
1673
iterator(const TensorIteratorBase & iter)1674 SplitUntil32Bit::iterator::iterator(const TensorIteratorBase& iter) {
1675 vec.emplace_back(new TensorIterator(iter));
1676 vec.emplace_back(nullptr); // ++ first pops the last element
1677 ++(*this);
1678 }
1679
operator ++()1680 SplitUntil32Bit::iterator& SplitUntil32Bit::iterator::operator++() {
1681 vec.pop_back();
1682 while (!vec.empty() && !vec.back()->can_use_32bit_indexing()) {
1683 auto& iter = *vec.back();
1684 auto split_dim = iter.get_dim_to_split();
1685 vec.emplace_back(iter.split(split_dim));
1686 }
1687 return *this;
1688 }
1689
operator *() const1690 TensorIterator& SplitUntil32Bit::iterator::operator*() const {
1691 return *vec.back();
1692 }
1693
begin() const1694 SplitUntil32Bit::iterator SplitUntil32Bit::begin() const {
1695 return SplitUntil32Bit::iterator(iter);
1696 }
1697
end() const1698 SplitUntil32Bit::iterator SplitUntil32Bit::end() const {
1699 return SplitUntil32Bit::iterator();
1700 }
1701
DimCounter(IntArrayRef shape,Range range)1702 DimCounter::DimCounter(IntArrayRef shape, Range range)
1703 : shape(shape)
1704 , range(range)
1705 , values(shape.size())
1706 , offset(range.begin) {
1707 std::fill(values.begin(), values.end(), 0);
1708 if (range.begin == 0) {
1709 return;
1710 }
1711
1712 int64_t linear_offset = range.begin;
1713 auto ndim = values.size();
1714 for (const auto dim : c10::irange(ndim)) {
1715 int64_t size = shape[dim];
1716 if (size > 0) {
1717 values[dim] = linear_offset % size;
1718 linear_offset /= size;
1719 }
1720 }
1721 TORCH_INTERNAL_ASSERT(linear_offset == 0);
1722 }
1723
is_done() const1724 bool DimCounter::is_done() const {
1725 return offset >= range.end;
1726 }
1727
increment(const std::array<int64_t,2> & step)1728 void DimCounter::increment(const std::array<int64_t, 2>& step) {
1729 offset += step[0] * step[1];
1730 auto ndim = values.size();
1731 int64_t overflow = step[0];
1732 size_t i = 0;
1733 if (step[1] != 1) {
1734 TORCH_INTERNAL_ASSERT(step[0] == shape[0] && values[0] == 0);
1735 i = 1;
1736 overflow = step[1];
1737 }
1738 for (; i < ndim && overflow > 0; i++) {
1739 auto size = shape[i];
1740 auto prev = values[i];
1741 auto value = prev + overflow;
1742 if (value >= size) {
1743 overflow = 1;
1744 value -= size;
1745 TORCH_INTERNAL_ASSERT(value < size);
1746 } else {
1747 overflow = 0;
1748 }
1749 values[i] = static_cast<int64_t>(value);
1750 }
1751 TORCH_INTERNAL_ASSERT(overflow == 0 || overflow == 1);
1752 }
1753
max_2d_step() const1754 std::array<int64_t, 2> DimCounter::max_2d_step() const {
1755 int64_t step0 = std::min(shape[0] - values[0], range.end - offset);
1756 int64_t step1 = 1;
1757 if (step0 == shape[0] && !shape.empty()) {
1758 step1 = std::min(shape[1] - values[1], (range.end - offset) / shape[0]);
1759 }
1760 return {step0, step1};
1761 }
1762
1763 } // namespace at
1764