#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #include #endif /* Implement a numpy like searchsorted and a TF like bucketize function running on cpu * * - torch.searchsorted(sorted_sequence, values, right=False, side=None, out_int32=False, sorter=None) * sorted_sequence - N*D or 1D (apply to all values) tensor containing sorted sequences in last dimension * values - N*D tensor or a Scalar (when sorted_sequence is 1D) containing the search values * right - corresponding to lower bound if False and upper bound if True * side - (preferred to right) corresponding to lower bound if 'left' and upper bound if 'right' * out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True. * sorter - if provided, sorted_sequence may not be sorted and the sorted order is given by this tensor * * - torch.bucketize(values, boundaries, right=False, out_int32=False) * values - N*D tensor or a Scalar containing the search value * boundaries - 1D tensor containing a sorted sequences * right - corresponding to lower bound if False and upper bound if True * out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True. * * - Restrictions are defined in searchsorted_pre_check() */ namespace at::native { namespace { // minimal size for searchsorted_cpu_contiguous to run parallel (multithread) constexpr int64_t SEARCHSORTED_GRAIN_SIZE = 200; // customized lower_bound func to ensure the low bound of 'nan', 'inf' etc. be the end of boundary // and we can properly handle a sorter argument // std::lower_bound can not be used here since its customized comparator need strict weak ordering // and the customized comparators require both arguments to have the same type, which wouldn't // happen when comparing val of input_t to an indexer value from sorter of int64 template int64_t cus_lower_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) { // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2 const int64_t orig_start = start; while (start < end) { const int64_t mid = start + ((end - start) >> 1); const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid]; if (!(mid_val >= val)) { start = mid + 1; } else { end = mid; } } return start; } // customized upper_bound func to ensure we can properly handle a sorter argument // std::upper_bound can not be used here since its customized comparator requires both arguments to have the // same type, which wouldn't happen when comparing val of input_t to an indexer value from sorter of int64 template int64_t cus_upper_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) { // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2 const int64_t orig_start = start; while (start < end) { const int64_t mid = start + ((end - start) >> 1); const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid]; if (!(mid_val > val)) { start = mid + 1; } else { end = mid; } } return start; } template void searchsorted_cpu_contiguous(Tensor& result, const Tensor& input, const Tensor& boundaries, const bool& right, const Tensor& sorter) { int64_t numel_in = input.numel(); bool is_scalar_input = input.dim() == 0 && numel_in == 1; // inner most dim size of input and boundaries int64_t idim_in = is_scalar_input ? 1 : input.sizes().back(); int64_t idim_bd = boundaries.sizes().back(); const input_t *data_in = input.const_data_ptr(); const input_t *data_bd = boundaries.const_data_ptr(); const int64_t *data_st = sorter.defined() ? sorter.const_data_ptr() : nullptr; output_t *data_out = result.data_ptr(); bool is_1d_boundaries = boundaries.dim() == 1; at::parallel_for(0, numel_in, SEARCHSORTED_GRAIN_SIZE, [&](int64_t start, int64_t end) { for (const auto i : c10::irange(start, end)) { // If boundaries tensor is 1d, we always search the entire boundary tensor int64_t start_bd = is_1d_boundaries ? 0 : i / idim_in * idim_bd; int64_t end_bd = start_bd + idim_bd; int64_t pos = !right ? cus_lower_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd : cus_upper_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd; // type conversion might happen here data_out[i] = pos; } }); } void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right, const Tensor& sorter) { if (!out_int32) { AT_DISPATCH_ALL_TYPES_AND2( ScalarType::Half, ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_cpu", [&] { searchsorted_cpu_contiguous( result, input, boundaries, right, sorter); }); } else { AT_DISPATCH_ALL_TYPES_AND2( ScalarType::Half, ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_cpu", [&] { searchsorted_cpu_contiguous( result, input, boundaries, right, sorter); }); } } } Tensor& searchsorted_out_cpu( const Tensor& sorted_sequence, const Tensor& self, bool out_int32, bool right, const std::optional side_opt, const std::optional& sorter_opt, Tensor& result) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned sorter_maybe_owned = at::borrow_from_optional_tensor(sorter_opt); const Tensor& sorter = *sorter_maybe_owned; searchsorted_pre_check(sorted_sequence, self, result, out_int32, right, side_opt, sorter); resize_output(result, self.sizes()); // we have two inputs to set right, pre_check checks that they aren't set to opposites bool is_right = side_opt ? *side_opt == "right" : right; if (self.numel() == 0) { return result; } // for non-contiguous result tensors, we write the output to a contiguous copy so we can later copy back, maintaining the original result tensor Tensor out = result; if (!result.is_contiguous()) { out = result.contiguous(); } if (sorted_sequence.is_contiguous() && self.is_contiguous() && sorted_sequence.dtype() == self.dtype() && sorter.is_contiguous()) { dispatch(out, self, sorted_sequence, out_int32, is_right, sorter); } else { Tensor trimmed_input; Tensor trimmed_boundaries; Tensor trimmed_sorter; searchsorted_maybe_trim_input_tensors(trimmed_input, trimmed_boundaries, trimmed_sorter, self, sorted_sequence, sorter); const Tensor& final_input = trimmed_input.defined() ? trimmed_input : self; const Tensor& final_boundaries = trimmed_boundaries.defined() ? trimmed_boundaries : sorted_sequence; const Tensor& final_sorter = trimmed_sorter.defined() ? trimmed_sorter : sorter; dispatch(out, final_input, final_boundaries, out_int32, is_right, final_sorter); } // if result is non-contiguous, we wrote the answer to a copied version, so we copy back to the original result tensor if (!result.is_contiguous()) { result.copy_(out); } return result; } Tensor& searchsorted_out_cpu( const Tensor& sorted_sequence, const Scalar& self, bool out_int32, bool right, const std::optional side_opt, const std::optional& sorter_opt, Tensor& result) { const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device()); return searchsorted_out_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt, result); } Tensor searchsorted_cpu( const Tensor& sorted_sequence, const Tensor& self, bool out_int32, bool right, const std::optional side_opt, const std::optional& sorter_opt) { ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long; c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type); Tensor result = at::empty({0}, options, MemoryFormat::Contiguous); at::native::searchsorted_out_cpu(sorted_sequence, self, out_int32, right, side_opt, sorter_opt, result); return result; } Tensor searchsorted_cpu( const Tensor& sorted_sequence, const Scalar& self, bool out_int32, bool right, const std::optional side_opt, const std::optional& sorter_opt) { const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device()); return searchsorted_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt); } Tensor& bucketize_out_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) { TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")"); at::native::searchsorted_out_cpu(boundaries, self, out_int32, right, std::nullopt, std::nullopt, result); return result; } Tensor bucketize_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right) { ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long; c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type); Tensor result = at::empty({0}, options, MemoryFormat::Contiguous); at::native::bucketize_out_cpu(self, boundaries, out_int32, right, result); return result; } Tensor bucketize_cpu(const Scalar& self, const Tensor& boundaries, bool out_int32, bool right) { return bucketize_cpu(searchsorted_scalar_tensor(self, boundaries.device()), boundaries, out_int32, right); } } // namespace at::native