xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/aot/codegen.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/aot/codegen.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_replace.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
32 #include "tensorflow/compiler/xla/service/compiler.h"
33 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 
38 namespace tensorflow {
39 namespace tfcompile {
40 
41 namespace {
42 
43 using BufferInfo = xla::cpu_function_runtime::BufferInfo;
44 
IsAlpha(char c)45 bool IsAlpha(char c) {
46   return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
47 }
48 
IsAlphaNum(char c)49 bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
50 
51 // Convert an XLA type into a C++ type.
XLATypeToCpp(xla::PrimitiveType type,string * str)52 Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
53   switch (type) {
54     case xla::PRED:
55       *str = "bool";
56       break;
57     case xla::S8:
58       *str = "tensorflow::int8";
59       break;
60     case xla::S16:
61       *str = "tensorflow::int16";
62       break;
63     case xla::S32:
64       *str = "tensorflow::int32";
65       break;
66     case xla::S64:
67       *str = "int64_t";
68       break;
69     case xla::U8:
70       *str = "tensorflow::uint8";
71       break;
72     case xla::U16:
73       *str = "tensorflow::uint16";
74       break;
75     case xla::U32:
76       *str = "tensorflow::uint32";
77       break;
78     case xla::U64:
79       *str = "tensorflow::uint64";
80       break;
81     case xla::F32:
82       *str = "float";
83       break;
84     case xla::F64:
85       *str = "double";
86       break;
87     default:
88       return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type),
89                                    " has no equivalent in C++");
90   }
91   return OkStatus();
92 }
93 
94 // Returns the sum of the size of each buffer in `buffer_infos`.
TotalBufferBytes(const std::vector<BufferInfo> & buffer_infos)95 size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
96   return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
97                          [](size_t size, const BufferInfo& buffer_info) {
98                            return size + buffer_info.size();
99                          });
100 }
101 
102 // Returns a vector of BufferInfo instances in `buffer_infos` that are entry
103 // parameter buffers.
ExtractEntryParamBufferInfos(const std::vector<BufferInfo> & buffer_infos)104 std::vector<BufferInfo> ExtractEntryParamBufferInfos(
105     const std::vector<BufferInfo>& buffer_infos) {
106   std::vector<BufferInfo> result;
107   std::copy_if(buffer_infos.begin(), buffer_infos.end(),
108                std::back_inserter(result), [](const BufferInfo& buffer_info) {
109                  return buffer_info.is_entry_parameter();
110                });
111   return result;
112 }
113 
114 // Returns a vector of BufferInfo instances in `buffer_infos` that are temp
115 // buffers.
ExtractTempBufferInfos(const std::vector<BufferInfo> & buffer_infos)116 std::vector<BufferInfo> ExtractTempBufferInfos(
117     const std::vector<BufferInfo>& buffer_infos) {
118   std::vector<BufferInfo> result;
119   std::copy_if(buffer_infos.begin(), buffer_infos.end(),
120                std::back_inserter(result), [](const BufferInfo& buffer_info) {
121                  return buffer_info.is_temp_buffer();
122                });
123   return result;
124 }
125 
126 // Add (from,to) rewrite pairs based on the given shape.  These rewrite pairs
127 // are used to generate methods for args and results.
AddRewritesForShape(int i,const xla::Shape & shape,std::vector<std::pair<string,string>> * rewrites)128 Status AddRewritesForShape(int i, const xla::Shape& shape,
129                            std::vector<std::pair<string, string>>* rewrites) {
130   string type;
131   TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
132   std::vector<string> dim_vars;
133   string dim_sizes, indices;
134   int count = 1;
135   if (shape.rank() == 0 ||
136       (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
137     dim_sizes = "[1]";
138     indices = "[0]";
139   } else {
140     for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
141       dim_vars.push_back(absl::StrCat("size_t dim", dim));
142       dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
143       indices += absl::StrCat("[dim", dim, "]");
144       count *= shape.dimensions(dim);
145     }
146   }
147   rewrites->push_back({"{{I}}", absl::StrCat(i)});
148   rewrites->push_back({"{{TYPE}}", type});
149   rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
150   rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
151   rewrites->push_back({"{{INDICES}}", indices});
152   rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
153   return OkStatus();
154 }
155 
156 // Returns code rewritten by replacing all rewrite pairs, with an extra rewrite
157 // for the name.  Note that the rewriting strategy is roughly O(N*M), where N is
158 // the size of the code and M is the number of rewrites.  It's fine for now
159 // since N and M are pretty small.
160 //
161 // TODO(toddw): If this becomes a problem, we should be able to change the
162 // algorithm to O(N) by using a state machine, e.g. regexps or a real
163 // text-templating mechanism.
RewriteWithName(const string & name,string code,const std::vector<std::pair<string,string>> & rewrites)164 string RewriteWithName(const string& name, string code,
165                        const std::vector<std::pair<string, string>>& rewrites) {
166   absl::StrReplaceAll(rewrites, &code);
167   absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
168   return code;
169 }
170 
171 // Generate methods for args (inputs).
GenArgMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,const CompileResult & compile_result,string * methods)172 Status GenArgMethods(const tf2xla::Config& config,
173                      const xla::ProgramShapeProto& ps,
174                      const CompileResult& compile_result, string* methods) {
175   const int num_args = ps.parameters_size();
176   // feed_size() + variable_size() is the maximum number of args as an
177   // implementation may not create an argument for an unused variable.
178   if (config.feed_size() + config.variable_size() < num_args) {
179     return errors::InvalidArgument(
180         "mismatch between feed_size(", config.feed_size(), ")+variable_size(",
181         config.variable_size(), ") and num_args(", num_args, ")");
182   }
183   for (int i = 0; i < config.feed_size(); ++i) {
184     std::vector<std::pair<string, string>> rewrites;
185     TF_RETURN_IF_ERROR(
186         AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
187     const string code = R"(
188   void set_arg{{NAME}}_data(const void* data) {
189     set_arg_data({{I}}, data);
190   }
191   {{TYPE}}* arg{{NAME}}_data() {
192     return static_cast<{{TYPE}}*>(arg_data({{I}}));
193   }
194   {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
195     return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
196         arg_data({{I}}))){{INDICES}};
197   }
198   const {{TYPE}}* arg{{NAME}}_data() const {
199     return static_cast<const {{TYPE}}*>(arg_data({{I}}));
200   }
201   const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
202     return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
203         arg_data({{I}}))){{INDICES}};
204   }
205   int arg{{NAME}}_size() const {
206     return {{COUNT}} * sizeof({{TYPE}});
207   }
208   int arg{{NAME}}_count() const {
209     return {{COUNT}};
210   }
211 )";
212     *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
213     if (!config.feed(i).name().empty()) {
214       *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
215     }
216   }
217   return OkStatus();
218 }
219 
220 // Generate methods for results (outputs).
GenResultMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)221 Status GenResultMethods(const tf2xla::Config& config,
222                         const xla::ProgramShapeProto& ps, string* methods) {
223   if (ps.result().element_type() != xla::TUPLE) {
224     // The XlaCompiler we use to build the xla computation always generates a
225     // tuple result, and we rely on this to simplify code generation.
226     return errors::Internal("codegen requires the XLA result to be a tuple");
227   }
228   size_t num_results = ps.result().tuple_shapes_size();
229   int readonly_variables = absl::c_count_if(
230       config.variable(),
231       [](const tf2xla::Variable& var) { return var.readonly(); });
232   const int actual_num_results =
233       config.fetch_size() + config.variable_size() - readonly_variables;
234   if (actual_num_results != num_results) {
235     return errors::InvalidArgument("mismatch between fetch_size(",
236                                    config.fetch_size(), ")+variable_size(",
237                                    config.variable_size(), ") and tuple_size(",
238                                    ps.result().tuple_shapes_size(), ")");
239   }
240   for (int i = 0; i < config.fetch_size(); ++i) {
241     std::vector<std::pair<string, string>> rewrites;
242     TF_RETURN_IF_ERROR(AddRewritesForShape(
243         i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
244     string code = R"(
245   {{TYPE}}* result{{NAME}}_data() {
246     return static_cast<{{TYPE}}*>(result_data({{I}}));
247   }
248   {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
249     return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
250         result_data({{I}}))){{INDICES}};
251   }
252   const {{TYPE}}* result{{NAME}}_data() const {
253     return static_cast<const {{TYPE}}*>(result_data({{I}}));
254   }
255   const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
256     return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
257         result_data({{I}}))){{INDICES}};
258   }
259   int result{{NAME}}_size() const {
260     return {{COUNT}} * sizeof({{TYPE}});
261   }
262   int result{{NAME}}_count() const {
263     return {{COUNT}};
264   }
265 )";
266     *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
267     if (!config.fetch(i).name().empty()) {
268       *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
269     }
270   }
271   return OkStatus();
272 }
273 
274 // Generate methods for variables.
GenVariableMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)275 Status GenVariableMethods(const tf2xla::Config& config,
276                           const xla::ProgramShapeProto& ps, string* methods) {
277   const int num_args = ps.parameters_size();
278   for (int i = config.feed_size(); i < num_args; ++i) {
279     std::vector<std::pair<string, string>> rewrites;
280     TF_RETURN_IF_ERROR(
281         AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
282     const string code = R"(
283   void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
284     set_arg_data({{I}}, data);
285   }
286   {{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() {
287     return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}}));
288   }
289   {{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
290     return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>(
291         arg_data({{I}}))){{INDICES}};
292   }
293   const {{TYPE}}* var_{{NAME}}_data() const {
294     return static_cast<const {{TYPE}}*>(arg_data({{I}}));
295   }
296   const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const {
297     return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
298         arg_data({{I}}))){{INDICES}};
299   }
300   int var_{{NAME}}_size() const {
301     return {{COUNT}} * sizeof({{TYPE}});
302   }
303   int var_{{NAME}}_count() const {
304     return {{COUNT}};
305   }
306 )";
307     const tf2xla::Variable& var = config.variable(i - config.feed_size());
308     rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
309     *methods += RewriteWithName(
310         var.name().empty() ? var.node_name() : var.name(), code, rewrites);
311   }
312   return OkStatus();
313 }
314 
315 // Generates code implementing {Arg,Result}Names(), where T is one of
316 // tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
317 // string literal in the array, with nullptr terminating the array.
318 template <typename T>
GenNameToIndexCode(const T & entries,bool generate)319 string GenNameToIndexCode(const T& entries, bool generate) {
320   // No need for a static array if we're not supposed to generate the data.
321   if (!generate) {
322     return "{\n    return nullptr;\n  }";
323   }
324   // Determine when to stop. We stop emitting string literals after the last
325   // non-empty name.
326   int end = entries.size();
327   for (int i = entries.size() - 1; i >= 0; --i) {
328     if (!entries[i].name().empty()) {
329       break;
330     }
331     end = i;
332   }
333   // Emit string literals up to the last non-empty name.
334   string code = "{\n    static const char* kNames[] = {";
335   for (int i = 0; i < end; ++i) {
336     if (i > 0) {
337       code += ", ";
338     }
339     code += "\"";
340     code += entries[i].name();
341     code += "\"";
342   }
343   if (end > 0) {
344     code += ", ";
345   }
346   code += "nullptr};\n    return kNames;\n  }";
347   return code;
348 }
349 
ValidateFeedFetchCppNames(const tf2xla::Config & config)350 Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
351   for (const tf2xla::Feed& feed : config.feed()) {
352     if (!feed.name().empty()) {
353       TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
354     }
355   }
356   for (const tf2xla::Fetch& fetch : config.fetch()) {
357     if (!fetch.name().empty()) {
358       TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
359     }
360   }
361   for (const tf2xla::Variable& variable : config.variable()) {
362     if (!variable.name().empty()) {
363       TF_RETURN_IF_ERROR(ValidateCppIdent(variable.name(), "variable name"));
364     } else {
365       TF_RETURN_IF_ERROR(
366           ValidateCppIdent(variable.node_name(), "variable name"));
367     }
368   }
369   return OkStatus();
370 }
371 
372 // Returns a list of C++ expressions that, when executed, will construct the
373 // BufferInfo instances in `buffer_infos`.
BufferInfosToCppExpression(const std::vector<BufferInfo> & buffer_infos)374 std::vector<string> BufferInfosToCppExpression(
375     const std::vector<BufferInfo>& buffer_infos) {
376   std::vector<string> buffer_infos_as_strings;
377   std::transform(buffer_infos.begin(), buffer_infos.end(),
378                  std::back_inserter(buffer_infos_as_strings),
379                  [](const BufferInfo& buffer_info) {
380                    std::pair<uint64, uint64> encoded = buffer_info.Encode();
381                    string encoded_second_as_str =
382                        encoded.second == ~0ULL
383                            ? "~0ULL"
384                            : absl::StrCat(encoded.second, "ULL");
385                    return absl::StrCat(
386                        "::xla::cpu_function_runtime::BufferInfo({",
387                        encoded.first, "ULL, ", encoded_second_as_str, "})");
388                  });
389   return buffer_infos_as_strings;
390 }
391 }  // namespace
392 
GenerateHeader(const CodegenOpts & opts,const tf2xla::Config & config,const CompileResult & compile_result,const MetadataResult & metadata_result,string * header)393 Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
394                       const CompileResult& compile_result,
395                       const MetadataResult& metadata_result, string* header) {
396   TF_RETURN_IF_ERROR(ValidateConfig(config));
397   TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
398   const int64_t result_index = compile_result.aot->result_buffer_index();
399   const std::vector<BufferInfo>& buffer_infos =
400       compile_result.aot->buffer_infos();
401   const std::vector<int32> arg_index_table =
402       ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
403   std::vector<string> buffer_infos_as_strings =
404       BufferInfosToCppExpression(buffer_infos);
405   const int64_t buffer_infos_size = buffer_infos.size();
406   if (result_index < 0 || result_index >= buffer_infos_size) {
407     return errors::InvalidArgument("result index: ", result_index,
408                                    " is outside the range of temp sizes: [0,",
409                                    buffer_infos.size(), ")");
410   }
411 
412   // Compute sizes and generate methods.
413   std::vector<BufferInfo> buffer_infos_for_args =
414       ExtractEntryParamBufferInfos(buffer_infos);
415   std::vector<BufferInfo> buffer_infos_for_temps =
416       ExtractTempBufferInfos(buffer_infos);
417   const xla::ProgramShapeProto& ps = compile_result.program_shape;
418   string methods_arg, methods_result, methods_variable;
419   TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
420   TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
421   TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable));
422   const size_t arg_bytes_aligned =
423       xla::cpu_function_runtime::AlignedBufferBytes(
424           buffer_infos_for_args.data(), buffer_infos_for_args.size(),
425           /*allocate_entry_params=*/true);
426   const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
427   const size_t temp_bytes_aligned =
428       xla::cpu_function_runtime::AlignedBufferBytes(
429           buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
430           /*allocate_entry_params=*/true);
431   const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
432 
433   // Create rewrite strings for namespace start and end.
434   string ns_start;
435   for (const string& n : opts.namespaces) {
436     ns_start += absl::StrCat("namespace ", n, " {\n");
437   }
438   ns_start += "\n";
439   string ns_end("\n");
440   for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
441     const string& n = opts.namespaces[i];
442     ns_end += absl::StrCat("}  // end namespace ", n, "\n");
443   }
444 
445   // Generate metadata.
446   const string arg_names_code =
447       GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
448 
449   auto variable_copy = config.variable();
450   for (auto& var : variable_copy) {
451     if (var.name().empty()) {
452       var.set_name(var.node_name());
453     }
454   }
455   const string variable_names_code =
456       GenNameToIndexCode(variable_copy, opts.gen_name_to_index);
457 
458   const string result_names_code =
459       GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
460   const string include_xla_data_proto =
461       opts.gen_program_shape
462           ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
463           : "";
464 
465   const string include_hlo_profile_printer_data_proto =
466       opts.gen_hlo_profile_printer_data
467           ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
468           : "";
469 
470   // When HLO profiling is disabled we only forward declare the
471   // HloProfilePrinter protobuf.  So we can only conditionally emit this code
472   // calling HloProfilePrinter::profile_counters_size.
473   const string assign_profile_counters_size =
474       opts.gen_hlo_profile_printer_data
475           ? "set_static_data_profile_counters_size(data, "
476             "get_static_data_hlo_profile_printer_data(data)->"
477             "profile_counters_size());"
478           : "";
479 
480   // Use a poor-man's text templating mechanism; first populate the full header
481   // with placeholder tokens, and then rewrite the tokens with real values.
482   *header =
483       R"(// Generated by tfcompile, the TensorFlow graph compiler.  DO NOT EDIT!
484 //
485 // This header was generated via ahead-of-time compilation of a TensorFlow
486 // graph.  An object file corresponding to this header was also generated.
487 // This header gives access to the functionality in that object file.
488 //
489 // clang-format off
490 
491 #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_  // NOLINT(build/header_guard)
492 #define TFCOMPILE_GENERATED_{{ENTRY}}_H_  // NOLINT(build/header_guard)
493 
494 {{INCLUDE_XLA_DATA_PROTO}}
495 {{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
496 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
497 #include "tensorflow/core/platform/types.h"
498 
499 namespace Eigen { struct ThreadPoolDevice; }
500 namespace xla { class ExecutableRunOptions; }
501 
502 // (Implementation detail) Entry point to the function in the object file.
503 extern "C" void {{ENTRY}}(
504     void* result, const ::xla::ExecutableRunOptions* run_options,
505     const void** args, void** temps, XlaCustomCallStatus* status,
506     int64_t* profile_counters);
507 
508 {{DECLS_FROM_OBJ_FILE}}
509 
510 {{NS_START}}
511 // {{CLASS}} represents a computation previously specified in a
512 // TensorFlow graph, now compiled into executable code. This extends the generic
513 // XlaCompiledCpuFunction class with statically type-safe arg and result
514 // methods. Usage example:
515 //
516 //   {{CLASS}} computation;
517 //   // ...set args using computation.argN methods
518 //   CHECK(computation.Run());
519 //   // ...inspect results using computation.resultN methods
520 //
521 // The Run method invokes the actual computation, with inputs read from arg
522 // buffers, and outputs written to result buffers. Each Run call may also use
523 // a set of temporary buffers for the computation.
524 //
525 // By default each instance of this class manages its own arg, result and temp
526 // buffers. The AllocMode constructor parameter may be used to modify the
527 // buffer allocation strategy.
528 //
529 // Under the default allocation strategy, this class is thread-compatible:
530 // o Calls to non-const methods require exclusive access to the object.
531 // o Concurrent calls to const methods are OK, if those calls are made while it
532 //   is guaranteed that no thread may call a non-const method.
533 //
534 // The logical function signature is:
535 //   {{PROGRAM_SHAPE}}
536 //
537 // Memory stats:
538 //   arg bytes total:    {{ARG_BYTES_TOTAL}}
539 //   arg bytes aligned:  {{ARG_BYTES_ALIGNED}}
540 //   temp bytes total:   {{TEMP_BYTES_TOTAL}}
541 //   temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
542 class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
543  public:
544   // Number of input arguments for the compiled computation.
545   static constexpr size_t kNumArgs = {{ARG_NUM}};
546 
547   // Number of variables for the compiled computation.
548   static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
549 
550   // Byte size of each argument buffer. There are kNumArgs entries.
551   static const ::int64_t ArgSize(::tensorflow::int32 index) {
552     return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
553   }
554 
555   // Returns static data used to create an XlaCompiledCpuFunction.
556   static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
557     static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
558       XlaCompiledCpuFunction::StaticData* data =
559         new XlaCompiledCpuFunction::StaticData;
560       set_static_data_raw_function(data, {{ENTRY}});
561       set_static_data_buffer_infos(data, BufferInfos());
562       set_static_data_num_buffers(data, kNumBuffers);
563       set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
564       set_static_data_num_args(data, kNumArgs);
565       set_static_data_num_variables(data, kNumVariables);
566       set_static_data_result_index(data, kResultIndex);
567       set_static_data_arg_names(data, StaticArgNames());
568       set_static_data_variable_names(data, StaticVariableNames());
569       set_static_data_result_names(data, StaticResultNames());
570       set_static_data_program_shape(data, StaticProgramShape());
571       set_static_data_hlo_profile_printer_data(
572           data, StaticHloProfilePrinterData());
573 {{ASSIGN_PROFILE_COUNTERS_SIZE}}
574       return data;
575     }();
576     return *kStaticData;
577   }
578 
579   {{CLASS}}(AllocMode alloc_mode =
580             AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
581       : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
582 
583   {{CLASS}}(const {{CLASS}}&) = delete;
584   {{CLASS}}& operator=(const {{CLASS}}&) = delete;
585 
586   // Arg methods for managing input buffers. Buffers are in row-major order.
587   // There is a set of methods for each positional argument, with the following
588   // general form:
589   //
590   // void set_argN_data(void* data)
591   //   Sets the buffer of type T for positional argument N. May be called in
592   //   any AllocMode. Must be called before Run to have an affect. Must be
593   //   called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
594   //   argument, to set the argument buffers.
595   //
596   // T* argN_data()
597   //   Returns the buffer of type T for positional argument N.
598   //
599   // T& argN(...dim indices...)
600   //   Returns a reference to the value of type T for positional argument N,
601   //   with dim indices specifying which value. No bounds checking is performed
602   //   on dim indices.
603 {{METHODS_ARG}}
604 
605   // Result methods for managing output buffers. Buffers are in row-major order.
606   // Must only be called after a successful Run call. There is a set of methods
607   // for each positional result, with the following general form:
608   //
609   // T* resultN_data()
610   //   Returns the buffer of type T for positional result N.
611   //
612   // T& resultN(...dim indices...)
613   //   Returns a reference to the value of type T for positional result N,
614   //   with dim indices specifying which value. No bounds checking is performed
615   //   on dim indices.
616   //
617   // Unlike the arg methods, there is no set_resultN_data method. The result
618   // buffers are managed internally, and may change after each call to Run.
619 {{METHODS_RESULT}}
620 
621   // Methods for managing variable buffers. Buffers are in row-major order.
622   //
623   // For read-write variables we generate the following methods:
624   //
625   // void set_var_X_data(T* data)
626   //   Sets the buffer for variable X.  Must be called before Run if the
627   //   allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
628   //
629   // T* var_X_data()
630   //   Returns the buffer of type T for variable X.  If the allocation mode is
631   //   RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
632   //   buffer passed to set_var_X_data.
633   //
634   // T& var_X(...dim indices...)
635   //   Returns a reference to the value of type T for variable X,
636   //   with dim indices specifying which value. No bounds checking is performed
637   //   on dim indices.
638   //
639   // For readonly variables we generate the same set of methods, except that we
640   // use `const T` instead of `T`.  We use `const T` to avoid erasing the
641   // constness of the buffer passed to `set_var_X_data` but the underlying
642   // buffer is not const (and thus the const can be safely const-cast'ed away)
643   // unless `set_var_X_data` is called with a pointer to constant storage.
644 {{METHODS_VARIABLE}}
645 
646  private:
647   // Number of buffers for the compiled computation.
648   static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
649 
650   static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
651     static const ::xla::cpu_function_runtime::BufferInfo
652       kBufferInfos[kNumBuffers] = {
653 {{BUFFER_INFOS_AS_STRING}}
654       };
655     return kBufferInfos;
656   }
657 
658   static const ::tensorflow::int32* ArgIndexToBufferIndex() {
659     static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
660 {{ARG_INDEX_TABLE}}
661     };
662     return kArgIndexToBufferIndex;
663   }
664 
665   // The 0-based index of the result tuple in the temporary buffers.
666   static constexpr size_t kResultIndex = {{RESULT_INDEX}};
667 
668   // Array of names of each positional argument, terminated by nullptr.
669   static const char** StaticArgNames() {{ARG_NAMES_CODE}}
670 
671   // Array of names of each positional variable, terminated by nullptr.
672   static const char** StaticVariableNames() {{VARIABLE_NAMES_CODE}}
673 
674   // Array of names of each positional result, terminated by nullptr.
675   static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
676 
677   // Shape of the args and results.
678   static const ::xla::ProgramShapeProto* StaticProgramShape() {
679     static const ::xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
680     return kShape;
681   }
682 
683   // Metadata that can be used to pretty-print profile counters.
684   static const ::xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
685     static const ::xla::HloProfilePrinterData* kHloProfilePrinterData =
686       {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
687     return kHloProfilePrinterData;
688   }
689 };
690 {{NS_END}}
691 
692 #endif  // TFCOMPILE_GENERATED_{{ENTRY}}_H_
693 
694 // clang-format on
695 )";
696   // The replacement strategy is naive, but good enough for our purposes.
697   const std::vector<std::pair<string, string>> rewrites = {
698       {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)},
699       {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
700       {"{{ARG_NAMES_CODE}}", arg_names_code},
701       {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
702       {"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
703       {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
704       {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
705       {"{{CLASS}}", opts.class_name},
706       {"{{DECLS_FROM_OBJ_FILE}}",
707        absl::StrJoin(metadata_result.header_variable_decls, "\n")},
708       {"{{ENTRY}}", compile_result.entry_point},
709       {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
710        metadata_result.hlo_profile_printer_data_access_shim},
711       {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
712       {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
713        include_hlo_profile_printer_data_proto},
714       {"{{METHODS_ARG}}\n", methods_arg},
715       {"{{METHODS_RESULT}}\n", methods_result},
716       {"{{METHODS_VARIABLE}}\n", methods_variable},
717       {"{{NS_END}}\n", ns_end},
718       {"{{NS_START}}\n", ns_start},
719       {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
720       {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
721        metadata_result.program_shape_access_shim},
722       {"{{VARIABLE_NAMES_CODE}}", variable_names_code},
723       {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
724       {"{{RESULT_NAMES_CODE}}", result_names_code},
725       {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
726       {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
727       {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
728       {"{{BUFFER_INFOS_AS_STRING}}",
729        absl::StrJoin(buffer_infos_as_strings, ",\n")}};
730   absl::StrReplaceAll(rewrites, header);
731   return OkStatus();
732 }
733 
CreateUniqueIdentifier(const CodegenOpts & opts,absl::string_view suffix)734 static string CreateUniqueIdentifier(const CodegenOpts& opts,
735                                      absl::string_view suffix) {
736   string result = "__tfcompile";
737   for (const string& n : opts.namespaces) {
738     absl::StrAppend(&result, "_", n);
739   }
740 
741   absl::StrAppend(&result, "_", opts.class_name, "_", suffix);
742   return result;
743 }
744 
GenerateMetadata(const CodegenOpts & opts,const CompileResult & compile_result,MetadataResult * metadata_result)745 Status GenerateMetadata(const CodegenOpts& opts,
746                         const CompileResult& compile_result,
747                         MetadataResult* metadata_result) {
748   std::unique_ptr<xla::ProgramShapeProto> program_shape;
749 
750   if (opts.gen_program_shape) {
751     program_shape =
752         absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
753 
754     // The parameter names are currently meaningless, and redundant with the
755     // rest of our metadata, so clear them out to avoid confusion and save
756     // space.
757     program_shape->clear_parameter_names();
758   }
759 
760   // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
761   // a shim that evaluates to nullptr, which is what we want.
762 
763   ProtobufToEmbed program_shape_protobuf{
764       CreateUniqueIdentifier(opts, "ProgramShapeProto"),
765       "::xla::ProgramShapeProto", program_shape.get()};
766 
767   ProtobufToEmbed hlo_profile_printer_data_protobuf{
768       CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
769       "::xla::HloProfilePrinterData",
770       compile_result.aot->hlo_profile_printer_data()};
771 
772   TF_ASSIGN_OR_RETURN(
773       EmbeddedProtocolBuffers embedded_protobufs,
774       CreateEmbeddedProtocolBuffers(
775           opts.target_triple,
776           {program_shape_protobuf, hlo_profile_printer_data_protobuf}));
777 
778   metadata_result->program_shape_access_shim =
779       std::move(embedded_protobufs.cpp_shims[0].expression);
780   metadata_result->hlo_profile_printer_data_access_shim =
781       std::move(embedded_protobufs.cpp_shims[1].expression);
782   metadata_result->header_variable_decls.emplace_back(
783       std::move(embedded_protobufs.cpp_shims[0].variable_decl));
784   metadata_result->header_variable_decls.emplace_back(
785       std::move(embedded_protobufs.cpp_shims[1].variable_decl));
786   metadata_result->object_file_data =
787       std::move(embedded_protobufs.object_file_data);
788   return OkStatus();
789 }
790 
ParseCppClass(const string & cpp_class,string * class_name,std::vector<string> * namespaces)791 Status ParseCppClass(const string& cpp_class, string* class_name,
792                      std::vector<string>* namespaces) {
793   class_name->clear();
794   namespaces->clear();
795   if (cpp_class.empty()) {
796     return errors::InvalidArgument("empty cpp_class: " + cpp_class);
797   }
798   std::vector<string> parts = absl::StrSplit(cpp_class, "::");
799   if (parts.front().empty()) {
800     // Allow a fully qualified name that starts with "::".
801     parts.erase(parts.begin());
802   }
803   for (int i = 0, end = parts.size(); i < end; ++i) {
804     if (i < end - 1) {
805       TF_RETURN_IF_ERROR(ValidateCppIdent(
806           parts[i], "in namespace component of cpp_class: " + cpp_class));
807       namespaces->push_back(parts[i]);
808     } else {
809       TF_RETURN_IF_ERROR(ValidateCppIdent(
810           parts[i], "in class name of cpp_class: " + cpp_class));
811       *class_name = parts[i];
812     }
813   }
814   return OkStatus();
815 }
816 
ValidateCppIdent(absl::string_view ident,absl::string_view msg)817 Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) {
818   if (ident.empty()) {
819     return errors::InvalidArgument("empty identifier: ", msg);
820   }
821   // Require that the identifier starts with a nondigit, and is composed of
822   // nondigits and digits, as specified in section [2.11 Identifiers] of the
823   // C++11 Standard.  Note that nondigit is defined as [_a-zA-Z] and digit is
824   // defined as [0-9].
825   //
826   // Technically the standard also allows for `universal-character-name`, with a
827   // table of allowed unicode ranges, as well as `other implementation-defined
828   // characters`.  We disallow those here to give better error messages, at the
829   // expensive of being more restrictive than the standard.
830   if (ident[0] != '_' && !IsAlpha(ident[0])) {
831     return errors::InvalidArgument("illegal leading char: ", msg);
832   }
833   for (size_t pos = 1; pos < ident.size(); ++pos) {
834     if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
835       return errors::InvalidArgument("illegal char: ", msg);
836     }
837   }
838   return OkStatus();
839 }
840 
841 }  // namespace tfcompile
842 }  // namespace tensorflow
843