1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
16
17 #include <cstdint>
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/tpu/tpu_compile_interface.h"
24 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
25
26 namespace tensorflow {
27 namespace tpu {
28 namespace {
CreateShapePrefix(const std::vector<tensorflow::TensorShape> & dynamic_shapes)29 std::string CreateShapePrefix(
30 const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
31 std::string shapes_prefix;
32 for (const TensorShape& shape : dynamic_shapes) {
33 for (int64_t size : shape.dim_sizes()) {
34 absl::StrAppend(&shapes_prefix, size, ",");
35 }
36 absl::StrAppend(&shapes_prefix, ";");
37 }
38 return shapes_prefix;
39 }
40
41 // Include compilation configurations of the arguments that are not captured
42 // by the called graph.
CreateConfigPrefix(const TPUCompileMetadataProto & metadata)43 std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
44 std::string config_prefix;
45 for (const auto& arg : metadata.args()) {
46 if (arg.is_same_data_across_replicas()) {
47 absl::StrAppend(&config_prefix, ":s");
48 // Same.
49 } else {
50 // Different.
51 absl::StrAppend(&config_prefix, ":");
52 }
53 if (arg.enable_xla_sharding() ==
54 tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
55 // Enabled.
56 absl::StrAppend(&config_prefix, "e");
57 }
58 if (arg.unrestricted_layout()) {
59 // Unrestricted.
60 absl::StrAppend(&config_prefix, ":u");
61 }
62 absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
63 if (arg.has_shape()) {
64 absl::StrAppend(&config_prefix, ",shape(");
65 for (const auto& dim : arg.shape().dim()) {
66 absl::StrAppend(&config_prefix, dim.size(), ",");
67 }
68 absl::StrAppend(&config_prefix, ")");
69 }
70 }
71 return config_prefix;
72 }
73 } // namespace
74
CreateFingerprintWithNameAndShapes(uint64 name,const std::vector<tensorflow::TensorShape> & shapes)75 uint64 CreateFingerprintWithNameAndShapes(
76 uint64 name, const std::vector<tensorflow::TensorShape>& shapes) {
77 std::string shape_prefix = CreateShapePrefix(shapes);
78 VLOG(2) << "CreateFingerprintWithNameAndShapes, name: " << name
79 << ", shape_prefix: " << shape_prefix;
80 return TpuCompileInterface::Get()->FingerprintString(
81 absl::StrCat(name, "_", shape_prefix));
82 }
83
84 // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
85 // data to compute the fingerprint.
GuaranteedConstFingerprint(const string & fingerprint_in_metadata,const OpInputList & guaranteed_constants)86 std::string GuaranteedConstFingerprint(
87 const string& fingerprint_in_metadata,
88 const OpInputList& guaranteed_constants) {
89 if (fingerprint_in_metadata.empty()) {
90 uint64_t fingerprint = 0;
91 for (const Tensor& constant : guaranteed_constants) {
92 fingerprint =
93 tpu::OpsApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
94 fingerprint, constant.tensor_data().data(),
95 constant.tensor_data().size());
96 }
97 return std::to_string(fingerprint);
98 } else {
99 return fingerprint_in_metadata;
100 }
101 }
102
103 // The `guaranteed_constants` must be passed as reference due to the lazy
104 // evaluation of `guaranteed_const_fingerprint()` callback.
CreateCompilationCacheKey(absl::string_view function_name,uint64 function_library_fingerprint,uint64 mlir_module_fingerprint,const OpInputList & guaranteed_constants,const std::vector<TensorShape> & dynamic_shapes,const TPUCompileMetadataProto & metadata,const TpuMeshStateInterface & mesh_state,uint64_t session_id,ResourceMgr * resource_mgr)105 TpuCompilationCacheKey CreateCompilationCacheKey(
106 absl::string_view function_name, uint64 function_library_fingerprint,
107 uint64 mlir_module_fingerprint, const OpInputList& guaranteed_constants,
108 const std::vector<TensorShape>& dynamic_shapes,
109 const TPUCompileMetadataProto& metadata,
110 const TpuMeshStateInterface& mesh_state, uint64_t session_id,
111 ResourceMgr* resource_mgr) {
112 VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
113 std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
114 VLOG(1) << "shapes_prefix = " << shapes_prefix;
115 std::string config_prefix = CreateConfigPrefix(metadata);
116 VLOG(1) << "config_prefix = " << config_prefix;
117 std::vector<int32_t> flattened_device_ids;
118 if (metadata.has_device_assignment()) {
119 for (const auto& device :
120 metadata.device_assignment().computation_devices()) {
121 flattened_device_ids.insert(flattened_device_ids.end(),
122 device.replica_device_ids().begin(),
123 device.replica_device_ids().end());
124 }
125 }
126 CompilationCacheKeyResult result =
127 tpu::OpsApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
128 CompilationCacheKeyProperty{
129 config_prefix.data(), shapes_prefix.data(), function_name.data(),
130 mlir_module_fingerprint, flattened_device_ids.data(),
131 flattened_device_ids.size(), guaranteed_constants.size(),
132 function_library_fingerprint, metadata.num_cores_per_replica(),
133 metadata.num_replicas(), mesh_state.data(), session_id,
134 resource_mgr});
135 auto buffer_cleanup = gtl::MakeCleanup([result]() {
136 tpu::OpsApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
137 });
138 TpuCompilationCacheKey key;
139 key.prefix = result.key;
140 key.debug_string = result.debug_string;
141 key.session_id = session_id;
142
143 // Guaranteed constants can be different across sessions. Use session_handle
144 // and guaranteed_const fingerprint to guarantee no collision.
145 if (guaranteed_constants.size() > 0) {
146 key.has_guaranteed_const = true;
147 key.session_handle = metadata.session_handle();
148 // Both `metadata` and `guaranteed_constants` lifetime are captured by
149 // reference based on the assumption that these variables lifetime is
150 // managed through the `TPUCompileOpKernelImpl` that outlives the
151 // lifetime of the compilation cache lookups.
152 string fingerprint;
153 key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
154 fingerprint]() mutable {
155 if (fingerprint.empty()) {
156 fingerprint = GuaranteedConstFingerprint(
157 metadata.guaranteed_const_fingerprint(), guaranteed_constants);
158 }
159 return fingerprint;
160 };
161 }
162 return key;
163 }
164 } // namespace tpu
165 } // namespace tensorflow
166