xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qembeddingbag.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <cstdint>
4 
5 namespace at {
6 namespace native {
7 Tensor& embedding_bag_byte_rowwise_offsets_out(
8     Tensor& output,
9     const Tensor& weight,
10     const Tensor& indices,
11     const std::optional<Tensor>& offsets_in,
12     const bool /* scale_grad_by_freq */,
13     const int64_t /* mode */,
14     bool pruned_weights,
15     const std::optional<Tensor>& per_sample_weights_,
16     const std::optional<Tensor>& compressed_indices_mapping,
17     bool include_last_offset);
18 
19 Tensor& embedding_bag_4bit_rowwise_offsets_out(
20     Tensor& output,
21     const Tensor& weight,
22     const Tensor& indices,
23     const std::optional<Tensor>& offsets_in,
24     const bool /* scale_grad_by_freq */,
25     const int64_t /* mode */,
26     bool pruned_weights,
27     const std::optional<Tensor>& per_sample_weights_,
28     const std::optional<Tensor>& compressed_indices_mapping,
29     bool include_last_offset);
30 
31 Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight);
32 
33 } // native
34 } // at
35