xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_context.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_context.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/core/common_runtime/dma_helper.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace tensorflow {
37 
38 const char XlaContext::kXlaContextResourceName[] = "_xla_context";
39 
40 // Looks up the context associated with the current step. It is stored
41 // in a resource container managed by the device.
Get(const OpKernelContext * ctx)42 /* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) {
43   // When an Op kernel wants to use an XLA JIT context, the
44   // per-step context is looked up in the resource manager. The
45   // JIT will prepopulate the JITContext.
46   XlaContext* context;
47   TF_CHECK_OK(ctx->step_container()->Lookup(ctx->resource_manager(),
48                                             kXlaContextResourceName, &context));
49   // The resource manager handed us a fresh reference to 'context', but retains
50   // a reference itself so the context won't be freed. The resource manager will
51   // outlive the JIT compilation.
52   context->Unref();
53   return *context;
54 }
55 
set_args(std::vector<XlaExpression> args)56 void XlaContext::set_args(std::vector<XlaExpression> args) {
57   args_ = std::move(args);
58 }
59 
XlaContext(XlaCompiler * compiler,xla::XlaBuilder * builder,const Graph * graph)60 XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
61                        const Graph* graph)
62     : compiler_(compiler), builder_(builder) {
63   if (graph) {
64     for (const Node* node : graph->nodes()) {
65       stack_traces_[node->name()] = node->GetStackTrace();
66     }
67   }
68 }
69 
DebugString() const70 string XlaContext::DebugString() const { return "XLA JIT context"; }
71 
SetRetval(int index,const XlaExpression & expression)72 void XlaContext::SetRetval(int index, const XlaExpression& expression) {
73   const int64_t retvals_size = retvals_.size();
74   if (retvals_size <= index) {
75     retvals_.resize(index + 1);
76   }
77   retvals_[index] = expression;
78 }
79 
AddResource(std::unique_ptr<XlaResource> resource)80 XlaResource* XlaContext::AddResource(std::unique_ptr<XlaResource> resource) {
81   resources_.push_back(std::move(resource));
82   return resources_.back().get();
83 }
84 
GetOrCreateMax(const DataType type)85 const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
86   return LookupOrCreate(type, &max_func_, [type] {
87     const string type_string = DataTypeString(type);
88     VLOG(1) << "Building Max() for " << type_string;
89     xla::XlaBuilder b("max<" + type_string + ">");
90     xla::PrimitiveType xla_type;
91     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
92     auto x =
93         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
94     auto y =
95         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
96     xla::Max(x, y);
97     return b.Build().value();
98   });
99 }
100 
GetOrCreateMin(const DataType type)101 const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) {
102   return LookupOrCreate(type, &min_func_, [type] {
103     const string type_string = DataTypeString(type);
104     VLOG(1) << "Building Min() for " << type_string;
105     xla::XlaBuilder b("min<" + type_string + ">");
106     xla::PrimitiveType xla_type;
107     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
108     auto x =
109         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
110     auto y =
111         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
112     xla::Min(x, y);
113     return b.Build().value();
114   });
115 }
116 
GetOrCreateAdd(const DataType type)117 const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) {
118   return LookupOrCreate(type, &add_func_, [type] {
119     const string type_string = DataTypeString(type);
120     VLOG(1) << "Building Add() for " << type_string;
121     xla::XlaBuilder b("add<" + type_string + ">");
122     xla::PrimitiveType xla_type;
123     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
124     auto x =
125         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
126     auto y =
127         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
128     xla::Add(x, y);
129     return b.Build().value();
130   });
131 }
132 
GetOrCreateMul(const DataType type)133 const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) {
134   return LookupOrCreate(type, &mul_func_, [type] {
135     const string type_string = DataTypeString(type);
136     VLOG(1) << "Building Mul() for " << type_string;
137     xla::XlaBuilder b("mul<" + type_string + ">");
138     xla::PrimitiveType xla_type;
139     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
140     auto x =
141         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
142     auto y =
143         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
144     xla::Mul(x, y);
145     return b.Build().value();
146   });
147 }
148 
LookupOrCreate(DataType type,ComputationMap * out,const std::function<xla::XlaComputation ()> & create)149 const xla::XlaComputation* XlaContext::LookupOrCreate(
150     DataType type, ComputationMap* out,
151     const std::function<xla::XlaComputation()>& create) {
152   {
153     const auto& entry = (*out)[type];
154     if (!entry.IsNull()) {
155       return &entry;
156     }
157   }
158   auto new_entry = create();
159   {
160     // Somebody else might have made one concurrently.
161     auto& entry = (*out)[type];
162     if (entry.IsNull()) {
163       entry = std::move(new_entry);
164     }
165     return &entry;
166   }
167 }
168 
RecordCollectiveInfoFromNestedCompilationResult(const XlaCompilationResult & result)169 Status XlaContext::RecordCollectiveInfoFromNestedCompilationResult(
170     const XlaCompilationResult& result) {
171   if (result.collective_info) {
172     return RecordCollectiveInfo(result.collective_info->group_key,
173                                 result.collective_info->group_size)
174         .status();
175   }
176   return OkStatus();
177 }
178 
RecordCollectiveInfo(int group_key,int group_size)179 StatusOr<int64_t> XlaContext::RecordCollectiveInfo(int group_key,
180                                                    int group_size) {
181   if (!collective_info_) {
182     collective_info_ = {group_key, group_size, 0};
183   } else if (collective_info_->group_key != group_key ||
184              collective_info_->group_size != group_size) {
185     return errors::InvalidArgument(
186         "Only single configuration of CollectiveReduceV2Op is ",
187         "supported in a given cluster. Recorded group_key=",
188         collective_info_->group_key,
189         " attempting to insert group_key=", group_key);
190   }
191 
192   // Create the channel_id to be used for the collective. Avoid having the
193   // same channel_id to be used for 2 or more collectives since XLA attempts
194   // to "gang schedule" all collectives with the same channel_id.
195   return (static_cast<int64_t>(group_key) << 32) | collective_info_->next_id++;
196 }
197 
198 }  // namespace tensorflow
199