Searched refs:tuple_shape_tree (Results 1 – 3 of 3) sorted by relevance
/aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/ |
H A D | tpu_compile_op_support.cc | 152 const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) { in GetSubtree() argument 154 xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(), in GetSubtree() 160 element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {}); in GetSubtree() 167 ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape); in GetPerDeviceShape() local 171 HloSharding element_sharding = tuple_shape_tree.element({i}); in GetPerDeviceShape() 173 element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i)); in GetPerDeviceShape()
|
H A D | tpu_compile_op_support.h | 117 const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xrt/kernels/ |
H A D | xrt_state_ops.h | 128 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree, in ParseTupleTree() argument 139 *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>( in ParseTupleTree() 141 tuple_shape_tree->ForEachMutableElement( in ParseTupleTree() 144 if (tuple_shape_tree->IsLeaf(index)) { in ParseTupleTree() 449 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree; in Compute() local 457 &tuple_shape_tree, &device_ordinal, rm)); in Compute() 469 device_ref.device_ordinal(), tuple_shape_tree, in Compute()
|