xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/tensor_util.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/backend/backend_interface.h>
4 #include <torch/csrc/lazy/core/shape.h>
5 
6 #include <ATen/FunctionalTensorWrapper.h>
7 
8 #include <string>
9 #include <vector>
10 
11 namespace torch {
12 namespace lazy {
13 
14 TORCH_API std::vector<int64_t> ComputeArrayStrides(
15     c10::ArrayRef<int64_t> sizes);
16 
17 TORCH_API std::vector<at::Tensor> DataHandlesToTensors(
18     c10::ArrayRef<BackendDataPtr> data_handles,
19     at::ScalarType dest_element_type);
20 
21 // Uploads an ATEN tensor data to the device and fetches the corresponding
22 // device data handle.
23 TORCH_API BackendDataPtr
24 TensorToDataHandle(const at::Tensor& tensor, const BackendDevice& device);
25 
26 // Retrieves the device data handles by parallel uploading data onto the
27 // corresponding devices.
28 TORCH_API std::vector<BackendDataPtr> CreateTensorsData(
29     const std::vector<at::Tensor>& tensors,
30     const std::vector<BackendDevice>& devices);
31 
32 // Makes a deep copy of an ATEN tensor.
CopyTensor(const at::Tensor & ref)33 inline at::Tensor CopyTensor(const at::Tensor& ref) {
34   return ref.to(ref.options(), /*non_blocking=*/false, /*copy=*/true);
35 }
36 
37 // Same as above, with an additional cast.
38 inline at::Tensor CopyTensor(
39     const at::Tensor& ref,
40     at::ScalarType dest_type,
41     bool copy = true) {
42   return ref.to(ref.options().dtype(dest_type), /*non_blocking=*/false, copy);
43 }
44 
45 template <typename T, typename S>
OptionalOr(const std::optional<S> & value,T defval)46 T OptionalOr(const std::optional<S>& value, T defval) {
47   return value ? static_cast<T>(*value) : defval;
48 }
49 
50 // Unwraps tensor to target dtype if it's a wrapped number.
UnwrapNumber(const at::Tensor & tensor,at::ScalarType dtype)51 inline at::Tensor UnwrapNumber(const at::Tensor& tensor, at::ScalarType dtype) {
52   return tensor.unsafeGetTensorImpl()->is_wrapped_number() ? tensor.to(dtype)
53                                                            : tensor;
54 }
55 
56 template <typename T>
MakeIntScalar(T value)57 at::Scalar MakeIntScalar(T value) {
58   return at::Scalar(static_cast<int64_t>(value));
59 }
60 
61 // Routing values to device data maximizes the changes for compilation cache
62 // hits, but it can prevent the compiler to perform optimizations. So tensor
63 // values which are within a given set, are routed to constant scalars if this
64 // API returns true.
65 TORCH_API bool IsSpecialScalar(const at::Scalar& value);
66 
67 // Note: returns a reference instead of a fresh tensor to avoid refcount bumps.
maybe_unwrap_functional(const at::Tensor & tensor)68 inline const at::Tensor& maybe_unwrap_functional(const at::Tensor& tensor) {
69   if (at::functionalization::impl::isFunctionalTensor(tensor)) {
70     return at::functionalization::impl::unsafeGetFunctionalWrapper(tensor)
71         ->value();
72   } else {
73     return tensor;
74   }
75 }
76 
77 } // namespace lazy
78 } // namespace torch
79