Home
last modified time | relevance | path

Searched defs:all_scatter (Results 1 – 3 of 3) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/
H A Ddtensor_allreduce_scatter_optimization.cc50 mlir::TF::DTensorAllScatterOp all_scatter, int scatter_dim) { in GetScatterGroupAssignment()
81 if (auto all_scatter = mlir::dyn_cast<mlir::TF::DTensorAllScatterOp>( in ApplyOptimization() local
H A Dcollectives.cc137 mlir::TF::DTensorAllScatterOp all_scatter = in EmitAllScatter() local
274 auto all_scatter = in EmitRelayout() local
/aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/utils/
H A Dcollective_lowering.cc630 mlir::TF::DTensorAllScatterOp all_scatter) { in LowerAllScatterOp()
795 module.walk([&](mlir::TF::DTensorAllScatterOp all_scatter) { in runOnOperation()
799 for (mlir::TF::DTensorAllScatterOp all_scatter : all_scatters) in runOnOperation() local