1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
17
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Error.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/Regex.h"
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/Diagnostics.h" // from @llvm-project
30 #include "mlir/IR/Location.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/Support/LogicalResult.h" // from @llvm-project
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/device_set.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36
37 namespace tensorflow {
38
39 constexpr char kDevicesAttr[] = "tf.devices";
40
41 namespace {
42
43 // Parse GPU compute capability from physical device description. If compute
44 // capability is not found in device description, return a unit attribute.
ParseGpuDeviceMetadata(const Device & device,mlir::Builder * builder)45 mlir::Attribute ParseGpuDeviceMetadata(const Device& device,
46 mlir::Builder* builder) {
47 // Parse GPU device compute capability from physical device description.
48 static auto* r = new llvm::Regex("compute capability: ([0-9]+)\\.([0-9]+)");
49
50 llvm::SmallVector<llvm::StringRef, 3> cc;
51 if (r->match(device.attributes().physical_device_desc(), &cc)) {
52 return mlir::TF::GpuDeviceMetadata::get(
53 builder->getContext(), std::stoi(cc[1].str()), std::stoi(cc[2].str()));
54 }
55
56 return builder->getUnitAttr();
57 }
58
59 // Get devices from an array of string attributes.
60 // TODO(ezhulenev): Update all tests to use dictionary attribute for
61 // `tf.devices` and remove this function.
GetDevicesFromOp(mlir::Operation * op,mlir::ArrayAttr array_attr,mlir::TF::RuntimeDevices * devices)62 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
63 mlir::ArrayAttr array_attr,
64 mlir::TF::RuntimeDevices* devices) {
65 DeviceNameUtils::ParsedName device;
66
67 for (auto& kv : llvm::enumerate(array_attr)) {
68 const int idx = kv.index();
69
70 auto string_attr = kv.value().dyn_cast<mlir::StringAttr>();
71 if (!string_attr)
72 return op->emitOpError(llvm::formatv(
73 "bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
74
75 if (DeviceNameUtils::ParseFullName(string_attr.getValue().str(), &device)) {
76 devices->AddDevice(device);
77 } else {
78 return op->emitOpError(
79 llvm::formatv("bad '{0}' attribute, '{1}', not a valid device",
80 kDevicesAttr, string_attr.getValue()));
81 }
82 }
83
84 return mlir::success();
85 }
86
87 // Get devices from a dictionary attribute.
GetDevicesFromOp(mlir::Operation * op,mlir::DictionaryAttr dict_attr,mlir::TF::RuntimeDevices * devices)88 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
89 mlir::DictionaryAttr dict_attr,
90 mlir::TF::RuntimeDevices* devices) {
91 DeviceNameUtils::ParsedName device;
92
93 // Parse device names and metadata from dictionary attribute.
94 for (auto& kv : dict_attr) {
95 const mlir::StringAttr name = kv.getName();
96 const mlir::Attribute attr = kv.getValue();
97
98 if (!DeviceNameUtils::ParseFullName(name.str(), &device))
99 return op->emitOpError(
100 llvm::formatv("bad '{0}' attribute, '{1}', not a valid device",
101 kDevicesAttr, name.strref()));
102
103 if (auto gpu_metadata = attr.dyn_cast<mlir::TF::GpuDeviceMetadata>()) {
104 devices->AddGpuDevice(device, gpu_metadata);
105 } else {
106 devices->AddDevice(device);
107 }
108 }
109
110 return mlir::success();
111 }
112
113 } // namespace
114
AddDevicesToOp(mlir::Operation * op,const DeviceSet * device_set)115 void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) {
116 if (!device_set) return;
117
118 mlir::MLIRContext* ctx = op->getContext();
119 mlir::Builder builder(ctx);
120
121 // Collect devices with attached metadata.
122 llvm::SmallVector<mlir::NamedAttribute, 8> devices;
123 devices.reserve(device_set->devices().size());
124
125 // For device that do not have any metadata, or if we failed to parse metadata
126 // from the DeviceSet, we add a unit attribute to the `tf.devices` attribute.
127 for (Device* device : device_set->devices()) {
128 string name = DeviceNameUtils::ParsedNameToString(device->parsed_name());
129
130 if (device->device_type() == DEVICE_GPU) {
131 auto metadata = ParseGpuDeviceMetadata(*device, &builder);
132 devices.push_back(builder.getNamedAttr(name, metadata));
133 } else {
134 auto metadata = builder.getUnitAttr();
135 devices.push_back(builder.getNamedAttr(name, metadata));
136 }
137 }
138
139 op->setAttr(kDevicesAttr, builder.getDictionaryAttr(devices));
140 }
141
GetDevicesFromOp(mlir::Operation * op,mlir::TF::RuntimeDevices * devices)142 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
143 mlir::TF::RuntimeDevices* devices) {
144 auto devices_attr = op->getAttr(kDevicesAttr);
145 if (!devices_attr) return mlir::success();
146
147 if (auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>()) {
148 return GetDevicesFromOp(op, array_attr, devices);
149
150 } else if (auto dict_attr = devices_attr.dyn_cast<mlir::DictionaryAttr>()) {
151 return GetDevicesFromOp(op, dict_attr, devices);
152 }
153
154 return op->emitOpError(
155 llvm::formatv("unsupported '{0}' attribute", kDevicesAttr));
156 }
157
GetDeviceOrdinalFromDeviceString(mlir::Location loc,llvm::StringRef device,int64_t * device_ordinal)158 mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc,
159 llvm::StringRef device,
160 int64_t* device_ordinal) {
161 DeviceNameUtils::ParsedName parsed_name;
162 if (!DeviceNameUtils::ParseFullName(
163 absl::string_view(device.data(), device.size()), &parsed_name))
164 return mlir::emitError(loc) << "invalid device '" << device << "'";
165
166 if (!parsed_name.has_id)
167 return mlir::emitError(loc) << "device '" << device << "' has no id";
168
169 *device_ordinal = parsed_name.id;
170 return mlir::success();
171 }
172
173 } // namespace tensorflow
174