xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/aot/embedded_protocol_buffers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_replace.h"
23 #include "llvm/ADT/Triple.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/GlobalVariable.h"
26 #include "llvm/IR/LLVMContext.h"
27 #include "llvm/IR/LegacyPassManager.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/MC/TargetRegistry.h"
30 #include "llvm/Target/TargetMachine.h"
31 #include "llvm/Target/TargetOptions.h"
32 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 
35 namespace tensorflow {
36 namespace tfcompile {
37 
38 using xla::llvm_ir::AsStringRef;
39 
AddEmbeddedProtocolBufferToLlvmModule(llvm::Module * module,const::tensorflow::protobuf::MessageLite & proto,absl::string_view unique_identifier,string * protobuf_array_symbol_name,int64_t * protobuf_array_size)40 static void AddEmbeddedProtocolBufferToLlvmModule(
41     llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
42     absl::string_view unique_identifier, string* protobuf_array_symbol_name,
43     int64_t* protobuf_array_size) {
44   string protobuf_array_contents = proto.SerializeAsString();
45   *protobuf_array_symbol_name =
46       absl::StrCat(unique_identifier, "_protobuf_array_contents");
47   *protobuf_array_size = protobuf_array_contents.size();
48 
49   llvm::Constant* protobuf_array_initializer =
50       llvm::ConstantDataArray::getString(module->getContext(),
51                                          AsStringRef(protobuf_array_contents),
52                                          /*AddNull=*/false);
53   new llvm::GlobalVariable(
54       *module, protobuf_array_initializer->getType(),
55       /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
56       protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
57 }
58 
CreateCPPShimExpression(absl::string_view qualified_cpp_protobuf_name,absl::string_view protobuf_array_symbol_name,int64_t protobuf_array_size)59 static string CreateCPPShimExpression(
60     absl::string_view qualified_cpp_protobuf_name,
61     absl::string_view protobuf_array_symbol_name, int64_t protobuf_array_size) {
62   string code =
63       "[]() {\n"
64       "    {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n"
65       "    proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n"
66       "    return proto;\n"
67       "  }()";
68 
69   return absl::StrReplaceAll(
70       code,
71       {
72           {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)},
73           {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)},
74           {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)},
75       });
76 }
77 
CodegenModule(llvm::TargetMachine * target_machine,std::unique_ptr<llvm::Module> module)78 static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
79                                       std::unique_ptr<llvm::Module> module) {
80   llvm::SmallVector<char, 0> stream_buffer;
81   llvm::raw_svector_ostream ostream(stream_buffer);
82   llvm::legacy::PassManager codegen_passes;
83 
84   if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr,
85                                           llvm::CGFT_ObjectFile)) {
86     return xla::InternalError(
87         "Could not create pass pipeline to generate object file");
88   }
89 
90   codegen_passes.run(*module);
91 
92   return string(stream_buffer.begin(), stream_buffer.end());
93 }
94 
95 static StatusOr<std::unique_ptr<llvm::TargetMachine>>
GetTargetMachineFromTriple(absl::string_view target_triple)96 GetTargetMachineFromTriple(absl::string_view target_triple) {
97   std::string error;
98   std::string normalized_triple =
99       llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
100   const llvm::Target* target =
101       llvm::TargetRegistry::lookupTarget(normalized_triple, error);
102   if (target == nullptr) {
103     return xla::InternalError("TargetRegistry::lookupTarget failed: %s",
104                               error.c_str());
105   }
106 
107   return absl::WrapUnique(target->createTargetMachine(
108       normalized_triple, /*CPU=*/"",
109       /*Features=*/"", llvm::TargetOptions(), llvm::None));
110 }
111 
CreateEmbeddedProtocolBuffers(absl::string_view target_triple,absl::Span<const ProtobufToEmbed> protobufs_to_embed)112 StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
113     absl::string_view target_triple,
114     absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
115   TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
116                       GetTargetMachineFromTriple(target_triple));
117 
118   llvm::LLVMContext llvm_context;
119   std::unique_ptr<llvm::Module> module_with_serialized_proto =
120       absl::make_unique<llvm::Module>("embedded_data_module", llvm_context);
121 
122   EmbeddedProtocolBuffers result;
123 
124   for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
125     string cpp_shim, cpp_variable_decl;
126     if (protobuf_to_embed.message) {
127       string protobuf_array_symbol_name;
128       int64_t protobuf_array_size;
129 
130       AddEmbeddedProtocolBufferToLlvmModule(
131           module_with_serialized_proto.get(), *protobuf_to_embed.message,
132           protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
133           &protobuf_array_size);
134       cpp_shim = CreateCPPShimExpression(
135           protobuf_to_embed.qualified_cpp_protobuf_name,
136           protobuf_array_symbol_name, protobuf_array_size);
137 
138       cpp_variable_decl =
139           absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];");
140     } else {
141       cpp_shim = "nullptr";
142     }
143     result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
144   }
145 
146   TF_ASSIGN_OR_RETURN(result.object_file_data,
147                       CodegenModule(target_machine.get(),
148                                     std::move(module_with_serialized_proto)));
149   return result;
150 }
151 
152 }  // namespace tfcompile
153 }  // namespace tensorflow
154