Home
last modified time | relevance | path

Searched refs:tuple_shape_tree (Results 1 – 3 of 3) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/
H A Dtpu_compile_op_support.cc152 const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) { in GetSubtree() argument
154 xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(), in GetSubtree()
160 element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {}); in GetSubtree()
167 ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape); in GetPerDeviceShape() local
171 HloSharding element_sharding = tuple_shape_tree.element({i}); in GetPerDeviceShape()
173 element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i)); in GetPerDeviceShape()
H A Dtpu_compile_op_support.h117 const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xrt/kernels/
H A Dxrt_state_ops.h128 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree, in ParseTupleTree() argument
139 *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>( in ParseTupleTree()
141 tuple_shape_tree->ForEachMutableElement( in ParseTupleTree()
144 if (tuple_shape_tree->IsLeaf(index)) { in ParseTupleTree()
449 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree; in Compute() local
457 &tuple_shape_tree, &device_ordinal, rm)); in Compute()
469 device_ref.device_ordinal(), tuple_shape_tree, in Compute()