xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorCompare.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 
5 namespace c10 {
6 class Scalar;
7 }
8 
9 namespace at {
10 class Tensor;
11 struct TensorIterator;
12 struct TensorIteratorBase;
13 }
14 
15 namespace at::native {
16 
17 using reduce_minmax_fn =
18     void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
19 using structured_reduce_minmax_fn =
20     void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);
21 
22 DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub);
23 DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub);
24 
25 using where_fn = void (*)(TensorIterator &);
26 DECLARE_DISPATCH(where_fn, where_kernel);
27 
28 using is_infinity_op_fn = void (*)(TensorIteratorBase &);
29 DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
30 DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
31 
32 using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
33 DECLARE_DISPATCH(mode_fn, mode_stub);
34 
35 using clamp_tensor_fn = void (*)(TensorIteratorBase &);
36 DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub);
37 
38 namespace detail {
39     enum class ClampLimits {Min, Max, MinMax};
40 }
41 
42 DECLARE_DISPATCH(void (*)(TensorIteratorBase &, const c10::Scalar&, const c10::Scalar&), clamp_scalar_stub);
43 DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_min_scalar_stub);
44 DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_max_scalar_stub);
45 
46 using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
47 DECLARE_DISPATCH(isin_default_fn, isin_default_stub);
48 
49 } // namespace at::native
50