Home
last modified time | relevance | path

Searched defs:per_sample_weights (Results 1 – 20 of 20) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/
H A DEmbeddingBag.cpp875 const std::optional<Tensor>& per_sample_weights, in check_arguments()
972 const std::optional<Tensor>& per_sample_weights, in make_offset2bag_out()
1028 const std::optional<Tensor>& per_sample_weights, in make_offset2bag()
1132 const std::optional<Tensor>& per_sample_weights, in _embedding_bag_cpu_impl_out()
1183 const Tensor& per_sample_weights, in _embedding_bag_cpu_impl()
1233 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; in embedding_bag() local
1276 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; in _embedding_bag_forward_only_cpu() local
1299 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; in _embedding_bag_cpu() local
1325 const std::optional<at::Tensor>& per_sample_weights, in _embedding_bag_cpu_out()
1400 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; in _embedding_bag_backward_symint() local
[all …]
/aosp_15_r20/external/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/
H A Dembedding_bag.py245 per_sample_weights, argument
318 per_sample_weights, argument
439 def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg): argument
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/
H A DEmbeddingBag.cu117 const scalar_t* per_sample_weights, int64_t per_sample_weights_stride, in EmbeddingBag_updateOutputKernel_sum_mean()
175 const Tensor& per_sample_weights, in embedding_bag_backward_cuda_sum_avg()
316 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; in _embedding_bag_forward_only_cuda() local
350 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; in _embedding_bag_cuda() local
436 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; in _embedding_bag_dense_backward_cuda() local
H A DEmbeddingBackwardKernel.cu85 const scalar_t* per_sample_weights, int64_t per_sample_weights_stride, in compute_grad_weight_bags()
222 const Tensor &per_sample_weights) { in embedding_backward_cuda_kernel()
/aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/
H A Dembedding.cpp138 const Tensor& per_sample_weights) { in forward()
/aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/
H A Dembedding.h98 const Tensor& per_sample_weights, in embedding_bag()
/aosp_15_r20/external/pytorch/torch/ao/nn/qat/modules/
H A Dembedding_ops.py180 def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor: argument
/aosp_15_r20/external/pytorch/torch/onnx/
H A Dsymbolic_opset18.py237 per_sample_weights, argument
H A Dsymbolic_opset10.py598 per_sample_weights, argument
H A Dsymbolic_opset11.py1211 per_sample_weights, argument
H A Dsymbolic_helper.py1948 per_sample_weights, argument
H A Dsymbolic_opset9.py939 per_sample_weights, argument
/aosp_15_r20/external/pytorch/test/nn/
H A Dtest_embedding.py925 per_sample_weights=None, argument
/aosp_15_r20/external/pytorch/test/quantization/eager/
H A Dtest_quantize_eager_ptq.py848 def forward(self, indices, offsets, per_sample_weights, linear_in): argument
/aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/
H A Dops.cpp1135 const auto per_sample_weights = in __anon11f46a8b3102() local
1173 const auto per_sample_weights = in __anon11f46a8b3302() local
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/
H A Dshim_common.cpp469 AtenTensorHandle per_sample_weights, // optional argument in aoti_torch__embedding_bag()
/aosp_15_r20/external/pytorch/torch/
H A D_meta_registrations.py3356 per_sample_weights=None, argument
6360 per_sample_weights, argument
6401 per_sample_weights, argument
/aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/
H A Ddist_autograd_test.py2050 def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights): argument
/aosp_15_r20/external/pytorch/test/quantization/core/
H A Dtest_quantized_op.py4618 include_last_offset, weights, per_sample_weights, argument
/aosp_15_r20/external/pytorch/torch/testing/_internal/
H A Dcommon_quantization.py2587 def forward(self, indices, offsets, per_sample_weights): argument