Home
last modified time | relevance | path

Searched defs:batch_dims (Results 1 – 25 of 64) sorted by relevance

123

/aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/
H A Dragged_gather_ops.py38 batch_dims=0, argument
112 def _gather(params, indices, axis, batch_dims): argument
174 def _batch_gather(params, indices, axis, batch_dims): argument
331 batch_dims=0): argument
341 batch_dims=0, argument
485 batch_dims=0): argument
H A Dragged_gather_op_test.py192 batch_dims=0, argument
396 batch_dims=0): argument
/aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/math_ops/
H A Dbanded_triangular_solve_op_test.py27 def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None): argument
40 def _verifySolveAllWaysReal(self, x, y, batch_dims=None): argument
43 def _verifySolveAllWaysComplex(self, x, y, batch_dims=None): argument
51 batch_dims=None, argument
/aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/linalg/
H A Dmatrix_triangular_solve_op_test.py27 def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None): argument
40 def _verifySolveAllWaysReal(self, x, y, batch_dims=None): argument
43 def _verifySolveAllWaysComplex(self, x, y, batch_dims=None): argument
51 batch_dims=None, argument
H A Dmatrix_solve_op_test.py37 def _verifySolve(self, x, y, batch_dims=None): argument
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/
H A Dqr_expander.cc109 Status House(XlaOp x, XlaOp k, absl::Span<const int64_t> batch_dims, in House()
214 std::vector<int64_t> batch_dims(num_batch_dims); in QrBlock() local
319 PrimitiveType type, absl::Span<const int64_t> batch_dims, XlaOp vs, in CompactWYRepresentation()
394 std::vector<int64_t> batch_dims(num_batch_dims); in BuildQrDecomposition() local
465 std::vector<int64_t> batch_dims(num_batch_dims); in ProductOfElementaryHouseholderReflectors() local
H A Ddot_as_convolution_util.h53 std::vector<DimNums> batch_dims; member
H A Dindexed_array_analysis.cc981 absl::Span<const int64_t> batch_dims) { in GetOnlyNonContractingNonBatchDim()
1006 absl::Span<const int64_t> batch_dims) { in CanFoldDotIntoIndexedArray()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/
H A Dgather_scatter_handler.cc170 PartitionedHlo& indices, absl::Span<const int64_t> batch_dims, in PartitionGatherIndexPassthroughPartition()
272 const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims, in PartitionGatherPassthroughOperand()
376 const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims, in PartitionGatherTrivialIndexedOperandDimension()
448 std::vector<int64_t> batch_dims; in PartitionGatherTrivialIndexedOperandDimension() local
487 const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims, in PartitionGatherIndexParallelDimensions()
640 const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims, in PartitionGather()
690 std::vector<int64_t> batch_dims; in HandleGather() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/
H A Dmatmul_utils.cc51 const Shape& shape, absl::Span<const int64_t> batch_dims, in GetNonContractingDims()
69 absl::Span<const int64_t> batch_dims, in GetBatchRowColumnShape()
163 const Shape& shape, absl::Span<const int64_t> batch_dims, in For()
184 auto batch_dims = absl::Span<const int64_t>(dims).first(num_batch_dims); in For() local
216 auto batch_dims = (operand_idx == 0) ? dot_dims.lhs_batch_dimensions() in CanFoldTransposeOperandIntoDot() local
H A Dgpu_layout_assignment.cc403 absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims, in SetDotOperandLayout()
425 absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims, in SetOperandBatchRowsColsLayout()
H A Dcusolver_rewriter.cc62 std::vector<int64_t> batch_dims(a_shape.dimensions().begin(), in CreateCholesky() local
/aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/array_ops/
H A Dgather_op_test.py476 def testBatchDims(self, params, indices, batch_dims, expected=None, argument
594 batch_dims, axis, output_shape): argument
629 def _batchNumpyGather(self, params, indices, axis, batch_dims): argument
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/
H A Dsvd.cc120 std::vector<int64_t> batch_dims(num_batch_dims); in HouseRow() local
185 std::vector<int64_t> batch_dims(num_batch_dims); in HouseCol() local
259 std::vector<int64_t> batch_dims(num_batch_dims); in HouseHolderBidiagonalization() local
461 std::vector<int64_t> batch_dims(num_batch_dims); in OneSidedJacobiUpdate() local
840 std::vector<int64_t> batch_dims(num_batch_dims); in SVD() local
H A Dlu_decomposition.cc35 const std::vector<int64_t> batch_dims( in LuDecomposition() local
H A Dself_adjoint_eig.cc70 const std::vector<int64_t> batch_dims( in SelfAdjointEig() local
/aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/
H A Dgather_spmd_expander.cc54 int batch_dims = gather_op.batch_dims(); in ExpandOp() local
174 int batch_dims = gather_op.batch_dims(); in ComputeLayoutForward() local
253 int batch_dims = gather_op.batch_dims(); in ComputeLayoutBackward() local
/aosp_15_r20/external/tensorflow/tensorflow/python/ops/structured/
H A Dstructured_array_ops_test.py1113 def testGather(self, params, indices, axis, batch_dims, expected): argument
1150 def testGatherRagged(self, params, indices, axis, batch_dims, expected): argument
1200 indices, axis, batch_dims, argument
/aosp_15_r20/external/tensorflow/tensorflow/python/ops/
H A Darray_ops.py5107 batch_dims=0): # pylint: disable=g-doc-args argument
5317 batch_dims=0, argument
5347 def _batch_gather(params, indices, batch_dims, axis=None): argument
5473 def gather_nd(params, indices, name=None, batch_dims=0): argument
5731 def gather_nd_v2(params, indices, batch_dims=0, name=None): argument
5738 def batch_gather_nd(params, indices, batch_dims, name=None): argument
H A Darray_grad.py595 def _GetBatchIndices(params_shape, indices, batch_dims): argument
617 def _BatchGatherGrad(params_shape, values, indices, batch_dims, argument
/aosp_15_r20/external/tensorflow/tensorflow/core/ops/
H A Dsparse_csr_matrix_ops.cc279 ShapeHandle batch_dims; in __anona88575120702() local
432 ShapeHandle batch_dims; in __anona88575120a02() local
/aosp_15_r20/external/tensorflow/tensorflow/cc/gradients/
H A Darray_grad.cc632 const Output& indices, int batch_dims) { in GetBatchIndices()
660 Output indices, int batch_dims, Output gather_dim_size) { in BatchGatherGrad()
715 int batch_dims; in GatherV2Grad() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/
H A Dmatrix_solve_op_test.py54 def testSolve(self, n, nrhs, batch_dims, rhs_batch_dims, adjoint): argument
/aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/
H A Dgather.h38 int batch_dims = op_params.batch_dims; in Gather() local
/aosp_15_r20/external/pytorch/aten/src/ATen/test/
H A Dlegacy_vmap_test.cpp94 auto batch_dims = maxBatchDimsAtFront(); in TEST() local
234 auto batch_dims = maxBatchDimsAtFront(); in TEST() local
261 BatchDims batch_dims = { in TEST() local

123