xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h"
17 
18 #include <string>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo.pb.h"
23 #include "tensorflow/core/platform/path.h"
24 #include "tensorflow/core/public/session.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 // Creates a float tensor of linearly increasing values, starting from offset.
CreateInputTensor(const TensorShape & shape,float offset)30 Tensor CreateInputTensor(const TensorShape& shape, float offset) {
31   Tensor tensor(DT_FLOAT, shape);
32   for (int64 i = 0; i < tensor.flat<float>().size(); ++i) {
33     tensor.flat<float>()(i) = offset + i;
34   }
35   return tensor;
36 }
37 
MakeNode(absl::string_view name,absl::string_view op,absl::Span<const std::string> inputs,absl::Span<const std::pair<std::string,FunctionDefHelper::AttrValueWrapper>> attrs)38 NodeDef MakeNode(
39     absl::string_view name, absl::string_view op,
40     absl::Span<const std::string> inputs,
41     absl::Span<
42         const std::pair<std::string, FunctionDefHelper::AttrValueWrapper>>
43         attrs) {
44   NodeDef node;
45   node.set_name(std::string(name));
46   node.set_op(std::string(op));
47   for (const auto& input : inputs) node.add_input(input);
48   for (const auto& attr : attrs)
49     node.mutable_attr()->insert({attr.first, attr.second.proto});
50   return node;
51 }
52 
53 }  // namespace
54 
GetTestGraph(const PartialTensorShape & input_shape)55 GraphDef XlaCompilationCacheSerializeTest::GetTestGraph(
56     const PartialTensorShape& input_shape) {
57   FunctionDef make_test_fn = FunctionDefHelper::Define(
58       "TestFn", {"a:float", "b:float", "c:float"}, {"m:float"}, {},
59       {{{"d"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
60        {{"e"}, "Mul", {"d", "c"}, {{"T", DT_FLOAT}}},
61        {{"f"}, "Add", {"e", "a"}, {{"T", DT_FLOAT}}},
62        {{"g"}, "Mul", {"f", "b"}, {{"T", DT_FLOAT}}},
63        // Force two clusters by excluding this node explicitly.
64        {{"h"}, "Add", {"g", "f"}, {{"T", DT_FLOAT}, {"_XlaCompile", false}}},
65        {{"i"}, "Add", {"h", "e"}, {{"T", DT_FLOAT}}},
66        {{"j"}, "Add", {"i", "h"}, {{"T", DT_FLOAT}}},
67        {{"k"}, "Add", {"j", "h"}, {{"T", DT_FLOAT}}},
68        {{"l"}, "Add", {"k", "h"}, {{"T", DT_FLOAT}}},
69        {{"m"}, "Identity", {"l"}, {{"T", DT_FLOAT}}}});
70 
71   GraphDef graph;
72   *graph.mutable_library()->add_function() = make_test_fn;
73   *graph.add_node() = MakeNode("a", "Placeholder", {},
74                                {{"dtype", DT_FLOAT}, {"shape", input_shape}});
75   *graph.add_node() = MakeNode("b", "Placeholder", {},
76                                {{"dtype", DT_FLOAT}, {"shape", input_shape}});
77   *graph.add_node() = MakeNode("c", "Placeholder", {},
78                                {{"dtype", DT_FLOAT}, {"shape", input_shape}});
79   *graph.add_node() = MakeNode("m", "TestFn", {"a", "b", "c"}, {});
80   return graph;
81 }
82 
ExecuteWithBatch(const GraphDef & graph,int batch)83 Status XlaCompilationCacheSerializeTest::ExecuteWithBatch(const GraphDef& graph,
84                                                           int batch) {
85   const TensorShape shape({batch, 4});
86 
87   // Compute the golden output tensor
88   std::vector<Tensor> golden_output_tensors;
89   {
90     SessionOptions options;
91     std::unique_ptr<Session> session(NewSession(options));
92     TF_RETURN_IF_ERROR(session->Create(graph));
93     RunOptions run_options;
94 
95     Tensor input_a = CreateInputTensor(shape, 0);
96     Tensor input_b = CreateInputTensor(shape, shape.num_elements());
97     Tensor input_c = CreateInputTensor(shape, 2 * shape.num_elements());
98     TF_RETURN_IF_ERROR(session->Run(
99         run_options,
100         {std::make_pair("a", input_a), std::make_pair("b", input_b),
101          std::make_pair("c", input_c)},
102         {"m"}, {}, &golden_output_tensors, nullptr));
103     TF_RETURN_IF_ERROR(session->Close());
104   }
105 
106   // Compute the XLA compiled output
107   std::vector<Tensor> output_tensors;
108   {
109     SessionOptions options;
110     auto& opts =
111         *options.config.mutable_graph_options()->mutable_optimizer_options();
112     opts.set_global_jit_level(OptimizerOptions::ON_1);
113     opts.set_cpu_global_jit(true);
114 
115     std::unique_ptr<Session> session(NewSession(options));
116     TF_RETURN_IF_ERROR(session->Create(graph));
117     RunOptions run_options;
118     Tensor input_a = CreateInputTensor(shape, 0);
119     Tensor input_b = CreateInputTensor(shape, shape.num_elements());
120     Tensor input_c = CreateInputTensor(shape, 2 * shape.num_elements());
121     TF_RETURN_IF_ERROR(session->Run(
122         run_options,
123         {std::make_pair("a", input_a), std::make_pair("b", input_b),
124          std::make_pair("c", input_c)},
125         {"m"}, {}, &output_tensors, nullptr));
126     TF_RETURN_IF_ERROR(session->Close());
127   }
128 
129   Tensor f32_input(DT_FLOAT, shape);
130   for (int64 i = 0; i < f32_input.NumElements(); ++i) {
131     EXPECT_NEAR(golden_output_tensors[0].flat<float>()(i),
132                 output_tensors[0].flat<float>()(i), 1e-3);
133   }
134   return OkStatus();
135 }
136 
137 Status
AlterPersistentCacheEntryHloModuleNames(absl::string_view persistent_cache_dir_path,absl::string_view file_prefix)138 XlaCompilationCacheSerializeTest::AlterPersistentCacheEntryHloModuleNames(
139     absl::string_view persistent_cache_dir_path,
140     absl::string_view file_prefix) {
141   Env* env = Env::Default();
142   std::vector<string> file_names;
143   TF_RETURN_IF_ERROR(
144       env->GetChildren(tensorflow::testing::TmpDir(), &file_names));
145 
146   bool altered = false;
147   for (const auto& file_name : file_names) {
148     if (absl::EndsWith(file_name, ".pb") &&
149         absl::StartsWith(file_name, file_prefix)) {
150       XlaSerializedCacheEntry entry;
151       auto file_path = io::JoinPath(persistent_cache_dir_path, file_name);
152       TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(env, file_path, &entry));
153       entry.mutable_hlo_module()->set_name(
154           absl::StrCat(entry.hlo_module().name(), "_altered"));
155       TF_RETURN_IF_ERROR(WriteBinaryProto(env, file_path, entry));
156       altered = true;
157     }
158   }
159 
160   if (!altered) {
161     return errors::NotFound(
162         "Did not find any persistent XLA compilation cache entries to alter.");
163   }
164   return OkStatus();
165 }
166 
167 }  // namespace tensorflow
168