Home
last modified time | relevance | path

Searched defs:bdims (Results 1 – 11 of 11) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/
H A DLegacyBatchedTensorImpl.h63 BatchDimsRef bdims() const { in bdims() function
128 BatchDimsRef bdims) { in createBatchDimBitset()
137 inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) { in createVmapLevelsBitset()
H A DLegacyBatchedTensorImpl.cpp9 BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) in BatchedTensorImpl()
116 Tensor makeBatched(const Tensor& tensor, BatchDims bdims) { in makeBatched()
133 BatchDims bdims; in addBatchDim() local
H A DLegacyVmapTransforms.cpp9 static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) { in areBdimsAtFrontInOrder()
21 auto bdims = batched->bdims(); in permuteBatchDimsToFront() local
92 BatchDims bdims; in computeFrontBatchDimsFromLevels() local
/aosp_15_r20/external/pytorch/aten/src/ATen/native/
H A DLegacyBatching.cpp23 auto bdims = batched->bdims(); in has_level() local
50 auto bdims = batched->bdims(); in remove_existing_batch_dim() local
/aosp_15_r20/external/pytorch/aten/src/ATen/test/
H A Dlegacy_vmap_test.cpp198 BatchDims bdims = {{/*lvl*/1, /*dim*/0}, {/*lvl*/3, /*dim*/1}}; in TEST() local
207 BatchDims bdims = {{/*lvl*/1, /*dim*/1}}; in TEST() local
217 BatchDims bdims = {{/*lvl*/1, /*dim*/1}, {/*lvl*/2,/*dim*/2}, {/*lvl*/3,/*dim*/0}}; in TEST() local
305 static void checkBatchDimsEqual(BatchDimsRef bdims, BatchDimsRef expected_bdims) { in checkBatchDimsEqual()
/aosp_15_r20/external/pytorch/torch/_functorch/
H A Dvmap.py513 def wrap_batched(args, bdims, level): argument
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/
H A Dtriangular_solve_expander.cc146 int bdims = b_shape.rank(); in SolveWithInvertedDiagonalBlocks() local
/aosp_15_r20/external/pytorch/aten/src/ATen/functorch/
H A DBatchRulesScatterOps.cpp20 static bool any_has_value(ArrayRef<std::optional<int64_t>> bdims) { in any_has_value()
/aosp_15_r20/external/pytorch/test/
H A Dtest_legacy_vmap.py848 def slice_inputs(inputs, bdims, i): argument
/aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/
H A Drandomized_tests.cc970 std::vector<int64_t> bdims(dims.begin() + skip, dims.end()); in BroadcastableToDims() local
985 auto bdims = BroadcastableToDims(dims); in BroadcastableDims() local
/aosp_15_r20/external/pytorch/test/functorch/
H A Dtest_vmap.py1265 def slice_inputs(inputs, bdims, i): argument