1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/base_collective_executor.h" 21 #include "tensorflow/core/framework/collective.h" 22 23 namespace tensorflow { 24 25 // Hierarchical tree-algorithm implementation of collective broadcast. 26 class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { 27 public: 28 HierarchicalTreeBroadcaster(); 29 ~HierarchicalTreeBroadcaster() override = default; 30 31 // Establishes the subdiv permutations needed for a hierarchical broadcast. 32 // If all devices are local, establishes a single subdiv comprising all 33 // devices. If any devices are on a different task, establishes n+1 subdivs 34 // for n tasks. 35 // The first subdiv comprises one device per task which gets the tensor on 36 // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task 37 // i. 38 Status InitializeCollectiveParams(CollectiveParams* col_params) override; 39 40 // Initializes members of CollectiveContext not yet initialized, i.e. device 41 // and device_locality. Also saves the CollectiveContext in this object. 42 Status InitializeCollectiveContext( 43 std::shared_ptr<CollectiveContext> col_ctx) override; 44 45 // Begins async execution of the hierarchical tree broadcast. 46 // Must be called in a blockable thread. 47 // TODO(b/80529858): remove the previous warning when we have a dedicated 48 // collective threadpool. 49 void Run(StatusCallback done) override; 50 51 // Returns the rank of the device from which this device should receive 52 // its value, -1 if no value should be received. 53 static int TreeRecvFrom(const CollectiveParams& cp, int subdiv); 54 55 // Populates targets with the ranks of the devices to which this device 56 // should forward the value. 57 static void TreeSendTo(const CollectiveParams& cp, int subdiv, 58 std::vector<int>* targets); 59 60 private: 61 // Get the task to which the device at `device_rank` belongs. 62 int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task); 63 64 // Sends `src_tensor` asynchronously from this device to device at `dst_rank` 65 // in `subdiv`. Calls `done` upon completion. 66 void DispatchSend(int subdiv, int dst_rank, int src_rank, 67 const Tensor* src_tensor, const StatusCallback& done); 68 69 // Receives a tensor into the memory buffer owned by `dst_tensor` at this 70 // device from device at `src_rank` in `subdiv`. Calls `done` upon 71 // completion. 72 void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor, 73 const StatusCallback& done); 74 75 // Executes the hierarchical broadcast defined by this op. 76 void RunTree(); 77 78 std::shared_ptr<CollectiveContext> col_ctx_; 79 const CollectiveParams* col_params_; // Not owned 80 StatusCallback done_; 81 Status status_; 82 bool is_source_; 83 }; 84 85 } // namespace tensorflow 86 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ 87