1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes for allocating XLA literals in device memory and managing handles 17 // that refer to them. 18 19 #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 20 #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 21 22 #include <functional> 23 #include <memory> 24 #include <string> 25 26 #include "tensorflow/compiler/tf2xla/literal_util.h" 27 #include "tensorflow/compiler/tf2xla/shape_util.h" 28 #include "tensorflow/compiler/tf2xla/type_util.h" 29 #include "tensorflow/compiler/xla/client/local_client.h" 30 #include "tensorflow/compiler/xla/layout_util.h" 31 #include "tensorflow/compiler/xla/literal.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 #include "tensorflow/compiler/xrt/xrt.pb.h" 36 #include "tensorflow/compiler/xrt/xrt_device.h" 37 #include "tensorflow/compiler/xrt/xrt_memory_manager.h" 38 #include "tensorflow/compiler/xrt/xrt_metrics.h" 39 #include "tensorflow/compiler/xrt/xrt_state.h" 40 #include "tensorflow/core/common_runtime/dma_helper.h" 41 #include "tensorflow/core/framework/op_kernel.h" 42 #include "tensorflow/core/framework/resource_mgr.h" 43 #include "tensorflow/core/framework/tensor.h" 44 #include "tensorflow/core/framework/tensor_shape.h" 45 #include "tensorflow/core/framework/types.pb.h" 46 #include "tensorflow/core/lib/core/errors.h" 47 #include "tensorflow/core/lib/core/refcount.h" 48 #include "tensorflow/core/lib/core/status.h" 49 #include "tensorflow/core/lib/gtl/cleanup.h" 50 #include "tensorflow/core/lib/monitoring/percentile_sampler.h" 51 #include "tensorflow/core/lib/monitoring/timed.h" 52 #include "tensorflow/core/platform/types.h" 53 54 namespace tensorflow { 55 56 // Helper functions for templated ops. 57 class XRTStateHelpers { 58 public: 59 // The Status return value allows us to use the 60 // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an 61 // OpKernel::Compute method. MakeLiteral(const xla::LiteralProto & proto,xla::Literal * literal)62 static Status MakeLiteral(const xla::LiteralProto& proto, 63 xla::Literal* literal) { 64 TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto)); 65 return OkStatus(); 66 } 67 68 // ParseTupleNode is the recursive function used to parse a recursive 69 // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the 70 // tuple shape where every leaf is an existing allocation. As a side-effect it 71 // fills in input_vector by looking up allocations from handles in the 72 // input_tensor_list as they are referenced by nodes in the proto. ParseTupleNode(const xrt::XLATupleNode & tuple_node,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::Shape * shape,ResourceMgr * rm)73 static Status ParseTupleNode( 74 const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list, 75 std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector, 76 xla::Shape* shape, ResourceMgr* rm) { 77 if (tuple_node.tuples_size() > 0) { 78 // This is an internal node in the proto so descend recursively. 79 xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({}); 80 std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy); 81 *xla::ShapeUtil::GetMutableSubshape(shape, {}) = 82 xla::ShapeUtil::MakeTupleShape(subshapes); 83 for (int i = 0; i < tuple_node.tuples_size(); ++i) { 84 TF_RETURN_IF_ERROR(ParseTupleNode( 85 tuple_node.tuples(i), input_tensor_list, input_vector, 86 xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm)); 87 } 88 } else { 89 // This is a leaf node in the proto so look up the referenced input. 90 int input_index = tuple_node.input_index(); 91 if (input_index < 0 || input_index >= input_vector->size()) { 92 return errors::InvalidArgument("Invalid tuple input index ", 93 input_index, ": MakeTuple has ", 94 input_vector->size(), " inputs."); 95 } 96 bool release_this_input = tuple_node.release_input_handle(); 97 XRTTupleAllocation::ExpandedTupleInput& input = 98 input_vector->at(input_index); 99 if (input.allocation != nullptr && 100 (input.release_allocation_after_use || release_this_input)) { 101 return errors::InvalidArgument( 102 "Invalid tuple tree: input index ", input_index, 103 " is repeated but release_input_handle is true."); 104 } 105 if (input.allocation == nullptr) { 106 // We haven't dereferenced this handle yet. 107 TF_RET_CHECK( 108 TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape())); 109 int64_t key = input_tensor_list[input_index].scalar<int64_t>()(); 110 TF_ASSIGN_OR_RETURN(input.allocation, 111 XRTMemoryManager::Get(rm)->Lookup(key)); 112 input.release_allocation_after_use = release_this_input; 113 } 114 } 115 return OkStatus(); 116 } 117 118 // Parses a xrt::XLATupleNode proto recursively and returns the corresponding 119 // ShapeTree where each leaf is an allocation corresponding to a handle in 120 // input_tensor_list. The ordinal of one of the allocations is returned in 121 // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with 122 // no leaves, device_ordinal will always be filled in by a successful call to 123 // ParseTupleTree. ParseTupleTree(const xrt::XLATupleNode & tuple_tree_root,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> * tuple_shape_tree,int * device_ordinal,ResourceMgr * rm)124 static Status ParseTupleTree( 125 const xrt::XLATupleNode& tuple_tree_root, 126 const OpInputList& input_tensor_list, 127 std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector, 128 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree, 129 int* device_ordinal, ResourceMgr* rm) { 130 // First get the shape of the 'spine' of the new tuple, where every leaf is 131 // an existing allocation. As a side-effect dereference the input handles 132 // into allocations in input_vector. 133 xla::Shape tuple_tree_shape; 134 TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list, 135 input_vector, &tuple_tree_shape, rm)); 136 // Make the shape tree of allocations where the shape is the spine and each 137 // leaf is one of the allocations looked up in input_vector. Internal nodes 138 // have nullptr allocations. 139 *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>( 140 tuple_tree_shape); 141 tuple_shape_tree->ForEachMutableElement( 142 [&](const xla::ShapeIndex& index, 143 XRTTupleAllocation::ExpandedTupleInput* element) { 144 if (tuple_shape_tree->IsLeaf(index)) { 145 // Find the matching leaf in the proto tree. 146 const xrt::XLATupleNode* tuple_node = &tuple_tree_root; 147 for (int i = 0; i < index.size(); ++i) { 148 tuple_node = &tuple_node->tuples(index[i]); 149 } 150 // Copy the appropriate input allocation to the leaf of the 151 // tuple_shape_tree. 152 int input_index = tuple_node->input_index(); 153 *element = input_vector->at(input_index); 154 CHECK(element->release_allocation_after_use == 155 tuple_node->release_input_handle()); 156 // We just need to know the device_ordinal of one of the 157 // allocations. We will validate later that they are all the same. 158 *device_ordinal = (*element).allocation->device_ordinal(); 159 } 160 }); 161 return OkStatus(); 162 } 163 }; 164 165 // Op that allocates memory for a literal and transfers it to the device. 166 template <class DeviceAccessor> 167 class XRTAllocateOp : public OpKernel { 168 public: XRTAllocateOp(OpKernelConstruction * ctx)169 explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 170 ~XRTAllocateOp() override = default; 171 XRTAllocateOp(const XRTAllocateOp&) = delete; 172 XRTAllocateOp& operator=(const XRTAllocateOp&) = delete; 173 Compute(OpKernelContext * ctx)174 void Compute(OpKernelContext* ctx) override { 175 VLOG(1) << "XRTAllocateOp::Compute"; 176 auto timed = monitoring::MakeTimed(xrt_metrics::GetAllocateCell()); 177 178 const Tensor& allocation_info = ctx->input(0); 179 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()), 180 errors::Internal("allocation input should be a string scalar")); 181 xrt::XLAAllocation allocation_proto; 182 OP_REQUIRES(ctx, 183 ParseFromTString(allocation_info.scalar<tstring>()(), 184 &allocation_proto), 185 errors::InvalidArgument( 186 "Unable to parse allocation input to XLAAllocation")); 187 188 xla::Literal literal; 189 OP_REQUIRES_OK( 190 ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal)); 191 192 ResourceMgr* rm; 193 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 194 195 // We are guaranteed that the underlying device object won't be deleted out 196 // from under us, while the ScopedRef is live. 197 class DeviceAccessor::ScopedRef device_ref; 198 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 199 200 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 201 XRTTupleAllocation* allocation; 202 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( 203 literal, memory_manager.get(), device_ref.backend(), 204 device_ref.device_ordinal(), &allocation, 205 device_ref.allocator())); 206 207 Tensor output(DT_INT64, TensorShape({})); 208 output.scalar<int64_t>()() = memory_manager->Register(allocation); 209 ctx->set_output(0, output); 210 } 211 }; 212 213 // Op that allocates uninitialized memory on the device for a tensor of 214 // a particular shape. 215 template <class DeviceAccessor> 216 class XRTAllocateUninitializedOp : public OpKernel { 217 public: XRTAllocateUninitializedOp(OpKernelConstruction * ctx)218 explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx) 219 : OpKernel(ctx) { 220 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 221 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_)); 222 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_)); 223 } 224 ~XRTAllocateUninitializedOp() override = default; 225 XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete; 226 XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) = 227 delete; 228 Compute(OpKernelContext * ctx)229 void Compute(OpKernelContext* ctx) override { 230 VLOG(1) << "XRTAllocateUninitializedOp::Compute"; 231 auto timed = 232 monitoring::MakeTimed(xrt_metrics::GetAllocateUninitializedCell()); 233 ResourceMgr* rm; 234 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 235 236 // We are guaranteed that the underlying device object won't be deleted out 237 // from under us, while the ScopedRef is live. 238 class DeviceAccessor::ScopedRef device_ref; 239 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 240 241 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 242 XRTTupleAllocation* allocation; 243 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateUninitialized( 244 xla_shape_, memory_manager.get(), 245 device_ref.backend(), device_ref.device_ordinal(), 246 &allocation, device_ref.allocator())); 247 248 Tensor output(DT_INT64, TensorShape({})); 249 output.scalar<int64_t>()() = memory_manager->Register(allocation); 250 ctx->set_output(0, output); 251 } 252 253 private: 254 DataType dtype_; 255 TensorShape tf_shape_; 256 xla::Shape xla_shape_; 257 }; 258 259 // Op that allocates memory for a tensor (with optional layout) and transfers it 260 // to the device, returning an allocation handle. 261 template <class DeviceAccessor> 262 class XRTAllocateFromTensorOp : public OpKernel { 263 public: XRTAllocateFromTensorOp(OpKernelConstruction * ctx)264 explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 265 bool make_tuple = false; 266 OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); 267 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); 268 OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); 269 std::vector<int64_t> minor_to_major; 270 if (ctx->HasAttr("layouts")) { 271 OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); 272 } 273 OP_REQUIRES( 274 ctx, tf_shapes_.size() == dtypes_.size(), 275 errors::InvalidArgument("shapes and dtypes must be the same length")); 276 std::vector<xla::Shape> xla_shapes; 277 xla_shapes.reserve(tf_shapes_.size()); 278 for (int i = 0; i < tf_shapes_.size(); i++) { 279 xla::Shape xla_shape; 280 OP_REQUIRES_OK( 281 ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); 282 xla_shapes.push_back(std::move(xla_shape)); 283 } 284 if (xla_shapes.size() > 1 || make_tuple) { 285 shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); 286 } else { 287 shape_.Swap(&xla_shapes.front()); 288 } 289 if (!minor_to_major.empty()) { 290 xla::Shape shape_with_layouts; 291 OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, 292 /*layout_func=*/nullptr, 293 &shape_with_layouts)); 294 shape_.Swap(&shape_with_layouts); 295 } 296 } 297 298 ~XRTAllocateFromTensorOp() override = default; 299 XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete; 300 XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete; 301 Compute(OpKernelContext * ctx)302 void Compute(OpKernelContext* ctx) override { 303 VLOG(1) << "XRTAllocateFromTensorOp::Compute"; 304 auto timed = 305 monitoring::MakeTimed(xrt_metrics::GetAllocateFromTensorCell()); 306 307 OpInputList values; 308 OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values)); 309 OP_REQUIRES(ctx, values.size() == tf_shapes_.size(), 310 errors::InvalidArgument( 311 "Wrong number of inputs to XRTAllocateFromTensor: ", 312 values.size(), " vs. ", tf_shapes_.size())); 313 314 std::vector<const char*> tensors_data; 315 for (size_t i = 0; i < values.size(); ++i) { 316 const Tensor& input_tensor = values[i]; 317 OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i], 318 errors::InvalidArgument( 319 "Input tensor type and input dtype do not match")); 320 // We allow the requested on-device shape to differ from the shape of the 321 // input tensor, as long as they have the same number of elements. 322 OP_REQUIRES( 323 ctx, 324 input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(), 325 errors::InvalidArgument( 326 "Input tensor must have the number of elements specified " 327 "in the matching input shape: ", 328 input_tensor.shape().num_elements(), " vs. ", 329 tf_shapes_[i].num_elements(), " at index ", i)); 330 tensors_data.push_back( 331 static_cast<const char*>(DMAHelper::base(&input_tensor))); 332 } 333 // Use the buffer straight out of the input tensors to create the literal. 334 xla::BorrowingLiteral literal = 335 shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_) 336 : xla::BorrowingLiteral(tensors_data.front(), shape_); 337 ResourceMgr* rm; 338 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 339 340 // We are guaranteed that the underlying device object won't be deleted out 341 // from under us, while the ScopedRef is live. 342 class DeviceAccessor::ScopedRef device_ref; 343 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 344 345 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 346 XRTTupleAllocation* allocation; 347 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( 348 literal, memory_manager.get(), device_ref.backend(), 349 device_ref.device_ordinal(), &allocation, 350 device_ref.allocator())); 351 352 Tensor output(DT_INT64, TensorShape({})); 353 output.scalar<int64_t>()() = memory_manager->Register(allocation); 354 ctx->set_output(0, output); 355 } 356 357 private: 358 std::vector<TensorShape> tf_shapes_; 359 DataTypeVector dtypes_; 360 xla::Shape shape_; 361 }; 362 363 // Op that takes a tuple handle input and returns a handle to a sub-tuple of the 364 // input. 365 template <bool discard_, class DeviceAccessor> 366 class XRTSubTupleOp : public OpKernel { 367 public: XRTSubTupleOp(OpKernelConstruction * ctx)368 explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 369 ~XRTSubTupleOp() override = default; 370 XRTSubTupleOp(const XRTSubTupleOp&) = delete; 371 XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete; 372 Compute(OpKernelContext * ctx)373 void Compute(OpKernelContext* ctx) override { 374 VLOG(1) << "XRTSubTupleOp::Compute"; 375 auto timed = monitoring::MakeTimed(xrt_metrics::GetSubTupleCell()); 376 377 const Tensor& handle_tensor = ctx->input(0); 378 OP_REQUIRES( 379 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 380 errors::Internal("computation input should be an int64 scalar")); 381 int64_t allocation_handle = handle_tensor.scalar<int64_t>()(); 382 383 const Tensor& subtuple_info = ctx->input(1); 384 OP_REQUIRES( 385 ctx, TensorShapeUtils::IsVector(subtuple_info.shape()), 386 errors::Internal("tuple index input should be an int32 vector")); 387 xla::ShapeIndex shape_index; 388 for (int i = 0; i < subtuple_info.dim_size(0); ++i) { 389 shape_index.push_back(subtuple_info.vec<int32>()(i)); 390 } 391 392 ResourceMgr* rm; 393 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 394 395 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 396 RefPtr<XRTTupleAllocation> allocation; 397 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 398 399 if (discard_) { 400 VLOG(2) << "Releasing handle " << allocation_handle; 401 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 402 } 403 404 XRTTupleAllocation* suballocation; 405 OP_REQUIRES_OK( 406 ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index, 407 &suballocation, !discard_)); 408 409 Tensor output(DT_INT64, TensorShape({})); 410 output.scalar<int64_t>()() = memory_manager->Register(suballocation); 411 ctx->set_output(0, output); 412 } 413 }; 414 415 // Op that allocates memory for a literal and transfers it to the device. 416 template <class DeviceAccessor> 417 class XRTMakeTupleOp : public OpKernel { 418 public: XRTMakeTupleOp(OpKernelConstruction * ctx)419 explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 420 ~XRTMakeTupleOp() override = default; 421 XRTMakeTupleOp(const XRTMakeTupleOp&) = delete; 422 XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete; 423 Compute(OpKernelContext * ctx)424 void Compute(OpKernelContext* ctx) override { 425 VLOG(1) << "XRTMakeTupleOp::Compute"; 426 auto timed = monitoring::MakeTimed(xrt_metrics::GetMakeTupleCell()); 427 428 const Tensor& tuple_info = ctx->input(0); 429 OP_REQUIRES( 430 ctx, TensorShapeUtils::IsScalar(tuple_info.shape()), 431 errors::Internal("tuple description input should be a string scalar")); 432 xrt::XLATupleNode tuple_proto; 433 OP_REQUIRES( 434 ctx, ParseFromTString(tuple_info.scalar<tstring>()(), &tuple_proto), 435 errors::InvalidArgument("Unable to parse tuple input to XLATupleNode")); 436 437 OpInputList arg_list; 438 OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list)); 439 440 // For each input, the allocation it corresponds to and a flag indicating 441 // whether or not it should be released, i.e. discarded from the resource 442 // manager. One ref on each allocation is owned by this vector, and freed on 443 // exit. 444 std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector( 445 arg_list.size()); 446 ResourceMgr* rm; 447 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 448 449 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree; 450 // device_ordinal is filled in by ParseTupleTree with the ordinal of one of 451 // the allocations. It is guaranteed that there is at least on allocation in 452 // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that 453 // all the allocations are on the same device. 454 int device_ordinal; 455 OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree( 456 tuple_proto, arg_list, &input_vector, 457 &tuple_shape_tree, &device_ordinal, rm)); 458 459 // We are guaranteed that the underlying device object won't be deleted out 460 // from under us, while the ScopedRef is live. 461 class DeviceAccessor::ScopedRef device_ref; 462 OP_REQUIRES_OK( 463 ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref)); 464 465 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 466 XRTTupleAllocation* output_allocation; 467 OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple( 468 memory_manager.get(), device_ref.backend(), 469 device_ref.device_ordinal(), tuple_shape_tree, 470 &output_allocation, device_ref.allocator())); 471 RefPtr<XRTTupleAllocation> output_ptr(output_allocation); 472 for (int i = 0; i < input_vector.size(); ++i) { 473 if (input_vector[i].release_allocation_after_use) { 474 OP_REQUIRES_OK( 475 ctx, memory_manager->Release(arg_list[i].scalar<int64_t>()())); 476 } 477 } 478 479 Tensor output(DT_INT64, TensorShape({})); 480 output.scalar<int64_t>()() = 481 memory_manager->Register(std::move(output_ptr)); 482 ctx->set_output(0, output); 483 } 484 }; 485 486 // Op that reads a device-resident tuple to host memory and returns it as a 487 // literal. 488 template <bool discard_, class DeviceAccessor> 489 class XRTReadLiteralOp : public OpKernel { 490 public: XRTReadLiteralOp(OpKernelConstruction * ctx)491 explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 492 ~XRTReadLiteralOp() override = default; 493 XRTReadLiteralOp(const XRTReadLiteralOp&) = delete; 494 XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete; 495 Compute(OpKernelContext * ctx)496 void Compute(OpKernelContext* ctx) override { 497 VLOG(1) << "XRTReadLiteralOp::Compute"; 498 auto timed = monitoring::MakeTimed(xrt_metrics::GetReadLiteralCell()); 499 500 const Tensor& handle_tensor = ctx->input(0); 501 OP_REQUIRES( 502 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 503 errors::Internal("computation input should be an int64 scalar")); 504 int64_t allocation_handle = handle_tensor.scalar<int64_t>()(); 505 506 ResourceMgr* rm; 507 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 508 509 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 510 RefPtr<XRTTupleAllocation> allocation; 511 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 512 513 if (discard_) { 514 VLOG(2) << "Releasing handle " << allocation_handle; 515 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 516 } 517 518 // We are guaranteed that the underlying device object won't be deleted out 519 // from under us, while the ScopedRef is live. 520 class DeviceAccessor::ScopedRef device_ref; 521 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 522 ctx, allocation->device_ordinal(), &device_ref)); 523 524 xla::Literal literal(allocation->on_host_shape()); 525 OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal)); 526 xla::LiteralProto literal_proto = literal.ToProto(); 527 528 Tensor output(DT_STRING, TensorShape({})); 529 SerializeToTString(literal_proto, &output.scalar<tstring>()()); 530 ctx->set_output(0, output); 531 } 532 }; 533 534 // Op that reads a device-resident tuple to host memory and returns it as a 535 // literal. 536 template <class DeviceAccessor> 537 class XRTReadToTensorOp : public OpKernel { 538 public: XRTReadToTensorOp(OpKernelConstruction * ctx)539 explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 540 OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); 541 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); 542 } 543 ~XRTReadToTensorOp() override = default; 544 XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; 545 XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; 546 Compute(OpKernelContext * ctx)547 void Compute(OpKernelContext* ctx) override { 548 VLOG(1) << "XRTReadToTensorOp::Compute"; 549 auto timed = monitoring::MakeTimed(xrt_metrics::GetReadToTensorCell()); 550 551 const Tensor& handle_tensor = ctx->input(0); 552 // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not 553 // just scalars.) 554 OP_REQUIRES( 555 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 556 errors::Internal("computation input should be an int64 scalar")); 557 int64_t allocation_handle = handle_tensor.scalar<int64_t>()(); 558 559 ResourceMgr* rm; 560 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 561 562 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 563 RefPtr<XRTTupleAllocation> allocation; 564 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 565 566 if (discard_) { 567 VLOG(2) << "Releasing handle " << allocation_handle; 568 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 569 } 570 571 // We are guaranteed that the underlying device object won't be deleted out 572 // from under us, while the ScopedRef is live. 573 class DeviceAccessor::ScopedRef device_ref; 574 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 575 ctx, allocation->device_ordinal(), &device_ref)); 576 577 xla::Shape shape = allocation->on_host_shape(); 578 int output = 0; 579 Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( 580 &shape, 581 [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { 582 if (subshape->IsTuple()) return OkStatus(); 583 584 xla::PrimitiveType xla_type; 585 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( 586 ctx->expected_output_dtype(output), &xla_type)); 587 if (xla_type != subshape->element_type()) { 588 return errors::InvalidArgument( 589 "Type mismatch between buffer type (", subshape->ToString(), 590 ") and tensor type (", 591 DataTypeString(ctx->expected_output_dtype(output)), 592 ") for output tensor ", output); 593 } 594 595 TensorShape output_shape; 596 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); 597 598 Tensor* output_tensor; 599 TF_RETURN_IF_ERROR( 600 ctx->allocate_output(output, output_shape, &output_tensor)); 601 602 XRTTupleAllocation* sub; 603 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( 604 allocation.get(), index, &sub, /*alias_parent_allocation=*/true)); 605 core::ScopedUnref sub_unref(sub); 606 607 xla::MutableBorrowingLiteral literal; 608 TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( 609 xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, 610 &literal)); 611 TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal)); 612 613 ++output; 614 return OkStatus(); 615 }); 616 OP_REQUIRES_OK(ctx, status); 617 } 618 bool discard_; 619 DataTypeVector dtypes_; 620 }; 621 622 // Op that writes a new literal value into device-resident memory. 623 template <class DeviceAccessor> 624 class XRTWriteLiteralOp : public OpKernel { 625 public: XRTWriteLiteralOp(OpKernelConstruction * ctx)626 explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 627 ~XRTWriteLiteralOp() override = default; 628 XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; 629 XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; 630 Compute(OpKernelContext * ctx)631 void Compute(OpKernelContext* ctx) override { 632 VLOG(1) << "XRTWriteLiteralOp::Compute"; 633 auto timed = monitoring::MakeTimed(xrt_metrics::GetWriteLiteralCell()); 634 635 const Tensor& handle_tensor = ctx->input(0); 636 OP_REQUIRES( 637 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 638 errors::Internal("computation input should be an int64 scalar")); 639 int64_t allocation_handle = handle_tensor.scalar<int64_t>()(); 640 641 const Tensor& literal_info = ctx->input(1); 642 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), 643 errors::Internal("literal input should be a string scalar")); 644 xla::LiteralProto literal_proto; 645 OP_REQUIRES( 646 ctx, ParseFromTString(literal_info.scalar<tstring>()(), &literal_proto), 647 errors::InvalidArgument( 648 "Unable to parse allocation input to LiteralProto")); 649 xla::Literal literal; 650 OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); 651 652 ResourceMgr* rm; 653 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 654 655 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 656 RefPtr<XRTTupleAllocation> allocation; 657 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 658 659 // We are guaranteed that the underlying device object won't be deleted out 660 // from under us, while the ScopedRef is live. 661 typename DeviceAccessor::ScopedRef device_ref; 662 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 663 ctx, allocation->device_ordinal(), &device_ref)); 664 OP_REQUIRES_OK(ctx, 665 allocation->WriteLiteral(device_ref.backend(), literal)); 666 667 Tensor output(DT_INT64, TensorShape({})); 668 output.scalar<int64_t>()() = allocation_handle; 669 ctx->set_output(0, output); 670 } 671 }; 672 673 // Op that discards a handle to device memory. 674 template <class DeviceAccessor> 675 class XRTReleaseAllocationOp : public OpKernel { 676 public: XRTReleaseAllocationOp(OpKernelConstruction * ctx)677 explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 678 ~XRTReleaseAllocationOp() override = default; 679 XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete; 680 XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete; 681 Compute(OpKernelContext * ctx)682 void Compute(OpKernelContext* ctx) override { 683 VLOG(1) << "XRTReleaseAllocationOp::Compute"; 684 auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseAllocationCell()); 685 686 ResourceMgr* rm; 687 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 688 689 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 690 const Tensor& allocation_handle = ctx->input(0); 691 auto flat_keys = allocation_handle.flat<int64_t>(); 692 for (int64_t i = 0; i < flat_keys.size(); ++i) { 693 int64_t key = flat_keys(i); 694 OP_REQUIRES_OK(ctx, memory_manager->Release(key)); 695 VLOG(2) << "Released allocation handle " << key; 696 } 697 } 698 }; 699 700 // Op that discards a handle to device memory. 701 template <class DeviceAccessor> 702 class XRTReleaseAllAllocationsOp : public OpKernel { 703 public: XRTReleaseAllAllocationsOp(OpKernelConstruction * ctx)704 explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) 705 : OpKernel(ctx) {} 706 ~XRTReleaseAllAllocationsOp() override = default; 707 XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; 708 XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = 709 delete; 710 Compute(OpKernelContext * ctx)711 void Compute(OpKernelContext* ctx) override { 712 VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; 713 auto timed = 714 monitoring::MakeTimed(xrt_metrics::GetReleaseAllAllocationsCell()); 715 716 ResourceMgr* rm; 717 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 718 XRTMemoryManager::Get(rm)->ReleaseAllAllocations(); 719 } 720 }; 721 722 template <class DeviceAccessor> 723 class XRTCompactAllocationsOp : public OpKernel { 724 public: XRTCompactAllocationsOp(OpKernelConstruction * ctx)725 explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 726 ~XRTCompactAllocationsOp() override = default; 727 XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete; 728 XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete; 729 Compute(OpKernelContext * ctx)730 void Compute(OpKernelContext* ctx) override { 731 VLOG(1) << "XRTCompactAllocationsOp::Compute"; 732 auto timed = 733 monitoring::MakeTimed(xrt_metrics::GetCompactAllocationsCell()); 734 735 ResourceMgr* rm; 736 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 737 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 738 class DeviceAccessor::ScopedRef device_ref; 739 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 740 OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations( 741 device_ref.backend(), device_ref.device_ordinal(), 742 device_ref.allocator())); 743 } 744 }; 745 746 template <class DeviceAccessor> 747 class XRTMemoryInfoOp : public OpKernel { 748 public: XRTMemoryInfoOp(OpKernelConstruction * ctx)749 explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 750 ~XRTMemoryInfoOp() override = default; 751 XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete; 752 XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete; 753 Compute(OpKernelContext * ctx)754 void Compute(OpKernelContext* ctx) override { 755 auto kernel_fn = [&]() -> Status { 756 VLOG(1) << "XRTMemoryInfoOp::Compute"; 757 758 class DeviceAccessor::ScopedRef device_ref; 759 TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref)); 760 TF_ASSIGN_OR_RETURN( 761 se::StreamExecutor * stream_executor, 762 device_ref.backend()->stream_executor(device_ref.device_ordinal())); 763 int64_t mem_free = -1; 764 int64_t mem_total = -1; 765 if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) { 766 VLOG(2) << "Device " << ctx->device()->name() 767 << " does not expose memory information"; 768 } 769 xrt::MemoryInfo mem_info; 770 mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1); 771 mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1); 772 773 Tensor output(DT_STRING, TensorShape({})); 774 output.scalar<tstring>()() = mem_info.SerializeAsString(); 775 ctx->set_output(0, output); 776 return OkStatus(); 777 }; 778 OP_REQUIRES_OK(ctx, kernel_fn()); 779 } 780 }; 781 782 } // namespace tensorflow 783 784 #endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 785