xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Error.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/Regex.h"
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/device_set.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 
39 constexpr char kDevicesAttr[] = "tf.devices";
40 
41 namespace {
42 
43 // Parse GPU compute capability from physical device description. If compute
44 // capability is not found in device description, return a unit attribute.
ParseGpuDeviceMetadata(const Device & device,mlir::Builder * builder)45 mlir::Attribute ParseGpuDeviceMetadata(const Device& device,
46                                        mlir::Builder* builder) {
47   // Parse GPU device compute capability from physical device description.
48   static auto* r = new llvm::Regex("compute capability: ([0-9]+)\\.([0-9]+)");
49 
50   llvm::SmallVector<llvm::StringRef, 3> cc;
51   if (r->match(device.attributes().physical_device_desc(), &cc)) {
52     return mlir::TF::GpuDeviceMetadata::get(
53         builder->getContext(), std::stoi(cc[1].str()), std::stoi(cc[2].str()));
54   }
55 
56   return builder->getUnitAttr();
57 }
58 
59 // Get devices from an array of string attributes.
60 // TODO(ezhulenev): Update all tests to use dictionary attribute for
61 // `tf.devices` and remove this function.
GetDevicesFromOp(mlir::Operation * op,mlir::ArrayAttr array_attr,mlir::TF::RuntimeDevices * devices)62 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
63                                      mlir::ArrayAttr array_attr,
64                                      mlir::TF::RuntimeDevices* devices) {
65   DeviceNameUtils::ParsedName device;
66 
67   for (auto& kv : llvm::enumerate(array_attr)) {
68     const int idx = kv.index();
69 
70     auto string_attr = kv.value().dyn_cast<mlir::StringAttr>();
71     if (!string_attr)
72       return op->emitOpError(llvm::formatv(
73           "bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
74 
75     if (DeviceNameUtils::ParseFullName(string_attr.getValue().str(), &device)) {
76       devices->AddDevice(device);
77     } else {
78       return op->emitOpError(
79           llvm::formatv("bad '{0}' attribute, '{1}', not a valid device",
80                         kDevicesAttr, string_attr.getValue()));
81     }
82   }
83 
84   return mlir::success();
85 }
86 
87 // Get devices from a dictionary attribute.
GetDevicesFromOp(mlir::Operation * op,mlir::DictionaryAttr dict_attr,mlir::TF::RuntimeDevices * devices)88 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
89                                      mlir::DictionaryAttr dict_attr,
90                                      mlir::TF::RuntimeDevices* devices) {
91   DeviceNameUtils::ParsedName device;
92 
93   // Parse device names and metadata from dictionary attribute.
94   for (auto& kv : dict_attr) {
95     const mlir::StringAttr name = kv.getName();
96     const mlir::Attribute attr = kv.getValue();
97 
98     if (!DeviceNameUtils::ParseFullName(name.str(), &device))
99       return op->emitOpError(
100           llvm::formatv("bad '{0}' attribute, '{1}', not a valid device",
101                         kDevicesAttr, name.strref()));
102 
103     if (auto gpu_metadata = attr.dyn_cast<mlir::TF::GpuDeviceMetadata>()) {
104       devices->AddGpuDevice(device, gpu_metadata);
105     } else {
106       devices->AddDevice(device);
107     }
108   }
109 
110   return mlir::success();
111 }
112 
113 }  // namespace
114 
AddDevicesToOp(mlir::Operation * op,const DeviceSet * device_set)115 void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) {
116   if (!device_set) return;
117 
118   mlir::MLIRContext* ctx = op->getContext();
119   mlir::Builder builder(ctx);
120 
121   // Collect devices with attached metadata.
122   llvm::SmallVector<mlir::NamedAttribute, 8> devices;
123   devices.reserve(device_set->devices().size());
124 
125   // For device that do not have any metadata, or if we failed to parse metadata
126   // from the DeviceSet, we add a unit attribute to the `tf.devices` attribute.
127   for (Device* device : device_set->devices()) {
128     string name = DeviceNameUtils::ParsedNameToString(device->parsed_name());
129 
130     if (device->device_type() == DEVICE_GPU) {
131       auto metadata = ParseGpuDeviceMetadata(*device, &builder);
132       devices.push_back(builder.getNamedAttr(name, metadata));
133     } else {
134       auto metadata = builder.getUnitAttr();
135       devices.push_back(builder.getNamedAttr(name, metadata));
136     }
137   }
138 
139   op->setAttr(kDevicesAttr, builder.getDictionaryAttr(devices));
140 }
141 
GetDevicesFromOp(mlir::Operation * op,mlir::TF::RuntimeDevices * devices)142 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
143                                      mlir::TF::RuntimeDevices* devices) {
144   auto devices_attr = op->getAttr(kDevicesAttr);
145   if (!devices_attr) return mlir::success();
146 
147   if (auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>()) {
148     return GetDevicesFromOp(op, array_attr, devices);
149 
150   } else if (auto dict_attr = devices_attr.dyn_cast<mlir::DictionaryAttr>()) {
151     return GetDevicesFromOp(op, dict_attr, devices);
152   }
153 
154   return op->emitOpError(
155       llvm::formatv("unsupported '{0}' attribute", kDevicesAttr));
156 }
157 
GetDeviceOrdinalFromDeviceString(mlir::Location loc,llvm::StringRef device,int64_t * device_ordinal)158 mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc,
159                                                      llvm::StringRef device,
160                                                      int64_t* device_ordinal) {
161   DeviceNameUtils::ParsedName parsed_name;
162   if (!DeviceNameUtils::ParseFullName(
163           absl::string_view(device.data(), device.size()), &parsed_name))
164     return mlir::emitError(loc) << "invalid device '" << device << "'";
165 
166   if (!parsed_name.has_id)
167     return mlir::emitError(loc) << "device '" << device << "' has no id";
168 
169   *device_ordinal = parsed_name.id;
170   return mlir::success();
171 }
172 
173 }  // namespace tensorflow
174