1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_context.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/core/common_runtime/dma_helper.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace tensorflow {
37
38 const char XlaContext::kXlaContextResourceName[] = "_xla_context";
39
40 // Looks up the context associated with the current step. It is stored
41 // in a resource container managed by the device.
Get(const OpKernelContext * ctx)42 /* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) {
43 // When an Op kernel wants to use an XLA JIT context, the
44 // per-step context is looked up in the resource manager. The
45 // JIT will prepopulate the JITContext.
46 XlaContext* context;
47 TF_CHECK_OK(ctx->step_container()->Lookup(ctx->resource_manager(),
48 kXlaContextResourceName, &context));
49 // The resource manager handed us a fresh reference to 'context', but retains
50 // a reference itself so the context won't be freed. The resource manager will
51 // outlive the JIT compilation.
52 context->Unref();
53 return *context;
54 }
55
set_args(std::vector<XlaExpression> args)56 void XlaContext::set_args(std::vector<XlaExpression> args) {
57 args_ = std::move(args);
58 }
59
XlaContext(XlaCompiler * compiler,xla::XlaBuilder * builder,const Graph * graph)60 XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
61 const Graph* graph)
62 : compiler_(compiler), builder_(builder) {
63 if (graph) {
64 for (const Node* node : graph->nodes()) {
65 stack_traces_[node->name()] = node->GetStackTrace();
66 }
67 }
68 }
69
DebugString() const70 string XlaContext::DebugString() const { return "XLA JIT context"; }
71
SetRetval(int index,const XlaExpression & expression)72 void XlaContext::SetRetval(int index, const XlaExpression& expression) {
73 const int64_t retvals_size = retvals_.size();
74 if (retvals_size <= index) {
75 retvals_.resize(index + 1);
76 }
77 retvals_[index] = expression;
78 }
79
AddResource(std::unique_ptr<XlaResource> resource)80 XlaResource* XlaContext::AddResource(std::unique_ptr<XlaResource> resource) {
81 resources_.push_back(std::move(resource));
82 return resources_.back().get();
83 }
84
GetOrCreateMax(const DataType type)85 const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
86 return LookupOrCreate(type, &max_func_, [type] {
87 const string type_string = DataTypeString(type);
88 VLOG(1) << "Building Max() for " << type_string;
89 xla::XlaBuilder b("max<" + type_string + ">");
90 xla::PrimitiveType xla_type;
91 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
92 auto x =
93 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
94 auto y =
95 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
96 xla::Max(x, y);
97 return b.Build().value();
98 });
99 }
100
GetOrCreateMin(const DataType type)101 const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) {
102 return LookupOrCreate(type, &min_func_, [type] {
103 const string type_string = DataTypeString(type);
104 VLOG(1) << "Building Min() for " << type_string;
105 xla::XlaBuilder b("min<" + type_string + ">");
106 xla::PrimitiveType xla_type;
107 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
108 auto x =
109 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
110 auto y =
111 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
112 xla::Min(x, y);
113 return b.Build().value();
114 });
115 }
116
GetOrCreateAdd(const DataType type)117 const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) {
118 return LookupOrCreate(type, &add_func_, [type] {
119 const string type_string = DataTypeString(type);
120 VLOG(1) << "Building Add() for " << type_string;
121 xla::XlaBuilder b("add<" + type_string + ">");
122 xla::PrimitiveType xla_type;
123 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
124 auto x =
125 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
126 auto y =
127 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
128 xla::Add(x, y);
129 return b.Build().value();
130 });
131 }
132
GetOrCreateMul(const DataType type)133 const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) {
134 return LookupOrCreate(type, &mul_func_, [type] {
135 const string type_string = DataTypeString(type);
136 VLOG(1) << "Building Mul() for " << type_string;
137 xla::XlaBuilder b("mul<" + type_string + ">");
138 xla::PrimitiveType xla_type;
139 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
140 auto x =
141 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
142 auto y =
143 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
144 xla::Mul(x, y);
145 return b.Build().value();
146 });
147 }
148
LookupOrCreate(DataType type,ComputationMap * out,const std::function<xla::XlaComputation ()> & create)149 const xla::XlaComputation* XlaContext::LookupOrCreate(
150 DataType type, ComputationMap* out,
151 const std::function<xla::XlaComputation()>& create) {
152 {
153 const auto& entry = (*out)[type];
154 if (!entry.IsNull()) {
155 return &entry;
156 }
157 }
158 auto new_entry = create();
159 {
160 // Somebody else might have made one concurrently.
161 auto& entry = (*out)[type];
162 if (entry.IsNull()) {
163 entry = std::move(new_entry);
164 }
165 return &entry;
166 }
167 }
168
RecordCollectiveInfoFromNestedCompilationResult(const XlaCompilationResult & result)169 Status XlaContext::RecordCollectiveInfoFromNestedCompilationResult(
170 const XlaCompilationResult& result) {
171 if (result.collective_info) {
172 return RecordCollectiveInfo(result.collective_info->group_key,
173 result.collective_info->group_size)
174 .status();
175 }
176 return OkStatus();
177 }
178
RecordCollectiveInfo(int group_key,int group_size)179 StatusOr<int64_t> XlaContext::RecordCollectiveInfo(int group_key,
180 int group_size) {
181 if (!collective_info_) {
182 collective_info_ = {group_key, group_size, 0};
183 } else if (collective_info_->group_key != group_key ||
184 collective_info_->group_size != group_size) {
185 return errors::InvalidArgument(
186 "Only single configuration of CollectiveReduceV2Op is ",
187 "supported in a given cluster. Recorded group_key=",
188 collective_info_->group_key,
189 " attempting to insert group_key=", group_key);
190 }
191
192 // Create the channel_id to be used for the collective. Avoid having the
193 // same channel_id to be used for 2 or more collectives since XLA attempts
194 // to "gang schedule" all collectives with the same channel_id.
195 return (static_cast<int64_t>(group_key) << 32) | collective_info_->next_id++;
196 }
197
198 } // namespace tensorflow
199