1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
16
17 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
18 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
19 #include "tensorflow/core/platform/casts.h"
20 #if defined(LIBTPU_ON_GCE)
21 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
22 #endif
23 #include "absl/cleanup/cleanup.h"
24 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
25 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27
28 namespace tensorflow {
29 namespace tpu {
CreateChannelCredentials()30 std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
31 return ::grpc::InsecureChannelCredentials(); // NOLINT
32 }
33
34 #if defined(LIBTPU_ON_GCE)
35 template <>
DeserializeRpcResponseToCacheEntry(absl::string_view local_proto_key,GetTpuProgramResponseExternal * response,std::shared_ptr<CacheEntry> * cache_entry)36 Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
37 absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
38 std::shared_ptr<CacheEntry>* cache_entry) {
39 CHECK_NE(response, nullptr);
40 CHECK_NE(cache_entry, nullptr);
41 *cache_entry = std::make_shared<CacheEntry>();
42 CacheEntry& entry = **cache_entry;
43 entry.key = std::string(local_proto_key);
44
45 if (response->is_empty()) {
46 entry.size = 0;
47 } else {
48 TpuSerializedProto serialized_response_proto =
49 stream_executor::tpu::SerializeProto(*response);
50 auto cleanup = absl::MakeCleanup([&serialized_response_proto]() {
51 stream_executor::tpu::SerializedProto_Free(serialized_response_proto);
52 });
53 // When we lookup from remote cache, we fetch a TPU program for a specific
54 // core, hence we allocate TPU program group for a single program.
55 auto tpu_program_group = absl::make_unique<TpuProgramGroup>();
56
57 // TODO(b/166575150): can be optimized by sending the buffer over the gRPC
58 // without an extra deserializing.
59 TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromRpcResponseProtos(
60 {serialized_response_proto}));
61 entry.tpu_program_group = std::move(tpu_program_group);
62 entry.size = entry.tpu_program_group->program_size();
63 }
64
65 return Status::OK();
66 }
67
SerializeCacheEntryToBufferSlices(const TpuCompilationCacheEntry & cache_entry)68 xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
69 const TpuCompilationCacheEntry& cache_entry) {
70 if (cache_entry.tpu_program_group() == nullptr) {
71 // It's possible that the sharding/unsharding entry does not exist, but the
72 // main entry must exist.
73 GetTpuProgramResponseExternal header;
74 header.set_is_empty(true);
75 std::string encoded_header;
76 if (!header.AppendToString(&encoded_header)) {
77 return errors::Internal("Failed to serialize TPU program metadata.");
78 }
79 ::grpc::Slice slice(encoded_header);
80 return std::vector<::grpc::Slice>{slice};
81 }
82
83 const TpuProgramGroup* tpu_program_group =
84 tensorflow::down_cast<const TpuProgramGroup*>(
85 cache_entry.tpu_program_group());
86 CHECK_NE(tpu_program_group, nullptr);
87 CHECK_GE(tpu_program_group->program_count(), 0);
88 CHECK_GE(cache_entry.core_index(), 0);
89 CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count());
90 const int64 program_size = tpu_program_group->program_size();
91 if (program_size > INT_MAX) {
92 return errors::Internal("TPU program exceeded 2 GiB.");
93 }
94
95 TpuExecutableSerializedProto executable;
96 auto cleanup_executable = absl::MakeCleanup([&executable]() {
97 if (executable.size > 0) {
98 stream_executor::tpu::SerializedProto_Free(executable);
99 }
100 });
101 auto get_executable_status = tpu_program_group->SerializeExecutable(
102 cache_entry.core_index(), &executable);
103 if (!get_executable_status.ok()) {
104 return errors::Internal("Failed to serialize TPU program.");
105 }
106
107 // Encode and serialize header fields.
108 GetTpuProgramResponseExternal header;
109 if (!header.mutable_proto()->ParseFromArray(executable.bytes,
110 executable.size)) {
111 return errors::Internal("Failed to serialize TPU program.");
112 }
113 header.set_is_empty(false);
114
115
116 bool may_modify_variables =
117 tpu_program_group->may_modify_variables(cache_entry.core_index());
118 header.set_may_modify_variables(may_modify_variables);
119
120 CompilerMetadataSerializedProto compiler_metadata;
121 auto cleanup_compiler_metadata = absl::MakeCleanup([&compiler_metadata]() {
122 if (compiler_metadata.size > 0) {
123 stream_executor::tpu::SerializedProto_Free(compiler_metadata);
124 }
125 });
126 Status get_compiler_metadata_status =
127 tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(),
128 &compiler_metadata);
129 if (!get_compiler_metadata_status.ok()) {
130 return errors::Internal("Failed to serialize compiler metadata.");
131 }
132 if (!header.mutable_compiler_metadata()->ParseFromArray(
133 compiler_metadata.bytes, compiler_metadata.size)) {
134 return errors::Internal("Failed to deserialize compiler metadata.");
135 }
136 std::string encoded_header;
137 if (!header.AppendToString(&encoded_header)) {
138 return errors::Internal("Failed to serialize TPU program metadata.");
139 }
140
141 return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
142 }
143 #endif // LIBTPU_ON_GCE
144 } // namespace tpu
145 } // namespace tensorflow
146