xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/gradients/README.md (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# C++ gradients
2
3Gradients are currently being ported from
4[python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/ops)
5to C++ (in this directory).
6
7Contributions are welcome and much appreciated; please follow the instructions
8below.
9
101.  Create the op gradient function in `foo_grad.cc` corresponding to the
11    `foo_grad.py` file where the op originated (i.e. `array_grad.py` op
12    gradients should be written in `array_grad.cc`).
13
142.  Write the op gradient with the following naming scheme:
15
16    ```
17    Status OpNameGrad(const Scope& scope, const Operation& op,
18                      const std::vector<Output>& grad_inputs,
19                      std::vector<Output>* grad_outputs) {
20      ...
21      return scope.status();
22    }
23    REGISTER_GRADIENT_OP("OpName", OpNameGrad);
24    ```
25
263.  Ops gradients are implemented by using the
27    [C++ API](https://www.tensorflow.org/api_docs/cc/).
28
294.  Tests should be included in `foo_grad_test.cc`. Please see
30    [`array_grad_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/gradients/array_grad_test.cc)
31    for many examples. Tests are as simple as, creating a placeholder input for
32    the op's inputs and calling `RunTest` (`RunTest` uses a
33    [gradient checker](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/framework/gradient_checker.cc)
34    to verify that the theoretical gradient matches the numeric gradient). For
35    example:
36
37    ```
38    TEST_F(ArrayGradTest, IdentityGrad) {
39      TensorShape shape({5, 2});
40      auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
41      auto y = Identity(scope_, x);
42      RunTest(x, shape, y, shape);
43    }
44    ```
45
46NOTE: There are some ops that require features from the C++ API that are not yet
47implemented.
48
49*   Ops that require PartialTensorShape information cannot yet be implemented.
50
51*   Ops that require SparseTensor or IndexSlices (currently only in python)
52    cannot yet be implemented.
53
54*   Maybe more.
55
56For questions: Please create an issue assigned to suharshs.
57