xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/transfer_guard_lib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This files implements the configuration management for transfer guards.
17 // C++ backends responsible for enforcing transfer guard levels.
18 
19 #include "tensorflow/compiler/xla/python/transfer_guard_lib.h"
20 
21 #include <memory>
22 #include <optional>
23 #include <string>
24 
25 #include "absl/base/attributes.h"
26 #include "pybind11/cast.h"
27 #include "pybind11/pybind11.h"
28 #include "pybind11_abseil/absl_casters.h"  // from @pybind11_abseil
29 #include "tensorflow/compiler/xla/status.h"
30 #include "tensorflow/compiler/xla/util.h"
31 
32 namespace jax {
33 
34 namespace py = ::pybind11;
35 
36 namespace {
37 
38 // Protected by the GIL.
39 TransferGuardState& global_state = *new TransferGuardState();
40 
41 ABSL_CONST_INIT thread_local TransferGuardState thread_local_state;
42 
43 // The default transfer guard level.
44 constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow;
45 
46 // Returns the transfer guard action for a transfer.
GetTransferGuardAction(TransferGuardLevel guard_level,bool explicit_transfer)47 TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level,
48                                            bool explicit_transfer) {
49   switch (guard_level) {
50     case TransferGuardLevel::kAllow:
51       return TransferGuardAction::kAllow;
52     case TransferGuardLevel::kLog:
53       if (explicit_transfer) {
54         return TransferGuardAction::kAllow;
55       } else {
56         return TransferGuardAction::kLog;
57       }
58     case TransferGuardLevel::kDisallow:
59       if (explicit_transfer) {
60         return TransferGuardAction::kAllow;
61       } else {
62         return TransferGuardAction::kDisallow;
63       }
64     case TransferGuardLevel::kLogExplicit:
65       return TransferGuardAction::kLog;
66     case TransferGuardLevel::kDisallowExplicit:
67       return TransferGuardAction::kDisallow;
68     default:
69       // Unreachable; gracefully handle the unexpected guard level and prevent a
70       // compiler warning.
71       return TransferGuardAction::kDisallow;
72   }
73 }
74 
75 // Returns the transfer guard action for a host-to-device transfer.
76 // REQUIRES: Python GIL.
GetTransferGuardActionForHostToDevice()77 TransferGuardAction GetTransferGuardActionForHostToDevice() {
78   return GetTransferGuardAction(
79       thread_local_state.host_to_device.value_or(
80           global_state.host_to_device.value_or(kDefaultGuardLevel)),
81       thread_local_state.explicit_device_put);
82 }
83 
84 // Returns the transfer guard action for a device-to-device transfer.
85 // REQUIRES: Python GIL.
GetTransferGuardActionForDeviceToDevice()86 TransferGuardAction GetTransferGuardActionForDeviceToDevice() {
87   return GetTransferGuardAction(
88       thread_local_state.device_to_device.value_or(
89           global_state.device_to_device.value_or(kDefaultGuardLevel)),
90       thread_local_state.explicit_device_put);
91 }
92 
93 // Returns the transfer guard action for a device-to-host transfer.
94 // REQUIRES: Python GIL.
GetTransferGuardActionForDeviceToHost()95 TransferGuardAction GetTransferGuardActionForDeviceToHost() {
96   return GetTransferGuardAction(
97       thread_local_state.device_to_host.value_or(
98           global_state.device_to_host.value_or(kDefaultGuardLevel)),
99       thread_local_state.explicit_device_get);
100 }
101 
102 }  // namespace
103 
ApplyTransferGuardToHostToDevice(absl::FunctionRef<std::string ()> formatter)104 xla::Status ApplyTransferGuardToHostToDevice(
105     absl::FunctionRef<std::string()> formatter) {
106   switch (GetTransferGuardActionForHostToDevice()) {
107     case TransferGuardAction::kAllow:
108       break;
109     case TransferGuardAction::kLog:
110       LOG(WARNING) << "host-to-device transfer: " << formatter();
111       break;
112     case TransferGuardAction::kDisallow:
113       return xla::InvalidArgument("Disallowed host-to-device transfer: %s",
114                                   formatter());
115   }
116   return ::tensorflow::OkStatus();
117 }
118 
ApplyTransferGuardToDeviceToDevice(absl::FunctionRef<std::string ()> formatter)119 xla::Status ApplyTransferGuardToDeviceToDevice(
120     absl::FunctionRef<std::string()> formatter) {
121   switch (GetTransferGuardActionForDeviceToDevice()) {
122     case TransferGuardAction::kAllow:
123       break;
124     case TransferGuardAction::kLog:
125       LOG(WARNING) << "device-to-device transfer: " << formatter();
126       break;
127     case TransferGuardAction::kDisallow:
128       return xla::InvalidArgument("Disallowed device-to-device transfer: %s",
129                                   formatter());
130   }
131   return ::tensorflow::OkStatus();
132 }
133 
ApplyTransferGuardToDeviceToHost(absl::FunctionRef<std::string ()> formatter)134 xla::Status ApplyTransferGuardToDeviceToHost(
135     absl::FunctionRef<std::string()> formatter) {
136   switch (GetTransferGuardActionForDeviceToHost()) {
137     case TransferGuardAction::kAllow:
138       break;
139     case TransferGuardAction::kLog:
140       LOG(WARNING) << "device-to-host transfer: " << formatter();
141       break;
142     case TransferGuardAction::kDisallow:
143       return xla::InvalidArgument("Disallowed device-to-host transfer: %s",
144                                   formatter());
145   }
146   return ::tensorflow::OkStatus();
147 }
148 
BuildTransferGuardSubmodule(py::module & m)149 void BuildTransferGuardSubmodule(py::module& m) {
150   py::module tglib = m.def_submodule("transfer_guard_lib",
151                                      "Jax transfer guard support library");
152 
153   py::enum_<TransferGuardLevel> tglevel(tglib, "TransferGuardLevel");
154   tglevel.value("ALLOW", TransferGuardLevel::kAllow);
155   tglevel.value("LOG", TransferGuardLevel::kLog);
156   tglevel.value("DISALLOW", TransferGuardLevel::kDisallow);
157   tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit);
158   tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit);
159 
160   py::class_<TransferGuardState> tgstate(tglib, "TransferGuardState");
161   tgstate.def_readwrite("host_to_device", &TransferGuardState::host_to_device);
162   tgstate.def_readwrite("device_to_device",
163                         &TransferGuardState::device_to_device);
164   tgstate.def_readwrite("device_to_host", &TransferGuardState::device_to_host);
165   tgstate.def_readwrite("explicit_device_put",
166                         &TransferGuardState::explicit_device_put);
167   tgstate.def_readwrite("explicit_device_get",
168                         &TransferGuardState::explicit_device_get);
169 
170   tglib.def(
171       "global_state", [&]() { return &global_state; },
172       py::return_value_policy::reference);
173   tglib.def(
174       "thread_local_state", [&]() { return &thread_local_state; },
175       py::return_value_policy::reference);
176 }
177 
178 }  // namespace jax
179