Home
last modified time | relevance | path

Searched refs:parallel_tensor (Results 1 – 4 of 4) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/python/distribute/parallel_device/
H A Dparallel_device.py130 def _unpack_tensor(self, parallel_tensor): argument
132 if not isinstance(parallel_tensor, (
135 "Expected a tensor, got {}.".format(parallel_tensor))
138 parallel_tensor, num_replicas=len(self.components))
140 def unpack(self, parallel_tensor): argument
157 parallel_tensor = variable_utils.convert_variables_to_tensors(
158 parallel_tensor)
159 for tensor in nest.flatten(parallel_tensor, expand_composites=True):
163 return [nest.pack_sequence_as(parallel_tensor, unpacked,
/aosp_15_r20/external/tensorflow/tensorflow/c/eager/parallel_device/
H A Dparallel_device.cc162 std::unique_ptr<ParallelTensor> parallel_tensor( in ExecuteWithSpecialOps() local
166 parallel_inputs.push_back(parallel_tensor.get()); in ExecuteWithSpecialOps()
167 implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor)); in ExecuteWithSpecialOps()
233 ParallelTensor* parallel_tensor = reinterpret_cast<ParallelTensor*>(data); in ParallelTensorSummarize() local
235 Status cpp_status = parallel_tensor->SummarizeValue(summary); in ParallelTensorSummarize()
293 ParallelTensor* parallel_tensor = reinterpret_cast<ParallelTensor*>( in CopyTensorFromParallelDevice() local
296 if (parallel_tensor->num_tensors() == 1) { in CopyTensorFromParallelDevice()
299 return TFE_TensorHandleCopySharingTensor(parallel_tensor->tensor(0), in CopyTensorFromParallelDevice()
/aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/
H A Ddtensor_device_util.cc83 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in BroadcastTensorHandleToParallelTensor() local
87 return parallel_tensor; in BroadcastTensorHandleToParallelTensor()
149 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in BroadcastResourceTensor() local
154 std::move(parallel_tensor), mesh, in BroadcastResourceTensor()
307 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in Broadcast() local
312 Status s = parallel_tensor->Shape(&shape); in Broadcast()
326 std::move(parallel_tensor), mesh, std::move(layout), *shape, in Broadcast()
H A Ddtensor_device.cc883 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in Pack() local
903 TensorWithLayout::Wrap(std::move(parallel_tensor), in Pack()