xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/framework/gradients.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/framework/gradients.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "tensorflow/cc/framework/grad_op_registry.h"
22 #include "tensorflow/cc/framework/while_gradients.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/graph/while_context.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/platform/macros.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 struct OutputHash {
operator ()tensorflow::__anonf72ede120111::OutputHash38   uint64 operator()(const Output& x) const {
39     return x.hash();
40   }
41 };
42 
43 struct OutputEq {
operator ()tensorflow::__anonf72ede120111::OutputEq44   bool operator()(const Output& x, const Output& y) const {
45     return (x.node() == y.node()) && (x.index() == y.index());
46   }
47 };
48 
49 class SymbolicGradientBuilder {
50  public:
51   SymbolicGradientBuilder(const Scope& scope,
52                           const ops::GradOpRegistry* registry,
53                           const std::vector<Output>& outputs,
54                           const std::vector<Output>& inputs,
55                           const std::vector<Output>& grad_inputs,
56                           std::vector<Output>* grad_outputs);
57 
58   Status AddGradients();
59 
NoGradient()60   static Output NoGradient() { return Output(nullptr, -1); }
61 
62  private:
63   Status Initialize();
64 
65   // For each forward edge from `src` to `dst` in the initial/forward graph:
66   // propagates gradients `dst_grad` backwards along the edge from `src`
67   // to `dst` in the graph. This will add `dst_grad` to the list of pending
68   // gradients for the node associated with `src`.
69   Status BackpropAlongEdge(const Output& dst_grad, const Output& src);
70 
71   // Adds a node to the graph (returned in `grad`) that sums the in-bound
72   // gradients to `src` (if there are more than one).
73   Status SumGradients(const Output& src, Output* grad);
74 
75   // Returns true if `opname` is registered in `registry_` with no gradient
76   // function, false otherwise.
77   bool IsPrimitiveOpWithNoGrad(const string& opname);
78 
79   // Call the gradient function for `op`, storing the result in `grad_outputs`.
80   Status CallGradFunction(const Operation& op,
81                           const std::vector<Output>& grad_inputs,
82                           std::vector<Output>* grad_outputs);
83 
84   // Returns a list mapping whether each node in the graph is reachable
85   // from outputs_. Keyed by node id.
86   std::vector<bool> GetReachableNodes();
87 
88   // Creates the gradient subgraph for a while loop (or just stores
89   // `summed_grads` if not all incoming gradients are available yet). All exit
90   // nodes (which are the first nodes of a loop encountered in the backwards
91   // pass) are passed to this function rather than processed normally.
92   // `summed_grads` is the sum of `exit_node`s gradients.
93   Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads);
94 
95   // Gets the set of node ids at which to stop backprop. These are all elements
96   // of `outputs_` that do not get transitively consumed by other `outputs_`.
97   // Used to identify nodes at which to stop backprop.
98   std::unordered_set<int> GetStopBackpropNodes(
99       const std::vector<bool>& reachable_nodes,
100       const std::unordered_set<int>& output_nodes) const;
101 
102   const Scope& scope_;
103   const ops::GradOpRegistry* registry_;
104   const std::vector<Output>& outputs_;
105   const std::vector<Output>& inputs_;
106   const std::vector<Output>& grad_inputs_;
107   std::vector<Output>* grad_outputs_;
108 
109   // A vector of output endpoints which represents backpropagated gradients.
110   typedef std::vector<Output> BackproppedGradients;
111 
112   // backprops_ is a map from a node output to its accumulated
113   // gradients.  When a node output has accumulated all its
114   // gradients, we add a node which sums them up.
115   std::unordered_map<Output, BackproppedGradients, OutputHash, OutputEq>
116       backprops_;
117 
118   // pending[i] is count-down counter for i-th node's expected
119   // backprops.  When pending[i] becomes zero, we collected all
120   // backprop gradients for all outputs of the ith-node.
121   std::vector<int> pending_;
122 
123   // `ready` keeps track of nodes that have been completely
124   // backpropped. Initially, for every output in `outputs_`, we add initial
125   // gradients from `grad_inputs_`.
126   std::deque<Node*> ready_;
127 
128   // The set of node ids in `inputs_`. Used to identify nodes at backprop
129   // frontier. Maps from Output -> index into `grad_outputs_`.
130   std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_;
131 
132   // For each while loop in the graph, collects the summed gradients for each of
133   // the loop's exit nodes. Note that unlike backprops_, this map contains the
134   // output of SumGradients(), not the input (i.e. each exit node may have
135   // multiple incoming gradients, but we only store the combined Output here).
136   std::map<WhileContext*, std::map<Node*, Output>> while_backprops_;
137 
138   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
139 };
140 
SymbolicGradientBuilder(const Scope & scope,const ops::GradOpRegistry * registry,const std::vector<Output> & outputs,const std::vector<Output> & inputs,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)141 SymbolicGradientBuilder::SymbolicGradientBuilder(
142     const Scope& scope, const ops::GradOpRegistry* registry,
143     const std::vector<Output>& outputs, const std::vector<Output>& inputs,
144     const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs)
145     : scope_(scope),
146       registry_(registry),
147       outputs_(outputs),
148       inputs_(inputs),
149       grad_inputs_(grad_inputs),
150       grad_outputs_(grad_outputs) {}
151 
BackpropAlongEdge(const Output & dst_grad,const Output & src)152 Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad,
153                                                   const Output& src) {
154   if (src.node() == nullptr) {
155     return errors::Internal("Attempted to backprop along an invalid edge.");
156   }
157   auto iter = backprops_.find(src);
158   if (iter != backprops_.end()) {
159     auto* grads = &iter->second;
160     grads->push_back(dst_grad);
161     if (--pending_[src.node()->id()] == 0) {
162       ready_.push_back(src.node());
163     }
164   }
165   return OkStatus();
166 }
167 
GetReachableNodes()168 std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
169   std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false);
170   std::deque<Node*> queue;
171   for (const Output& out : outputs_) {
172     if (!reachable_nodes[out.node()->id()]) {
173       queue.push_back(out.node());
174       reachable_nodes[out.node()->id()] = true;
175     }
176   }
177 
178   while (!queue.empty()) {
179     Node* n = queue.front();
180     queue.pop_front();
181     for (const Edge* e : n->in_edges()) {
182       if (e->IsControlEdge()) continue;
183       if (!reachable_nodes[e->src()->id()]) {
184         queue.push_back(e->src());
185         reachable_nodes[e->src()->id()] = true;
186       }
187     }
188   }
189   return reachable_nodes;
190 }
191 
GetStopBackpropNodes(const std::vector<bool> & reachable_nodes,const std::unordered_set<int> & output_nodes) const192 std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
193     const std::vector<bool>& reachable_nodes,
194     const std::unordered_set<int>& output_nodes) const {
195   // Output nodes that get transitively consumed by other `outputs_` are stored
196   // in `internal_outputs`.
197   std::unordered_set<int> internal_outputs;
198   std::unordered_set<Node*> visited;
199   // Initialize `queue` for BFS traversal. Nodes in `queue` hold upcoming nodes
200   // along with the last Node in `output_` encountered along that path. If no
201   // `output_` node was encountered, pair.second will be nullptr.
202   std::deque<std::pair<Node*, Node*>> queue;
203   for (const Output& nout : inputs_) {
204     auto const& pair = visited.insert(nout.node());
205     if (pair.second) {
206       queue.push_back(std::make_pair(nout.node(), static_cast<Node*>(nullptr)));
207     }
208   }
209   // BFS from nodes in 'inputs_' along out edges for the entire graph. Internal
210   // output nodes are recorded during the traversal. All nodes that are output
211   // nodes but not internal output nodes are considered the frontier of the
212   // output nodes, and thus our stop backprop nodes.
213   while (!queue.empty()) {
214     std::pair<Node*, Node*> p = queue.front();
215     Node* n = p.first;
216     queue.pop_front();
217     for (const Edge* e : n->out_edges()) {
218       // If a node is not reachable from outputs_, we can stop.
219       if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue;
220 
221       auto const& pair = visited.insert(e->dst());
222       if (pair.second) {
223         int node_id = e->dst()->id();
224         Node* last_output_node = p.second;
225         if (output_nodes.find(node_id) != output_nodes.end()) {
226           // We reached an output node.
227           if (last_output_node != nullptr) {
228             // If we had already found an output node on this path so we mark
229             // it as an internal output.
230             internal_outputs.insert(last_output_node->id());
231           }
232           // Mark this newly found output node to insert in the queue.
233           last_output_node = e->dst();
234         }
235         queue.push_back(std::make_pair(e->dst(), last_output_node));
236       }
237     }
238   }
239   // Finally, we set stop_backprop_nodes to all output_nodes that aren't also
240   // internal_outputs.
241   std::unordered_set<int> stop_backprop_nodes;
242   for (int output_node : output_nodes) {
243     if (internal_outputs.find(output_node) == internal_outputs.end()) {
244       stop_backprop_nodes.insert(output_node);
245     }
246   }
247   return stop_backprop_nodes;
248 }
249 
Initialize()250 Status SymbolicGradientBuilder::Initialize() {
251   if (outputs_.size() != grad_inputs_.size()) {
252     return errors::InvalidArgument(
253         "Must specify a gradient input for each output.");
254   }
255   std::vector<bool> reachable_nodes = GetReachableNodes();
256   for (const Output& input : inputs_) {
257     if (!reachable_nodes[input.node()->id()]) {
258       return errors::InvalidArgument(
259           "Cannot compute the partial derivative for node '",
260           input.node()->name(),
261           "' as it's unreachable from the output node(s).");
262     }
263   }
264   grad_outputs_->clear();
265   grad_outputs_->resize(inputs_.size());
266 
267   std::unordered_set<int> output_nodes;
268   output_nodes.reserve(outputs_.size());
269   for (size_t i = 0; i < outputs_.size(); ++i) {
270     output_nodes.insert(outputs_[i].node()->id());
271   }
272 
273   std::unordered_set<int> stop_backprop_nodes =
274       GetStopBackpropNodes(reachable_nodes, output_nodes);
275 
276   // Populate `input_nodes_` from Outputs in `inputs_`.
277   input_nodes_.reserve(inputs_.size());
278   for (size_t i = 0; i < inputs_.size(); ++i) {
279     input_nodes_.insert({inputs_[i], i});
280   }
281 
282   // TODO(andydavis) Consider a more efficient data structure for `pending_` to
283   // handle computing gradients over small subgraphs from a very large graph.
284   pending_.resize(scope_.graph()->num_node_ids(), 0);
285   {
286     backprops_.clear();
287     std::unordered_set<Node*> visited;
288     std::deque<Node*> queue;
289     for (const Output& nout : inputs_) {
290       auto const& pair = visited.insert(nout.node());
291       if (pair.second) {
292         queue.push_back(nout.node());
293       }
294     }
295 
296     // Going forward to figure out which endpoints need backprop-ed.
297     // A node's endpoints need to be backprop-ed only if one of the
298     // arg node can reach the node via data edges.
299     while (!queue.empty()) {
300       Node* n = queue.front();
301       queue.pop_front();
302       for (int i = 0; i < n->num_outputs(); ++i) {
303         backprops_[{n, i}].clear();
304       }
305       int num_expected_backprops = 0;
306       if (stop_backprop_nodes.find(n->id()) == stop_backprop_nodes.end()) {
307         // Internal node: continue BFS along connected outputs.
308         for (const Edge* e : n->out_edges()) {
309           // If a node is not reachable from outputs_,
310           // we don't expect it to receive a backpropagated gradient.
311           // It will not be counted in num_expected_backprops.
312           if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue;
313           auto const& pair = visited.insert(e->dst());
314           if (pair.second) {
315             queue.push_back(e->dst());
316           }
317           ++num_expected_backprops;
318         }
319       }
320       if (output_nodes.find(n->id()) != output_nodes.end()) {
321         // Output node: update `num_expected_backprops` for each Output in
322         // `outputs_` that references `n`.
323         for (const Output& output : outputs_) {
324           if (output.node() == n) {
325             ++num_expected_backprops;
326           }
327         }
328       }
329       pending_[n->id()] = num_expected_backprops;
330     }
331   }
332 
333   {
334     // Initialize backprop with `grad_inputs_`.
335     const size_t num_dy = grad_inputs_.size();
336     for (size_t i = 0; i < num_dy; ++i) {
337       TF_RETURN_IF_ERROR(BackpropAlongEdge(grad_inputs_[i], outputs_[i]));
338     }
339   }
340   return OkStatus();
341 }
342 
SumGradients(const Output & src,Output * grad)343 Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) {
344   auto iter = backprops_.find(src);
345   if (iter == backprops_.end()) {
346     return errors::Internal(
347         "Unable to find backprop list for node.id ", src.node()->name());
348   }
349   const auto& grads = iter->second;
350   // Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed).
351   // Return any valid backpropped gradients that remain after filtering,
352   // or 'NoGradient' otherwise.
353   std::vector<Output> grads_to_keep;
354   for (const Output& o : grads) {
355     if (o == NoGradient()) continue;
356     grads_to_keep.push_back(o);
357   }
358 
359   if (grads_to_keep.empty()) {
360     // Nothing propagated back. Return 'NoGradient'.
361     *grad = NoGradient();
362   } else if (grads_to_keep.size() == 1) {
363     // Just one backprop edge.
364     *grad = grads_to_keep[0];
365   } else {
366     // Otherwise, adds backprop-ed gradients.
367     // TODO(andydavis) Use a better accumulator here.
368     *grad = ops::AddN(scope_, grads_to_keep);
369   }
370 
371   return OkStatus();
372 }
373 
IsPrimitiveOpWithNoGrad(const string & opname)374 bool SymbolicGradientBuilder::IsPrimitiveOpWithNoGrad(const string& opname) {
375   ops::GradFunc grad_fn;
376   Status s = registry_->Lookup(opname, &grad_fn);
377   return s.ok() && (grad_fn == nullptr);
378 }
379 
CallGradFunction(const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)380 Status SymbolicGradientBuilder::CallGradFunction(
381     const Operation& op,
382     const std::vector<Output>& grad_inputs,
383     std::vector<Output>* grad_outputs) {
384   ops::GradFunc grad_fn;
385   TF_RETURN_IF_ERROR(registry_->Lookup(op.node()->type_string(), &grad_fn));
386   TF_RETURN_IF_ERROR(grad_fn(scope_, op, grad_inputs, grad_outputs));
387   TF_RETURN_IF_ERROR(scope_.status());
388   return OkStatus();
389 }
390 
ProcessWhileLoop(Node * exit_node,const Output & summed_grads)391 Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node,
392                                                  const Output& summed_grads) {
393   // TODO(skyewm): detect second-order gradient and return bad status
394   // TODO(skyewm): handle (or at least detect) nested while loops
395 
396   // TODO(skyewm): handle NoGradient in while loop
397   if (summed_grads == NoGradient()) {
398     return errors::Unimplemented(
399         "Missing gradient into while loop not yet implemented");
400   }
401 
402   DCHECK(exit_node->IsExit());
403   WhileContext* while_ctx = exit_node->while_ctx();
404   DCHECK(while_ctx != nullptr);
405 
406   // Record 'summed_grads' as the backprop input associated with 'exit_node'
407   std::map<Node*, Output>& backprops = while_backprops_[while_ctx];
408   DCHECK(backprops.find(exit_node) == backprops.end());
409   backprops[exit_node] = summed_grads;
410 
411   // Wait until we have all exit nodes' backprops collected before processing
412   // the while loop.
413   // TODO(skyewm): what if not all the exit nodes are reachable?
414   if (backprops.size() < while_ctx->exit_nodes().size()) return OkStatus();
415 
416   // We've seen all the exit nodes for this loop and have collected all the
417   // backprops. Create the gradient graph for the while loop.
418   Scope while_scope =
419       scope_.NewSubScope(strings::StrCat(while_ctx->frame_name(), "_grad"));
420   std::vector<Output> dy;
421   for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]);
422   std::vector<Output> dx;
423   TF_RETURN_IF_ERROR(AddWhileLoopGradient(while_ctx, while_scope, dy, &dx));
424 
425   // Backprop along the in edges to the while loop (i.e. the inputs to the enter
426   // nodes)
427   DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size());
428   for (int i = 0, end = dx.size(); i < end; ++i) {
429     Node* enter_node = while_ctx->enter_nodes()[i];
430     for (const Edge* e : enter_node->in_edges()) {
431       if (e->IsControlEdge()) continue;
432       TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()}));
433     }
434   }
435   return OkStatus();
436 }
437 
AddGradients()438 Status SymbolicGradientBuilder::AddGradients() {
439   // Initialize backprops.
440   TF_RETURN_IF_ERROR(Initialize());
441 
442   // Backward propagation.
443   std::vector<Output> dy;
444   while (!ready_.empty()) {
445     // n has collected all gradients.
446     Node* n = ready_.front();
447     ready_.pop_front();
448 
449     // dy[i] is the sum of i-th output's backpropped gradients.
450     const int num_y = n->num_outputs();
451     dy.clear();
452     dy.resize(num_y, {nullptr, 0});
453     std::vector<int> no_grad_dy_indices;
454     for (int i = 0; i < num_y; ++i) {
455       TF_RETURN_IF_ERROR(SumGradients({n, i}, &dy[i]));
456       if (dy[i] == NoGradient()) {
457         no_grad_dy_indices.push_back(i);
458       }
459       auto iter = input_nodes_.find({n, i});
460       if (iter != input_nodes_.end()) {
461         // Return gradients for Output in 'grad_outputs_'.
462         (*grad_outputs_)[iter->second] = dy[i];
463       }
464     }
465 
466     // Stop backprop if none of the inputs to `n` are in `backprops_'.
467     bool stop_node = true;
468     for (const Edge* e : n->in_edges()) {
469       if (e->IsControlEdge()) continue;
470       if (backprops_.find({e->src(), e->src_output()}) != backprops_.end()) {
471         stop_node = false;
472         break;
473       }
474     }
475 
476     if (stop_node) {
477       continue;
478     }
479 
480     // Special case: if we find an exit node, process the associated while loop.
481     // Note that ProcessWhileLoop() calls BackpropAlongEdge() if necessary
482     // (which updates ready_), and we skip all the regular processing below
483     // after calling it.
484     if (n->IsExit()) {
485       DCHECK_EQ(dy.size(), 1);
486       TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0]));
487       continue;
488     }
489     // All loop-specific control flow ops should have been handled above
490     DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString();
491 
492     const int num_no_grad = no_grad_dy_indices.size();
493     if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) {
494       // No grad defined for this op, or all outputs returned 'NoGradient':
495       // Backprop 'NoGradient' along the in edges.
496       for (const Edge* e : n->in_edges()) {
497         if (e->IsControlEdge()) continue;
498         TF_RETURN_IF_ERROR(
499             BackpropAlongEdge(NoGradient(), {e->src(), e->src_output()}));
500       }
501       continue;
502     }
503 
504     if (num_no_grad > 0 && num_no_grad < num_y) {
505       // The outputs of 'n' returned a mixture of valid gradients and
506       // 'NoGradient'. Therefore, we need to add 'ZerosLike' nodes for each
507       // 'NoGradient' output before we call the gradient function for 'n'.
508       // TODO(andydavis) If static shapes are known, replace 'ZerosLike' with
509       // zero-filled Constant node of appropriate shape.
510       for (const int dy_index : no_grad_dy_indices) {
511         dy[dy_index] = ops::ZerosLike(scope_, Output(n, dy_index));
512       }
513     }
514 
515     // TODO(andydavis) Add option to encapsulate grad function in
516     // SymbolicGradientOp (as opposed to inlining into the graph).
517     std::vector<Output> dx;
518     TF_RETURN_IF_ERROR(CallGradFunction(Operation(n), dy, &dx));
519 
520     // Backprop along the in edges.
521     // TODO(andydavis) Find cleaner way to map each grad output returned by
522     // gradient function to the src node/output to which it should be
523     // backpropped. Maybe grad functions can return a vector of Output pairs to
524     // make this association explicit.
525     for (const Edge* e : n->in_edges()) {
526       if (e->IsControlEdge()) continue;
527       size_t dx_index = e->dst_input();
528       if (dx_index >= dx.size()) {
529         return errors::Internal(
530             "Invalid gradient output index: ", dx_index, " size: ", dx.size());
531       }
532       TF_RETURN_IF_ERROR(
533           BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()}));
534     }
535   }
536 
537   // Check if any input nodes still have pending gradients and have not been
538   // processed yet. This happens if not all outputs of a node are in 'inputs_'.
539   std::unordered_map<Node*, int> requested_grads;
540   for (const Output& nout : inputs_) {
541     if (pending_[nout.node()->id()] > 0) {
542       DCHECK_GT(nout.node()->num_outputs(), 1);
543       int idx = input_nodes_[nout];
544       DCHECK(((*grad_outputs_)[idx].node() == nullptr));
545       TF_RETURN_IF_ERROR(SumGradients(nout, &(*grad_outputs_)[idx]));
546       ++requested_grads[nout.node()];
547     }
548   }
549   for (const auto& p : requested_grads) {
550     int num_requested_inputs = p.first->num_outputs() - pending_[p.first->id()];
551     CHECK_EQ(num_requested_inputs, p.second);
552   }
553   return OkStatus();
554 }
555 
556 }  // namespace
557 
AddSymbolicGradients(const Scope & scope,const std::vector<Output> & outputs,const std::vector<Output> & inputs,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)558 Status AddSymbolicGradients(const Scope& scope,
559                             const std::vector<Output>& outputs,
560                             const std::vector<Output>& inputs,
561                             const std::vector<Output>& grad_inputs,
562                             std::vector<Output>* grad_outputs) {
563   SymbolicGradientBuilder builder(scope, ops::GradOpRegistry::Global(), outputs,
564                                   inputs, grad_inputs, grad_outputs);
565   return builder.AddGradients();
566 }
567 
AddSymbolicGradients(const Scope & scope,const std::vector<Output> & outputs,const std::vector<Output> & inputs,std::vector<Output> * grad_outputs)568 Status AddSymbolicGradients(const Scope& scope,
569                             const std::vector<Output>& outputs,
570                             const std::vector<Output>& inputs,
571                             std::vector<Output>* grad_outputs) {
572   std::vector<Output> grad_inputs;
573   grad_inputs.reserve(outputs.size());
574   for (const Output& output : outputs) {
575     grad_inputs.emplace_back(ops::OnesLike(scope, output));
576   }
577   return AddSymbolicGradients(scope, outputs, inputs, grad_inputs,
578                               grad_outputs);
579 }
580 
NoGradient()581 Output NoGradient() { return SymbolicGradientBuilder::NoGradient(); }
582 
583 }  // end namespace tensorflow
584