1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/aot/codegen.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_replace.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
32 #include "tensorflow/compiler/xla/service/compiler.h"
33 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37
38 namespace tensorflow {
39 namespace tfcompile {
40
41 namespace {
42
43 using BufferInfo = xla::cpu_function_runtime::BufferInfo;
44
IsAlpha(char c)45 bool IsAlpha(char c) {
46 return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
47 }
48
IsAlphaNum(char c)49 bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
50
51 // Convert an XLA type into a C++ type.
XLATypeToCpp(xla::PrimitiveType type,string * str)52 Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
53 switch (type) {
54 case xla::PRED:
55 *str = "bool";
56 break;
57 case xla::S8:
58 *str = "tensorflow::int8";
59 break;
60 case xla::S16:
61 *str = "tensorflow::int16";
62 break;
63 case xla::S32:
64 *str = "tensorflow::int32";
65 break;
66 case xla::S64:
67 *str = "int64_t";
68 break;
69 case xla::U8:
70 *str = "tensorflow::uint8";
71 break;
72 case xla::U16:
73 *str = "tensorflow::uint16";
74 break;
75 case xla::U32:
76 *str = "tensorflow::uint32";
77 break;
78 case xla::U64:
79 *str = "tensorflow::uint64";
80 break;
81 case xla::F32:
82 *str = "float";
83 break;
84 case xla::F64:
85 *str = "double";
86 break;
87 default:
88 return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type),
89 " has no equivalent in C++");
90 }
91 return OkStatus();
92 }
93
94 // Returns the sum of the size of each buffer in `buffer_infos`.
TotalBufferBytes(const std::vector<BufferInfo> & buffer_infos)95 size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
96 return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
97 [](size_t size, const BufferInfo& buffer_info) {
98 return size + buffer_info.size();
99 });
100 }
101
102 // Returns a vector of BufferInfo instances in `buffer_infos` that are entry
103 // parameter buffers.
ExtractEntryParamBufferInfos(const std::vector<BufferInfo> & buffer_infos)104 std::vector<BufferInfo> ExtractEntryParamBufferInfos(
105 const std::vector<BufferInfo>& buffer_infos) {
106 std::vector<BufferInfo> result;
107 std::copy_if(buffer_infos.begin(), buffer_infos.end(),
108 std::back_inserter(result), [](const BufferInfo& buffer_info) {
109 return buffer_info.is_entry_parameter();
110 });
111 return result;
112 }
113
114 // Returns a vector of BufferInfo instances in `buffer_infos` that are temp
115 // buffers.
ExtractTempBufferInfos(const std::vector<BufferInfo> & buffer_infos)116 std::vector<BufferInfo> ExtractTempBufferInfos(
117 const std::vector<BufferInfo>& buffer_infos) {
118 std::vector<BufferInfo> result;
119 std::copy_if(buffer_infos.begin(), buffer_infos.end(),
120 std::back_inserter(result), [](const BufferInfo& buffer_info) {
121 return buffer_info.is_temp_buffer();
122 });
123 return result;
124 }
125
126 // Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
127 // are used to generate methods for args and results.
AddRewritesForShape(int i,const xla::Shape & shape,std::vector<std::pair<string,string>> * rewrites)128 Status AddRewritesForShape(int i, const xla::Shape& shape,
129 std::vector<std::pair<string, string>>* rewrites) {
130 string type;
131 TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
132 std::vector<string> dim_vars;
133 string dim_sizes, indices;
134 int count = 1;
135 if (shape.rank() == 0 ||
136 (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
137 dim_sizes = "[1]";
138 indices = "[0]";
139 } else {
140 for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
141 dim_vars.push_back(absl::StrCat("size_t dim", dim));
142 dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
143 indices += absl::StrCat("[dim", dim, "]");
144 count *= shape.dimensions(dim);
145 }
146 }
147 rewrites->push_back({"{{I}}", absl::StrCat(i)});
148 rewrites->push_back({"{{TYPE}}", type});
149 rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
150 rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
151 rewrites->push_back({"{{INDICES}}", indices});
152 rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
153 return OkStatus();
154 }
155
156 // Returns code rewritten by replacing all rewrite pairs, with an extra rewrite
157 // for the name. Note that the rewriting strategy is roughly O(N*M), where N is
158 // the size of the code and M is the number of rewrites. It's fine for now
159 // since N and M are pretty small.
160 //
161 // TODO(toddw): If this becomes a problem, we should be able to change the
162 // algorithm to O(N) by using a state machine, e.g. regexps or a real
163 // text-templating mechanism.
RewriteWithName(const string & name,string code,const std::vector<std::pair<string,string>> & rewrites)164 string RewriteWithName(const string& name, string code,
165 const std::vector<std::pair<string, string>>& rewrites) {
166 absl::StrReplaceAll(rewrites, &code);
167 absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
168 return code;
169 }
170
171 // Generate methods for args (inputs).
GenArgMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,const CompileResult & compile_result,string * methods)172 Status GenArgMethods(const tf2xla::Config& config,
173 const xla::ProgramShapeProto& ps,
174 const CompileResult& compile_result, string* methods) {
175 const int num_args = ps.parameters_size();
176 // feed_size() + variable_size() is the maximum number of args as an
177 // implementation may not create an argument for an unused variable.
178 if (config.feed_size() + config.variable_size() < num_args) {
179 return errors::InvalidArgument(
180 "mismatch between feed_size(", config.feed_size(), ")+variable_size(",
181 config.variable_size(), ") and num_args(", num_args, ")");
182 }
183 for (int i = 0; i < config.feed_size(); ++i) {
184 std::vector<std::pair<string, string>> rewrites;
185 TF_RETURN_IF_ERROR(
186 AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
187 const string code = R"(
188 void set_arg{{NAME}}_data(const void* data) {
189 set_arg_data({{I}}, data);
190 }
191 {{TYPE}}* arg{{NAME}}_data() {
192 return static_cast<{{TYPE}}*>(arg_data({{I}}));
193 }
194 {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
195 return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
196 arg_data({{I}}))){{INDICES}};
197 }
198 const {{TYPE}}* arg{{NAME}}_data() const {
199 return static_cast<const {{TYPE}}*>(arg_data({{I}}));
200 }
201 const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
202 return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
203 arg_data({{I}}))){{INDICES}};
204 }
205 int arg{{NAME}}_size() const {
206 return {{COUNT}} * sizeof({{TYPE}});
207 }
208 int arg{{NAME}}_count() const {
209 return {{COUNT}};
210 }
211 )";
212 *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
213 if (!config.feed(i).name().empty()) {
214 *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
215 }
216 }
217 return OkStatus();
218 }
219
220 // Generate methods for results (outputs).
GenResultMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)221 Status GenResultMethods(const tf2xla::Config& config,
222 const xla::ProgramShapeProto& ps, string* methods) {
223 if (ps.result().element_type() != xla::TUPLE) {
224 // The XlaCompiler we use to build the xla computation always generates a
225 // tuple result, and we rely on this to simplify code generation.
226 return errors::Internal("codegen requires the XLA result to be a tuple");
227 }
228 size_t num_results = ps.result().tuple_shapes_size();
229 int readonly_variables = absl::c_count_if(
230 config.variable(),
231 [](const tf2xla::Variable& var) { return var.readonly(); });
232 const int actual_num_results =
233 config.fetch_size() + config.variable_size() - readonly_variables;
234 if (actual_num_results != num_results) {
235 return errors::InvalidArgument("mismatch between fetch_size(",
236 config.fetch_size(), ")+variable_size(",
237 config.variable_size(), ") and tuple_size(",
238 ps.result().tuple_shapes_size(), ")");
239 }
240 for (int i = 0; i < config.fetch_size(); ++i) {
241 std::vector<std::pair<string, string>> rewrites;
242 TF_RETURN_IF_ERROR(AddRewritesForShape(
243 i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
244 string code = R"(
245 {{TYPE}}* result{{NAME}}_data() {
246 return static_cast<{{TYPE}}*>(result_data({{I}}));
247 }
248 {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
249 return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
250 result_data({{I}}))){{INDICES}};
251 }
252 const {{TYPE}}* result{{NAME}}_data() const {
253 return static_cast<const {{TYPE}}*>(result_data({{I}}));
254 }
255 const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
256 return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
257 result_data({{I}}))){{INDICES}};
258 }
259 int result{{NAME}}_size() const {
260 return {{COUNT}} * sizeof({{TYPE}});
261 }
262 int result{{NAME}}_count() const {
263 return {{COUNT}};
264 }
265 )";
266 *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
267 if (!config.fetch(i).name().empty()) {
268 *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
269 }
270 }
271 return OkStatus();
272 }
273
274 // Generate methods for variables.
GenVariableMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)275 Status GenVariableMethods(const tf2xla::Config& config,
276 const xla::ProgramShapeProto& ps, string* methods) {
277 const int num_args = ps.parameters_size();
278 for (int i = config.feed_size(); i < num_args; ++i) {
279 std::vector<std::pair<string, string>> rewrites;
280 TF_RETURN_IF_ERROR(
281 AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
282 const string code = R"(
283 void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
284 set_arg_data({{I}}, data);
285 }
286 {{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() {
287 return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}}));
288 }
289 {{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
290 return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>(
291 arg_data({{I}}))){{INDICES}};
292 }
293 const {{TYPE}}* var_{{NAME}}_data() const {
294 return static_cast<const {{TYPE}}*>(arg_data({{I}}));
295 }
296 const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const {
297 return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
298 arg_data({{I}}))){{INDICES}};
299 }
300 int var_{{NAME}}_size() const {
301 return {{COUNT}} * sizeof({{TYPE}});
302 }
303 int var_{{NAME}}_count() const {
304 return {{COUNT}};
305 }
306 )";
307 const tf2xla::Variable& var = config.variable(i - config.feed_size());
308 rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
309 *methods += RewriteWithName(
310 var.name().empty() ? var.node_name() : var.name(), code, rewrites);
311 }
312 return OkStatus();
313 }
314
315 // Generates code implementing {Arg,Result}Names(), where T is one of
316 // tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
317 // string literal in the array, with nullptr terminating the array.
318 template <typename T>
GenNameToIndexCode(const T & entries,bool generate)319 string GenNameToIndexCode(const T& entries, bool generate) {
320 // No need for a static array if we're not supposed to generate the data.
321 if (!generate) {
322 return "{\n return nullptr;\n }";
323 }
324 // Determine when to stop. We stop emitting string literals after the last
325 // non-empty name.
326 int end = entries.size();
327 for (int i = entries.size() - 1; i >= 0; --i) {
328 if (!entries[i].name().empty()) {
329 break;
330 }
331 end = i;
332 }
333 // Emit string literals up to the last non-empty name.
334 string code = "{\n static const char* kNames[] = {";
335 for (int i = 0; i < end; ++i) {
336 if (i > 0) {
337 code += ", ";
338 }
339 code += "\"";
340 code += entries[i].name();
341 code += "\"";
342 }
343 if (end > 0) {
344 code += ", ";
345 }
346 code += "nullptr};\n return kNames;\n }";
347 return code;
348 }
349
ValidateFeedFetchCppNames(const tf2xla::Config & config)350 Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
351 for (const tf2xla::Feed& feed : config.feed()) {
352 if (!feed.name().empty()) {
353 TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
354 }
355 }
356 for (const tf2xla::Fetch& fetch : config.fetch()) {
357 if (!fetch.name().empty()) {
358 TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
359 }
360 }
361 for (const tf2xla::Variable& variable : config.variable()) {
362 if (!variable.name().empty()) {
363 TF_RETURN_IF_ERROR(ValidateCppIdent(variable.name(), "variable name"));
364 } else {
365 TF_RETURN_IF_ERROR(
366 ValidateCppIdent(variable.node_name(), "variable name"));
367 }
368 }
369 return OkStatus();
370 }
371
372 // Returns a list of C++ expressions that, when executed, will construct the
373 // BufferInfo instances in `buffer_infos`.
BufferInfosToCppExpression(const std::vector<BufferInfo> & buffer_infos)374 std::vector<string> BufferInfosToCppExpression(
375 const std::vector<BufferInfo>& buffer_infos) {
376 std::vector<string> buffer_infos_as_strings;
377 std::transform(buffer_infos.begin(), buffer_infos.end(),
378 std::back_inserter(buffer_infos_as_strings),
379 [](const BufferInfo& buffer_info) {
380 std::pair<uint64, uint64> encoded = buffer_info.Encode();
381 string encoded_second_as_str =
382 encoded.second == ~0ULL
383 ? "~0ULL"
384 : absl::StrCat(encoded.second, "ULL");
385 return absl::StrCat(
386 "::xla::cpu_function_runtime::BufferInfo({",
387 encoded.first, "ULL, ", encoded_second_as_str, "})");
388 });
389 return buffer_infos_as_strings;
390 }
391 } // namespace
392
GenerateHeader(const CodegenOpts & opts,const tf2xla::Config & config,const CompileResult & compile_result,const MetadataResult & metadata_result,string * header)393 Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
394 const CompileResult& compile_result,
395 const MetadataResult& metadata_result, string* header) {
396 TF_RETURN_IF_ERROR(ValidateConfig(config));
397 TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
398 const int64_t result_index = compile_result.aot->result_buffer_index();
399 const std::vector<BufferInfo>& buffer_infos =
400 compile_result.aot->buffer_infos();
401 const std::vector<int32> arg_index_table =
402 ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
403 std::vector<string> buffer_infos_as_strings =
404 BufferInfosToCppExpression(buffer_infos);
405 const int64_t buffer_infos_size = buffer_infos.size();
406 if (result_index < 0 || result_index >= buffer_infos_size) {
407 return errors::InvalidArgument("result index: ", result_index,
408 " is outside the range of temp sizes: [0,",
409 buffer_infos.size(), ")");
410 }
411
412 // Compute sizes and generate methods.
413 std::vector<BufferInfo> buffer_infos_for_args =
414 ExtractEntryParamBufferInfos(buffer_infos);
415 std::vector<BufferInfo> buffer_infos_for_temps =
416 ExtractTempBufferInfos(buffer_infos);
417 const xla::ProgramShapeProto& ps = compile_result.program_shape;
418 string methods_arg, methods_result, methods_variable;
419 TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
420 TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
421 TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable));
422 const size_t arg_bytes_aligned =
423 xla::cpu_function_runtime::AlignedBufferBytes(
424 buffer_infos_for_args.data(), buffer_infos_for_args.size(),
425 /*allocate_entry_params=*/true);
426 const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
427 const size_t temp_bytes_aligned =
428 xla::cpu_function_runtime::AlignedBufferBytes(
429 buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
430 /*allocate_entry_params=*/true);
431 const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
432
433 // Create rewrite strings for namespace start and end.
434 string ns_start;
435 for (const string& n : opts.namespaces) {
436 ns_start += absl::StrCat("namespace ", n, " {\n");
437 }
438 ns_start += "\n";
439 string ns_end("\n");
440 for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
441 const string& n = opts.namespaces[i];
442 ns_end += absl::StrCat("} // end namespace ", n, "\n");
443 }
444
445 // Generate metadata.
446 const string arg_names_code =
447 GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
448
449 auto variable_copy = config.variable();
450 for (auto& var : variable_copy) {
451 if (var.name().empty()) {
452 var.set_name(var.node_name());
453 }
454 }
455 const string variable_names_code =
456 GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
457
458 const string result_names_code =
459 GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
460 const string include_xla_data_proto =
461 opts.gen_program_shape
462 ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
463 : "";
464
465 const string include_hlo_profile_printer_data_proto =
466 opts.gen_hlo_profile_printer_data
467 ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
468 : "";
469
470 // When HLO profiling is disabled we only forward declare the
471 // HloProfilePrinter protobuf. So we can only conditionally emit this code
472 // calling HloProfilePrinter::profile_counters_size.
473 const string assign_profile_counters_size =
474 opts.gen_hlo_profile_printer_data
475 ? "set_static_data_profile_counters_size(data, "
476 "get_static_data_hlo_profile_printer_data(data)->"
477 "profile_counters_size());"
478 : "";
479
480 // Use a poor-man's text templating mechanism; first populate the full header
481 // with placeholder tokens, and then rewrite the tokens with real values.
482 *header =
483 R"(// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
484 //
485 // This header was generated via ahead-of-time compilation of a TensorFlow
486 // graph. An object file corresponding to this header was also generated.
487 // This header gives access to the functionality in that object file.
488 //
489 // clang-format off
490
491 #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
492 #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
493
494 {{INCLUDE_XLA_DATA_PROTO}}
495 {{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
496 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
497 #include "tensorflow/core/platform/types.h"
498
499 namespace Eigen { struct ThreadPoolDevice; }
500 namespace xla { class ExecutableRunOptions; }
501
502 // (Implementation detail) Entry point to the function in the object file.
503 extern "C" void {{ENTRY}}(
504 void* result, const ::xla::ExecutableRunOptions* run_options,
505 const void** args, void** temps, XlaCustomCallStatus* status,
506 int64_t* profile_counters);
507
508 {{DECLS_FROM_OBJ_FILE}}
509
510 {{NS_START}}
511 // {{CLASS}} represents a computation previously specified in a
512 // TensorFlow graph, now compiled into executable code. This extends the generic
513 // XlaCompiledCpuFunction class with statically type-safe arg and result
514 // methods. Usage example:
515 //
516 // {{CLASS}} computation;
517 // // ...set args using computation.argN methods
518 // CHECK(computation.Run());
519 // // ...inspect results using computation.resultN methods
520 //
521 // The Run method invokes the actual computation, with inputs read from arg
522 // buffers, and outputs written to result buffers. Each Run call may also use
523 // a set of temporary buffers for the computation.
524 //
525 // By default each instance of this class manages its own arg, result and temp
526 // buffers. The AllocMode constructor parameter may be used to modify the
527 // buffer allocation strategy.
528 //
529 // Under the default allocation strategy, this class is thread-compatible:
530 // o Calls to non-const methods require exclusive access to the object.
531 // o Concurrent calls to const methods are OK, if those calls are made while it
532 // is guaranteed that no thread may call a non-const method.
533 //
534 // The logical function signature is:
535 // {{PROGRAM_SHAPE}}
536 //
537 // Memory stats:
538 // arg bytes total: {{ARG_BYTES_TOTAL}}
539 // arg bytes aligned: {{ARG_BYTES_ALIGNED}}
540 // temp bytes total: {{TEMP_BYTES_TOTAL}}
541 // temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
542 class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
543 public:
544 // Number of input arguments for the compiled computation.
545 static constexpr size_t kNumArgs = {{ARG_NUM}};
546
547 // Number of variables for the compiled computation.
548 static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
549
550 // Byte size of each argument buffer. There are kNumArgs entries.
551 static const ::int64_t ArgSize(::tensorflow::int32 index) {
552 return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
553 }
554
555 // Returns static data used to create an XlaCompiledCpuFunction.
556 static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
557 static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
558 XlaCompiledCpuFunction::StaticData* data =
559 new XlaCompiledCpuFunction::StaticData;
560 set_static_data_raw_function(data, {{ENTRY}});
561 set_static_data_buffer_infos(data, BufferInfos());
562 set_static_data_num_buffers(data, kNumBuffers);
563 set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
564 set_static_data_num_args(data, kNumArgs);
565 set_static_data_num_variables(data, kNumVariables);
566 set_static_data_result_index(data, kResultIndex);
567 set_static_data_arg_names(data, StaticArgNames());
568 set_static_data_variable_names(data, StaticVariableNames());
569 set_static_data_result_names(data, StaticResultNames());
570 set_static_data_program_shape(data, StaticProgramShape());
571 set_static_data_hlo_profile_printer_data(
572 data, StaticHloProfilePrinterData());
573 {{ASSIGN_PROFILE_COUNTERS_SIZE}}
574 return data;
575 }();
576 return *kStaticData;
577 }
578
579 {{CLASS}}(AllocMode alloc_mode =
580 AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
581 : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
582
583 {{CLASS}}(const {{CLASS}}&) = delete;
584 {{CLASS}}& operator=(const {{CLASS}}&) = delete;
585
586 // Arg methods for managing input buffers. Buffers are in row-major order.
587 // There is a set of methods for each positional argument, with the following
588 // general form:
589 //
590 // void set_argN_data(void* data)
591 // Sets the buffer of type T for positional argument N. May be called in
592 // any AllocMode. Must be called before Run to have an affect. Must be
593 // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
594 // argument, to set the argument buffers.
595 //
596 // T* argN_data()
597 // Returns the buffer of type T for positional argument N.
598 //
599 // T& argN(...dim indices...)
600 // Returns a reference to the value of type T for positional argument N,
601 // with dim indices specifying which value. No bounds checking is performed
602 // on dim indices.
603 {{METHODS_ARG}}
604
605 // Result methods for managing output buffers. Buffers are in row-major order.
606 // Must only be called after a successful Run call. There is a set of methods
607 // for each positional result, with the following general form:
608 //
609 // T* resultN_data()
610 // Returns the buffer of type T for positional result N.
611 //
612 // T& resultN(...dim indices...)
613 // Returns a reference to the value of type T for positional result N,
614 // with dim indices specifying which value. No bounds checking is performed
615 // on dim indices.
616 //
617 // Unlike the arg methods, there is no set_resultN_data method. The result
618 // buffers are managed internally, and may change after each call to Run.
619 {{METHODS_RESULT}}
620
621 // Methods for managing variable buffers. Buffers are in row-major order.
622 //
623 // For read-write variables we generate the following methods:
624 //
625 // void set_var_X_data(T* data)
626 // Sets the buffer for variable X. Must be called before Run if the
627 // allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
628 //
629 // T* var_X_data()
630 // Returns the buffer of type T for variable X. If the allocation mode is
631 // RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
632 // buffer passed to set_var_X_data.
633 //
634 // T& var_X(...dim indices...)
635 // Returns a reference to the value of type T for variable X,
636 // with dim indices specifying which value. No bounds checking is performed
637 // on dim indices.
638 //
639 // For readonly variables we generate the same set of methods, except that we
640 // use `const T` instead of `T`. We use `const T` to avoid erasing the
641 // constness of the buffer passed to `set_var_X_data` but the underlying
642 // buffer is not const (and thus the const can be safely const-cast'ed away)
643 // unless `set_var_X_data` is called with a pointer to constant storage.
644 {{METHODS_VARIABLE}}
645
646 private:
647 // Number of buffers for the compiled computation.
648 static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
649
650 static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
651 static const ::xla::cpu_function_runtime::BufferInfo
652 kBufferInfos[kNumBuffers] = {
653 {{BUFFER_INFOS_AS_STRING}}
654 };
655 return kBufferInfos;
656 }
657
658 static const ::tensorflow::int32* ArgIndexToBufferIndex() {
659 static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
660 {{ARG_INDEX_TABLE}}
661 };
662 return kArgIndexToBufferIndex;
663 }
664
665 // The 0-based index of the result tuple in the temporary buffers.
666 static constexpr size_t kResultIndex = {{RESULT_INDEX}};
667
668 // Array of names of each positional argument, terminated by nullptr.
669 static const char** StaticArgNames() {{ARG_NAMES_CODE}}
670
671 // Array of names of each positional variable, terminated by nullptr.
672 static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
673
674 // Array of names of each positional result, terminated by nullptr.
675 static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
676
677 // Shape of the args and results.
678 static const ::xla::ProgramShapeProto* StaticProgramShape() {
679 static const ::xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
680 return kShape;
681 }
682
683 // Metadata that can be used to pretty-print profile counters.
684 static const ::xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
685 static const ::xla::HloProfilePrinterData* kHloProfilePrinterData =
686 {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
687 return kHloProfilePrinterData;
688 }
689 };
690 {{NS_END}}
691
692 #endif // TFCOMPILE_GENERATED_{{ENTRY}}_H_
693
694 // clang-format on
695 )";
696 // The replacement strategy is naive, but good enough for our purposes.
697 const std::vector<std::pair<string, string>> rewrites = {
698 {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)},
699 {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
700 {"{{ARG_NAMES_CODE}}", arg_names_code},
701 {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
702 {"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
703 {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
704 {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
705 {"{{CLASS}}", opts.class_name},
706 {"{{DECLS_FROM_OBJ_FILE}}",
707 absl::StrJoin(metadata_result.header_variable_decls, "\n")},
708 {"{{ENTRY}}", compile_result.entry_point},
709 {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
710 metadata_result.hlo_profile_printer_data_access_shim},
711 {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
712 {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
713 include_hlo_profile_printer_data_proto},
714 {"{{METHODS_ARG}}\n", methods_arg},
715 {"{{METHODS_RESULT}}\n", methods_result},
716 {"{{METHODS_VARIABLE}}\n", methods_variable},
717 {"{{NS_END}}\n", ns_end},
718 {"{{NS_START}}\n", ns_start},
719 {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
720 {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
721 metadata_result.program_shape_access_shim},
722 {"{{VARIABLE_NAMES_CODE}}", variable_names_code},
723 {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
724 {"{{RESULT_NAMES_CODE}}", result_names_code},
725 {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
726 {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
727 {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
728 {"{{BUFFER_INFOS_AS_STRING}}",
729 absl::StrJoin(buffer_infos_as_strings, ",\n")}};
730 absl::StrReplaceAll(rewrites, header);
731 return OkStatus();
732 }
733
CreateUniqueIdentifier(const CodegenOpts & opts,absl::string_view suffix)734 static string CreateUniqueIdentifier(const CodegenOpts& opts,
735 absl::string_view suffix) {
736 string result = "__tfcompile";
737 for (const string& n : opts.namespaces) {
738 absl::StrAppend(&result, "_", n);
739 }
740
741 absl::StrAppend(&result, "_", opts.class_name, "_", suffix);
742 return result;
743 }
744
GenerateMetadata(const CodegenOpts & opts,const CompileResult & compile_result,MetadataResult * metadata_result)745 Status GenerateMetadata(const CodegenOpts& opts,
746 const CompileResult& compile_result,
747 MetadataResult* metadata_result) {
748 std::unique_ptr<xla::ProgramShapeProto> program_shape;
749
750 if (opts.gen_program_shape) {
751 program_shape =
752 absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
753
754 // The parameter names are currently meaningless, and redundant with the
755 // rest of our metadata, so clear them out to avoid confusion and save
756 // space.
757 program_shape->clear_parameter_names();
758 }
759
760 // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
761 // a shim that evaluates to nullptr, which is what we want.
762
763 ProtobufToEmbed program_shape_protobuf{
764 CreateUniqueIdentifier(opts, "ProgramShapeProto"),
765 "::xla::ProgramShapeProto", program_shape.get()};
766
767 ProtobufToEmbed hlo_profile_printer_data_protobuf{
768 CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
769 "::xla::HloProfilePrinterData",
770 compile_result.aot->hlo_profile_printer_data()};
771
772 TF_ASSIGN_OR_RETURN(
773 EmbeddedProtocolBuffers embedded_protobufs,
774 CreateEmbeddedProtocolBuffers(
775 opts.target_triple,
776 {program_shape_protobuf, hlo_profile_printer_data_protobuf}));
777
778 metadata_result->program_shape_access_shim =
779 std::move(embedded_protobufs.cpp_shims[0].expression);
780 metadata_result->hlo_profile_printer_data_access_shim =
781 std::move(embedded_protobufs.cpp_shims[1].expression);
782 metadata_result->header_variable_decls.emplace_back(
783 std::move(embedded_protobufs.cpp_shims[0].variable_decl));
784 metadata_result->header_variable_decls.emplace_back(
785 std::move(embedded_protobufs.cpp_shims[1].variable_decl));
786 metadata_result->object_file_data =
787 std::move(embedded_protobufs.object_file_data);
788 return OkStatus();
789 }
790
ParseCppClass(const string & cpp_class,string * class_name,std::vector<string> * namespaces)791 Status ParseCppClass(const string& cpp_class, string* class_name,
792 std::vector<string>* namespaces) {
793 class_name->clear();
794 namespaces->clear();
795 if (cpp_class.empty()) {
796 return errors::InvalidArgument("empty cpp_class: " + cpp_class);
797 }
798 std::vector<string> parts = absl::StrSplit(cpp_class, "::");
799 if (parts.front().empty()) {
800 // Allow a fully qualified name that starts with "::".
801 parts.erase(parts.begin());
802 }
803 for (int i = 0, end = parts.size(); i < end; ++i) {
804 if (i < end - 1) {
805 TF_RETURN_IF_ERROR(ValidateCppIdent(
806 parts[i], "in namespace component of cpp_class: " + cpp_class));
807 namespaces->push_back(parts[i]);
808 } else {
809 TF_RETURN_IF_ERROR(ValidateCppIdent(
810 parts[i], "in class name of cpp_class: " + cpp_class));
811 *class_name = parts[i];
812 }
813 }
814 return OkStatus();
815 }
816
ValidateCppIdent(absl::string_view ident,absl::string_view msg)817 Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) {
818 if (ident.empty()) {
819 return errors::InvalidArgument("empty identifier: ", msg);
820 }
821 // Require that the identifier starts with a nondigit, and is composed of
822 // nondigits and digits, as specified in section [2.11 Identifiers] of the
823 // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
824 // defined as [0-9].
825 //
826 // Technically the standard also allows for `universal-character-name`, with a
827 // table of allowed unicode ranges, as well as `other implementation-defined
828 // characters`. We disallow those here to give better error messages, at the
829 // expensive of being more restrictive than the standard.
830 if (ident[0] != '_' && !IsAlpha(ident[0])) {
831 return errors::InvalidArgument("illegal leading char: ", msg);
832 }
833 for (size_t pos = 1; pos < ident.size(); ++pos) {
834 if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
835 return errors::InvalidArgument("illegal char: ", msg);
836 }
837 }
838 return OkStatus();
839 }
840
841 } // namespace tfcompile
842 } // namespace tensorflow
843