xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/pywrap_tfe.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
17 #define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
18 
19 // Place `<locale>` before <Python.h> to avoid build failure in macOS.
20 #include <locale>
21 
22 // The empty line above is on purpose as otherwise clang-format will
23 // automatically move <Python.h> before <locale>.
24 #include <Python.h>
25 
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
30 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
31 
32 typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
33     TFE_InputTensorHandles;
34 typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2>
35     TFE_OutputTensorHandles;
36 
37 // Execute a TensorFlow operation.
38 //
39 // 'device_name': Name of the device on which to execute the operation, or NULL
40 //                for automatic selection.
41 // 'op_name': Name of the TensorFlow op to execute.
42 // 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors
43 //           will be provided as input to the operation.
44 // 'attrs': A Python tuple alternating names and attr values.
45 // 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will
46 //            placed. On success, its elements will be filled in and the
47 //            caller takes ownership of each returned TFE_TensorHandle.
48 //            'outputs' MUST be sized to be at least as large as the number
49 //            of tensors produced by the operation and will be resized to
50 //            the actual number of tensors produced.
51 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
52                     const char* op_name, TFE_InputTensorHandles* inputs,
53                     PyObject* attrs, TFE_OutputTensorHandles* outputs,
54                     TF_Status* out_status);
55 
56 // Execute a cancelable TensorFlow operation.
57 //
58 // Arguments as above (for TFE_Py_Execute), with the addition of:
59 // 'cancellation_manager': A pointer to a TFE_CancellationManager that can be
60 //                         used to cancel execution of the given operation.
61 typedef struct TFE_CancellationManager TFE_CancellationManager;
62 void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
63                               const char* op_name,
64                               TFE_InputTensorHandles* inputs, PyObject* attrs,
65                               TFE_CancellationManager* cancellation_manager,
66                               TFE_OutputTensorHandles* outputs,
67                               TF_Status* out_status);
68 
69 // Registers e as the Exception class for handling not ok Status. Returns
70 // Py_None if registration succeeds, else throws a TypeError and returns NULL.
71 //
72 // This function is not thread-safe.
73 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
74 
75 // Registers e as the VSpace to use.
76 // `vspace` must be a imperative_grad.py:VSpace named tuple.
77 PyObject* TFE_Py_RegisterVSpace(PyObject* e);
78 
79 // Registers e as the Exception to be raised when the conditions of
80 // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
81 // is a signal to the calling code that it should fall back to the safer (and
82 // more complete) code path.
83 //
84 // This function is not thread-safe.
85 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
86 
87 // Registers e as the gradient_function.
88 // The registered function takes
89 // (op_name, attrs, num_inputs, inputs, outputs, output_gradients) and returns
90 // the input gradients. This function will not correctly be able to generate
91 // gradients for functional ops - the gradients for those ops are calculated
92 // through a different codepath (see function.py for additional information).
93 //
94 // This function is not thread-safe.
95 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e);
96 
97 // Registers e as the forward_gradient_function.  The registered function takes
98 // (op_name, attrs, inputs, outputs, tangents) and returns the output
99 // tangents. This function is used only for operations, not for custom gradients
100 // or functional ops.
101 //
102 // This function is not thread-safe.
103 PyObject* TFE_Py_RegisterJVPFunction(PyObject* e);
104 
105 namespace tensorflow {
106 
107 // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
108 // `exception` if not nullptr, else using the class registered via
109 // TFE_Py_RegisterExceptionClass), and returns -1.
110 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception);
111 
112 }  // namespace tensorflow
113 
114 // Returns 0 if 'status' is ok. Otherwise, raises an exception (using
115 // `exception` if not nullptr, else using the class registered via
116 // TFE_Py_RegisterExceptionClass), and returns -1.
117 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
118                                   PyObject* exception);
119 
120 // Returns the string associated with the passed-in python object.
121 const char* TFE_GetPythonString(PyObject* o);
122 
123 // Returns a unique id on each call.
124 int64_t get_uid();
125 
126 // Wraps the output of get_uid as a Python Long object. Ownership is passed to
127 // the caller.
128 PyObject* TFE_Py_UID();
129 
130 // Deleter for Context objects, called from the Capsule that owns it.
131 void TFE_DeleteContextCapsule(PyObject* context);
132 
133 // Returns true if o is an instance of EagerTensor, but not a subclass. Else
134 // returns false.
135 bool EagerTensor_CheckExact(const PyObject* o);
136 
137 // Helper function to construct a new EagerTensor from a TFE_TensorHandle.
138 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
139                                 const bool is_packed = false);
140 
141 // Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
142 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
143 
144 // Creates the `EagerTensor` class by subclassing `base_class` and returns the
145 // newly created type, or nullptr on error.
146 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
147 
148 // Sets `profiler` as the current profiler to receive callbacks about events
149 // on eager tensors. Currently, the only reported event is creation.
150 // `profiler` is expected to have a `created(self, eager_tensor)` method that
151 // takes the created tensor as its single argument.
152 // Previous profiler, if any, is unset and will not receive any more
153 // callbacks.
154 // To unset the profiler, pass Py_None as the value of `profiler`.
155 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
156 
157 // Creates a new tape and adds it to the active set. `persistent` and
158 // `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
159 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
160                             PyObject* watch_accessed_variables);
161 
162 // Removes the passed tape from the set of active tapes.
163 void TFE_Py_TapeSetRemove(PyObject* tape);
164 
165 // Adds the passed tape to the set of active tapes.
166 void TFE_Py_TapeSetAdd(PyObject* tape);
167 
168 // Returns true if the tape stack is empty.
169 PyObject* TFE_Py_TapeSetIsEmpty();
170 
171 // Check if any backward tape should record an operation given inputs.
172 //
173 // Does not take forward accumulators into account.
174 PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors);
175 
176 // Determine possible gradient types, taking forward accumulators into account.
177 //   - 0 if no tape will record (implies TFE_Py_TapeSetShouldRecordBackprop
178 //     is false and no forward accumulator is watching)
179 //   - 1 if first-order gradients may be requested
180 //   - 2 if higher-order gradients may be requested
181 PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors);
182 
183 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
184 void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id);
185 
186 // Stops any gradient recording on the current thread.
187 //
188 // Includes forward accumulators.
189 void TFE_Py_TapeSetStopOnThread();
190 
191 // Restarts gradient recording on the current thread.
192 void TFE_Py_TapeSetRestartOnThread();
193 
194 // Checks whether gradient recording is stopped on the current thread.
195 PyObject* TFE_Py_TapeSetIsStopped();
196 
197 // Records an operation for the purpose of gradient computation.
198 //
199 // Arguments:
200 //  - op_type is a string for the operation type, used in the backprop code
201 //  - output_tensors are a list of Python Tensor objects output by the operation
202 //  - input_tensors are a list of input Tensors to the recorded operation
203 //  - backward_function is the function to be called during backprop or
204 //    forwardprop to, given the gradients of the output tensors, produce the
205 //    gradients of the input tensors. This function is automatically transposed
206 //    during forwardprop.
207 //  - forward_function is an optional special-case for forwardprop, taking input
208 //    jvps and returning output jvps.
209 //
210 // Records an operation both for backprop (gradient tape) and forwardprop
211 // (forward accumulator). Equivalent to calling both
212 // TFE_Py_TapeSetRecordOperationBackprop and
213 // TFE_Py_TapeSetRecordOperationForwardprop.
214 PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
215                                         PyObject* output_tensors,
216                                         PyObject* input_tensors,
217                                         PyObject* backward_function,
218                                         PyObject* forward_function);
219 
220 // Records an operation only for backprop (gradient tapes).
221 //
222 // Same arguments as TFE_Py_TapeSetRecordOperation.
223 PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
224                                                 PyObject* output_tensors,
225                                                 PyObject* input_tensors,
226                                                 PyObject* backward_function);
227 
228 // Records an operation only for forwardprop (forward accumulators).
229 //
230 // Arguments:
231 //  - op_type is a string for the operation type, used in the backprop code
232 //  - output_tensors are a list of Python Tensor objects output by the operation
233 //  - input_tensors are a list of input Tensors to the recorded operation
234 //  - backward_function is the function to be called to, given the gradients of
235 //    the output tensors, produce the gradients of the input tensors. This
236 //    function is automatically transposed to produce output gradients given
237 //    input gradients.
238 //  - forwardprop_output_indices indicates any output_tensors which contain
239 //    JVPs. Typically these will have come from TFE_Py_PackJVPs. May
240 //    be None or an empty sequence if there are no JVP outputs from the
241 //    operation.
242 PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
243     PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
244     PyObject* backward_function, PyObject* forwardprop_output_indices);
245 
246 // Notifies all tapes that a variable has been accessed.
247 void TFE_Py_TapeVariableAccessed(PyObject* variable);
248 
249 // Watches the given variable object on the given tape.
250 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
251 
252 // Computes a gradient based on information recorded on the tape.`tape` must
253 // have been produced by TFE_Py_NewTape. `target` and `sources` must be python
254 // lists of Tensor objects. `output_gradients` is either None or a python list
255 // of either Tensor or None, and if not None should have the same length as
256 // target.
257 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
258                               PyObject* sources, PyObject* output_gradients,
259                               PyObject* sources_raw,
260                               PyObject* unconnected_gradients,
261                               TF_Status* status);
262 
263 // Execute a tensorflow operation assuming that all provided inputs are
264 // correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
265 // it will simply fail with a NotImplementedError.
266 //
267 // The "args" PyObject* is meant to be a tuple with the following structure:
268 //  Item 1: The Python eager Context object
269 //  Item 2: op_name: Name of the TensorFlow op to execute.
270 //  Item 3: name: An optional name for the operation.
271 //  Item 4 onwards: inputs - This is a list of inputs followed by a list of
272 //        attrs. It is not necessary for type attrs to be present.
273 //
274 // Note: the device_name and op_callbacks, which were previously passed
275 // as arguments, are now read via GetEagerContextThreadLocalData().
276 //
277 // This is named _C since there doesn't seem to be any way to make it visible
278 // in the SWIG interface without renaming due to the use of the %native
279 // directive.
280 PyObject* TFE_Py_FastPathExecute_C(PyObject* args);
281 
282 // Record the gradient for a given op.
283 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
284                                 PyObject* attrs, PyObject* results,
285                                 PyObject* forward_pass_name_scope);
286 
287 // Returns all variables watched by the given tape in the order those variables
288 // were created.
289 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
290 
291 // Creates a new forward accumulator. Does not add it to the active set.
292 PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch);
293 
294 // Adds a ForwardAccumulator to the active set, meaning it will watch executed
295 // operations. It must not already be in the active set.
296 PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator);
297 // Removes a forward accumulator from the active set, meaning it will no longer
298 // be watching operations.
299 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator);
300 
301 // Tell the forward accumulator `accumulator` to watch `tensor`, with a Tensor
302 // tangent vector `tangent` of matching shape and dtype.
303 void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
304                                     PyObject* tangent);
305 
306 // Looks up the Jacobian-vector product of `tensor` in the forward accumulator
307 // `accumulator`. Returns None if no JVP is available.
308 PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor);
309 
310 // Temporarily push or pop transient state for accumulators in the active set.
311 //
312 // Allows an accumulator which is currently processing an operation to
313 // temporarily reset its state. This is useful when building forwardprop
314 // versions of functions, where an accumulator will trigger function building
315 // and then must process captured symbolic tensors while building it. Without
316 // pushing and popping, accumulators ignore operations executed as a direct
317 // result of their own jvp computations.
318 PyObject* TFE_Py_ForwardAccumulatorPushState();
319 PyObject* TFE_Py_ForwardAccumulatorPopState();
320 
321 // Collects state from all current forward accumulators related to `tensors`.
322 //
323 // This is useful for packing JVPs as function inputs before executing a
324 // function which computes primals and JVPs at the same time.
325 //
326 // Does not include accumulators which are currently in the process of computing
327 // a jvp (and so appear somewhere on the current execution stack) or any
328 // accumulators more deeply nested.
329 //
330 // Includes JVPs for `tensors` and any higher-order JVPs for those
331 // (recursively). Returns a two-element tuple (indices, jvps):
332 //   indices: A sequence of sequences of two-element tuples. Each forward
333 //       accumulator is represented as a sequence of tuples with (primal_index,
334 //       jvp_index). Both integers index into the concatenated `tensors + jvps`
335 //       array.
336 //   jvps: A flat list of Tensors. Best interpreted as a sequence to be
337 //       appended to `tensors`.
338 PyObject* TFE_Py_PackJVPs(PyObject* tensors);
339 
340 // Variable Watcher methods.
341 
342 // Creates a new variable watcher and adds it to the set of active variable
343 // watchers.
344 PyObject* TFE_Py_VariableWatcherNew();
345 
346 // Removes the passed variable watcher from the set of active variable watchers.
347 void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher);
348 
349 // Notifies all variable watchers that a variable has been accessed.
350 void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable);
351 
352 // Returns all variables watched by the given variable_watcher in the order
353 // those variables were created.
354 PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher);
355 
356 // Returns an EagerTensor of dimension [len(`tensors`)] containing
357 // the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
358 // TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
359 // `tensors`. For example, if `tensors` contains tensors of with shapes
360 // [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with
361 // `slice_dim` equal to 1 will return [2, 5, 7].
362 // On error, returns nullptr and sets python exception.
363 // REQUIRES: `tensors` is a python list/tuple of EagerTensors
364 // REQUIRES: `slice_dim` is non-negative and smaller than the rank of all
365 //   tensors in `tensors`.
366 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
367 
368 // Returns the shape of this tensor's on-device representation.
369 // The shape is represented as a Python tuple of integers.
370 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
371 
372 void TFE_Py_EnableInteractivePythonLogging();
373 
374 // Sets the current Python eager Context object (defined
375 // in eager/context.py). This function must be called at least once before
376 // eager tensors are created.
377 // If an error is encountered, sets python error and returns NULL. Else, returns
378 // Py_None.
379 //
380 // Not thread-safe.
381 // TODO(mdan): Retire this - non-Python users should only need the EagerContext.
382 PyObject* TFE_Py_SetEagerContext(PyObject* py_context);
383 
384 // Returns the current eager Context object (defined in eager/context.py)
385 // that was last set using TFE_Py_SetEagerContext.
386 // If an error is encountered, sets python error and returns NULL.
387 // The returned PyObject is "new", i.e. the caller must call Py_DECREF on it at
388 // some point.
389 PyObject* GetPyEagerContext();
390 
391 // These are exposed since there is SWIG code that calls these.
392 // Returns a pre-allocated status if it exists.
393 TF_Status* GetStatus();
394 // Returns the pre-allocated status to the code.
395 void ReturnStatus(TF_Status* status);
396 
397 namespace tensorflow {
398 
399 // Returns the DataType for the specified tensor.  Returns DT_INVALID if
400 // PyObject is not a tensor.
401 DataType PyTensor_DataType(PyObject* tensor);
402 
403 // Thread-local data associated with a Python eager Context object.
404 //
405 // TODO(edloper): Consider changing device_name and scope_name to a const char*
406 // (with nullptr used for None). However, note that existing code (e.g.
407 // TFE_TensorHandleCache::Lookup) assumes that the lifetime of these strings
408 // extends beyond the point where their value is changed; so we'd need to make
409 // sure that the strings stay alive (maybe using PyUnicode_InternInPlace?)
410 struct EagerContextThreadLocalData {
411   bool is_eager = false;
412   bool invoking_op_callbacks = false;
413   tensorflow::Safe_PyObjectPtr device_name;
414   tensorflow::Safe_PyObjectPtr scope_name;
415   tensorflow::Safe_PyObjectPtr device_spec;
416   tensorflow::Safe_PyObjectPtr function_call_options;
417   tensorflow::Safe_PyObjectPtr executor;
418   tensorflow::Safe_PyObjectPtr op_callbacks;
419 };
420 
421 // Create a thread-local-data structure associated with py_eager_context.
422 // `is_eager` and `device_spec` are used to supply default values for those
423 // fields whenever a new thread-local instance is created for py_eager_tensor.
424 //
425 // This function assumes that the Python GIL is held (and does not perform its
426 // own locking).
427 void MakeEagerContextThreadLocalData(PyObject* py_eager_context,
428                                      PyObject* is_eager,
429                                      PyObject* device_spec);
430 
431 // Returns the thread-local instance of EagerContextThreadLocalData that is
432 // associated with the given Python Context object.  If an instance has not
433 // yet been created for `py_eager_context` in this thread, then a new one is
434 // created, and initialized with the default values specified in
435 // MakeEagerContextThreadLocalData.
436 EagerContextThreadLocalData* GetEagerContextThreadLocalData(
437     PyObject* py_eager_context);
438 
439 // Free data structures used to track py_eager_context.
440 //
441 // This frees global state associated with py_eager_context, as well as thread-
442 // local state associated with py_eager_context and the current thread. If you
443 // wish to destroy thread-local state associated with a single py_eager_context
444 // for multiple threads, then you must call this method from each thread.
445 //
446 // Thread-local state assocaited with eager contexts is also automatically
447 // cleaned up when the thread is destroyed.
448 //
449 // This function assumes that the Python GIL is held (and does not perform its
450 // own locking).
451 void DestroyEagerContextThreadLocalData(PyObject* py_eager_context);
452 
453 }  // namespace tensorflow
454 
455 #endif  // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
456