xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/dtensor_device.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_
17 #define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_
18 
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/c/eager/c_api_experimental.h"
25 
26 namespace tensorflow {
27 namespace dtensor {
28 
29 // Configure a custom device which runs dtensor while executing
30 // operations on `underlying_devices`. Allocates `device_info` and fills
31 // `device`, which should then be passed to
32 // TFE_RegisterCustomDevice. This only affects eager execution.
33 //
34 // `device_name` arg should match the `device_name` argument to
35 // TFE_RegisterCustomDevice, and is the name of the custom device itself
36 // (e.g. pass it to `tf.device` to place operations on it from Python).
37 void AllocateDTensorDevice(absl::string_view device_name,
38                            TFE_CustomDevice* device, void** device_info);
39 
40 // Add a mesh to the layout propagator indicated by `device_info`.
41 //
42 // `serialized_mesh` is a serialized Mesh proto.
43 //
44 // is_async indicates whether DTensor operations on this mesh will return
45 // immediately (with "non-ready" handles) or block until executed. This is
46 // exposed as an option for ease of debugging, and will typically be on.
47 //
48 // `is_host_mesh` indicates this is a CPU mesh used only for sea-of-donuts-style
49 // host collectives.
50 void AddMesh(const std::string& serialized_mesh, void* device_info,
51              bool is_async, bool is_host_mesh, TF_Status* status);
52 
53 // Sets a requested layout for outputs of all operations.
54 void ExperimentalSetDefaultLayout(const std::string& serialized_layout,
55                                   void* device_info, TF_Status* status);
56 void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status);
57 
58 // TODO(b/175928457): remove once the bug is fixed.
59 // Sets a requested default mesh.
60 void ExperimentalSetDefaultMesh(const std::string& serialized_mesh,
61                                 void* device_info, TF_Status* status);
62 void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status);
63 
64 // Determines whether tensors with a shape previously associated with only one
65 // layout use that layout if nothing else can be inferred.
66 void SetSameShapePolicy(void* device_info, bool enabled);
67 
68 // Sets the global device ID-to-core ID mapping for a mesh. Global device IDs
69 // are equal to XLA replica IDs for the single XLA computation used by DTensor.
70 //
71 // See the comment above Mesh::tpu_core_ids() for some nuances.
72 void SetTPUCoreIDs(const std::string& mesh_name,
73                    const std::vector<int>& tpu_core_ids, void* device_info,
74                    TF_Status* status);
75 
76 // TODO(b/187112276): Delete once we have the TPUCoreIDs live with Device.
77 void ClearTPUCoreIDs(void* device_info);
78 
79 // Returns TPU core locations when given a list of TPU core IDs.
80 std::vector<std::vector<int>> TPUCoreIDsToLocations(
81     TFE_Context* context, const std::vector<int>& tpu_core_ids,
82     void* device_info);
83 
84 // Returns TPU core IDs when given a list of TPU core locations.
85 std::vector<int> TPUCoreLocationsToIDs(
86     TFE_Context* context,
87     const std::vector<std::vector<int>>& tpu_core_locations, void* device_info);
88 
89 // Pack `inputs` tensors into a single parallel tensor handle.
90 TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
91                        TFE_TensorHandle** inputs,
92                        const std::string& string_layout, void* device_info,
93                        TF_Status* status);
94 
95 // Returns the raw components placed on each device of `inputs`'s mesh.
96 std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
97                                       TFE_TensorHandle* input,
98                                       void* device_info, TF_Status* status);
99 
100 // Returns the layout of the dtensor 'input'.
101 std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
102                         void* device_info, TF_Status* status);
103 
104 // Pack `indices`, `values`, `shapes` tensors into a SparseTensorWithLayout.
105 TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
106                              TFE_TensorHandle** indices,
107                              TFE_TensorHandle** values,
108                              TFE_TensorHandle** shapes,
109                              const std::string& string_layout,
110                              void* device_info, TF_Status* status);
111 
112 // Returns whether `input` is a sparse dtensor. Used in `Unpack` at the python
113 // level to determine whether we should wrap component tensors back into a
114 // SparseTensor.
115 bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
116                      void* device_info, TF_Status* status);
117 
118 // Returns a dictionary with cache hits and cache miss information.
119 // Cache hit count is mapped under 'hit', and cache miss count is mapped under
120 // 'miss'.
121 std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
122     TFE_Context* context, void* device_info, TF_Status* status);
123 }  // namespace dtensor
124 }  // namespace tensorflow
125 
126 #endif  // TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_
127