xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_functional_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_
17 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_
18 
19 #include "absl/base/call_once.h"
20 #include "tensorflow/compiler/jit/shape_inference.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/optimization_registry.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
28 #include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h"
29 #include "tensorflow/core/tpu/tpu_api.h"
30 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
31 #include "tensorflow/core/util/reffed_status_callback.h"
32 #include "absl/container/flat_hash_map.h"
33 
34 namespace tensorflow {
35 // Holds node's shape information for Concat/Split.
36 using EdgeShapes = absl::flat_hash_map<const Edge*, std::vector<int>>;
37 using GroupedEdges =
38     absl::flat_hash_map<std::string, std::vector<const Edge*>>;
39 
40 // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op that
41 // we find as part of searching for inputs to models that are replicated.
42 using XlaShardingInfoMap = absl::flat_hash_map<
43     std::string, std::tuple<tensorflow::DataType, std::string, std::string>>;
44 
45 // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep
46 // for each TpuReplicatedInput op that we find as part of searching for inputs
47 // to models that are replicated.
48 using TpuReplicatedInputInfoMap =
49     absl::flat_hash_map<std::string,
50                            std::tuple<tensorflow::DataType, Node*>>;
51 
52 namespace tpu_functional_internal {
53 
54 // Helper functions for graph rewrites.
55 GroupedEdges GroupTensorsForInputPacking(
56     const EdgeShapes& tpu_input_shapes,
57     const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
58     bool input_shape_opt, bool group_tensors_for_packing);
59 GroupedEdges GroupTensorsForOutputPacking(Graph* graph,
60                                           EdgeShapes& tpu_output_shapes,
61                                           GraphShapeInfo* shape_info);
62 
63 Status CreateConcatAndSplitNodesForInputTensor(
64     Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes,
65     const absl::flat_hash_map<std::string, std::vector<const Edge*>>&
66         grouped_input_edges,
67     int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded,
68     const XlaShardingInfoMap& xla_sharding_info,
69     const TpuReplicatedInputInfoMap& tpu_replicated_input_info);
70 Status CreateConcatAndSplitNodesForOutputTensor(
71     Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes,
72     GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output,
73     int32_t minimum_output_tensors_packing);
74 
75 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name,
76                               EdgeShapes* tpu_input_shapes,
77                               int num_cores_per_replica);
78 
79 }  // namespace tpu_functional_internal
80 
81 typedef FunctionLibraryRuntime::Handle FHandle;
82 
83 // A `TPUPartitionedCallOp` asynchronously executes a function on exactly one
84 // TPU core and potentially across multiple other devices, but within a single
85 // process. The kernel places and partitions the function's underlying graph,
86 // executing each of the partitioned subgraphs as a function.
87 //
88 // The core on which the TPU computation is executed must be specified via the
89 // `device_ordinal` input. Different invocations of this op may specify
90 // different device ordinals, making it possible to map TPU computations to
91 // different cores at runtime. Currently, macro-substitution of device ordinals
92 // is only supported for the following whitelisted ops:
93 //   * TPUExecute
94 //   * InfeedEnqueue
95 //   * InfeedEnqueueTuple
96 //
97 // Attempting to compute a TPUPartitionedCallOp whose function body has a
98 // non-whitelisted node bearing an attribute named "device_ordinal" will result
99 // in an error.
100 //
101 // TODO(akshayka): This class duplicates most of the logic of
102 // `PartitionedCallOp`; once that class and this one have evolved to stable
103 // states, and if at that time they remain sufficiently similar, either unify
104 // them in one op or set up an inheritance structure that allows for code reuse.
105 class TPUPartitionedCallOp : public AsyncOpKernel {
106  public:
TPUPartitionedCallOp(OpKernelConstruction * ctx)107   explicit TPUPartitionedCallOp(OpKernelConstruction* ctx)
108       : AsyncOpKernel(ctx),
109         pool_(ctx->env(), "InitializeVarOnTPUPool", 1),
110         library_runtime_(nullptr) {
111     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
112     // If the importer has set the original function name, it means the function
113     // attribute is referring to a rewritten function, but we need to use the
114     // original function name in order to find it in the function library.
115     std::string orig_f;
116     if (ctx->GetAttr("_orig_f", &orig_f).ok()) {
117       func_.set_name(orig_f);
118     }
119     auto status = ctx->GetAttr("autotuner_thresh", &autotuner_thresh_);
120     if (!status.ok()) {
121       autotuner_thresh_ = 0;
122     }
123     tensorflow::tpu::OpsApiFn()->TfTpu_GetTpuPartitionedCallParamsFn(
124         &runtime_params_);
125   }
126 
~TPUPartitionedCallOp()127   ~TPUPartitionedCallOp() override {}
128 
129   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
130 
131  private:
132   struct DeviceAndFHandle {
133     std::string device;
134     FHandle handle;
135 
136     // The FLD passed to `library_runtime_` as an overlay function library for
137     // instantiation of function `handle`. This is a snapshot of the currrent
138     // `flib_def_`. Since `flib_def_` can be changed concurrently by another
139     // graph rewrite when executing `handle`, we need to make sure each
140     // `handle` uses a different FLD to avoid races. See b/181149591.
141     std::unique_ptr<FunctionLibraryDefinition> flib_def;
142   };
143 
144   struct TPUMetadata {
145     tpu::TopologyProto topology;
146     int num_cores_per_replica = 1;
147     std::vector<int> device_assignment;
148   };
149 
150   // This method is thread-safe.
151   Status GetTpuCoreOrdinal(OpKernelContext* ctx, uint64 input_hash,
152                            int64_t* ordinal_selector_req_id,
153                            int32_t* core_ordinal);
154 
155   // Helper to create and initialize a TPU variable given a CPU variable
156   // var: the CPU variable created by the user
157   // ndef: the node def of the corresponding TPU var handle that we created
158   // device_ordinal: TPU device ordinal on which to initialize this variable
159   Status InitializeVarOnTPU(OpKernelContext* ctx,
160                             const core::RefCountPtr<Var>& var, NodeDef* ndef,
161                             int device_ordinal, bool fast_mem)
162       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
163 
164   // Helper to create and initialize partitioned TPU variables given a CPU
165   // variable with XLA sharding annotation.
166   // var: the CPU variable created by the user.
167   // ndefs: the node def of the corresponding TPU var handle on all the logical
168   //   cores.
169   // split_dim: the partition dimension of the variable. If -1, the variable is
170   //   replicated.
171   // device_ordinal: The index of the TPU core that is scheduled to run
172   //   the computation. In the case of XLA SPMD, it is the "primary" core, which
173   //   is the smallest index of all the cores.
174   Status InitializeShardedVarOnTPU(OpKernelContext* ctx,
175                                    const core::RefCountPtr<Var>& var,
176                                    std::vector<NodeDef>& ndefs, int split_dim,
177                                    const std::vector<string>& tpu_devices)
178       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
179 
180   // Check if any of the immediate successors of node has attribute
181   // "_tpu_replicate".
182   bool IsInputToTPUReplicate(Node* node) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
183 
184   // Replace an _Arg node of type DT_RESOURCE by a VarHandleOp on TPU
185   Status ReplaceResourceArgsWithVarHandleOps(Graph* graph, OpKernelContext* ctx,
186                                              int device_ordinal,
187                                              bool enable_spmd_xla_partitioning,
188                                              const TPUMetadata& tpu_metadata)
189       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
190 
191   // Replace a _Arg node indicates a variable on CPU host by sharded/replicated
192   // variables on all logical TPU devices.
193   Status ReplaceAndPartitionXLAShardingVariable(
194       Graph* graph, OpKernelContext* ctx, int device_ordinal,
195       ResourceHandle& handle, Node* variable, const TPUMetadata& tpu_metadata)
196       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
197 
198   Status ShardInputsWithXlaSharding(Graph* graph,
199                                     const std::string& cluster_name,
200                                     int num_cores_per_replica,
201                                     OpKernelContext* ctx)
202       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
203 
204   // Rewrite the graph for input and output optimiazations.
205   // TODO(ylc): Move this function to Graph optimization pass.
206   Status OptimizeTpuInputOutputTensors(
207       Graph* graph, bool enable_spmd_xla_partitioning,
208       int num_cores_per_replica,
209       std::map<std::string, std::vector<int>>& named_input_shapes,
210       OpKernelContext* ctx) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
211 
212   Status InferShapesWithResourceVar(Graph* graph, OpKernelContext* ctx,
213                                     std::map<int, InferredShape>& arg_shapes,
214                                     GraphShapeInfo* tpu_inferred_info);
215 
216   // Copies the graph backing `func_` into `graph`.
217   Status GetGraphFromFunction(Graph* graph, int device_ordinal,
218                               bool* use_spmd_for_xla_partitioning,
219                               TPUMetadata* tpu_metadata)
220       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
221 
222   // Places the graph carried by `optimization_options` and runs graph
223   // optimization passes (pre-placement, post-placement, and post-rewrite).
224   Status PlacementHelper(
225       const DeviceSet& device_set,
226       const GraphOptimizationPassOptions& optimization_options,
227       const string& function_name);
228   // Partitions `graph`, populates `subgraphs` with the partitions, and runs
229   // the post-partitioning graph optimization passes.
230   Status PartitionHelper(
231       const DeviceSet& device_set,
232       const GraphOptimizationPassOptions& optimization_options, Graph* graph,
233       std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs);
234 
235   // Adds and instantiates a function backed by `graph` with name
236   // `function_name` on device `target_device`, storing the handle in `handle`.
237   // If `out_flib_def` is not null, it will be set to a copy of `flib_def_` and
238   // used for instantiation.
239   Status InstantiatePartition(
240       const Graph& graph, const string& function_name,
241       const string& target_device, FHandle* handle,
242       std::unique_ptr<FunctionLibraryDefinition>* out_flib_def)
243       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
244   // Adds and instantiates functions for each subgraph in `subgraphs` after
245   // rewriting nodes' `device_ordinal` attributes to match `replica_id` when
246   // num_cores_per_replica == 1.
247   Status InstantiateFunctionsFromSubgraphs(
248       const DeviceSet& device_set, int replica_id, uint64 cache_hash,
249       int num_cores_per_replica,
250       std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs)
251       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
252 
253   // Rewrites `graph` such that the device ordinal attributes of all whitelisted
254   // nodes (see `IsSupportedTPUOp`) are set to `device_ordinal`;
255   // `*modified` is set to true if the graph is modified in the process (i.e.,
256   // if it contains a whitelisted node), otherwise is unmodified.
257   //
258   // Returns an error if
259   //   (1) the graph contains a non-whitelisted node that carries an attribute
260   //       with name "device_ordinal", or
261   //   (2) the set of device ordinals found among the graph's nodes has
262   //       cardinality greater than 1.
263   Status SetDeviceOrdinal(const DeviceSet& device_set, int device_ordinal,
264                           Graph* graph, bool* modified)
265       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
266 
267   void ExecuteRemoteFunction(const FunctionLibraryRuntime::Options& opts,
268                              FHandle handle, OpKernelContext* ctx,
269                              ReffedStatusCallback* done)
270       ABSL_LOCKS_EXCLUDED(mu_);
271   void ExecuteLocalFunction(const FunctionLibraryRuntime::Options& opts,
272                             const OpInputList& arguments, FHandle handle,
273                             OpKernelContext* ctx, ReffedStatusCallback* done)
274       ABSL_LOCKS_EXCLUDED(mu_);
275   void ExecuteFunctions(const std::vector<DeviceAndFHandle>& functions,
276                         OpKernelContext* ctx, int device_ordinal,
277                         int64_t ordinal_selector_req_id, DoneCallback done)
278       ABSL_LOCKS_EXCLUDED(mu_);
279 
ShouldUseRemoteExecutionForFn(const std::string & target_device,bool * remote_execution)280   Status ShouldUseRemoteExecutionForFn(const std::string& target_device,
281                                        bool* remote_execution) {
282     DeviceNameUtils::ParsedName target_device_parsed;
283     DeviceNameUtils::ParsedName local_device_parsed;
284 
285     if (!DeviceNameUtils::ParseFullOrLocalName(target_device,
286                                                &target_device_parsed)) {
287       return errors::InvalidArgument("Cannot parse target device ",
288                                      target_device);
289     }
290     if (!DeviceNameUtils::ParseFullOrLocalName(local_device_name_,
291                                                &local_device_parsed)) {
292       return errors::InvalidArgument("Cannot parse local device ",
293                                      local_device_name_);
294     }
295 
296     if (DeviceNameUtils::AreCompatibleDevNames(target_device_parsed,
297                                                local_device_parsed)) {
298       *remote_execution = false;
299     } else {
300       *remote_execution = true;
301     }
302     return OkStatus();
303   }
304 
305   // Init once flagas.
306   absl::once_flag once_;
307   absl::once_flag ordinal_selector_once_;
308 
309   // Device manager and device set.
310   const DeviceMgr* device_mgr_;
311   DeviceSet device_set_;
312 
313   // Threadpool.
314   thread::ThreadPool pool_;
315 
316   // `func_` is the original function supplied to this OpKernel.
317   NameAttrList func_;
318   string local_device_name_;
319   // Maps from cache key to their corresponding functions, which are
320   // represented as (device, handle) pairs.
321   gtl::FlatMap<uint64, std::vector<DeviceAndFHandle>> partition_cache_
322       ABSL_GUARDED_BY(mu_);
323 
324   // A set contains seen ordinals. Used by variable initialization on TPU.
325   absl::flat_hash_set<int> seen_ordinals_;
326 
327   // Record the indices of the _Arg with type DT_RESOURCE that goes
328   // into a TPU Op.
329   std::vector<bool> replaced_input_indices_;
330 
331   absl::Mutex mu_;
332   // Function shards are added to the `flib_def_`, and later on it'll create
333   // a copy of `flib_def_` to pass to `library_runtime_` as an overlay function
334   // library for instantiation.
335   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
336   FunctionLibraryRuntime* library_runtime_;
337 
338   // Used to uniquify function names in `flib_def_`.
339   uint32 suffix_ = 0;
340 
341   // Minimum number of run steps (batches) necessary to trigger xla autotuner.
342   int autotuner_thresh_ = 0;
343 
344   // TPU core selection.
345   std::shared_ptr<tpu::TPUOrdinalSelector> ordinal_selector_;
346 
347   // Maps input hash to TF fingerprint.
348   absl::flat_hash_map<uint64, uint64> inputs_to_fingerprint_;
349 
350   // List of TPU devices
351   std::vector<Device*> tpu_devices_;
352 
353   TpuPartitionedCall_Params runtime_params_;
354 };
355 
356 }  // namespace tensorflow
357 
358 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_
359