xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/tensor_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/tensor_util.h>
2 
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Half.h>
5 #include <c10/util/complex.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/lazy/backend/backend_device.h>
8 #include <torch/csrc/lazy/backend/backend_interface.h>
9 #include <torch/csrc/lazy/core/config.h>
10 #include <torch/csrc/lazy/core/helpers.h>
11 
12 #include <algorithm>
13 #include <cstring>
14 #include <functional>
15 #include <list>
16 #include <numeric>
17 #include <thread>
18 
19 namespace torch {
20 namespace lazy {
21 
ComputeArrayStrides(c10::ArrayRef<int64_t> sizes)22 std::vector<int64_t> ComputeArrayStrides(c10::ArrayRef<int64_t> sizes) {
23   std::vector<int64_t> strides(sizes.size(), 1);
24   for (int64_t i = sizes.size(); i > 1; --i) {
25     strides[i - 2] = strides[i - 1] * sizes[i - 1];
26   }
27   return strides;
28 }
29 
DataHandlesToTensors(c10::ArrayRef<BackendDataPtr> data_handles,at::ScalarType dest_element_type)30 std::vector<at::Tensor> DataHandlesToTensors(
31     c10::ArrayRef<BackendDataPtr> data_handles,
32     at::ScalarType dest_element_type) {
33   std::vector<at::Tensor> tensors;
34   for (const auto& handle : data_handles) {
35     tensors.push_back(
36         getBackend()->MakeTensorFromComputationData(handle, dest_element_type));
37   }
38   return tensors;
39 }
40 
TensorToDataHandle(const at::Tensor & tensor,const BackendDevice & device)41 BackendDataPtr TensorToDataHandle(
42     const at::Tensor& tensor,
43     const BackendDevice& device) {
44   return getBackend()->MakeComputationDataFromTensor(
45       tensor, Shape(tensor.scalar_type(), tensor.sizes()), device);
46 }
47 
CreateTensorsData(const std::vector<at::Tensor> & tensors,const std::vector<BackendDevice> & devices)48 std::vector<BackendDataPtr> CreateTensorsData(
49     const std::vector<at::Tensor>& tensors,
50     const std::vector<BackendDevice>& devices) {
51   TORCH_CHECK(tensors.size() == devices.size());
52   std::vector<BackendDataPtr> result;
53   result.reserve(tensors.size());
54   for (const auto i : c10::irange(tensors.size())) {
55     result.push_back(TensorToDataHandle(tensors[i], devices[i]));
56   }
57   return result;
58 }
59 
IsSpecialScalar(const at::Scalar & value)60 bool IsSpecialScalar(const at::Scalar& value) {
61   if (FLAGS_torch_lazy_handle_special_scalars &&
62       (value.isIntegral(false) || value.isFloatingPoint())) {
63     if (FLAGS_torch_lazy_all_numbers_special_scalars) {
64       return true;
65     }
66     double scalar_value = value.toDouble();
67     return scalar_value == 0.0 || std::fabs(scalar_value) == 1.0;
68   }
69   return false;
70 }
71 
72 } // namespace lazy
73 } // namespace torch
74