1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This files implements the configuration management for transfer guards.
17 // C++ backends responsible for enforcing transfer guard levels.
18
19 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
20
21 #include <memory>
22 #include <optional>
23 #include <string>
24
25 #include "absl/base/attributes.h"
26 #include "pybind11/cast.h"
27 #include "pybind11/pybind11.h"
28 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
29 #include "tensorflow/compiler/xla/status.h"
30 #include "tensorflow/compiler/xla/util.h"
31
32 namespace jax {
33
34 namespace py = ::pybind11;
35
36 namespace {
37
38 // Protected by the GIL.
39 TransferGuardState& global_state = *new TransferGuardState();
40
41 ABSL_CONST_INIT thread_local TransferGuardState thread_local_state;
42
43 // The default transfer guard level.
44 constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow;
45
46 // Returns the transfer guard action for a transfer.
GetTransferGuardAction(TransferGuardLevel guard_level,bool explicit_transfer)47 TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level,
48 bool explicit_transfer) {
49 switch (guard_level) {
50 case TransferGuardLevel::kAllow:
51 return TransferGuardAction::kAllow;
52 case TransferGuardLevel::kLog:
53 if (explicit_transfer) {
54 return TransferGuardAction::kAllow;
55 } else {
56 return TransferGuardAction::kLog;
57 }
58 case TransferGuardLevel::kDisallow:
59 if (explicit_transfer) {
60 return TransferGuardAction::kAllow;
61 } else {
62 return TransferGuardAction::kDisallow;
63 }
64 case TransferGuardLevel::kLogExplicit:
65 return TransferGuardAction::kLog;
66 case TransferGuardLevel::kDisallowExplicit:
67 return TransferGuardAction::kDisallow;
68 default:
69 // Unreachable; gracefully handle the unexpected guard level and prevent a
70 // compiler warning.
71 return TransferGuardAction::kDisallow;
72 }
73 }
74
75 // Returns the transfer guard action for a host-to-device transfer.
76 // REQUIRES: Python GIL.
GetTransferGuardActionForHostToDevice()77 TransferGuardAction GetTransferGuardActionForHostToDevice() {
78 return GetTransferGuardAction(
79 thread_local_state.host_to_device.value_or(
80 global_state.host_to_device.value_or(kDefaultGuardLevel)),
81 thread_local_state.explicit_device_put);
82 }
83
84 // Returns the transfer guard action for a device-to-device transfer.
85 // REQUIRES: Python GIL.
GetTransferGuardActionForDeviceToDevice()86 TransferGuardAction GetTransferGuardActionForDeviceToDevice() {
87 return GetTransferGuardAction(
88 thread_local_state.device_to_device.value_or(
89 global_state.device_to_device.value_or(kDefaultGuardLevel)),
90 thread_local_state.explicit_device_put);
91 }
92
93 // Returns the transfer guard action for a device-to-host transfer.
94 // REQUIRES: Python GIL.
GetTransferGuardActionForDeviceToHost()95 TransferGuardAction GetTransferGuardActionForDeviceToHost() {
96 return GetTransferGuardAction(
97 thread_local_state.device_to_host.value_or(
98 global_state.device_to_host.value_or(kDefaultGuardLevel)),
99 thread_local_state.explicit_device_get);
100 }
101
102 } // namespace
103
ApplyTransferGuardToHostToDevice(absl::FunctionRef<std::string ()> formatter)104 xla::Status ApplyTransferGuardToHostToDevice(
105 absl::FunctionRef<std::string()> formatter) {
106 switch (GetTransferGuardActionForHostToDevice()) {
107 case TransferGuardAction::kAllow:
108 break;
109 case TransferGuardAction::kLog:
110 LOG(WARNING) << "host-to-device transfer: " << formatter();
111 break;
112 case TransferGuardAction::kDisallow:
113 return xla::InvalidArgument("Disallowed host-to-device transfer: %s",
114 formatter());
115 }
116 return ::tensorflow::OkStatus();
117 }
118
ApplyTransferGuardToDeviceToDevice(absl::FunctionRef<std::string ()> formatter)119 xla::Status ApplyTransferGuardToDeviceToDevice(
120 absl::FunctionRef<std::string()> formatter) {
121 switch (GetTransferGuardActionForDeviceToDevice()) {
122 case TransferGuardAction::kAllow:
123 break;
124 case TransferGuardAction::kLog:
125 LOG(WARNING) << "device-to-device transfer: " << formatter();
126 break;
127 case TransferGuardAction::kDisallow:
128 return xla::InvalidArgument("Disallowed device-to-device transfer: %s",
129 formatter());
130 }
131 return ::tensorflow::OkStatus();
132 }
133
ApplyTransferGuardToDeviceToHost(absl::FunctionRef<std::string ()> formatter)134 xla::Status ApplyTransferGuardToDeviceToHost(
135 absl::FunctionRef<std::string()> formatter) {
136 switch (GetTransferGuardActionForDeviceToHost()) {
137 case TransferGuardAction::kAllow:
138 break;
139 case TransferGuardAction::kLog:
140 LOG(WARNING) << "device-to-host transfer: " << formatter();
141 break;
142 case TransferGuardAction::kDisallow:
143 return xla::InvalidArgument("Disallowed device-to-host transfer: %s",
144 formatter());
145 }
146 return ::tensorflow::OkStatus();
147 }
148
BuildTransferGuardSubmodule(py::module & m)149 void BuildTransferGuardSubmodule(py::module& m) {
150 py::module tglib = m.def_submodule("transfer_guard_lib",
151 "Jax transfer guard support library");
152
153 py::enum_<TransferGuardLevel> tglevel(tglib, "TransferGuardLevel");
154 tglevel.value("ALLOW", TransferGuardLevel::kAllow);
155 tglevel.value("LOG", TransferGuardLevel::kLog);
156 tglevel.value("DISALLOW", TransferGuardLevel::kDisallow);
157 tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit);
158 tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit);
159
160 py::class_<TransferGuardState> tgstate(tglib, "TransferGuardState");
161 tgstate.def_readwrite("host_to_device", &TransferGuardState::host_to_device);
162 tgstate.def_readwrite("device_to_device",
163 &TransferGuardState::device_to_device);
164 tgstate.def_readwrite("device_to_host", &TransferGuardState::device_to_host);
165 tgstate.def_readwrite("explicit_device_put",
166 &TransferGuardState::explicit_device_put);
167 tgstate.def_readwrite("explicit_device_get",
168 &TransferGuardState::explicit_device_get);
169
170 tglib.def(
171 "global_state", [&]() { return &global_state; },
172 py::return_value_policy::reference);
173 tglib.def(
174 "thread_local_state", [&]() { return &thread_local_state; },
175 py::return_value_policy::reference);
176 }
177
178 } // namespace jax
179