1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ 17 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ 18 19 #include "absl/base/call_once.h" 20 #include "tensorflow/compiler/jit/shape_inference.h" 21 #include "tensorflow/core/common_runtime/device_mgr.h" 22 #include "tensorflow/core/common_runtime/optimization_registry.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/graph/graph.h" 26 #include "tensorflow/core/lib/core/threadpool.h" 27 #include "tensorflow/core/protobuf/tpu/topology.pb.h" 28 #include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h" 29 #include "tensorflow/core/tpu/tpu_api.h" 30 #include "tensorflow/core/tpu/tpu_ops_c_api.h" 31 #include "tensorflow/core/util/reffed_status_callback.h" 32 #include "absl/container/flat_hash_map.h" 33 34 namespace tensorflow { 35 // Holds node's shape information for Concat/Split. 36 using EdgeShapes = absl::flat_hash_map<const Edge*, std::vector<int>>; 37 using GroupedEdges = 38 absl::flat_hash_map<std::string, std::vector<const Edge*>>; 39 40 // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op that 41 // we find as part of searching for inputs to models that are replicated. 42 using XlaShardingInfoMap = absl::flat_hash_map< 43 std::string, std::tuple<tensorflow::DataType, std::string, std::string>>; 44 45 // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep 46 // for each TpuReplicatedInput op that we find as part of searching for inputs 47 // to models that are replicated. 48 using TpuReplicatedInputInfoMap = 49 absl::flat_hash_map<std::string, 50 std::tuple<tensorflow::DataType, Node*>>; 51 52 namespace tpu_functional_internal { 53 54 // Helper functions for graph rewrites. 55 GroupedEdges GroupTensorsForInputPacking( 56 const EdgeShapes& tpu_input_shapes, 57 const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes, 58 bool input_shape_opt, bool group_tensors_for_packing); 59 GroupedEdges GroupTensorsForOutputPacking(Graph* graph, 60 EdgeShapes& tpu_output_shapes, 61 GraphShapeInfo* shape_info); 62 63 Status CreateConcatAndSplitNodesForInputTensor( 64 Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes, 65 const absl::flat_hash_map<std::string, std::vector<const Edge*>>& 66 grouped_input_edges, 67 int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded, 68 const XlaShardingInfoMap& xla_sharding_info, 69 const TpuReplicatedInputInfoMap& tpu_replicated_input_info); 70 Status CreateConcatAndSplitNodesForOutputTensor( 71 Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes, 72 GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output, 73 int32_t minimum_output_tensors_packing); 74 75 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name, 76 EdgeShapes* tpu_input_shapes, 77 int num_cores_per_replica); 78 79 } // namespace tpu_functional_internal 80 81 typedef FunctionLibraryRuntime::Handle FHandle; 82 83 // A `TPUPartitionedCallOp` asynchronously executes a function on exactly one 84 // TPU core and potentially across multiple other devices, but within a single 85 // process. The kernel places and partitions the function's underlying graph, 86 // executing each of the partitioned subgraphs as a function. 87 // 88 // The core on which the TPU computation is executed must be specified via the 89 // `device_ordinal` input. Different invocations of this op may specify 90 // different device ordinals, making it possible to map TPU computations to 91 // different cores at runtime. Currently, macro-substitution of device ordinals 92 // is only supported for the following whitelisted ops: 93 // * TPUExecute 94 // * InfeedEnqueue 95 // * InfeedEnqueueTuple 96 // 97 // Attempting to compute a TPUPartitionedCallOp whose function body has a 98 // non-whitelisted node bearing an attribute named "device_ordinal" will result 99 // in an error. 100 // 101 // TODO(akshayka): This class duplicates most of the logic of 102 // `PartitionedCallOp`; once that class and this one have evolved to stable 103 // states, and if at that time they remain sufficiently similar, either unify 104 // them in one op or set up an inheritance structure that allows for code reuse. 105 class TPUPartitionedCallOp : public AsyncOpKernel { 106 public: TPUPartitionedCallOp(OpKernelConstruction * ctx)107 explicit TPUPartitionedCallOp(OpKernelConstruction* ctx) 108 : AsyncOpKernel(ctx), 109 pool_(ctx->env(), "InitializeVarOnTPUPool", 1), 110 library_runtime_(nullptr) { 111 OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); 112 // If the importer has set the original function name, it means the function 113 // attribute is referring to a rewritten function, but we need to use the 114 // original function name in order to find it in the function library. 115 std::string orig_f; 116 if (ctx->GetAttr("_orig_f", &orig_f).ok()) { 117 func_.set_name(orig_f); 118 } 119 auto status = ctx->GetAttr("autotuner_thresh", &autotuner_thresh_); 120 if (!status.ok()) { 121 autotuner_thresh_ = 0; 122 } 123 tensorflow::tpu::OpsApiFn()->TfTpu_GetTpuPartitionedCallParamsFn( 124 &runtime_params_); 125 } 126 ~TPUPartitionedCallOp()127 ~TPUPartitionedCallOp() override {} 128 129 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; 130 131 private: 132 struct DeviceAndFHandle { 133 std::string device; 134 FHandle handle; 135 136 // The FLD passed to `library_runtime_` as an overlay function library for 137 // instantiation of function `handle`. This is a snapshot of the currrent 138 // `flib_def_`. Since `flib_def_` can be changed concurrently by another 139 // graph rewrite when executing `handle`, we need to make sure each 140 // `handle` uses a different FLD to avoid races. See b/181149591. 141 std::unique_ptr<FunctionLibraryDefinition> flib_def; 142 }; 143 144 struct TPUMetadata { 145 tpu::TopologyProto topology; 146 int num_cores_per_replica = 1; 147 std::vector<int> device_assignment; 148 }; 149 150 // This method is thread-safe. 151 Status GetTpuCoreOrdinal(OpKernelContext* ctx, uint64 input_hash, 152 int64_t* ordinal_selector_req_id, 153 int32_t* core_ordinal); 154 155 // Helper to create and initialize a TPU variable given a CPU variable 156 // var: the CPU variable created by the user 157 // ndef: the node def of the corresponding TPU var handle that we created 158 // device_ordinal: TPU device ordinal on which to initialize this variable 159 Status InitializeVarOnTPU(OpKernelContext* ctx, 160 const core::RefCountPtr<Var>& var, NodeDef* ndef, 161 int device_ordinal, bool fast_mem) 162 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 163 164 // Helper to create and initialize partitioned TPU variables given a CPU 165 // variable with XLA sharding annotation. 166 // var: the CPU variable created by the user. 167 // ndefs: the node def of the corresponding TPU var handle on all the logical 168 // cores. 169 // split_dim: the partition dimension of the variable. If -1, the variable is 170 // replicated. 171 // device_ordinal: The index of the TPU core that is scheduled to run 172 // the computation. In the case of XLA SPMD, it is the "primary" core, which 173 // is the smallest index of all the cores. 174 Status InitializeShardedVarOnTPU(OpKernelContext* ctx, 175 const core::RefCountPtr<Var>& var, 176 std::vector<NodeDef>& ndefs, int split_dim, 177 const std::vector<string>& tpu_devices) 178 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 179 180 // Check if any of the immediate successors of node has attribute 181 // "_tpu_replicate". 182 bool IsInputToTPUReplicate(Node* node) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 183 184 // Replace an _Arg node of type DT_RESOURCE by a VarHandleOp on TPU 185 Status ReplaceResourceArgsWithVarHandleOps(Graph* graph, OpKernelContext* ctx, 186 int device_ordinal, 187 bool enable_spmd_xla_partitioning, 188 const TPUMetadata& tpu_metadata) 189 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 190 191 // Replace a _Arg node indicates a variable on CPU host by sharded/replicated 192 // variables on all logical TPU devices. 193 Status ReplaceAndPartitionXLAShardingVariable( 194 Graph* graph, OpKernelContext* ctx, int device_ordinal, 195 ResourceHandle& handle, Node* variable, const TPUMetadata& tpu_metadata) 196 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 197 198 Status ShardInputsWithXlaSharding(Graph* graph, 199 const std::string& cluster_name, 200 int num_cores_per_replica, 201 OpKernelContext* ctx) 202 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 203 204 // Rewrite the graph for input and output optimiazations. 205 // TODO(ylc): Move this function to Graph optimization pass. 206 Status OptimizeTpuInputOutputTensors( 207 Graph* graph, bool enable_spmd_xla_partitioning, 208 int num_cores_per_replica, 209 std::map<std::string, std::vector<int>>& named_input_shapes, 210 OpKernelContext* ctx) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 211 212 Status InferShapesWithResourceVar(Graph* graph, OpKernelContext* ctx, 213 std::map<int, InferredShape>& arg_shapes, 214 GraphShapeInfo* tpu_inferred_info); 215 216 // Copies the graph backing `func_` into `graph`. 217 Status GetGraphFromFunction(Graph* graph, int device_ordinal, 218 bool* use_spmd_for_xla_partitioning, 219 TPUMetadata* tpu_metadata) 220 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 221 222 // Places the graph carried by `optimization_options` and runs graph 223 // optimization passes (pre-placement, post-placement, and post-rewrite). 224 Status PlacementHelper( 225 const DeviceSet& device_set, 226 const GraphOptimizationPassOptions& optimization_options, 227 const string& function_name); 228 // Partitions `graph`, populates `subgraphs` with the partitions, and runs 229 // the post-partitioning graph optimization passes. 230 Status PartitionHelper( 231 const DeviceSet& device_set, 232 const GraphOptimizationPassOptions& optimization_options, Graph* graph, 233 std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs); 234 235 // Adds and instantiates a function backed by `graph` with name 236 // `function_name` on device `target_device`, storing the handle in `handle`. 237 // If `out_flib_def` is not null, it will be set to a copy of `flib_def_` and 238 // used for instantiation. 239 Status InstantiatePartition( 240 const Graph& graph, const string& function_name, 241 const string& target_device, FHandle* handle, 242 std::unique_ptr<FunctionLibraryDefinition>* out_flib_def) 243 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 244 // Adds and instantiates functions for each subgraph in `subgraphs` after 245 // rewriting nodes' `device_ordinal` attributes to match `replica_id` when 246 // num_cores_per_replica == 1. 247 Status InstantiateFunctionsFromSubgraphs( 248 const DeviceSet& device_set, int replica_id, uint64 cache_hash, 249 int num_cores_per_replica, 250 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs) 251 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 252 253 // Rewrites `graph` such that the device ordinal attributes of all whitelisted 254 // nodes (see `IsSupportedTPUOp`) are set to `device_ordinal`; 255 // `*modified` is set to true if the graph is modified in the process (i.e., 256 // if it contains a whitelisted node), otherwise is unmodified. 257 // 258 // Returns an error if 259 // (1) the graph contains a non-whitelisted node that carries an attribute 260 // with name "device_ordinal", or 261 // (2) the set of device ordinals found among the graph's nodes has 262 // cardinality greater than 1. 263 Status SetDeviceOrdinal(const DeviceSet& device_set, int device_ordinal, 264 Graph* graph, bool* modified) 265 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 266 267 void ExecuteRemoteFunction(const FunctionLibraryRuntime::Options& opts, 268 FHandle handle, OpKernelContext* ctx, 269 ReffedStatusCallback* done) 270 ABSL_LOCKS_EXCLUDED(mu_); 271 void ExecuteLocalFunction(const FunctionLibraryRuntime::Options& opts, 272 const OpInputList& arguments, FHandle handle, 273 OpKernelContext* ctx, ReffedStatusCallback* done) 274 ABSL_LOCKS_EXCLUDED(mu_); 275 void ExecuteFunctions(const std::vector<DeviceAndFHandle>& functions, 276 OpKernelContext* ctx, int device_ordinal, 277 int64_t ordinal_selector_req_id, DoneCallback done) 278 ABSL_LOCKS_EXCLUDED(mu_); 279 ShouldUseRemoteExecutionForFn(const std::string & target_device,bool * remote_execution)280 Status ShouldUseRemoteExecutionForFn(const std::string& target_device, 281 bool* remote_execution) { 282 DeviceNameUtils::ParsedName target_device_parsed; 283 DeviceNameUtils::ParsedName local_device_parsed; 284 285 if (!DeviceNameUtils::ParseFullOrLocalName(target_device, 286 &target_device_parsed)) { 287 return errors::InvalidArgument("Cannot parse target device ", 288 target_device); 289 } 290 if (!DeviceNameUtils::ParseFullOrLocalName(local_device_name_, 291 &local_device_parsed)) { 292 return errors::InvalidArgument("Cannot parse local device ", 293 local_device_name_); 294 } 295 296 if (DeviceNameUtils::AreCompatibleDevNames(target_device_parsed, 297 local_device_parsed)) { 298 *remote_execution = false; 299 } else { 300 *remote_execution = true; 301 } 302 return OkStatus(); 303 } 304 305 // Init once flagas. 306 absl::once_flag once_; 307 absl::once_flag ordinal_selector_once_; 308 309 // Device manager and device set. 310 const DeviceMgr* device_mgr_; 311 DeviceSet device_set_; 312 313 // Threadpool. 314 thread::ThreadPool pool_; 315 316 // `func_` is the original function supplied to this OpKernel. 317 NameAttrList func_; 318 string local_device_name_; 319 // Maps from cache key to their corresponding functions, which are 320 // represented as (device, handle) pairs. 321 gtl::FlatMap<uint64, std::vector<DeviceAndFHandle>> partition_cache_ 322 ABSL_GUARDED_BY(mu_); 323 324 // A set contains seen ordinals. Used by variable initialization on TPU. 325 absl::flat_hash_set<int> seen_ordinals_; 326 327 // Record the indices of the _Arg with type DT_RESOURCE that goes 328 // into a TPU Op. 329 std::vector<bool> replaced_input_indices_; 330 331 absl::Mutex mu_; 332 // Function shards are added to the `flib_def_`, and later on it'll create 333 // a copy of `flib_def_` to pass to `library_runtime_` as an overlay function 334 // library for instantiation. 335 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 336 FunctionLibraryRuntime* library_runtime_; 337 338 // Used to uniquify function names in `flib_def_`. 339 uint32 suffix_ = 0; 340 341 // Minimum number of run steps (batches) necessary to trigger xla autotuner. 342 int autotuner_thresh_ = 0; 343 344 // TPU core selection. 345 std::shared_ptr<tpu::TPUOrdinalSelector> ordinal_selector_; 346 347 // Maps input hash to TF fingerprint. 348 absl::flat_hash_map<uint64, uint64> inputs_to_fingerprint_; 349 350 // List of TPU devices 351 std::vector<Device*> tpu_devices_; 352 353 TpuPartitionedCall_Params runtime_params_; 354 }; 355 356 } // namespace tensorflow 357 358 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ 359