Home
last modified time | relevance | path

Searched defs:embedding_rows (Results 1 – 2 of 2) sorted by relevance

/aosp_15_r20/external/pytorch/test/
H A Dtest_pruning_op.py18 def _generate_rowwise_mask(self, embedding_rows): argument
24 def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype): argument
67 def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype): argument
79 def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype): argument
/aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/
H A Dqembeddingbag_prepack.cpp64 int64_t embedding_rows = qweight.size(0); in prepack() local
243 const int64_t embedding_rows = c10::size_to_dim_(cols_dim, weight_sizes); in qembeddingbag_byte_prepack_out() local
366 int64_t embedding_rows = weight.size(0); in _qembeddingbag_nbit_prepack_helper() local