1 #include <limits>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <c10/core/Device.h>
5 #include <c10/core/Layout.h>
6 #include <c10/core/MemoryFormat.h>
7 #include <c10/core/Scalar.h>
8 #include <c10/core/ScalarType.h>
9 #include <optional>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/_functional_sym_constrain_range_native.h>
16 #include <ATen/ops/_make_dep_token_native.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/sym_constrain_range_native.h>
19 #include <ATen/ops/sym_constrain_range_for_size_native.h>
20 #include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
21 #endif
22
23 namespace at::native {
24
sym_constrain_range(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max)25 void sym_constrain_range(
26 const Scalar& size,
27 std::optional<int64_t> min,
28 std::optional<int64_t> max) {
29
30 int64_t min_val = min.has_value() ? min.value() : std::numeric_limits<int64_t>::min();
31 int64_t max_val = max.has_value() ? max.value() : std::numeric_limits<int64_t>::max();
32 int64_t size_as_int = size.toLong();
33
34 TORCH_CHECK(
35 max_val >= min_val,
36 "Max must be greater than or equal to min. Got min=",
37 min_val,
38 " max=",
39 max_val
40 );
41
42 TORCH_CHECK(
43 min_val <= size_as_int && size_as_int <= max_val,
44 "Invalid value range for ",
45 size_as_int,
46 " between [",
47 min_val,
48 ", ",
49 max_val,
50 "]."
51 );
52 }
53
_functional_sym_constrain_range(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max,const Tensor & dep_token)54 Tensor _functional_sym_constrain_range(
55 const Scalar& size,
56 std::optional<int64_t> min,
57 std::optional<int64_t> max,
58 const Tensor& dep_token) {
59 sym_constrain_range(size, min, max);
60 return dep_token.clone();
61 }
62
sym_constrain_range_for_size(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max)63 void sym_constrain_range_for_size(const Scalar& size, std::optional<int64_t> min, std::optional<int64_t> max) {
64 int64_t min_val = min.has_value() ? min.value() : 0;
65 if (max.has_value() && max.value() <= 2) {
66 TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value());
67 }
68 sym_constrain_range(size, min_val, max);
69 }
70
_functional_sym_constrain_range_for_size(const Scalar & size,std::optional<int64_t> min,std::optional<int64_t> max,const Tensor & dep_token)71 Tensor _functional_sym_constrain_range_for_size(
72 const Scalar& size,
73 std::optional<int64_t> min,
74 std::optional<int64_t> max,
75 const Tensor& dep_token) {
76 sym_constrain_range_for_size(size, min, max);
77 return dep_token.clone();
78 }
79
_make_dep_token_cpu(std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)80 Tensor _make_dep_token_cpu(
81 std::optional<ScalarType> dtype_opt,
82 std::optional<Layout> layout_opt,
83 std::optional<Device> device_opt,
84 std::optional<bool> pin_memory_opt,
85 std::optional<c10::MemoryFormat> memory_format_opt) {
86 return at::empty(
87 {}, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
88 }
89
90 } // namespace at::native
91