1 #include <torch/csrc/lazy/core/tensor_util.h>
2
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Half.h>
5 #include <c10/util/complex.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/lazy/backend/backend_device.h>
8 #include <torch/csrc/lazy/backend/backend_interface.h>
9 #include <torch/csrc/lazy/core/config.h>
10 #include <torch/csrc/lazy/core/helpers.h>
11
12 #include <algorithm>
13 #include <cstring>
14 #include <functional>
15 #include <list>
16 #include <numeric>
17 #include <thread>
18
19 namespace torch {
20 namespace lazy {
21
ComputeArrayStrides(c10::ArrayRef<int64_t> sizes)22 std::vector<int64_t> ComputeArrayStrides(c10::ArrayRef<int64_t> sizes) {
23 std::vector<int64_t> strides(sizes.size(), 1);
24 for (int64_t i = sizes.size(); i > 1; --i) {
25 strides[i - 2] = strides[i - 1] * sizes[i - 1];
26 }
27 return strides;
28 }
29
DataHandlesToTensors(c10::ArrayRef<BackendDataPtr> data_handles,at::ScalarType dest_element_type)30 std::vector<at::Tensor> DataHandlesToTensors(
31 c10::ArrayRef<BackendDataPtr> data_handles,
32 at::ScalarType dest_element_type) {
33 std::vector<at::Tensor> tensors;
34 for (const auto& handle : data_handles) {
35 tensors.push_back(
36 getBackend()->MakeTensorFromComputationData(handle, dest_element_type));
37 }
38 return tensors;
39 }
40
TensorToDataHandle(const at::Tensor & tensor,const BackendDevice & device)41 BackendDataPtr TensorToDataHandle(
42 const at::Tensor& tensor,
43 const BackendDevice& device) {
44 return getBackend()->MakeComputationDataFromTensor(
45 tensor, Shape(tensor.scalar_type(), tensor.sizes()), device);
46 }
47
CreateTensorsData(const std::vector<at::Tensor> & tensors,const std::vector<BackendDevice> & devices)48 std::vector<BackendDataPtr> CreateTensorsData(
49 const std::vector<at::Tensor>& tensors,
50 const std::vector<BackendDevice>& devices) {
51 TORCH_CHECK(tensors.size() == devices.size());
52 std::vector<BackendDataPtr> result;
53 result.reserve(tensors.size());
54 for (const auto i : c10::irange(tensors.size())) {
55 result.push_back(TensorToDataHandle(tensors[i], devices[i]));
56 }
57 return result;
58 }
59
IsSpecialScalar(const at::Scalar & value)60 bool IsSpecialScalar(const at::Scalar& value) {
61 if (FLAGS_torch_lazy_handle_special_scalars &&
62 (value.isIntegral(false) || value.isFloatingPoint())) {
63 if (FLAGS_torch_lazy_all_numbers_special_scalars) {
64 return true;
65 }
66 double scalar_value = value.toDouble();
67 return scalar_value == 0.0 || std::fabs(scalar_value) == 1.0;
68 }
69 return false;
70 }
71
72 } // namespace lazy
73 } // namespace torch
74