xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
16 
17 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
18 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
19 #include "tensorflow/core/platform/casts.h"
20 #if defined(LIBTPU_ON_GCE)
21 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
22 #endif
23 #include "absl/cleanup/cleanup.h"
24 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
25 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 
28 namespace tensorflow {
29 namespace tpu {
CreateChannelCredentials()30 std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
31   return ::grpc::InsecureChannelCredentials();  // NOLINT
32 }
33 
34 #if defined(LIBTPU_ON_GCE)
35 template <>
DeserializeRpcResponseToCacheEntry(absl::string_view local_proto_key,GetTpuProgramResponseExternal * response,std::shared_ptr<CacheEntry> * cache_entry)36 Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
37     absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
38     std::shared_ptr<CacheEntry>* cache_entry) {
39   CHECK_NE(response, nullptr);
40   CHECK_NE(cache_entry, nullptr);
41   *cache_entry = std::make_shared<CacheEntry>();
42   CacheEntry& entry = **cache_entry;
43   entry.key = std::string(local_proto_key);
44 
45   if (response->is_empty()) {
46     entry.size = 0;
47   } else {
48     TpuSerializedProto serialized_response_proto =
49         stream_executor::tpu::SerializeProto(*response);
50     auto cleanup = absl::MakeCleanup([&serialized_response_proto]() {
51       stream_executor::tpu::SerializedProto_Free(serialized_response_proto);
52     });
53     // When we lookup from remote cache, we fetch a TPU program for a specific
54     // core, hence we allocate TPU program group for a single program.
55     auto tpu_program_group = absl::make_unique<TpuProgramGroup>();
56 
57     // TODO(b/166575150): can be optimized by sending the buffer over the gRPC
58     // without an extra deserializing.
59     TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromRpcResponseProtos(
60         {serialized_response_proto}));
61     entry.tpu_program_group = std::move(tpu_program_group);
62     entry.size = entry.tpu_program_group->program_size();
63   }
64 
65   return Status::OK();
66 }
67 
SerializeCacheEntryToBufferSlices(const TpuCompilationCacheEntry & cache_entry)68 xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
69     const TpuCompilationCacheEntry& cache_entry) {
70   if (cache_entry.tpu_program_group() == nullptr) {
71     // It's possible that the sharding/unsharding entry does not exist, but the
72     // main entry must exist.
73     GetTpuProgramResponseExternal header;
74     header.set_is_empty(true);
75     std::string encoded_header;
76     if (!header.AppendToString(&encoded_header)) {
77       return errors::Internal("Failed to serialize TPU program metadata.");
78     }
79     ::grpc::Slice slice(encoded_header);
80     return std::vector<::grpc::Slice>{slice};
81   }
82 
83   const TpuProgramGroup* tpu_program_group =
84       tensorflow::down_cast<const TpuProgramGroup*>(
85           cache_entry.tpu_program_group());
86   CHECK_NE(tpu_program_group, nullptr);
87   CHECK_GE(tpu_program_group->program_count(), 0);
88   CHECK_GE(cache_entry.core_index(), 0);
89   CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count());
90   const int64 program_size = tpu_program_group->program_size();
91   if (program_size > INT_MAX) {
92     return errors::Internal("TPU program exceeded 2 GiB.");
93   }
94 
95   TpuExecutableSerializedProto executable;
96   auto cleanup_executable = absl::MakeCleanup([&executable]() {
97     if (executable.size > 0) {
98       stream_executor::tpu::SerializedProto_Free(executable);
99     }
100   });
101   auto get_executable_status = tpu_program_group->SerializeExecutable(
102       cache_entry.core_index(), &executable);
103   if (!get_executable_status.ok()) {
104     return errors::Internal("Failed to serialize TPU program.");
105   }
106 
107   // Encode and serialize header fields.
108   GetTpuProgramResponseExternal header;
109   if (!header.mutable_proto()->ParseFromArray(executable.bytes,
110                                               executable.size)) {
111     return errors::Internal("Failed to serialize TPU program.");
112   }
113   header.set_is_empty(false);
114 
115 
116   bool may_modify_variables =
117       tpu_program_group->may_modify_variables(cache_entry.core_index());
118   header.set_may_modify_variables(may_modify_variables);
119 
120   CompilerMetadataSerializedProto compiler_metadata;
121   auto cleanup_compiler_metadata = absl::MakeCleanup([&compiler_metadata]() {
122     if (compiler_metadata.size > 0) {
123       stream_executor::tpu::SerializedProto_Free(compiler_metadata);
124     }
125   });
126   Status get_compiler_metadata_status =
127       tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(),
128                                                    &compiler_metadata);
129   if (!get_compiler_metadata_status.ok()) {
130     return errors::Internal("Failed to serialize compiler metadata.");
131   }
132   if (!header.mutable_compiler_metadata()->ParseFromArray(
133           compiler_metadata.bytes, compiler_metadata.size)) {
134     return errors::Internal("Failed to deserialize compiler metadata.");
135   }
136   std::string encoded_header;
137   if (!header.AppendToString(&encoded_header)) {
138     return errors::Internal("Failed to serialize TPU program metadata.");
139   }
140 
141   return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
142 }
143 #endif  // LIBTPU_ON_GCE
144 }  // namespace tpu
145 }  // namespace tensorflow
146