1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h"
17
18 #include <string>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo.pb.h"
23 #include "tensorflow/core/platform/path.h"
24 #include "tensorflow/core/public/session.h"
25
26 namespace tensorflow {
27 namespace {
28
29 // Creates a float tensor of linearly increasing values, starting from offset.
CreateInputTensor(const TensorShape & shape,float offset)30 Tensor CreateInputTensor(const TensorShape& shape, float offset) {
31 Tensor tensor(DT_FLOAT, shape);
32 for (int64 i = 0; i < tensor.flat<float>().size(); ++i) {
33 tensor.flat<float>()(i) = offset + i;
34 }
35 return tensor;
36 }
37
MakeNode(absl::string_view name,absl::string_view op,absl::Span<const std::string> inputs,absl::Span<const std::pair<std::string,FunctionDefHelper::AttrValueWrapper>> attrs)38 NodeDef MakeNode(
39 absl::string_view name, absl::string_view op,
40 absl::Span<const std::string> inputs,
41 absl::Span<
42 const std::pair<std::string, FunctionDefHelper::AttrValueWrapper>>
43 attrs) {
44 NodeDef node;
45 node.set_name(std::string(name));
46 node.set_op(std::string(op));
47 for (const auto& input : inputs) node.add_input(input);
48 for (const auto& attr : attrs)
49 node.mutable_attr()->insert({attr.first, attr.second.proto});
50 return node;
51 }
52
53 } // namespace
54
GetTestGraph(const PartialTensorShape & input_shape)55 GraphDef XlaCompilationCacheSerializeTest::GetTestGraph(
56 const PartialTensorShape& input_shape) {
57 FunctionDef make_test_fn = FunctionDefHelper::Define(
58 "TestFn", {"a:float", "b:float", "c:float"}, {"m:float"}, {},
59 {{{"d"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
60 {{"e"}, "Mul", {"d", "c"}, {{"T", DT_FLOAT}}},
61 {{"f"}, "Add", {"e", "a"}, {{"T", DT_FLOAT}}},
62 {{"g"}, "Mul", {"f", "b"}, {{"T", DT_FLOAT}}},
63 // Force two clusters by excluding this node explicitly.
64 {{"h"}, "Add", {"g", "f"}, {{"T", DT_FLOAT}, {"_XlaCompile", false}}},
65 {{"i"}, "Add", {"h", "e"}, {{"T", DT_FLOAT}}},
66 {{"j"}, "Add", {"i", "h"}, {{"T", DT_FLOAT}}},
67 {{"k"}, "Add", {"j", "h"}, {{"T", DT_FLOAT}}},
68 {{"l"}, "Add", {"k", "h"}, {{"T", DT_FLOAT}}},
69 {{"m"}, "Identity", {"l"}, {{"T", DT_FLOAT}}}});
70
71 GraphDef graph;
72 *graph.mutable_library()->add_function() = make_test_fn;
73 *graph.add_node() = MakeNode("a", "Placeholder", {},
74 {{"dtype", DT_FLOAT}, {"shape", input_shape}});
75 *graph.add_node() = MakeNode("b", "Placeholder", {},
76 {{"dtype", DT_FLOAT}, {"shape", input_shape}});
77 *graph.add_node() = MakeNode("c", "Placeholder", {},
78 {{"dtype", DT_FLOAT}, {"shape", input_shape}});
79 *graph.add_node() = MakeNode("m", "TestFn", {"a", "b", "c"}, {});
80 return graph;
81 }
82
ExecuteWithBatch(const GraphDef & graph,int batch)83 Status XlaCompilationCacheSerializeTest::ExecuteWithBatch(const GraphDef& graph,
84 int batch) {
85 const TensorShape shape({batch, 4});
86
87 // Compute the golden output tensor
88 std::vector<Tensor> golden_output_tensors;
89 {
90 SessionOptions options;
91 std::unique_ptr<Session> session(NewSession(options));
92 TF_RETURN_IF_ERROR(session->Create(graph));
93 RunOptions run_options;
94
95 Tensor input_a = CreateInputTensor(shape, 0);
96 Tensor input_b = CreateInputTensor(shape, shape.num_elements());
97 Tensor input_c = CreateInputTensor(shape, 2 * shape.num_elements());
98 TF_RETURN_IF_ERROR(session->Run(
99 run_options,
100 {std::make_pair("a", input_a), std::make_pair("b", input_b),
101 std::make_pair("c", input_c)},
102 {"m"}, {}, &golden_output_tensors, nullptr));
103 TF_RETURN_IF_ERROR(session->Close());
104 }
105
106 // Compute the XLA compiled output
107 std::vector<Tensor> output_tensors;
108 {
109 SessionOptions options;
110 auto& opts =
111 *options.config.mutable_graph_options()->mutable_optimizer_options();
112 opts.set_global_jit_level(OptimizerOptions::ON_1);
113 opts.set_cpu_global_jit(true);
114
115 std::unique_ptr<Session> session(NewSession(options));
116 TF_RETURN_IF_ERROR(session->Create(graph));
117 RunOptions run_options;
118 Tensor input_a = CreateInputTensor(shape, 0);
119 Tensor input_b = CreateInputTensor(shape, shape.num_elements());
120 Tensor input_c = CreateInputTensor(shape, 2 * shape.num_elements());
121 TF_RETURN_IF_ERROR(session->Run(
122 run_options,
123 {std::make_pair("a", input_a), std::make_pair("b", input_b),
124 std::make_pair("c", input_c)},
125 {"m"}, {}, &output_tensors, nullptr));
126 TF_RETURN_IF_ERROR(session->Close());
127 }
128
129 Tensor f32_input(DT_FLOAT, shape);
130 for (int64 i = 0; i < f32_input.NumElements(); ++i) {
131 EXPECT_NEAR(golden_output_tensors[0].flat<float>()(i),
132 output_tensors[0].flat<float>()(i), 1e-3);
133 }
134 return OkStatus();
135 }
136
137 Status
AlterPersistentCacheEntryHloModuleNames(absl::string_view persistent_cache_dir_path,absl::string_view file_prefix)138 XlaCompilationCacheSerializeTest::AlterPersistentCacheEntryHloModuleNames(
139 absl::string_view persistent_cache_dir_path,
140 absl::string_view file_prefix) {
141 Env* env = Env::Default();
142 std::vector<string> file_names;
143 TF_RETURN_IF_ERROR(
144 env->GetChildren(tensorflow::testing::TmpDir(), &file_names));
145
146 bool altered = false;
147 for (const auto& file_name : file_names) {
148 if (absl::EndsWith(file_name, ".pb") &&
149 absl::StartsWith(file_name, file_prefix)) {
150 XlaSerializedCacheEntry entry;
151 auto file_path = io::JoinPath(persistent_cache_dir_path, file_name);
152 TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(env, file_path, &entry));
153 entry.mutable_hlo_module()->set_name(
154 absl::StrCat(entry.hlo_module().name(), "_altered"));
155 TF_RETURN_IF_ERROR(WriteBinaryProto(env, file_path, entry));
156 altered = true;
157 }
158 }
159
160 if (!altered) {
161 return errors::NotFound(
162 "Did not find any persistent XLA compilation cache entries to alter.");
163 }
164 return OkStatus();
165 }
166
167 } // namespace tensorflow
168