xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xrt/kernels/xrt_state_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18 
19 #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
20 #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
21 
22 #include <functional>
23 #include <memory>
24 #include <string>
25 
26 #include "tensorflow/compiler/tf2xla/literal_util.h"
27 #include "tensorflow/compiler/tf2xla/shape_util.h"
28 #include "tensorflow/compiler/tf2xla/type_util.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/compiler/xrt/xrt.pb.h"
36 #include "tensorflow/compiler/xrt/xrt_device.h"
37 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
38 #include "tensorflow/compiler/xrt/xrt_metrics.h"
39 #include "tensorflow/compiler/xrt/xrt_state.h"
40 #include "tensorflow/core/common_runtime/dma_helper.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/resource_mgr.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_shape.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/refcount.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
51 #include "tensorflow/core/lib/monitoring/timed.h"
52 #include "tensorflow/core/platform/types.h"
53 
54 namespace tensorflow {
55 
56 // Helper functions for templated ops.
57 class XRTStateHelpers {
58  public:
59   // The Status return value allows us to use the
60   // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
61   // OpKernel::Compute method.
MakeLiteral(const xla::LiteralProto & proto,xla::Literal * literal)62   static Status MakeLiteral(const xla::LiteralProto& proto,
63                             xla::Literal* literal) {
64     TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
65     return OkStatus();
66   }
67 
68   // ParseTupleNode is the recursive function used to parse a recursive
69   // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the
70   // tuple shape where every leaf is an existing allocation. As a side-effect it
71   // fills in input_vector by looking up allocations from handles in the
72   // input_tensor_list as they are referenced by nodes in the proto.
ParseTupleNode(const xrt::XLATupleNode & tuple_node,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::Shape * shape,ResourceMgr * rm)73   static Status ParseTupleNode(
74       const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list,
75       std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
76       xla::Shape* shape, ResourceMgr* rm) {
77     if (tuple_node.tuples_size() > 0) {
78       // This is an internal node in the proto so descend recursively.
79       xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({});
80       std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy);
81       *xla::ShapeUtil::GetMutableSubshape(shape, {}) =
82           xla::ShapeUtil::MakeTupleShape(subshapes);
83       for (int i = 0; i < tuple_node.tuples_size(); ++i) {
84         TF_RETURN_IF_ERROR(ParseTupleNode(
85             tuple_node.tuples(i), input_tensor_list, input_vector,
86             xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm));
87       }
88     } else {
89       // This is a leaf node in the proto so look up the referenced input.
90       int input_index = tuple_node.input_index();
91       if (input_index < 0 || input_index >= input_vector->size()) {
92         return errors::InvalidArgument("Invalid tuple input index ",
93                                        input_index, ": MakeTuple has ",
94                                        input_vector->size(), " inputs.");
95       }
96       bool release_this_input = tuple_node.release_input_handle();
97       XRTTupleAllocation::ExpandedTupleInput& input =
98           input_vector->at(input_index);
99       if (input.allocation != nullptr &&
100           (input.release_allocation_after_use || release_this_input)) {
101         return errors::InvalidArgument(
102             "Invalid tuple tree: input index ", input_index,
103             " is repeated but release_input_handle is true.");
104       }
105       if (input.allocation == nullptr) {
106         // We haven't dereferenced this handle yet.
107         TF_RET_CHECK(
108             TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape()));
109         int64_t key = input_tensor_list[input_index].scalar<int64_t>()();
110         TF_ASSIGN_OR_RETURN(input.allocation,
111                             XRTMemoryManager::Get(rm)->Lookup(key));
112         input.release_allocation_after_use = release_this_input;
113       }
114     }
115     return OkStatus();
116   }
117 
118   // Parses a xrt::XLATupleNode proto recursively and returns the corresponding
119   // ShapeTree where each leaf is an allocation corresponding to a handle in
120   // input_tensor_list. The ordinal of one of the allocations is returned in
121   // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with
122   // no leaves, device_ordinal will always be filled in by a successful call to
123   // ParseTupleTree.
ParseTupleTree(const xrt::XLATupleNode & tuple_tree_root,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> * tuple_shape_tree,int * device_ordinal,ResourceMgr * rm)124   static Status ParseTupleTree(
125       const xrt::XLATupleNode& tuple_tree_root,
126       const OpInputList& input_tensor_list,
127       std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
128       xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree,
129       int* device_ordinal, ResourceMgr* rm) {
130     // First get the shape of the 'spine' of the new tuple, where every leaf is
131     // an existing allocation. As a side-effect dereference the input handles
132     // into allocations in input_vector.
133     xla::Shape tuple_tree_shape;
134     TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list,
135                                       input_vector, &tuple_tree_shape, rm));
136     // Make the shape tree of allocations where the shape is the spine and each
137     // leaf is one of the allocations looked up in input_vector. Internal nodes
138     // have nullptr allocations.
139     *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>(
140         tuple_tree_shape);
141     tuple_shape_tree->ForEachMutableElement(
142         [&](const xla::ShapeIndex& index,
143             XRTTupleAllocation::ExpandedTupleInput* element) {
144           if (tuple_shape_tree->IsLeaf(index)) {
145             // Find the matching leaf in the proto tree.
146             const xrt::XLATupleNode* tuple_node = &tuple_tree_root;
147             for (int i = 0; i < index.size(); ++i) {
148               tuple_node = &tuple_node->tuples(index[i]);
149             }
150             // Copy the appropriate input allocation to the leaf of the
151             // tuple_shape_tree.
152             int input_index = tuple_node->input_index();
153             *element = input_vector->at(input_index);
154             CHECK(element->release_allocation_after_use ==
155                   tuple_node->release_input_handle());
156             // We just need to know the device_ordinal of one of the
157             // allocations. We will validate later that they are all the same.
158             *device_ordinal = (*element).allocation->device_ordinal();
159           }
160         });
161     return OkStatus();
162   }
163 };
164 
165 // Op that allocates memory for a literal and transfers it to the device.
166 template <class DeviceAccessor>
167 class XRTAllocateOp : public OpKernel {
168  public:
XRTAllocateOp(OpKernelConstruction * ctx)169   explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
170   ~XRTAllocateOp() override = default;
171   XRTAllocateOp(const XRTAllocateOp&) = delete;
172   XRTAllocateOp& operator=(const XRTAllocateOp&) = delete;
173 
Compute(OpKernelContext * ctx)174   void Compute(OpKernelContext* ctx) override {
175     VLOG(1) << "XRTAllocateOp::Compute";
176     auto timed = monitoring::MakeTimed(xrt_metrics::GetAllocateCell());
177 
178     const Tensor& allocation_info = ctx->input(0);
179     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()),
180                 errors::Internal("allocation input should be a string scalar"));
181     xrt::XLAAllocation allocation_proto;
182     OP_REQUIRES(ctx,
183                 ParseFromTString(allocation_info.scalar<tstring>()(),
184                                  &allocation_proto),
185                 errors::InvalidArgument(
186                     "Unable to parse allocation input to XLAAllocation"));
187 
188     xla::Literal literal;
189     OP_REQUIRES_OK(
190         ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
191 
192     ResourceMgr* rm;
193     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
194 
195     // We are guaranteed that the underlying device object won't be deleted out
196     // from under us, while the ScopedRef is live.
197     class DeviceAccessor::ScopedRef device_ref;
198     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
199 
200     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
201     XRTTupleAllocation* allocation;
202     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
203                             literal, memory_manager.get(), device_ref.backend(),
204                             device_ref.device_ordinal(), &allocation,
205                             device_ref.allocator()));
206 
207     Tensor output(DT_INT64, TensorShape({}));
208     output.scalar<int64_t>()() = memory_manager->Register(allocation);
209     ctx->set_output(0, output);
210   }
211 };
212 
213 // Op that allocates uninitialized memory on the device for a tensor of
214 // a particular shape.
215 template <class DeviceAccessor>
216 class XRTAllocateUninitializedOp : public OpKernel {
217  public:
XRTAllocateUninitializedOp(OpKernelConstruction * ctx)218   explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx)
219       : OpKernel(ctx) {
220     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
221     OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_));
222     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_));
223   }
224   ~XRTAllocateUninitializedOp() override = default;
225   XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete;
226   XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) =
227       delete;
228 
Compute(OpKernelContext * ctx)229   void Compute(OpKernelContext* ctx) override {
230     VLOG(1) << "XRTAllocateUninitializedOp::Compute";
231     auto timed =
232         monitoring::MakeTimed(xrt_metrics::GetAllocateUninitializedCell());
233     ResourceMgr* rm;
234     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
235 
236     // We are guaranteed that the underlying device object won't be deleted out
237     // from under us, while the ScopedRef is live.
238     class DeviceAccessor::ScopedRef device_ref;
239     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
240 
241     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
242     XRTTupleAllocation* allocation;
243     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateUninitialized(
244                             xla_shape_, memory_manager.get(),
245                             device_ref.backend(), device_ref.device_ordinal(),
246                             &allocation, device_ref.allocator()));
247 
248     Tensor output(DT_INT64, TensorShape({}));
249     output.scalar<int64_t>()() = memory_manager->Register(allocation);
250     ctx->set_output(0, output);
251   }
252 
253  private:
254   DataType dtype_;
255   TensorShape tf_shape_;
256   xla::Shape xla_shape_;
257 };
258 
259 // Op that allocates memory for a tensor (with optional layout) and transfers it
260 // to the device, returning an allocation handle.
261 template <class DeviceAccessor>
262 class XRTAllocateFromTensorOp : public OpKernel {
263  public:
XRTAllocateFromTensorOp(OpKernelConstruction * ctx)264   explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
265     bool make_tuple = false;
266     OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_));
267     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
268     OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple));
269     std::vector<int64_t> minor_to_major;
270     if (ctx->HasAttr("layouts")) {
271       OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major));
272     }
273     OP_REQUIRES(
274         ctx, tf_shapes_.size() == dtypes_.size(),
275         errors::InvalidArgument("shapes and dtypes must be the same length"));
276     std::vector<xla::Shape> xla_shapes;
277     xla_shapes.reserve(tf_shapes_.size());
278     for (int i = 0; i < tf_shapes_.size(); i++) {
279       xla::Shape xla_shape;
280       OP_REQUIRES_OK(
281           ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape));
282       xla_shapes.push_back(std::move(xla_shape));
283     }
284     if (xla_shapes.size() > 1 || make_tuple) {
285       shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes);
286     } else {
287       shape_.Swap(&xla_shapes.front());
288     }
289     if (!minor_to_major.empty()) {
290       xla::Shape shape_with_layouts;
291       OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major,
292                                              /*layout_func=*/nullptr,
293                                              &shape_with_layouts));
294       shape_.Swap(&shape_with_layouts);
295     }
296   }
297 
298   ~XRTAllocateFromTensorOp() override = default;
299   XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete;
300   XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete;
301 
Compute(OpKernelContext * ctx)302   void Compute(OpKernelContext* ctx) override {
303     VLOG(1) << "XRTAllocateFromTensorOp::Compute";
304     auto timed =
305         monitoring::MakeTimed(xrt_metrics::GetAllocateFromTensorCell());
306 
307     OpInputList values;
308     OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
309     OP_REQUIRES(ctx, values.size() == tf_shapes_.size(),
310                 errors::InvalidArgument(
311                     "Wrong number of inputs to XRTAllocateFromTensor: ",
312                     values.size(), " vs. ", tf_shapes_.size()));
313 
314     std::vector<const char*> tensors_data;
315     for (size_t i = 0; i < values.size(); ++i) {
316       const Tensor& input_tensor = values[i];
317       OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
318                   errors::InvalidArgument(
319                       "Input tensor type and input dtype do not match"));
320       // We allow the requested on-device shape to differ from the shape of the
321       // input tensor, as long as they have the same number of elements.
322       OP_REQUIRES(
323           ctx,
324           input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(),
325           errors::InvalidArgument(
326               "Input tensor must have the number of elements specified "
327               "in the matching input shape: ",
328               input_tensor.shape().num_elements(), " vs. ",
329               tf_shapes_[i].num_elements(), " at index ", i));
330       tensors_data.push_back(
331           static_cast<const char*>(DMAHelper::base(&input_tensor)));
332     }
333     // Use the buffer straight out of the input tensors to create the literal.
334     xla::BorrowingLiteral literal =
335         shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_)
336                          : xla::BorrowingLiteral(tensors_data.front(), shape_);
337     ResourceMgr* rm;
338     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
339 
340     // We are guaranteed that the underlying device object won't be deleted out
341     // from under us, while the ScopedRef is live.
342     class DeviceAccessor::ScopedRef device_ref;
343     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
344 
345     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
346     XRTTupleAllocation* allocation;
347     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
348                             literal, memory_manager.get(), device_ref.backend(),
349                             device_ref.device_ordinal(), &allocation,
350                             device_ref.allocator()));
351 
352     Tensor output(DT_INT64, TensorShape({}));
353     output.scalar<int64_t>()() = memory_manager->Register(allocation);
354     ctx->set_output(0, output);
355   }
356 
357  private:
358   std::vector<TensorShape> tf_shapes_;
359   DataTypeVector dtypes_;
360   xla::Shape shape_;
361 };
362 
363 // Op that takes a tuple handle input and returns a handle to a sub-tuple of the
364 // input.
365 template <bool discard_, class DeviceAccessor>
366 class XRTSubTupleOp : public OpKernel {
367  public:
XRTSubTupleOp(OpKernelConstruction * ctx)368   explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
369   ~XRTSubTupleOp() override = default;
370   XRTSubTupleOp(const XRTSubTupleOp&) = delete;
371   XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete;
372 
Compute(OpKernelContext * ctx)373   void Compute(OpKernelContext* ctx) override {
374     VLOG(1) << "XRTSubTupleOp::Compute";
375     auto timed = monitoring::MakeTimed(xrt_metrics::GetSubTupleCell());
376 
377     const Tensor& handle_tensor = ctx->input(0);
378     OP_REQUIRES(
379         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
380         errors::Internal("computation input should be an int64 scalar"));
381     int64_t allocation_handle = handle_tensor.scalar<int64_t>()();
382 
383     const Tensor& subtuple_info = ctx->input(1);
384     OP_REQUIRES(
385         ctx, TensorShapeUtils::IsVector(subtuple_info.shape()),
386         errors::Internal("tuple index input should be an int32 vector"));
387     xla::ShapeIndex shape_index;
388     for (int i = 0; i < subtuple_info.dim_size(0); ++i) {
389       shape_index.push_back(subtuple_info.vec<int32>()(i));
390     }
391 
392     ResourceMgr* rm;
393     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
394 
395     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
396     RefPtr<XRTTupleAllocation> allocation;
397     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
398 
399     if (discard_) {
400       VLOG(2) << "Releasing handle " << allocation_handle;
401       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
402     }
403 
404     XRTTupleAllocation* suballocation;
405     OP_REQUIRES_OK(
406         ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index,
407                                                &suballocation, !discard_));
408 
409     Tensor output(DT_INT64, TensorShape({}));
410     output.scalar<int64_t>()() = memory_manager->Register(suballocation);
411     ctx->set_output(0, output);
412   }
413 };
414 
415 // Op that allocates memory for a literal and transfers it to the device.
416 template <class DeviceAccessor>
417 class XRTMakeTupleOp : public OpKernel {
418  public:
XRTMakeTupleOp(OpKernelConstruction * ctx)419   explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
420   ~XRTMakeTupleOp() override = default;
421   XRTMakeTupleOp(const XRTMakeTupleOp&) = delete;
422   XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete;
423 
Compute(OpKernelContext * ctx)424   void Compute(OpKernelContext* ctx) override {
425     VLOG(1) << "XRTMakeTupleOp::Compute";
426     auto timed = monitoring::MakeTimed(xrt_metrics::GetMakeTupleCell());
427 
428     const Tensor& tuple_info = ctx->input(0);
429     OP_REQUIRES(
430         ctx, TensorShapeUtils::IsScalar(tuple_info.shape()),
431         errors::Internal("tuple description input should be a string scalar"));
432     xrt::XLATupleNode tuple_proto;
433     OP_REQUIRES(
434         ctx, ParseFromTString(tuple_info.scalar<tstring>()(), &tuple_proto),
435         errors::InvalidArgument("Unable to parse tuple input to XLATupleNode"));
436 
437     OpInputList arg_list;
438     OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list));
439 
440     // For each input, the allocation it corresponds to and a flag indicating
441     // whether or not it should be released, i.e. discarded from the resource
442     // manager. One ref on each allocation is owned by this vector, and freed on
443     // exit.
444     std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector(
445         arg_list.size());
446     ResourceMgr* rm;
447     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
448 
449     xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree;
450     // device_ordinal is filled in by ParseTupleTree with the ordinal of one of
451     // the allocations. It is guaranteed that there is at least on allocation in
452     // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that
453     // all the allocations are on the same device.
454     int device_ordinal;
455     OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree(
456                             tuple_proto, arg_list, &input_vector,
457                             &tuple_shape_tree, &device_ordinal, rm));
458 
459     // We are guaranteed that the underlying device object won't be deleted out
460     // from under us, while the ScopedRef is live.
461     class DeviceAccessor::ScopedRef device_ref;
462     OP_REQUIRES_OK(
463         ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref));
464 
465     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
466     XRTTupleAllocation* output_allocation;
467     OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple(
468                             memory_manager.get(), device_ref.backend(),
469                             device_ref.device_ordinal(), tuple_shape_tree,
470                             &output_allocation, device_ref.allocator()));
471     RefPtr<XRTTupleAllocation> output_ptr(output_allocation);
472     for (int i = 0; i < input_vector.size(); ++i) {
473       if (input_vector[i].release_allocation_after_use) {
474         OP_REQUIRES_OK(
475             ctx, memory_manager->Release(arg_list[i].scalar<int64_t>()()));
476       }
477     }
478 
479     Tensor output(DT_INT64, TensorShape({}));
480     output.scalar<int64_t>()() =
481         memory_manager->Register(std::move(output_ptr));
482     ctx->set_output(0, output);
483   }
484 };
485 
486 // Op that reads a device-resident tuple to host memory and returns it as a
487 // literal.
488 template <bool discard_, class DeviceAccessor>
489 class XRTReadLiteralOp : public OpKernel {
490  public:
XRTReadLiteralOp(OpKernelConstruction * ctx)491   explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
492   ~XRTReadLiteralOp() override = default;
493   XRTReadLiteralOp(const XRTReadLiteralOp&) = delete;
494   XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete;
495 
Compute(OpKernelContext * ctx)496   void Compute(OpKernelContext* ctx) override {
497     VLOG(1) << "XRTReadLiteralOp::Compute";
498     auto timed = monitoring::MakeTimed(xrt_metrics::GetReadLiteralCell());
499 
500     const Tensor& handle_tensor = ctx->input(0);
501     OP_REQUIRES(
502         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
503         errors::Internal("computation input should be an int64 scalar"));
504     int64_t allocation_handle = handle_tensor.scalar<int64_t>()();
505 
506     ResourceMgr* rm;
507     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
508 
509     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
510     RefPtr<XRTTupleAllocation> allocation;
511     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
512 
513     if (discard_) {
514       VLOG(2) << "Releasing handle " << allocation_handle;
515       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
516     }
517 
518     // We are guaranteed that the underlying device object won't be deleted out
519     // from under us, while the ScopedRef is live.
520     class DeviceAccessor::ScopedRef device_ref;
521     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
522                             ctx, allocation->device_ordinal(), &device_ref));
523 
524     xla::Literal literal(allocation->on_host_shape());
525     OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal));
526     xla::LiteralProto literal_proto = literal.ToProto();
527 
528     Tensor output(DT_STRING, TensorShape({}));
529     SerializeToTString(literal_proto, &output.scalar<tstring>()());
530     ctx->set_output(0, output);
531   }
532 };
533 
534 // Op that reads a device-resident tuple to host memory and returns it as a
535 // literal.
536 template <class DeviceAccessor>
537 class XRTReadToTensorOp : public OpKernel {
538  public:
XRTReadToTensorOp(OpKernelConstruction * ctx)539   explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
540     OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_));
541     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
542   }
543   ~XRTReadToTensorOp() override = default;
544   XRTReadToTensorOp(const XRTReadToTensorOp&) = delete;
545   XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete;
546 
Compute(OpKernelContext * ctx)547   void Compute(OpKernelContext* ctx) override {
548     VLOG(1) << "XRTReadToTensorOp::Compute";
549     auto timed = monitoring::MakeTimed(xrt_metrics::GetReadToTensorCell());
550 
551     const Tensor& handle_tensor = ctx->input(0);
552     // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not
553     // just scalars.)
554     OP_REQUIRES(
555         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
556         errors::Internal("computation input should be an int64 scalar"));
557     int64_t allocation_handle = handle_tensor.scalar<int64_t>()();
558 
559     ResourceMgr* rm;
560     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
561 
562     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
563     RefPtr<XRTTupleAllocation> allocation;
564     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
565 
566     if (discard_) {
567       VLOG(2) << "Releasing handle " << allocation_handle;
568       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
569     }
570 
571     // We are guaranteed that the underlying device object won't be deleted out
572     // from under us, while the ScopedRef is live.
573     class DeviceAccessor::ScopedRef device_ref;
574     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
575                             ctx, allocation->device_ordinal(), &device_ref));
576 
577     xla::Shape shape = allocation->on_host_shape();
578     int output = 0;
579     Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus(
580         &shape,
581         [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status {
582           if (subshape->IsTuple()) return OkStatus();
583 
584           xla::PrimitiveType xla_type;
585           TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(
586               ctx->expected_output_dtype(output), &xla_type));
587           if (xla_type != subshape->element_type()) {
588             return errors::InvalidArgument(
589                 "Type mismatch between buffer type (", subshape->ToString(),
590                 ") and tensor type (",
591                 DataTypeString(ctx->expected_output_dtype(output)),
592                 ") for output tensor ", output);
593           }
594 
595           TensorShape output_shape;
596           TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape));
597 
598           Tensor* output_tensor;
599           TF_RETURN_IF_ERROR(
600               ctx->allocate_output(output, output_shape, &output_tensor));
601 
602           XRTTupleAllocation* sub;
603           TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
604               allocation.get(), index, &sub, /*alias_parent_allocation=*/true));
605           core::ScopedUnref sub_unref(sub);
606 
607           xla::MutableBorrowingLiteral literal;
608           TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral(
609               xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor,
610               &literal));
611           TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal));
612 
613           ++output;
614           return OkStatus();
615         });
616     OP_REQUIRES_OK(ctx, status);
617   }
618   bool discard_;
619   DataTypeVector dtypes_;
620 };
621 
622 // Op that writes a new literal value into device-resident memory.
623 template <class DeviceAccessor>
624 class XRTWriteLiteralOp : public OpKernel {
625  public:
XRTWriteLiteralOp(OpKernelConstruction * ctx)626   explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
627   ~XRTWriteLiteralOp() override = default;
628   XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete;
629   XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete;
630 
Compute(OpKernelContext * ctx)631   void Compute(OpKernelContext* ctx) override {
632     VLOG(1) << "XRTWriteLiteralOp::Compute";
633     auto timed = monitoring::MakeTimed(xrt_metrics::GetWriteLiteralCell());
634 
635     const Tensor& handle_tensor = ctx->input(0);
636     OP_REQUIRES(
637         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
638         errors::Internal("computation input should be an int64 scalar"));
639     int64_t allocation_handle = handle_tensor.scalar<int64_t>()();
640 
641     const Tensor& literal_info = ctx->input(1);
642     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()),
643                 errors::Internal("literal input should be a string scalar"));
644     xla::LiteralProto literal_proto;
645     OP_REQUIRES(
646         ctx, ParseFromTString(literal_info.scalar<tstring>()(), &literal_proto),
647         errors::InvalidArgument(
648             "Unable to parse allocation input to LiteralProto"));
649     xla::Literal literal;
650     OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal));
651 
652     ResourceMgr* rm;
653     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
654 
655     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
656     RefPtr<XRTTupleAllocation> allocation;
657     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
658 
659     // We are guaranteed that the underlying device object won't be deleted out
660     // from under us, while the ScopedRef is live.
661     typename DeviceAccessor::ScopedRef device_ref;
662     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
663                             ctx, allocation->device_ordinal(), &device_ref));
664     OP_REQUIRES_OK(ctx,
665                    allocation->WriteLiteral(device_ref.backend(), literal));
666 
667     Tensor output(DT_INT64, TensorShape({}));
668     output.scalar<int64_t>()() = allocation_handle;
669     ctx->set_output(0, output);
670   }
671 };
672 
673 // Op that discards a handle to device memory.
674 template <class DeviceAccessor>
675 class XRTReleaseAllocationOp : public OpKernel {
676  public:
XRTReleaseAllocationOp(OpKernelConstruction * ctx)677   explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
678   ~XRTReleaseAllocationOp() override = default;
679   XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete;
680   XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete;
681 
Compute(OpKernelContext * ctx)682   void Compute(OpKernelContext* ctx) override {
683     VLOG(1) << "XRTReleaseAllocationOp::Compute";
684     auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseAllocationCell());
685 
686     ResourceMgr* rm;
687     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
688 
689     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
690     const Tensor& allocation_handle = ctx->input(0);
691     auto flat_keys = allocation_handle.flat<int64_t>();
692     for (int64_t i = 0; i < flat_keys.size(); ++i) {
693       int64_t key = flat_keys(i);
694       OP_REQUIRES_OK(ctx, memory_manager->Release(key));
695       VLOG(2) << "Released allocation handle " << key;
696     }
697   }
698 };
699 
700 // Op that discards a handle to device memory.
701 template <class DeviceAccessor>
702 class XRTReleaseAllAllocationsOp : public OpKernel {
703  public:
XRTReleaseAllAllocationsOp(OpKernelConstruction * ctx)704   explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx)
705       : OpKernel(ctx) {}
706   ~XRTReleaseAllAllocationsOp() override = default;
707   XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete;
708   XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) =
709       delete;
710 
Compute(OpKernelContext * ctx)711   void Compute(OpKernelContext* ctx) override {
712     VLOG(1) << "XRTReleaseAllAllocationsOp::Compute";
713     auto timed =
714         monitoring::MakeTimed(xrt_metrics::GetReleaseAllAllocationsCell());
715 
716     ResourceMgr* rm;
717     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
718     XRTMemoryManager::Get(rm)->ReleaseAllAllocations();
719   }
720 };
721 
722 template <class DeviceAccessor>
723 class XRTCompactAllocationsOp : public OpKernel {
724  public:
XRTCompactAllocationsOp(OpKernelConstruction * ctx)725   explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
726   ~XRTCompactAllocationsOp() override = default;
727   XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete;
728   XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete;
729 
Compute(OpKernelContext * ctx)730   void Compute(OpKernelContext* ctx) override {
731     VLOG(1) << "XRTCompactAllocationsOp::Compute";
732     auto timed =
733         monitoring::MakeTimed(xrt_metrics::GetCompactAllocationsCell());
734 
735     ResourceMgr* rm;
736     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
737     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
738     class DeviceAccessor::ScopedRef device_ref;
739     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
740     OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations(
741                             device_ref.backend(), device_ref.device_ordinal(),
742                             device_ref.allocator()));
743   }
744 };
745 
746 template <class DeviceAccessor>
747 class XRTMemoryInfoOp : public OpKernel {
748  public:
XRTMemoryInfoOp(OpKernelConstruction * ctx)749   explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
750   ~XRTMemoryInfoOp() override = default;
751   XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete;
752   XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete;
753 
Compute(OpKernelContext * ctx)754   void Compute(OpKernelContext* ctx) override {
755     auto kernel_fn = [&]() -> Status {
756       VLOG(1) << "XRTMemoryInfoOp::Compute";
757 
758       class DeviceAccessor::ScopedRef device_ref;
759       TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref));
760       TF_ASSIGN_OR_RETURN(
761           se::StreamExecutor * stream_executor,
762           device_ref.backend()->stream_executor(device_ref.device_ordinal()));
763       int64_t mem_free = -1;
764       int64_t mem_total = -1;
765       if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) {
766         VLOG(2) << "Device " << ctx->device()->name()
767                 << " does not expose memory information";
768       }
769       xrt::MemoryInfo mem_info;
770       mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1);
771       mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1);
772 
773       Tensor output(DT_STRING, TensorShape({}));
774       output.scalar<tstring>()() = mem_info.SerializeAsString();
775       ctx->set_output(0, output);
776       return OkStatus();
777     };
778     OP_REQUIRES_OK(ctx, kernel_fn());
779   }
780 };
781 
782 }  // namespace tensorflow
783 
784 #endif  // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
785