1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_ 17 #define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_ 18 19 #include <string> 20 #include <unordered_map> 21 #include <vector> 22 23 #include "absl/strings/string_view.h" 24 #include "tensorflow/c/eager/c_api_experimental.h" 25 26 namespace tensorflow { 27 namespace dtensor { 28 29 // Configure a custom device which runs dtensor while executing 30 // operations on `underlying_devices`. Allocates `device_info` and fills 31 // `device`, which should then be passed to 32 // TFE_RegisterCustomDevice. This only affects eager execution. 33 // 34 // `device_name` arg should match the `device_name` argument to 35 // TFE_RegisterCustomDevice, and is the name of the custom device itself 36 // (e.g. pass it to `tf.device` to place operations on it from Python). 37 void AllocateDTensorDevice(absl::string_view device_name, 38 TFE_CustomDevice* device, void** device_info); 39 40 // Add a mesh to the layout propagator indicated by `device_info`. 41 // 42 // `serialized_mesh` is a serialized Mesh proto. 43 // 44 // is_async indicates whether DTensor operations on this mesh will return 45 // immediately (with "non-ready" handles) or block until executed. This is 46 // exposed as an option for ease of debugging, and will typically be on. 47 // 48 // `is_host_mesh` indicates this is a CPU mesh used only for sea-of-donuts-style 49 // host collectives. 50 void AddMesh(const std::string& serialized_mesh, void* device_info, 51 bool is_async, bool is_host_mesh, TF_Status* status); 52 53 // Sets a requested layout for outputs of all operations. 54 void ExperimentalSetDefaultLayout(const std::string& serialized_layout, 55 void* device_info, TF_Status* status); 56 void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status); 57 58 // TODO(b/175928457): remove once the bug is fixed. 59 // Sets a requested default mesh. 60 void ExperimentalSetDefaultMesh(const std::string& serialized_mesh, 61 void* device_info, TF_Status* status); 62 void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status); 63 64 // Determines whether tensors with a shape previously associated with only one 65 // layout use that layout if nothing else can be inferred. 66 void SetSameShapePolicy(void* device_info, bool enabled); 67 68 // Sets the global device ID-to-core ID mapping for a mesh. Global device IDs 69 // are equal to XLA replica IDs for the single XLA computation used by DTensor. 70 // 71 // See the comment above Mesh::tpu_core_ids() for some nuances. 72 void SetTPUCoreIDs(const std::string& mesh_name, 73 const std::vector<int>& tpu_core_ids, void* device_info, 74 TF_Status* status); 75 76 // TODO(b/187112276): Delete once we have the TPUCoreIDs live with Device. 77 void ClearTPUCoreIDs(void* device_info); 78 79 // Returns TPU core locations when given a list of TPU core IDs. 80 std::vector<std::vector<int>> TPUCoreIDsToLocations( 81 TFE_Context* context, const std::vector<int>& tpu_core_ids, 82 void* device_info); 83 84 // Returns TPU core IDs when given a list of TPU core locations. 85 std::vector<int> TPUCoreLocationsToIDs( 86 TFE_Context* context, 87 const std::vector<std::vector<int>>& tpu_core_locations, void* device_info); 88 89 // Pack `inputs` tensors into a single parallel tensor handle. 90 TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs, 91 TFE_TensorHandle** inputs, 92 const std::string& string_layout, void* device_info, 93 TF_Status* status); 94 95 // Returns the raw components placed on each device of `inputs`'s mesh. 96 std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context, 97 TFE_TensorHandle* input, 98 void* device_info, TF_Status* status); 99 100 // Returns the layout of the dtensor 'input'. 101 std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input, 102 void* device_info, TF_Status* status); 103 104 // Pack `indices`, `values`, `shapes` tensors into a SparseTensorWithLayout. 105 TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs, 106 TFE_TensorHandle** indices, 107 TFE_TensorHandle** values, 108 TFE_TensorHandle** shapes, 109 const std::string& string_layout, 110 void* device_info, TF_Status* status); 111 112 // Returns whether `input` is a sparse dtensor. Used in `Unpack` at the python 113 // level to determine whether we should wrap component tensors back into a 114 // SparseTensor. 115 bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input, 116 void* device_info, TF_Status* status); 117 118 // Returns a dictionary with cache hits and cache miss information. 119 // Cache hit count is mapped under 'hit', and cache miss count is mapped under 120 // 'miss'. 121 std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount( 122 TFE_Context* context, void* device_info, TF_Status* status); 123 } // namespace dtensor 124 } // namespace tensorflow 125 126 #endif // TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_ 127