xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Constraints.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <limits>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <c10/core/Device.h>
5 #include <c10/core/Layout.h>
6 #include <c10/core/MemoryFormat.h>
7 #include <c10/core/Scalar.h>
8 #include <c10/core/ScalarType.h>
9 #include <optional>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/_functional_sym_constrain_range_native.h>
16 #include <ATen/ops/_make_dep_token_native.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/sym_constrain_range_native.h>
19 #include <ATen/ops/sym_constrain_range_for_size_native.h>
20 #include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
21 #endif
22 
23 namespace at::native {
24 
sym_constrain_range(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max)25 void sym_constrain_range(
26     const Scalar& size,
27     std::optional<int64_t> min,
28     std::optional<int64_t> max) {
29 
30     int64_t min_val = min.has_value() ? min.value() : std::numeric_limits<int64_t>::min();
31     int64_t max_val = max.has_value() ? max.value() : std::numeric_limits<int64_t>::max();
32     int64_t size_as_int = size.toLong();
33 
34     TORCH_CHECK(
35       max_val >= min_val,
36       "Max must be greater than or equal to min. Got min=",
37       min_val,
38       " max=",
39       max_val
40     );
41 
42     TORCH_CHECK(
43       min_val <= size_as_int && size_as_int <= max_val,
44       "Invalid value range for ",
45       size_as_int,
46       " between [",
47       min_val,
48       ", ",
49       max_val,
50       "]."
51     );
52 }
53 
_functional_sym_constrain_range(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max,const Tensor & dep_token)54 Tensor _functional_sym_constrain_range(
55     const Scalar& size,
56     std::optional<int64_t> min,
57     std::optional<int64_t> max,
58     const Tensor& dep_token) {
59   sym_constrain_range(size, min, max);
60   return dep_token.clone();
61 }
62 
sym_constrain_range_for_size(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max)63 void sym_constrain_range_for_size(const Scalar& size, std::optional<int64_t> min, std::optional<int64_t> max) {
64   int64_t min_val = min.has_value() ? min.value() : 0;
65   if (max.has_value() && max.value() <= 2) {
66     TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value());
67   }
68   sym_constrain_range(size, min_val, max);
69 }
70 
_functional_sym_constrain_range_for_size(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max,const Tensor & dep_token)71 Tensor _functional_sym_constrain_range_for_size(
72   const Scalar& size,
73   std::optional<int64_t> min,
74   std::optional<int64_t> max,
75   const Tensor& dep_token) {
76   sym_constrain_range_for_size(size, min, max);
77   return dep_token.clone();
78 }
79 
_make_dep_token_cpu(std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)80 Tensor _make_dep_token_cpu(
81     std::optional<ScalarType> dtype_opt,
82     std::optional<Layout> layout_opt,
83     std::optional<Device> device_opt,
84     std::optional<bool> pin_memory_opt,
85     std::optional<c10::MemoryFormat> memory_format_opt) {
86   return at::empty(
87       {}, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
88 }
89 
90 } // namespace at::native
91