xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_op_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
16 
17 #include <cstdint>
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/tpu/tpu_compile_interface.h"
24 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
25 
26 namespace tensorflow {
27 namespace tpu {
28 namespace {
CreateShapePrefix(const std::vector<tensorflow::TensorShape> & dynamic_shapes)29 std::string CreateShapePrefix(
30     const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
31   std::string shapes_prefix;
32   for (const TensorShape& shape : dynamic_shapes) {
33     for (int64_t size : shape.dim_sizes()) {
34       absl::StrAppend(&shapes_prefix, size, ",");
35     }
36     absl::StrAppend(&shapes_prefix, ";");
37   }
38   return shapes_prefix;
39 }
40 
41 // Include compilation configurations of the arguments that are not captured
42 // by the called graph.
CreateConfigPrefix(const TPUCompileMetadataProto & metadata)43 std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
44   std::string config_prefix;
45   for (const auto& arg : metadata.args()) {
46     if (arg.is_same_data_across_replicas()) {
47       absl::StrAppend(&config_prefix, ":s");
48       // Same.
49     } else {
50       // Different.
51       absl::StrAppend(&config_prefix, ":");
52     }
53     if (arg.enable_xla_sharding() ==
54         tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
55       // Enabled.
56       absl::StrAppend(&config_prefix, "e");
57     }
58     if (arg.unrestricted_layout()) {
59       // Unrestricted.
60       absl::StrAppend(&config_prefix, ":u");
61     }
62     absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
63     if (arg.has_shape()) {
64       absl::StrAppend(&config_prefix, ",shape(");
65       for (const auto& dim : arg.shape().dim()) {
66         absl::StrAppend(&config_prefix, dim.size(), ",");
67       }
68       absl::StrAppend(&config_prefix, ")");
69     }
70   }
71   return config_prefix;
72 }
73 }  // namespace
74 
CreateFingerprintWithNameAndShapes(uint64 name,const std::vector<tensorflow::TensorShape> & shapes)75 uint64 CreateFingerprintWithNameAndShapes(
76     uint64 name, const std::vector<tensorflow::TensorShape>& shapes) {
77   std::string shape_prefix = CreateShapePrefix(shapes);
78   VLOG(2) << "CreateFingerprintWithNameAndShapes, name: " << name
79           << ", shape_prefix: " << shape_prefix;
80   return TpuCompileInterface::Get()->FingerprintString(
81       absl::StrCat(name, "_", shape_prefix));
82 }
83 
84 // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
85 // data to compute the fingerprint.
GuaranteedConstFingerprint(const string & fingerprint_in_metadata,const OpInputList & guaranteed_constants)86 std::string GuaranteedConstFingerprint(
87     const string& fingerprint_in_metadata,
88     const OpInputList& guaranteed_constants) {
89   if (fingerprint_in_metadata.empty()) {
90     uint64_t fingerprint = 0;
91     for (const Tensor& constant : guaranteed_constants) {
92       fingerprint =
93           tpu::OpsApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
94               fingerprint, constant.tensor_data().data(),
95               constant.tensor_data().size());
96     }
97     return std::to_string(fingerprint);
98   } else {
99     return fingerprint_in_metadata;
100   }
101 }
102 
103 // The `guaranteed_constants` must be passed as reference due to the lazy
104 // evaluation of `guaranteed_const_fingerprint()` callback.
CreateCompilationCacheKey(absl::string_view function_name,uint64 function_library_fingerprint,uint64 mlir_module_fingerprint,const OpInputList & guaranteed_constants,const std::vector<TensorShape> & dynamic_shapes,const TPUCompileMetadataProto & metadata,const TpuMeshStateInterface & mesh_state,uint64_t session_id,ResourceMgr * resource_mgr)105 TpuCompilationCacheKey CreateCompilationCacheKey(
106     absl::string_view function_name, uint64 function_library_fingerprint,
107     uint64 mlir_module_fingerprint, const OpInputList& guaranteed_constants,
108     const std::vector<TensorShape>& dynamic_shapes,
109     const TPUCompileMetadataProto& metadata,
110     const TpuMeshStateInterface& mesh_state, uint64_t session_id,
111     ResourceMgr* resource_mgr) {
112   VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
113   std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
114   VLOG(1) << "shapes_prefix = " << shapes_prefix;
115   std::string config_prefix = CreateConfigPrefix(metadata);
116   VLOG(1) << "config_prefix = " << config_prefix;
117   std::vector<int32_t> flattened_device_ids;
118   if (metadata.has_device_assignment()) {
119     for (const auto& device :
120          metadata.device_assignment().computation_devices()) {
121       flattened_device_ids.insert(flattened_device_ids.end(),
122                                   device.replica_device_ids().begin(),
123                                   device.replica_device_ids().end());
124     }
125   }
126   CompilationCacheKeyResult result =
127       tpu::OpsApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
128           CompilationCacheKeyProperty{
129               config_prefix.data(), shapes_prefix.data(), function_name.data(),
130               mlir_module_fingerprint, flattened_device_ids.data(),
131               flattened_device_ids.size(), guaranteed_constants.size(),
132               function_library_fingerprint, metadata.num_cores_per_replica(),
133               metadata.num_replicas(), mesh_state.data(), session_id,
134               resource_mgr});
135   auto buffer_cleanup = gtl::MakeCleanup([result]() {
136     tpu::OpsApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
137   });
138   TpuCompilationCacheKey key;
139   key.prefix = result.key;
140   key.debug_string = result.debug_string;
141   key.session_id = session_id;
142 
143   // Guaranteed constants can be different across sessions. Use session_handle
144   // and guaranteed_const fingerprint to guarantee no collision.
145   if (guaranteed_constants.size() > 0) {
146     key.has_guaranteed_const = true;
147     key.session_handle = metadata.session_handle();
148     // Both `metadata` and `guaranteed_constants` lifetime are captured by
149     // reference based on the assumption that these variables lifetime is
150     // managed through the `TPUCompileOpKernelImpl` that outlives the
151     // lifetime of the compilation cache lookups.
152     string fingerprint;
153     key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
154                                         fingerprint]() mutable {
155       if (fingerprint.empty()) {
156         fingerprint = GuaranteedConstFingerprint(
157             metadata.guaranteed_const_fingerprint(), guaranteed_constants);
158       }
159       return fingerprint;
160     };
161   }
162   return key;
163 }
164 }  // namespace tpu
165 }  // namespace tensorflow
166